Skip to content

Commit e2b0dcb

Browse files
committed
Save float16 model.
1 parent 378ba54 commit e2b0dcb

File tree

1 file changed

+23
-6
lines changed

1 file changed

+23
-6
lines changed

examples/restoreformer/main.swift

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -394,17 +394,17 @@ func RestoreFormer(
394394
return (reader, Model([x, embedding], [out]))
395395
}
396396

397-
var initImg = Tensor<Float>(.CPU, .NCHW(1, 3, 512, 512))
397+
var initImg = Tensor<Float16>(.CPU, .NCHW(1, 3, 512, 512))
398398
if let image = try PNG.Data.Rectangular.decompress(
399399
path: "/home/liu/workspace/GFPGAN/inputs/cropped_faces/Adele_crop.png")
400400
{
401401
let rgba = image.unpack(as: PNG.RGBA<UInt8>.self)
402402
for y in 0..<512 {
403403
for x in 0..<512 {
404404
let pixel = rgba[y * 512 + x]
405-
initImg[0, 0, y, x] = Float(pixel.r) / 255 * 2 - 1
406-
initImg[0, 1, y, x] = Float(pixel.g) / 255 * 2 - 1
407-
initImg[0, 2, y, x] = Float(pixel.b) / 255 * 2 - 1
405+
initImg[0, 0, y, x] = Float16(Float(pixel.r) / 255 * 2 - 1)
406+
initImg[0, 1, y, x] = Float16(Float(pixel.g) / 255 * 2 - 1)
407+
initImg[0, 2, y, x] = Float16(Float(pixel.b) / 255 * 2 - 1)
408408
}
409409
}
410410
}
@@ -422,9 +422,26 @@ graph.withNoGrad {
422422
reader(state_dict)
423423
let result = restoreFormer(inputs: croppedFaceTensor, embeddingTensor)[0].as(of: Float.self)
424424
debugPrint(result)
425-
let restoreImg = restoreFormer(inputs: graph.variable(initImg).toGPU(0), embeddingTensor)[0].as(
426-
of: Float.self
425+
graph.openStore("/home/liu/workspace/swift-diffusion/restoreformer_v1.0.ckpt") {
426+
$0.write("embedding", variable: embeddingTensor)
427+
$0.write("restoreformer", model: restoreFormer)
428+
}
429+
let (_, restoreFormerf16) = RestoreFormer(
430+
nEmbed: 1024, embedDim: 256, ch: 64, chMult: [1, 2, 2, 4, 4, 8], zChannels: 256, numHeads: 8,
431+
numResBlocks: 2)
432+
let embeddingTensorf16 = DynamicGraph.Tensor<Float16>(from: embeddingTensor)
433+
let initImgTensor = graph.variable(initImg).toGPU(0)
434+
restoreFormerf16.compile(inputs: initImgTensor, embeddingTensorf16)
435+
graph.openStore("/home/liu/workspace/swift-diffusion/restoreformer_v1.0.ckpt") {
436+
$0.read("restoreformer", model: restoreFormerf16)
437+
}
438+
let restoreImg = restoreFormerf16(inputs: initImgTensor, embeddingTensorf16)[0].as(
439+
of: Float16.self
427440
).toCPU()
441+
graph.openStore("/home/liu/workspace/swift-diffusion/restoreformer_v1.0_f16.ckpt") {
442+
$0.write("embedding", variable: embeddingTensorf16)
443+
$0.write("restoreformer", model: restoreFormerf16)
444+
}
428445
var rgba = [PNG.RGBA<UInt8>](repeating: .init(0), count: 512 * 512)
429446
for y in 0..<512 {
430447
for x in 0..<512 {

0 commit comments

Comments
 (0)