@@ -9,6 +9,7 @@ public struct Storage {
99 var name : String
1010 var size : Int
1111 var dataType : DataType
12+ var BF16 : Bool
1213}
1314
1415public struct TensorDescriptor {
@@ -46,7 +47,7 @@ public final class SafeTensors {
4647 guard !( shape. contains { $0 <= 0 } ) else { continue }
4748 guard
4849 dtype == " f32 " || dtype == " f16 " || dtype == " float16 " || dtype == " float32 "
49- || dtype == " float " || dtype == " half "
50+ || dtype == " float " || dtype == " half " || dtype == " bf16 "
5051 else { continue }
5152 let dataType : DataType =
5253 dtype == " f32 " || dtype == " float32 " || dtype == " float " ? . Float32 : . Float16
@@ -58,7 +59,7 @@ public final class SafeTensors {
5859 }
5960 strides. reverse ( )
6061 let tensorDescriptor = TensorDescriptor (
61- storage: Storage ( name: key, size: offsetEnd - offsetStart, dataType: dataType) ,
62+ storage: Storage ( name: key, size: offsetEnd - offsetStart, dataType: dataType, BF16 : dtype == " bf16 " ) ,
6263 storageOffset: offsetStart, shape: shape, strides: strides)
6364 states [ key] = tensorDescriptor
6465 }
@@ -75,12 +76,27 @@ public final class SafeTensors {
7576 guard let address = $0. baseAddress else { fatalError ( ) }
7677 let tensor : AnyTensor
7778 if tensorDescriptor. storage. dataType == . Float16 {
78- tensor = Tensor < Float16 > (
79- . CPU, format: . NCHW, shape: TensorShape ( tensorDescriptor. shape) ,
80- unsafeMutablePointer: ( address + bufferStart + tensorDescriptor. storageOffset)
81- . assumingMemoryBound (
82- to: Float16 . self) , bindLifetimeOf: self
83- )
79+ if tensorDescriptor. storage. BF16 {
80+ let count = tensorDescriptor. strides [ 0 ] * tensorDescriptor. shape [ 0 ]
81+ let u16 = UnsafeMutablePointer< UInt16> . allocate( capacity: count * 2 )
82+ let bf16 = ( address + bufferStart + tensorDescriptor. storageOffset) . assumingMemoryBound ( to: UInt16 . self)
83+ for i in 0 ..< count {
84+ u16 [ i * 2 ] = 0
85+ u16 [ i * 2 + 1 ] = bf16 [ i]
86+ }
87+ tensor = Tensor < Float > (
88+ . CPU, format: . NCHW, shape: TensorShape ( tensorDescriptor. shape) ,
89+ unsafeMutablePointer: UnsafeMutableRawPointer ( u16) . assumingMemoryBound ( to: Float . self) , bindLifetimeOf: self
90+ ) . copied ( )
91+ u16. deallocate ( )
92+ } else {
93+ tensor = Tensor < Float16 > (
94+ . CPU, format: . NCHW, shape: TensorShape ( tensorDescriptor. shape) ,
95+ unsafeMutablePointer: ( address + bufferStart + tensorDescriptor. storageOffset)
96+ . assumingMemoryBound (
97+ to: Float16 . self) , bindLifetimeOf: self
98+ )
99+ }
84100 } else {
85101 tensor = Tensor < Float > (
86102 . CPU, format: . NCHW, shape: TensorShape ( tensorDescriptor. shape) ,
@@ -94,7 +110,7 @@ public final class SafeTensors {
94110 }
95111}
96112
97- let filename = " /home/liu/workspace/swift-diffusion/lucyCyberpunk_35Epochs .safetensors "
113+ let filename = " /home/liu/workspace/swift-diffusion/openjourneyLora_v1 .safetensors "
98114/*
99115let archive = Archive(url: URL(fileURLWithPath: filename), accessMode: .read)!
100116let entry = archive["archive/data.pkl"]!
@@ -174,6 +190,7 @@ for key in keys {
174190 keysSet. remove ( key)
175191 }
176192}
193+ print ( keysSet)
177194var unetMapCount = [ String: Int] ( )
178195for i in stride ( from: 0 , to: unetMap. count, by: 2 ) {
179196 unetMapCount [ unetMap [ i] ] = unetMapCount [ unetMap [ i] , default: 0 ] + 1
0 commit comments