Skip to content

Commit b09e565

Browse files
committed
Update formatter.
1 parent a270b4a commit b09e565

File tree

2 files changed

+35
-21
lines changed

2 files changed

+35
-21
lines changed

WORKSPACE

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -86,41 +86,41 @@ load("@rules_python//python:pip.bzl", "pip_install")
8686
new_git_repository(
8787
name = "SwiftArgumentParser",
8888
build_file = "swift-argument-parser.BUILD",
89-
commit = "82905286cc3f0fa8adc4674bf49437cab65a8373",
89+
commit = "9f39744e025c7d377987f30b03770805dcb0bcd1",
9090
remote = "https://github.com/apple/swift-argument-parser.git",
91-
shallow_since = "1647436700 -0500",
91+
shallow_since = "1661571047 -0500",
9292
)
9393

9494
new_git_repository(
9595
name = "SwiftSystem",
9696
build_file = "swift-system.BUILD",
97-
commit = "836bc4557b74fe6d2660218d56e3ce96aff76574",
97+
commit = "025bcb1165deab2e20d4eaba79967ce73013f496",
9898
remote = "https://github.com/apple/swift-system.git",
99-
shallow_since = "1638472952 -0800",
99+
shallow_since = "1654977448 -0700",
100100
)
101101

102102
new_git_repository(
103103
name = "SwiftToolsSupportCore",
104104
build_file = "swift-tools-support-core.BUILD",
105-
commit = "b7667f3e266af621e5cc9c77e74cacd8e8c00cb4",
105+
commit = "4f07be3dc201f6e2ee85b6942d0c220a16926811",
106106
remote = "https://github.com/apple/swift-tools-support-core.git",
107-
shallow_since = "1643831290 -0800",
107+
shallow_since = "1659981427 -0700",
108108
)
109109

110110
new_git_repository(
111111
name = "SwiftSyntax",
112112
build_file = "swift-syntax.BUILD",
113-
commit = "0b6c22b97f8e9320bca62e82cdbee601cf37ad3f",
113+
commit = "72d3da66b085c2299dd287c2be3b92b5ebd226de",
114114
remote = "https://github.com/apple/swift-syntax.git",
115-
shallow_since = "1647591231 +0100",
115+
shallow_since = "1664965455 +0200",
116116
)
117117

118118
new_git_repository(
119119
name = "SwiftFormat",
120120
build_file = "swift-format.BUILD",
121-
commit = "e6b8c60c7671066d229e30efa1e31acf57be412e",
121+
commit = "5f184220d032a019a63df457cdea4b9c8241e911",
122122
remote = "https://github.com/apple/swift-format.git",
123-
shallow_since = "1647972246 -0700",
123+
shallow_since = "1665415355 -0700",
124124
)
125125

126126
load("@s4nnc//:deps.bzl", "s4nnc_extra_deps")

