Skip to content

Commit ff8a4f3

Browse files
committed
Add adapter light.
1 parent c4d4455 commit ff8a4f3

File tree

1 file changed

+91
-10
lines changed

1 file changed

+91
-10
lines changed

examples/t2i-adapter/main.swift

Lines changed: 91 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@ func ResnetBlock(outChannels: Int, inConv: Bool) -> (
3838
)
3939
}
4040

41-
func Adapter(
42-
channels: [Int], numRepeat: Int
43-
) -> ((PythonObject) -> Void, Model) {
41+
func Adapter(channels: [Int], numRepeat: Int) -> ((PythonObject) -> Void, Model) {
4442
let x = Input()
4543
let convIn = Convolution(
4644
groups: 1, filters: channels[0], filterSize: [3, 3],
@@ -90,25 +88,108 @@ func Adapter(
9088
return (reader, Model([x], outs))
9189
}
9290

91+
func ResnetBlockLight(outChannels: Int) -> (
92+
Model, Model, Model
93+
) {
94+
let x = Input()
95+
let inLayerConv2d = Convolution(
96+
groups: 1, filters: outChannels, filterSize: [3, 3],
97+
hint: Hint(stride: [1, 1], border: Hint.Border(begin: [1, 1], end: [1, 1])))
98+
var out = inLayerConv2d(x)
99+
out = ReLU()(out)
100+
// Dropout if needed in the future (for training).
101+
let outLayerConv2d = Convolution(
102+
groups: 1, filters: outChannels, filterSize: [3, 3],
103+
hint: Hint(stride: [1, 1], border: Hint.Border(begin: [1, 1], end: [1, 1])))
104+
out = outLayerConv2d(out) + x
105+
return (
106+
inLayerConv2d, outLayerConv2d, Model([x], [out])
107+
)
108+
}
109+
110+
func Extractor(prefix: String, channel: Int, innerChannel: Int, numRepeat: Int, downsample: Bool) -> ((PythonObject) -> Void, Model) {
111+
let x = Input()
112+
let inConv = Convolution(groups: 1, filters: innerChannel, filterSize: [1, 1], hint: Hint(stride: [1, 1]))
113+
var out = inConv(x)
114+
var readers = [(PythonObject) -> Void]()
115+
for i in 0..<numRepeat {
116+
let (inLayerConv2d, outLayerConv2d, resnetBlock) = ResnetBlockLight(outChannels: innerChannel)
117+
out = resnetBlock(out)
118+
let reader: (PythonObject) -> Void = { state_dict in
119+
let block1_weight = state_dict["body.\(prefix).body.\(i).block1.weight"].numpy()
120+
let block1_bias = state_dict["body.\(prefix).body.\(i).block1.bias"].numpy()
121+
inLayerConv2d.parameters(for: .weight).copy(from: try! Tensor<Float>(numpy: block1_weight))
122+
inLayerConv2d.parameters(for: .bias).copy(from: try! Tensor<Float>(numpy: block1_bias))
123+
let block2_weight = state_dict["body.\(prefix).body.\(i).block2.weight"].numpy()
124+
let block2_bias = state_dict["body.\(prefix).body.\(i).block2.bias"].numpy()
125+
outLayerConv2d.parameters(for: .weight).copy(from: try! Tensor<Float>(numpy: block2_weight))
126+
outLayerConv2d.parameters(for: .bias).copy(from: try! Tensor<Float>(numpy: block2_bias))
127+
}
128+
readers.append(reader)
129+
}
130+
let outConv = Convolution(groups: 1, filters: channel, filterSize: [1, 1], hint: Hint(stride: [1, 1]))
131+
out = outConv(out)
132+
if downsample {
133+
let downsample = AveragePool(filterSize: [2, 2], hint: Hint(stride: [2, 2]))
134+
out = downsample(out)
135+
}
136+
let reader: (PythonObject) -> Void = { state_dict in
137+
let in_conv_weight = state_dict["body.\(prefix).in_conv.weight"].numpy()
138+
let in_conv_bias = state_dict["body.\(prefix).in_conv.bias"].numpy()
139+
inConv.parameters(for: .weight).copy(from: try! Tensor<Float>(numpy: in_conv_weight))
140+
inConv.parameters(for: .bias).copy(from: try! Tensor<Float>(numpy: in_conv_bias))
141+
let out_conv_weight = state_dict["body.\(prefix).out_conv.weight"].numpy()
142+
let out_conv_bias = state_dict["body.\(prefix).out_conv.bias"].numpy()
143+
outConv.parameters(for: .weight).copy(from: try! Tensor<Float>(numpy: out_conv_weight))
144+
outConv.parameters(for: .bias).copy(from: try! Tensor<Float>(numpy: out_conv_bias))
145+
for reader in readers {
146+
reader(state_dict)
147+
}
148+
}
149+
return (reader, Model([x], [out]))
150+
}
151+
152+
func AdapterLight(channels: [Int], numRepeat: Int) -> ((PythonObject) -> Void, Model) {
153+
var readers = [(PythonObject) -> Void]()
154+
let x = Input()
155+
var out: Model.IO = x
156+
var outs = [Model.IO]()
157+
for (i, channel) in channels.enumerated() {
158+
let (reader, extractor) = Extractor(prefix: "\(i)", channel: channel, innerChannel: channel / 4, numRepeat: numRepeat, downsample: i != 0)
159+
out = extractor(out)
160+
outs.append(out)
161+
readers.append(reader)
162+
}
163+
let reader: (PythonObject) -> Void = { state_dict in
164+
for reader in readers {
165+
reader(state_dict)
166+
}
167+
}
168+
return (reader, Model([x], outs))
169+
}
170+
93171
random.seed(42)
94172
numpy.random.seed(42)
95173
torch.manual_seed(42)
96174
torch.cuda.manual_seed_all(42)
97175

98-
let hint = torch.randn([2, 1, 512, 512])
176+
let hint = torch.randn([2, 3, 512, 512])
99177

100-
let adapter = ldm_modules_encoders_adapter.Adapter(cin: 64, channels: [320, 640, 1280, 1280], nums_rb: 2, ksize: 1, sk: true, use_conv: false).to(torch.device("cpu"))
101-
adapter.load_state_dict(torch.load("/home/liu/workspace/T2I-Adapter/models/t2iadapter_canny_sd14v1.pth"))
102-
let state_dict = adapter.state_dict()
103-
let ret = adapter(hint)
178+
// let adapter = ldm_modules_encoders_adapter.Adapter(cin: 64, channels: [320, 640, 1280, 1280], nums_rb: 2, ksize: 1, sk: true, use_conv: false).to(torch.device("cpu"))
179+
let adapterLight = ldm_modules_encoders_adapter.Adapter_light(cin: 64 * 3, channels: [320, 640, 1280, 1280], nums_rb: 4).to(torch.device("cpu"))
180+
adapterLight.load_state_dict(torch.load("/home/liu/workspace/T2I-Adapter/models/t2iadapter_color_sd14v1.pth"))
181+
let state_dict = adapterLight.state_dict()
182+
let ret = adapterLight(hint)
183+
print(adapterLight)
104184
print(ret[0])
105185

106186
let graph = DynamicGraph()
107187
let hintTensor = graph.variable(try! Tensor<Float>(numpy: hint.numpy())).toGPU(0)
108-
let (reader, adapternet) = Adapter(channels: [320, 640, 1280, 1280], numRepeat: 2)
188+
// let (reader, adapternet) = Adapter(channels: [320, 640, 1280, 1280], numRepeat: 2)
189+
let (reader, adapternet) = AdapterLight(channels: [320, 640, 1280, 1280], numRepeat: 4)
109190
graph.workspaceSize = 1_024 * 1_024 * 1_024
110191
graph.withNoGrad {
111-
let hintIn = hintTensor.reshaped(format: .NCHW, shape: [2, 1, 64, 8, 64, 8]).permuted(0, 1, 3, 5, 2, 4).copied().reshaped(.NCHW(2, 64, 64, 64))
192+
let hintIn = hintTensor.reshaped(format: .NCHW, shape: [2, 3, 64, 8, 64, 8]).permuted(0, 1, 3, 5, 2, 4).copied().reshaped(.NCHW(2, 64 * 3, 64, 64))
112193
var controls = adapternet(inputs: hintIn).map { $0.as(of: Float.self) }
113194
reader(state_dict)
114195
controls = adapternet(inputs: hintIn).map { $0.as(of: Float.self) }

0 commit comments

Comments
 (0)