@@ -38,9 +38,7 @@ func ResnetBlock(outChannels: Int, inConv: Bool) -> (
3838 )
3939}
4040
41- func Adapter(
42- channels: [ Int ] , numRepeat: Int
43- ) -> ( ( PythonObject ) -> Void , Model ) {
41+ func Adapter( channels: [ Int ] , numRepeat: Int ) -> ( ( PythonObject ) -> Void , Model ) {
4442 let x = Input ( )
4543 let convIn = Convolution (
4644 groups: 1 , filters: channels [ 0 ] , filterSize: [ 3 , 3 ] ,
@@ -90,25 +88,108 @@ func Adapter(
9088 return ( reader, Model ( [ x] , outs) )
9189}
9290
91+ func ResnetBlockLight( outChannels: Int ) -> (
92+ Model , Model , Model
93+ ) {
94+ let x = Input ( )
95+ let inLayerConv2d = Convolution (
96+ groups: 1 , filters: outChannels, filterSize: [ 3 , 3 ] ,
97+ hint: Hint ( stride: [ 1 , 1 ] , border: Hint . Border ( begin: [ 1 , 1 ] , end: [ 1 , 1 ] ) ) )
98+ var out = inLayerConv2d ( x)
99+ out = ReLU ( ) ( out)
100+ // Dropout if needed in the future (for training).
101+ let outLayerConv2d = Convolution (
102+ groups: 1 , filters: outChannels, filterSize: [ 3 , 3 ] ,
103+ hint: Hint ( stride: [ 1 , 1 ] , border: Hint . Border ( begin: [ 1 , 1 ] , end: [ 1 , 1 ] ) ) )
104+ out = outLayerConv2d ( out) + x
105+ return (
106+ inLayerConv2d, outLayerConv2d, Model ( [ x] , [ out] )
107+ )
108+ }
109+
110+ func Extractor( prefix: String , channel: Int , innerChannel: Int , numRepeat: Int , downsample: Bool ) -> ( ( PythonObject ) -> Void , Model ) {
111+ let x = Input ( )
112+ let inConv = Convolution ( groups: 1 , filters: innerChannel, filterSize: [ 1 , 1 ] , hint: Hint ( stride: [ 1 , 1 ] ) )
113+ var out = inConv ( x)
114+ var readers = [ ( PythonObject ) -> Void ] ( )
115+ for i in 0 ..< numRepeat {
116+ let ( inLayerConv2d, outLayerConv2d, resnetBlock) = ResnetBlockLight ( outChannels: innerChannel)
117+ out = resnetBlock ( out)
118+ let reader : ( PythonObject ) -> Void = { state_dict in
119+ let block1_weight = state_dict [ " body. \( prefix) .body. \( i) .block1.weight " ] . numpy ( )
120+ let block1_bias = state_dict [ " body. \( prefix) .body. \( i) .block1.bias " ] . numpy ( )
121+ inLayerConv2d. parameters ( for: . weight) . copy ( from: try ! Tensor < Float > ( numpy: block1_weight) )
122+ inLayerConv2d. parameters ( for: . bias) . copy ( from: try ! Tensor < Float > ( numpy: block1_bias) )
123+ let block2_weight = state_dict [ " body. \( prefix) .body. \( i) .block2.weight " ] . numpy ( )
124+ let block2_bias = state_dict [ " body. \( prefix) .body. \( i) .block2.bias " ] . numpy ( )
125+ outLayerConv2d. parameters ( for: . weight) . copy ( from: try ! Tensor < Float > ( numpy: block2_weight) )
126+ outLayerConv2d. parameters ( for: . bias) . copy ( from: try ! Tensor < Float > ( numpy: block2_bias) )
127+ }
128+ readers. append ( reader)
129+ }
130+ let outConv = Convolution ( groups: 1 , filters: channel, filterSize: [ 1 , 1 ] , hint: Hint ( stride: [ 1 , 1 ] ) )
131+ out = outConv ( out)
132+ if downsample {
133+ let downsample = AveragePool ( filterSize: [ 2 , 2 ] , hint: Hint ( stride: [ 2 , 2 ] ) )
134+ out = downsample ( out)
135+ }
136+ let reader : ( PythonObject ) -> Void = { state_dict in
137+ let in_conv_weight = state_dict [ " body. \( prefix) .in_conv.weight " ] . numpy ( )
138+ let in_conv_bias = state_dict [ " body. \( prefix) .in_conv.bias " ] . numpy ( )
139+ inConv. parameters ( for: . weight) . copy ( from: try ! Tensor < Float > ( numpy: in_conv_weight) )
140+ inConv. parameters ( for: . bias) . copy ( from: try ! Tensor < Float > ( numpy: in_conv_bias) )
141+ let out_conv_weight = state_dict [ " body. \( prefix) .out_conv.weight " ] . numpy ( )
142+ let out_conv_bias = state_dict [ " body. \( prefix) .out_conv.bias " ] . numpy ( )
143+ outConv. parameters ( for: . weight) . copy ( from: try ! Tensor < Float > ( numpy: out_conv_weight) )
144+ outConv. parameters ( for: . bias) . copy ( from: try ! Tensor < Float > ( numpy: out_conv_bias) )
145+ for reader in readers {
146+ reader ( state_dict)
147+ }
148+ }
149+ return ( reader, Model ( [ x] , [ out] ) )
150+ }
151+
152+ func AdapterLight( channels: [ Int ] , numRepeat: Int ) -> ( ( PythonObject ) -> Void , Model ) {
153+ var readers = [ ( PythonObject ) -> Void ] ( )
154+ let x = Input ( )
155+ var out : Model . IO = x
156+ var outs = [ Model . IO] ( )
157+ for (i, channel) in channels. enumerated ( ) {
158+ let ( reader, extractor) = Extractor ( prefix: " \( i) " , channel: channel, innerChannel: channel / 4 , numRepeat: numRepeat, downsample: i != 0 )
159+ out = extractor ( out)
160+ outs. append ( out)
161+ readers. append ( reader)
162+ }
163+ let reader : ( PythonObject ) -> Void = { state_dict in
164+ for reader in readers {
165+ reader ( state_dict)
166+ }
167+ }
168+ return ( reader, Model ( [ x] , outs) )
169+ }
170+
93171random. seed ( 42 )
94172numpy. random. seed ( 42 )
95173torch. manual_seed ( 42 )
96174torch. cuda. manual_seed_all ( 42 )
97175
98- let hint = torch. randn ( [ 2 , 1 , 512 , 512 ] )
176+ let hint = torch. randn ( [ 2 , 3 , 512 , 512 ] )
99177
100- let adapter = ldm_modules_encoders_adapter. Adapter ( cin: 64 , channels: [ 320 , 640 , 1280 , 1280 ] , nums_rb: 2 , ksize: 1 , sk: true , use_conv: false ) . to ( torch. device ( " cpu " ) )
101- adapter. load_state_dict ( torch. load ( " /home/liu/workspace/T2I-Adapter/models/t2iadapter_canny_sd14v1.pth " ) )
102- let state_dict = adapter. state_dict ( )
103- let ret = adapter ( hint)
178+ // let adapter = ldm_modules_encoders_adapter.Adapter(cin: 64, channels: [320, 640, 1280, 1280], nums_rb: 2, ksize: 1, sk: true, use_conv: false).to(torch.device("cpu"))
179+ let adapterLight = ldm_modules_encoders_adapter. Adapter_light ( cin: 64 * 3 , channels: [ 320 , 640 , 1280 , 1280 ] , nums_rb: 4 ) . to ( torch. device ( " cpu " ) )
180+ adapterLight. load_state_dict ( torch. load ( " /home/liu/workspace/T2I-Adapter/models/t2iadapter_color_sd14v1.pth " ) )
181+ let state_dict = adapterLight. state_dict ( )
182+ let ret = adapterLight ( hint)
183+ print ( adapterLight)
104184print ( ret [ 0 ] )
105185
106186let graph = DynamicGraph ( )
107187let hintTensor = graph. variable ( try ! Tensor < Float > ( numpy: hint. numpy ( ) ) ) . toGPU ( 0 )
108- let ( reader, adapternet) = Adapter ( channels: [ 320 , 640 , 1280 , 1280 ] , numRepeat: 2 )
188+ // let (reader, adapternet) = Adapter(channels: [320, 640, 1280, 1280], numRepeat: 2)
189+ let ( reader, adapternet) = AdapterLight ( channels: [ 320 , 640 , 1280 , 1280 ] , numRepeat: 4 )
109190graph. workspaceSize = 1_024 * 1_024 * 1_024
110191graph. withNoGrad {
111- let hintIn = hintTensor. reshaped ( format: . NCHW, shape: [ 2 , 1 , 64 , 8 , 64 , 8 ] ) . permuted ( 0 , 1 , 3 , 5 , 2 , 4 ) . copied ( ) . reshaped ( . NCHW( 2 , 64 , 64 , 64 ) )
192+ let hintIn = hintTensor. reshaped ( format: . NCHW, shape: [ 2 , 3 , 64 , 8 , 64 , 8 ] ) . permuted ( 0 , 1 , 3 , 5 , 2 , 4 ) . copied ( ) . reshaped ( . NCHW( 2 , 64 * 3 , 64 , 64 ) )
112193 var controls = adapternet ( inputs: hintIn) . map { $0. as ( of: Float . self) }
113194 reader ( state_dict)
114195 controls = adapternet ( inputs: hintIn) . map { $0. as ( of: Float . self) }
0 commit comments