Skip to content

Commit 0f80472

Browse files
committed
Fixed bf16 issues.
1 parent 336c77a commit 0f80472

File tree

2 files changed

+34
-17
lines changed

2 files changed

+34
-17
lines changed

examples/lora/main.swift

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ public struct Storage {
99
var name: String
1010
var size: Int
1111
var dataType: DataType
12+
var BF16: Bool
1213
}
1314

1415
public struct TensorDescriptor {
@@ -46,7 +47,7 @@ public final class SafeTensors {
4647
guard !(shape.contains { $0 <= 0 }) else { continue }
4748
guard
4849
dtype == "f32" || dtype == "f16" || dtype == "float16" || dtype == "float32"
49-
|| dtype == "float" || dtype == "half"
50+
|| dtype == "float" || dtype == "half" || dtype == "bf16"
5051
else { continue }
5152
let dataType: DataType =
5253
dtype == "f32" || dtype == "float32" || dtype == "float" ? .Float32 : .Float16
@@ -58,7 +59,7 @@ public final class SafeTensors {
5859
}
5960
strides.reverse()
6061
let tensorDescriptor = TensorDescriptor(
61-
storage: Storage(name: key, size: offsetEnd - offsetStart, dataType: dataType),
62+
storage: Storage(name: key, size: offsetEnd - offsetStart, dataType: dataType, BF16: dtype == "bf16"),
6263
storageOffset: offsetStart, shape: shape, strides: strides)
6364
states[key] = tensorDescriptor
6465
}
@@ -75,12 +76,27 @@ public final class SafeTensors {
7576
guard let address = $0.baseAddress else { fatalError() }
7677
let tensor: AnyTensor
7778
if tensorDescriptor.storage.dataType == .Float16 {
78-
tensor = Tensor<Float16>(
79-
.CPU, format: .NCHW, shape: TensorShape(tensorDescriptor.shape),
80-
unsafeMutablePointer: (address + bufferStart + tensorDescriptor.storageOffset)
81-
.assumingMemoryBound(
82-
to: Float16.self), bindLifetimeOf: self
83-
)
79+
if tensorDescriptor.storage.BF16 {
80+
let count = tensorDescriptor.strides[0] * tensorDescriptor.shape[0]
81+
let u16 = UnsafeMutablePointer<UInt16>.allocate(capacity: count * 2)
82+
let bf16 = (address + bufferStart + tensorDescriptor.storageOffset).assumingMemoryBound(to: UInt16.self)
83+
for i in 0..<count {
84+
u16[i * 2] = 0
85+
u16[i * 2 + 1] = bf16[i]
86+
}
87+
tensor = Tensor<Float>(
88+
.CPU, format: .NCHW, shape: TensorShape(tensorDescriptor.shape),
89+
unsafeMutablePointer: UnsafeMutableRawPointer(u16).assumingMemoryBound(to: Float.self), bindLifetimeOf: self
90+
).copied()
91+
u16.deallocate()
92+
} else {
93+
tensor = Tensor<Float16>(
94+
.CPU, format: .NCHW, shape: TensorShape(tensorDescriptor.shape),
95+
unsafeMutablePointer: (address + bufferStart + tensorDescriptor.storageOffset)
96+
.assumingMemoryBound(
97+
to: Float16.self), bindLifetimeOf: self
98+
)
99+
}
84100
} else {
85101
tensor = Tensor<Float>(
86102
.CPU, format: .NCHW, shape: TensorShape(tensorDescriptor.shape),
@@ -94,7 +110,7 @@ public final class SafeTensors {
94110
}
95111
}
96112

97-
let filename = "/home/liu/workspace/swift-diffusion/lucyCyberpunk_35Epochs.safetensors"
113+
let filename = "/home/liu/workspace/swift-diffusion/openjourneyLora_v1.safetensors"
98114
/*
99115
let archive = Archive(url: URL(fileURLWithPath: filename), accessMode: .read)!
100116
let entry = archive["archive/data.pkl"]!
@@ -174,6 +190,7 @@ for key in keys {
174190
keysSet.remove(key)
175191
}
176192
}
193+
print(keysSet)
177194
var unetMapCount = [String: Int]()
178195
for i in stride(from: 0, to: unetMap.count, by: 2) {
179196
unetMapCount[unetMap[i]] = unetMapCount[unetMap[i], default: 0] + 1

examples/txt2img/main.swift

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,16 +160,16 @@ graph.withNoGrad {
160160
let positionTensorGPU = positionTensor.toGPU(0)
161161
let casualAttentionMaskGPU = casualAttentionMask.toGPU(0)
162162
textModel.compile(inputs: tokensTensorGPU, positionTensorGPU, casualAttentionMaskGPU)
163-
graph.openStore(workDir + "/lora.ckpt") { lora in
163+
graph.openStore(workDir + "/moxin_v1.0_lora_f16.ckpt") { lora in
164164
let keys = Set(lora.keys)
165-
graph.openStore(workDir + "/sd-v1.4.ckpt") { store in
165+
graph.openStore(workDir + "/sd-v1.5.ckpt") { store in
166166
store.read("text_model", model: textModel) { name, _, _, _ in
167167
if keys.contains(name + "__up__") {
168168
let original = graph.variable(Tensor<UseFloatingPoint>(from: store.read(name)!)).toGPU(0)
169169
let up = graph.variable(Tensor<UseFloatingPoint>(lora.read(name + "__up__")!)).toGPU(0)
170170
let down = graph.variable(Tensor<UseFloatingPoint>(lora.read(name + "__down__")!)).toGPU(0)
171-
let final = original + 0.6 * (up * down)
172-
return .final(final.rawValue)
171+
let final = original + 0.8 * (up * down)
172+
return .final(final.rawValue.toCPU())
173173
}
174174
return .continue(name)
175175
}
@@ -186,9 +186,9 @@ graph.withNoGrad {
186186
let ts = timeEmbedding(timestep: 0, batchSize: 2, embeddingSize: 320, maxPeriod: 10_000).toGPU(0)
187187
unet.compile(inputs: xIn, graph.variable(Tensor<UseFloatingPoint>(from: ts)), c)
188188
decoder.compile(inputs: x)
189-
graph.openStore(workDir + "/lora.ckpt") { lora in
189+
graph.openStore(workDir + "/moxin_v1.0_lora_f16.ckpt") { lora in
190190
let keys = Set(lora.keys)
191-
graph.openStore(workDir + "/sd-v1.4.ckpt") { store in
191+
graph.openStore(workDir + "/sd-v1.5.ckpt") { store in
192192
store.read("unet", model: unet) { name, _, _, _ in
193193
if keys.contains(name + "__up__") {
194194
let original = graph.variable(Tensor<UseFloatingPoint>(from: store.read(name)!)).toGPU(0)
@@ -200,11 +200,11 @@ graph.withNoGrad {
200200
up = graph.variable(loraUp.reshaped(.NC(loraUp.shape[0], loraUp.shape[1] * loraUp.shape[2] * loraUp.shape[3]))).toGPU(0)
201201
let loraDown = Tensor<UseFloatingPoint>(lora.read(name + "__down__")!)
202202
down = graph.variable(loraDown.reshaped(.NC(loraDown.shape[0], loraDown.shape[1] * loraDown.shape[2] * loraDown.shape[3]))).toGPU(0)
203-
result = original + 0.6 * (up * down).reshaped(format: .NCHW, shape: original.shape)
203+
result = original + 0.8 * (up * down).reshaped(format: .NCHW, shape: original.shape)
204204
} else {
205205
up = graph.variable(Tensor<UseFloatingPoint>(lora.read(name + "__up__")!)).toGPU(0)
206206
down = graph.variable(Tensor<UseFloatingPoint>(lora.read(name + "__down__")!)).toGPU(0)
207-
result = original + 0.6 * (up * down)
207+
result = original + 0.8 * (up * down)
208208
}
209209
return .final(result.rawValue)
210210
}

0 commit comments

Comments
 (0)