examples/t2i-adapter/main.swift

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ func Adapter(channels: [Int], numRepeat: Int) -> ((PythonObject) -> Void, Model)
4949
var outs = [Model.IO]()
5050
for (i, channel) in channels.enumerated() {
5151
for j in 0..<numRepeat {
52-
let (skipModel, inLayerConv2d, outLayerConv2d, resnetBlock) = ResnetBlock(outChannels: channel, inConv: previousChannel != channel)
52+
let (skipModel, inLayerConv2d, outLayerConv2d, resnetBlock) = ResnetBlock(
53+
outChannels: channel, inConv: previousChannel != channel)
5354
previousChannel = channel
5455
out = resnetBlock(out)
5556
let reader: (PythonObject) -> Void = { state_dict in
@@ -107,9 +108,12 @@ func ResnetBlockLight(outChannels: Int) -> (
107108
)
108109
}
109110

110-
func Extractor(prefix: String, channel: Int, innerChannel: Int, numRepeat: Int, downsample: Bool) -> ((PythonObject) -> Void, Model) {
111+
func Extractor(prefix: String, channel: Int, innerChannel: Int, numRepeat: Int, downsample: Bool)
112+
-> ((PythonObject) -> Void, Model)
113+
{
111114
let x = Input()
112-
let inConv = Convolution(groups: 1, filters: innerChannel, filterSize: [1, 1], hint: Hint(stride: [1, 1]))
115+
let inConv = Convolution(
116+
groups: 1, filters: innerChannel, filterSize: [1, 1], hint: Hint(stride: [1, 1]))
113117
var out = inConv(x)
114118
var readers = [(PythonObject) -> Void]()
115119
for i in 0..<numRepeat {
@@ -127,7 +131,8 @@ func Extractor(prefix: String, channel: Int, innerChannel: Int, numRepeat: Int,
127131
}
128132
readers.append(reader)
129133
}
130-
let outConv = Convolution(groups: 1, filters: channel, filterSize: [1, 1], hint: Hint(stride: [1, 1]))
134+
let outConv = Convolution(
135+
groups: 1, filters: channel, filterSize: [1, 1], hint: Hint(stride: [1, 1]))
131136
out = outConv(out)
132137
if downsample {
133138
let downsample = AveragePool(filterSize: [2, 2], hint: Hint(stride: [2, 2]))
@@ -155,7 +160,9 @@ func AdapterLight(channels: [Int], numRepeat: Int) -> ((PythonObject) -> Void, M
155160
var out: Model.IO = x
156161
var outs = [Model.IO]()
157162
for (i, channel) in channels.enumerated() {
158-
let (reader, extractor) = Extractor(prefix: "\(i)", channel: channel, innerChannel: channel / 4, numRepeat: numRepeat, downsample: i != 0)
163+
let (reader, extractor) = Extractor(
164+
prefix: "\(i)", channel: channel, innerChannel: channel / 4, numRepeat: numRepeat,
165+
downsample: i != 0)
159166
out = extractor(out)
160167
outs.append(out)
161168
readers.append(reader)
@@ -246,14 +253,16 @@ func CLIPResidualAttentionBlock(prefix: String, k: Int, h: Int, b: Int, t: Int)
246253
return (reader, Model([x], [out]))
247254
}
248255

249-
func StyleAdapter(width: Int, outputDim: Int, layers: Int, heads: Int, tokens: Int, batchSize: Int) -> ((PythonObject) -> Void, Model)
256+
func StyleAdapter(width: Int, outputDim: Int, layers: Int, heads: Int, tokens: Int, batchSize: Int)
257+
-> ((PythonObject) -> Void, Model)
250258
{
251259
let x = Input()
252260
let lnPre = LayerNorm(epsilon: 1e-5, axis: [2])
253261
var out = lnPre(x)
254262
var readers = [(PythonObject) -> Void]()
255263
for i in 0..<layers {
256-
let (reader, block) = CLIPResidualAttentionBlock(prefix: "transformer_layes.\(i)", k: width / heads, h: heads, b: batchSize, t: 257 + tokens)
264+
let (reader, block) = CLIPResidualAttentionBlock(
265+
prefix: "transformer_layes.\(i)", k: width / heads, h: heads, b: batchSize, t: 257 + tokens)
257266
out = block(out.reshaped([batchSize, 257 + tokens, width]))
258267
readers.append(reader)
259268
}
@@ -295,21 +304,26 @@ let hint = torch.randn([2, 3, 512, 512])
295304
// let adapter = ldm_modules_encoders_adapter.Adapter(cin: 64, channels: [320, 640, 1280, 1280], nums_rb: 2, ksize: 1, sk: true, use_conv: false).to(torch.device("cpu"))
296305
// let adapterLight = ldm_modules_encoders_adapter.Adapter_light(cin: 64 * 3, channels: [320, 640, 1280, 1280], nums_rb: 4).to(torch.device("cpu"))
297306
let style = torch.randn([1, 257, 1024])
298-
let styleAdapter = ldm_modules_encoders_adapter.StyleAdapter(width: 1024, context_dim: 768, num_head: 8, n_layes: 3, num_token: 8).to(torch.device("cpu"))
299-
styleAdapter.load_state_dict(torch.load("/home/liu/workspace/T2I-Adapter/models/t2iadapter_style_sd14v1.pth"))
307+
let styleAdapter = ldm_modules_encoders_adapter.StyleAdapter(
308+
width: 1024, context_dim: 768, num_head: 8, n_layes: 3, num_token: 8
309+
).to(torch.device("cpu"))
310+
styleAdapter.load_state_dict(
311+
torch.load("/home/liu/workspace/T2I-Adapter/models/t2iadapter_style_sd14v1.pth"))
300312
let state_dict = styleAdapter.state_dict()
301313
print(state_dict.keys())
302314
let ret = styleAdapter(style)
303315
print(ret)
304316

305-
let styleEmbed = try Tensor<Float>(numpy: state_dict["style_embedding"].type(torch.float).cpu().numpy())
317+
let styleEmbed = try Tensor<Float>(
318+
numpy: state_dict["style_embedding"].type(torch.float).cpu().numpy())
306319

307320
let graph = DynamicGraph()
308321
let hintTensor = graph.variable(try! Tensor<Float>(numpy: hint.numpy())).toGPU(0)
309322
let styleTensor = graph.variable(try! Tensor<Float>(numpy: style.numpy())).toGPU(0)
310323
// let (reader, adapternet) = Adapter(channels: [320, 640, 1280, 1280], numRepeat: 2)
311324
// let (reader, adapternet) = AdapterLight(channels: [320, 640, 1280, 1280], numRepeat: 4)
312-
let (reader, styleadapternet) = StyleAdapter(width: 1024, outputDim: 768, layers: 3, heads: 8, tokens: 8, batchSize: 1)
325+
let (reader, styleadapternet) = StyleAdapter(
326+
width: 1024, outputDim: 768, layers: 3, heads: 8, tokens: 8, batchSize: 1)
313327
graph.workspaceSize = 1_024 * 1_024 * 1_024
314328
graph.withNoGrad {
315329
// let hintIn = hintTensor.reshaped(format: .NCHW, shape: [2, 3, 64, 8, 64, 8]).permuted(0, 1, 3, 5, 2, 4).copied().reshaped(.NCHW(2, 64 * 3, 64, 64))

0 commit comments

Comments
 (0)