Skip to content

Commit 0bfce7e

Browse files
committed
Reduce RAM consumption by ordering zq conv results.
1 parent 4cca4fe commit 0bfce7e

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

examples/kandinsky/main.swift

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -985,9 +985,13 @@ func SpatialNorm(prefix: String, channels: Int, heightScale: Float, widthScale:
985985
let normLayer = GroupNorm(axis: 1, groups: 32, epsilon: 1e-6, reduce: [2, 3])
986986
var out = normLayer(x)
987987
let zqOut = Upsample(.nearest, widthScale: widthScale, heightScale: heightScale)(zq)
988+
zqOut.add(dependencies: [out])
988989
let convY = Convolution(groups: 1, filters: channels, filterSize: [1, 1])
990+
out = out .* convY(zqOut)
989991
let convB = Convolution(groups: 1, filters: channels, filterSize: [1, 1])
990-
out = out .* convY(zqOut) + convB(zqOut)
992+
let bias = convB(zqOut)
993+
bias.add(dependencies: [out])
994+
out = out + bias
991995
let reader: (PythonObject) -> Void = { state_dict in
992996
let norm_layer_weight = state_dict["\(prefix).norm_layer.weight"].type(torch.float).cpu()
993997
.numpy()

examples/kandinsky2/main.swift

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -561,9 +561,13 @@ func SpatialNorm(prefix: String, channels: Int, heightScale: Float, widthScale:
561561
let normLayer = GroupNorm(axis: 1, groups: 32, epsilon: 1e-6, reduce: [2, 3])
562562
var out = normLayer(x)
563563
let zqOut = Upsample(.nearest, widthScale: widthScale, heightScale: heightScale)(zq)
564+
zqOut.add(dependencies: [out])
564565
let convY = Convolution(groups: 1, filters: channels, filterSize: [1, 1])
566+
out = out .* convY(zqOut)
565567
let convB = Convolution(groups: 1, filters: channels, filterSize: [1, 1])
566-
out = out .* convY(zqOut) + convB(zqOut)
568+
let bias = convB(zqOut)
569+
bias.add(dependencies: [out])
570+
out = out + bias
567571
return Model([x, zq], [out])
568572
}
569573

@@ -969,11 +973,13 @@ graph.withNoGrad {
969973
}
970974
let zeroOut = vit(inputs: vitInput, classEmbedding, vitPositionalEmbedding)[0].as(
971975
of: FloatType.self)
976+
/*
972977
graph.openStore("/home/liu/workspace/swift-diffusion/image_vit_l14_f16.ckpt") {
973978
$0.write("vit", model: vit)
974979
$0.write("class_embedding", variable: classEmbedding)
975980
$0.write("positional_embedding", variable: vitPositionalEmbedding)
976981
}
982+
*/
977983
debugPrint(zeroOut)
978984
debugPrint(zeroImgEmbGPU)
979985
for (i, timestep) in [0, 250, 500, 749, 999].enumerated().reversed() {
@@ -1082,7 +1088,9 @@ graph.withNoGrad {
10821088
$0.as(of: FloatType.self)
10831089
}
10841090
let xfProj = outputs[0]
1091+
debugPrint(xfProj)
10851092
let xfOutGPU = outputs[1]
1093+
debugPrint(xfOutGPU)
10861094
let timesteps = graph.variable(
10871095
Tensor<FloatType>(
10881096
from: timeEmbedding(timestep: 999, batchSize: 2, embeddingSize: 384, maxPeriod: 10_000).toGPU(
@@ -1156,7 +1164,7 @@ graph.withNoGrad {
11561164
zChannels: 4, channels: 128, channelMult: [1, 2, 2, 4], numResBlocks: 2, startHeight: 96,
11571165
startWidth: 96, attnResolutions: Set([32]))
11581166
movq.compile(inputs: image)
1159-
graph.openStore("/home/liu/workspace/swift-diffusion/kandinsky_movq_f16.ckpt") {
1167+
graph.openStore("/home/liu/workspace/swift-diffusion/kandinsky_movq_f32.ckpt") {
11601168
$0.read("movq", model: movq)
11611169
}
11621170
var result = movq(inputs: image)[0].as(of: FloatType.self)

0 commit comments

Comments
 (0)