44
44
import torch .nn as nn
45
45
46
46
from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD , OPENAI_CLIP_MEAN , OPENAI_CLIP_STD
47
- from timm .layers import trunc_normal_ , AvgPool2dSame , DropPath , calculate_drop_path_rates , Mlp , GlobalResponseNormMlp , \
48
- LayerNorm2d , LayerNorm , RmsNorm2d , RmsNorm , create_conv2d , get_act_layer , get_norm_layer , make_divisible , to_ntuple
49
- from timm .layers import SimpleNorm2d , SimpleNorm
50
- from timm .layers import NormMlpClassifierHead , ClassifierHead
47
+ from timm .layers import (
48
+ trunc_normal_ ,
49
+ AvgPool2dSame ,
50
+ DropPath ,
51
+ calculate_drop_path_rates ,
52
+ Mlp ,
53
+ GlobalResponseNormMlp ,
54
+ LayerNorm2d ,
55
+ LayerNorm ,
56
+ RmsNorm2d ,
57
+ RmsNorm ,
58
+ SimpleNorm2d ,
59
+ SimpleNorm ,
60
+ create_conv2d ,
61
+ get_act_layer ,
62
+ get_norm_layer ,
63
+ make_divisible ,
64
+ to_ntuple ,
65
+ NormMlpClassifierHead ,
66
+ ClassifierHead ,
67
+ )
51
68
from ._builder import build_model_with_cfg
52
69
from ._features import feature_take_indices
53
70
from ._manipulate import named_apply , checkpoint_seq
59
76
class Downsample (nn .Module ):
60
77
"""Downsample module for ConvNeXt."""
61
78
62
- def __init__ (self , in_chs : int , out_chs : int , stride : int = 1 , dilation : int = 1 ) -> None :
79
+ def __init__ (
80
+ self ,
81
+ in_chs : int ,
82
+ out_chs : int ,
83
+ stride : int = 1 ,
84
+ dilation : int = 1 ,
85
+ device = None ,
86
+ dtype = None ,
87
+ ) -> None :
63
88
"""Initialize Downsample module.
64
89
65
90
Args:
@@ -68,6 +93,7 @@ def __init__(self, in_chs: int, out_chs: int, stride: int = 1, dilation: int = 1
68
93
stride: Stride for downsampling.
69
94
dilation: Dilation rate.
70
95
"""
96
+ dd = {'device' : device , 'dtype' : dtype }
71
97
super ().__init__ ()
72
98
avg_stride = stride if dilation == 1 else 1
73
99
if stride > 1 or dilation > 1 :
@@ -77,7 +103,7 @@ def __init__(self, in_chs: int, out_chs: int, stride: int = 1, dilation: int = 1
77
103
self .pool = nn .Identity ()
78
104
79
105
if in_chs != out_chs :
80
- self .conv = create_conv2d (in_chs , out_chs , 1 , stride = 1 )
106
+ self .conv = create_conv2d (in_chs , out_chs , 1 , stride = 1 , ** dd )
81
107
else :
82
108
self .conv = nn .Identity ()
83
109
@@ -115,6 +141,8 @@ def __init__(
115
141
act_layer : Union [str , Callable ] = 'gelu' ,
116
142
norm_layer : Optional [Callable ] = None ,
117
143
drop_path : float = 0. ,
144
+ device = None ,
145
+ dtype = None ,
118
146
):
119
147
"""
120
148
@@ -133,6 +161,7 @@ def __init__(
133
161
norm_layer: Normalization layer (defaults to LN if not specified).
134
162
drop_path: Stochastic depth probability.
135
163
"""
164
+ dd = {'device' : device , 'dtype' : dtype }
136
165
super ().__init__ ()
137
166
out_chs = out_chs or in_chs
138
167
dilation = to_ntuple (2 )(dilation )
@@ -149,12 +178,18 @@ def __init__(
149
178
dilation = dilation [0 ],
150
179
depthwise = True ,
151
180
bias = conv_bias ,
181
+ ** dd ,
182
+ )
183
+ self .norm = norm_layer (out_chs , ** dd )
184
+ self .mlp = mlp_layer (
185
+ out_chs ,
186
+ int (mlp_ratio * out_chs ),
187
+ act_layer = act_layer ,
188
+ ** dd ,
152
189
)
153
- self .norm = norm_layer (out_chs )
154
- self .mlp = mlp_layer (out_chs , int (mlp_ratio * out_chs ), act_layer = act_layer )
155
- self .gamma = nn .Parameter (ls_init_value * torch .ones (out_chs )) if ls_init_value is not None else None
190
+ self .gamma = nn .Parameter (ls_init_value * torch .ones (out_chs , ** dd )) if ls_init_value is not None else None
156
191
if in_chs != out_chs or stride != 1 or dilation [0 ] != dilation [1 ]:
157
- self .shortcut = Downsample (in_chs , out_chs , stride = stride , dilation = dilation [0 ])
192
+ self .shortcut = Downsample (in_chs , out_chs , stride = stride , dilation = dilation [0 ], ** dd )
158
193
else :
159
194
self .shortcut = nn .Identity ()
160
195
self .drop_path = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
@@ -196,7 +231,9 @@ def __init__(
196
231
use_grn : bool = False ,
197
232
act_layer : Union [str , Callable ] = 'gelu' ,
198
233
norm_layer : Optional [Callable ] = None ,
199
- norm_layer_cl : Optional [Callable ] = None
234
+ norm_layer_cl : Optional [Callable ] = None ,
235
+ device = None ,
236
+ dtype = None ,
200
237
) -> None :
201
238
"""Initialize ConvNeXt stage.
202
239
@@ -216,14 +253,15 @@ def __init__(
216
253
norm_layer: Normalization layer.
217
254
norm_layer_cl: Normalization layer for channels last.
218
255
"""
256
+ dd = {'device' : device , 'dtype' : dtype }
219
257
super ().__init__ ()
220
258
self .grad_checkpointing = False
221
259
222
260
if in_chs != out_chs or stride > 1 or dilation [0 ] != dilation [1 ]:
223
261
ds_ks = 2 if stride > 1 or dilation [0 ] != dilation [1 ] else 1
224
262
pad = 'same' if dilation [1 ] > 1 else 0 # same padding needed if dilation used
225
263
self .downsample = nn .Sequential (
226
- norm_layer (in_chs ),
264
+ norm_layer (in_chs , ** dd ),
227
265
create_conv2d (
228
266
in_chs ,
229
267
out_chs ,
@@ -232,6 +270,7 @@ def __init__(
232
270
dilation = dilation [0 ],
233
271
padding = pad ,
234
272
bias = conv_bias ,
273
+ ** dd ,
235
274
),
236
275
)
237
276
in_chs = out_chs
@@ -253,6 +292,7 @@ def __init__(
253
292
use_grn = use_grn ,
254
293
act_layer = act_layer ,
255
294
norm_layer = norm_layer if conv_mlp else norm_layer_cl ,
295
+ ** dd ,
256
296
))
257
297
in_chs = out_chs
258
298
self .blocks = nn .Sequential (* stage_blocks )
@@ -324,6 +364,8 @@ def __init__(
324
364
norm_eps : Optional [float ] = None ,
325
365
drop_rate : float = 0. ,
326
366
drop_path_rate : float = 0. ,
367
+ device = None ,
368
+ dtype = None ,
327
369
):
328
370
"""
329
371
Args:
@@ -349,6 +391,7 @@ def __init__(
349
391
drop_path_rate: Stochastic depth drop rate.
350
392
"""
351
393
super ().__init__ ()
394
+ dd = {'device' : device , 'dtype' : dtype }
352
395
assert output_stride in (8 , 16 , 32 )
353
396
kernel_sizes = to_ntuple (4 )(kernel_sizes )
354
397
norm_layer , norm_layer_cl = _get_norm_layers (norm_layer , conv_mlp , norm_eps )
@@ -362,17 +405,17 @@ def __init__(
362
405
if stem_type == 'patch' :
363
406
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
364
407
self .stem = nn .Sequential (
365
- nn .Conv2d (in_chans , dims [0 ], kernel_size = patch_size , stride = patch_size , bias = conv_bias ),
366
- norm_layer (dims [0 ]),
408
+ nn .Conv2d (in_chans , dims [0 ], kernel_size = patch_size , stride = patch_size , bias = conv_bias , ** dd ),
409
+ norm_layer (dims [0 ], ** dd ),
367
410
)
368
411
stem_stride = patch_size
369
412
else :
370
413
mid_chs = make_divisible (dims [0 ] // 2 ) if 'tiered' in stem_type else dims [0 ]
371
414
self .stem = nn .Sequential (* filter (None , [
372
- nn .Conv2d (in_chans , mid_chs , kernel_size = 3 , stride = 2 , padding = 1 , bias = conv_bias ),
415
+ nn .Conv2d (in_chans , mid_chs , kernel_size = 3 , stride = 2 , padding = 1 , bias = conv_bias , ** dd ),
373
416
act_layer () if 'act' in stem_type else None ,
374
- nn .Conv2d (mid_chs , dims [0 ], kernel_size = 3 , stride = 2 , padding = 1 , bias = conv_bias ),
375
- norm_layer (dims [0 ]),
417
+ nn .Conv2d (mid_chs , dims [0 ], kernel_size = 3 , stride = 2 , padding = 1 , bias = conv_bias , ** dd ),
418
+ norm_layer (dims [0 ], ** dd ),
376
419
]))
377
420
stem_stride = 4
378
421
@@ -406,6 +449,7 @@ def __init__(
406
449
act_layer = act_layer ,
407
450
norm_layer = norm_layer ,
408
451
norm_layer_cl = norm_layer_cl ,
452
+ ** dd ,
409
453
))
410
454
prev_chs = out_chs
411
455
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
@@ -417,12 +461,13 @@ def __init__(
417
461
# otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
418
462
if head_norm_first :
419
463
assert not head_hidden_size
420
- self .norm_pre = norm_layer (self .num_features )
464
+ self .norm_pre = norm_layer (self .num_features , ** dd )
421
465
self .head = ClassifierHead (
422
466
self .num_features ,
423
467
num_classes ,
424
468
pool_type = global_pool ,
425
469
drop_rate = self .drop_rate ,
470
+ ** dd ,
426
471
)
427
472
else :
428
473
self .norm_pre = nn .Identity ()
@@ -434,6 +479,7 @@ def __init__(
434
479
drop_rate = self .drop_rate ,
435
480
norm_layer = norm_layer ,
436
481
act_layer = 'gelu' ,
482
+ ** dd ,
437
483
)
438
484
self .head_hidden_size = self .head .num_features
439
485
named_apply (partial (_init_weights , head_init_scale = head_init_scale ), self )
0 commit comments