Skip to content

Commit c4d4455

Browse files
committed
Add one model for t2i-adapter
1 parent 31274a1 commit c4d4455

File tree

3 files changed

+134
-2
lines changed

3 files changed

+134
-2
lines changed

WORKSPACE

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
33

44
git_repository(
55
name = "s4nnc",
6-
commit = "908eb21fdb9ea78094a2d31720140e2ad1cdbd39",
6+
commit = "7861c230e48e72ec0752c7845a99288b8c286c6d",
77
remote = "https://github.com/liuliu/s4nnc.git",
8-
shallow_since = "1678557607 -0500",
8+
shallow_since = "1679025289 -0400",
99
)
1010

1111
load("@s4nnc//:deps.bzl", "s4nnc_deps")

examples/BUILD.bazel

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,17 @@ swift_binary(
5555
],
5656
)
5757

58+
swift_binary(
59+
name = "t2i-adapter",
60+
srcs = ["t2i-adapter/main.swift"],
61+
deps = [
62+
"@PythonKit",
63+
"@SwiftNumerics//:Numerics",
64+
"@s4nnc//nnc",
65+
"@s4nnc//nnc:nnc_python",
66+
],
67+
)
68+
5869
swift_binary(
5970
name = "decoder",
6071
srcs = ["decoder/main.swift"],

examples/t2i-adapter/main.swift

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import Foundation
2+
import NNC
3+
import NNCPythonConversion
4+
import PythonKit
5+
6+
let ldm_modules_encoders_adapter = Python.import("ldm.modules.encoders.adapter")
7+
let torch = Python.import("torch")
8+
let random = Python.import("random")
9+
let numpy = Python.import("numpy")
10+
11+
func ResnetBlock(outChannels: Int, inConv: Bool) -> (
12+
Model?, Model, Model, Model
13+
) {
14+
let x = Input()
15+
let outX: Model.IO
16+
var skipModel: Model? = nil
17+
if inConv {
18+
let skip = Convolution(
19+
groups: 1, filters: outChannels, filterSize: [1, 1],
20+
hint: Hint(stride: [1, 1]))
21+
outX = skip(x)
22+
skipModel = skip
23+
} else {
24+
outX = x
25+
}
26+
let inLayerConv2d = Convolution(
27+
groups: 1, filters: outChannels, filterSize: [3, 3],
28+
hint: Hint(stride: [1, 1], border: Hint.Border(begin: [1, 1], end: [1, 1])))
29+
var out = inLayerConv2d(outX)
30+
out = ReLU()(out)
31+
// Dropout if needed in the future (for training).
32+
let outLayerConv2d = Convolution(
33+
groups: 1, filters: outChannels, filterSize: [1, 1],
34+
hint: Hint(stride: [1, 1]))
35+
out = outLayerConv2d(out) + outX
36+
return (
37+
skipModel, inLayerConv2d, outLayerConv2d, Model([x], [out])
38+
)
39+
}
40+
41+
func Adapter(
42+
channels: [Int], numRepeat: Int
43+
) -> ((PythonObject) -> Void, Model) {
44+
let x = Input()
45+
let convIn = Convolution(
46+
groups: 1, filters: channels[0], filterSize: [3, 3],
47+
hint: Hint(stride: [1, 1], border: Hint.Border(begin: [1, 1], end: [1, 1])))
48+
var out = convIn(x)
49+
var readers = [(PythonObject) -> Void]()
50+
var previousChannel = channels[0]
51+
var outs = [Model.IO]()
52+
for (i, channel) in channels.enumerated() {
53+
for j in 0..<numRepeat {
54+
let (skipModel, inLayerConv2d, outLayerConv2d, resnetBlock) = ResnetBlock(outChannels: channel, inConv: previousChannel != channel)
55+
previousChannel = channel
56+
out = resnetBlock(out)
57+
let reader: (PythonObject) -> Void = { state_dict in
58+
let block1_weight = state_dict["body.\(i * numRepeat + j).block1.weight"].numpy()
59+
let block1_bias = state_dict["body.\(i * numRepeat + j).block1.bias"].numpy()
60+
inLayerConv2d.parameters(for: .weight).copy(from: try! Tensor<Float>(numpy: block1_weight))
61+
inLayerConv2d.parameters(for: .bias).copy(from: try! Tensor<Float>(numpy: block1_bias))
62+
let block2_weight = state_dict["body.\(i * numRepeat + j).block2.weight"].numpy()
63+
let block2_bias = state_dict["body.\(i * numRepeat + j).block2.bias"].numpy()
64+
outLayerConv2d.parameters(for: .weight).copy(from: try! Tensor<Float>(numpy: block2_weight))
65+
outLayerConv2d.parameters(for: .bias).copy(from: try! Tensor<Float>(numpy: block2_bias))
66+
if let skipModel = skipModel {
67+
let in_conv_weight = state_dict["body.\(i * numRepeat + j).in_conv.weight"].numpy()
68+
let in_conv_bias = state_dict["body.\(i * numRepeat + j).in_conv.bias"].numpy()
69+
skipModel.parameters(for: .weight).copy(from: try! Tensor<Float>(numpy: in_conv_weight))
70+
skipModel.parameters(for: .bias).copy(from: try! Tensor<Float>(numpy: in_conv_bias))
71+
}
72+
}
73+
readers.append(reader)
74+
}
75+
outs.append(out)
76+
if i != channels.count - 1 {
77+
let downsample = AveragePool(filterSize: [2, 2], hint: Hint(stride: [2, 2]))
78+
out = downsample(out)
79+
}
80+
}
81+
let reader: (PythonObject) -> Void = { state_dict in
82+
let conv_in_weight = state_dict["conv_in.weight"].numpy()
83+
let conv_in_bias = state_dict["conv_in.bias"].numpy()
84+
convIn.parameters(for: .weight).copy(from: try! Tensor<Float>(numpy: conv_in_weight))
85+
convIn.parameters(for: .bias).copy(from: try! Tensor<Float>(numpy: conv_in_bias))
86+
for reader in readers {
87+
reader(state_dict)
88+
}
89+
}
90+
return (reader, Model([x], outs))
91+
}
92+
93+
random.seed(42)
94+
numpy.random.seed(42)
95+
torch.manual_seed(42)
96+
torch.cuda.manual_seed_all(42)
97+
98+
let hint = torch.randn([2, 1, 512, 512])
99+
100+
let adapter = ldm_modules_encoders_adapter.Adapter(cin: 64, channels: [320, 640, 1280, 1280], nums_rb: 2, ksize: 1, sk: true, use_conv: false).to(torch.device("cpu"))
101+
adapter.load_state_dict(torch.load("/home/liu/workspace/T2I-Adapter/models/t2iadapter_canny_sd14v1.pth"))
102+
let state_dict = adapter.state_dict()
103+
let ret = adapter(hint)
104+
print(ret[0])
105+
106+
let graph = DynamicGraph()
107+
let hintTensor = graph.variable(try! Tensor<Float>(numpy: hint.numpy())).toGPU(0)
108+
let (reader, adapternet) = Adapter(channels: [320, 640, 1280, 1280], numRepeat: 2)
109+
graph.workspaceSize = 1_024 * 1_024 * 1_024
110+
graph.withNoGrad {
111+
let hintIn = hintTensor.reshaped(format: .NCHW, shape: [2, 1, 64, 8, 64, 8]).permuted(0, 1, 3, 5, 2, 4).copied().reshaped(.NCHW(2, 64, 64, 64))
112+
var controls = adapternet(inputs: hintIn).map { $0.as(of: Float.self) }
113+
reader(state_dict)
114+
controls = adapternet(inputs: hintIn).map { $0.as(of: Float.self) }
115+
debugPrint(controls[0])
116+
/*
117+
graph.openStore("/home/liu/workspace/swift-diffusion/adapter.ckpt") {
118+
$0.write("adapter", model: adapter)
119+
}
120+
*/
121+
}

0 commit comments

Comments
 (0)