Skip to content

Commit 481a693

Browse files
committed
dd factory kwargs for fastvit, convnext, mambaout
1 parent 8721e1a commit 481a693

File tree

3 files changed

+220
-82
lines changed

3 files changed

+220
-82
lines changed

timm/models/convnext.py

Lines changed: 64 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,27 @@
4444
import torch.nn as nn
4545

4646
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
47-
from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, calculate_drop_path_rates, Mlp, GlobalResponseNormMlp, \
48-
LayerNorm2d, LayerNorm, RmsNorm2d, RmsNorm, create_conv2d, get_act_layer, get_norm_layer, make_divisible, to_ntuple
49-
from timm.layers import SimpleNorm2d, SimpleNorm
50-
from timm.layers import NormMlpClassifierHead, ClassifierHead
47+
from timm.layers import (
48+
trunc_normal_,
49+
AvgPool2dSame,
50+
DropPath,
51+
calculate_drop_path_rates,
52+
Mlp,
53+
GlobalResponseNormMlp,
54+
LayerNorm2d,
55+
LayerNorm,
56+
RmsNorm2d,
57+
RmsNorm,
58+
SimpleNorm2d,
59+
SimpleNorm,
60+
create_conv2d,
61+
get_act_layer,
62+
get_norm_layer,
63+
make_divisible,
64+
to_ntuple,
65+
NormMlpClassifierHead,
66+
ClassifierHead,
67+
)
5168
from ._builder import build_model_with_cfg
5269
from ._features import feature_take_indices
5370
from ._manipulate import named_apply, checkpoint_seq
@@ -59,7 +76,15 @@
5976
class Downsample(nn.Module):
6077
"""Downsample module for ConvNeXt."""
6178

62-
def __init__(self, in_chs: int, out_chs: int, stride: int = 1, dilation: int = 1) -> None:
79+
def __init__(
80+
self,
81+
in_chs: int,
82+
out_chs: int,
83+
stride: int = 1,
84+
dilation: int = 1,
85+
device=None,
86+
dtype=None,
87+
) -> None:
6388
"""Initialize Downsample module.
6489
6590
Args:
@@ -68,6 +93,7 @@ def __init__(self, in_chs: int, out_chs: int, stride: int = 1, dilation: int = 1
6893
stride: Stride for downsampling.
6994
dilation: Dilation rate.
7095
"""
96+
dd = {'device': device, 'dtype': dtype}
7197
super().__init__()
7298
avg_stride = stride if dilation == 1 else 1
7399
if stride > 1 or dilation > 1:
@@ -77,7 +103,7 @@ def __init__(self, in_chs: int, out_chs: int, stride: int = 1, dilation: int = 1
77103
self.pool = nn.Identity()
78104

79105
if in_chs != out_chs:
80-
self.conv = create_conv2d(in_chs, out_chs, 1, stride=1)
106+
self.conv = create_conv2d(in_chs, out_chs, 1, stride=1, **dd)
81107
else:
82108
self.conv = nn.Identity()
83109

@@ -115,6 +141,8 @@ def __init__(
115141
act_layer: Union[str, Callable] = 'gelu',
116142
norm_layer: Optional[Callable] = None,
117143
drop_path: float = 0.,
144+
device=None,
145+
dtype=None,
118146
):
119147
"""
120148
@@ -133,6 +161,7 @@ def __init__(
133161
norm_layer: Normalization layer (defaults to LN if not specified).
134162
drop_path: Stochastic depth probability.
135163
"""
164+
dd = {'device': device, 'dtype': dtype}
136165
super().__init__()
137166
out_chs = out_chs or in_chs
138167
dilation = to_ntuple(2)(dilation)
@@ -149,12 +178,18 @@ def __init__(
149178
dilation=dilation[0],
150179
depthwise=True,
151180
bias=conv_bias,
181+
**dd,
182+
)
183+
self.norm = norm_layer(out_chs, **dd)
184+
self.mlp = mlp_layer(
185+
out_chs,
186+
int(mlp_ratio * out_chs),
187+
act_layer=act_layer,
188+
**dd,
152189
)
153-
self.norm = norm_layer(out_chs)
154-
self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
155-
self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value is not None else None
190+
self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs, **dd)) if ls_init_value is not None else None
156191
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
157-
self.shortcut = Downsample(in_chs, out_chs, stride=stride, dilation=dilation[0])
192+
self.shortcut = Downsample(in_chs, out_chs, stride=stride, dilation=dilation[0], **dd)
158193
else:
159194
self.shortcut = nn.Identity()
160195
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@@ -196,7 +231,9 @@ def __init__(
196231
use_grn: bool = False,
197232
act_layer: Union[str, Callable] = 'gelu',
198233
norm_layer: Optional[Callable] = None,
199-
norm_layer_cl: Optional[Callable] = None
234+
norm_layer_cl: Optional[Callable] = None,
235+
device=None,
236+
dtype=None,
200237
) -> None:
201238
"""Initialize ConvNeXt stage.
202239
@@ -216,14 +253,15 @@ def __init__(
216253
norm_layer: Normalization layer.
217254
norm_layer_cl: Normalization layer for channels last.
218255
"""
256+
dd = {'device': device, 'dtype': dtype}
219257
super().__init__()
220258
self.grad_checkpointing = False
221259

222260
if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]:
223261
ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1
224262
pad = 'same' if dilation[1] > 1 else 0 # same padding needed if dilation used
225263
self.downsample = nn.Sequential(
226-
norm_layer(in_chs),
264+
norm_layer(in_chs, **dd),
227265
create_conv2d(
228266
in_chs,
229267
out_chs,
@@ -232,6 +270,7 @@ def __init__(
232270
dilation=dilation[0],
233271
padding=pad,
234272
bias=conv_bias,
273+
**dd,
235274
),
236275
)
237276
in_chs = out_chs
@@ -253,6 +292,7 @@ def __init__(
253292
use_grn=use_grn,
254293
act_layer=act_layer,
255294
norm_layer=norm_layer if conv_mlp else norm_layer_cl,
295+
**dd,
256296
))
257297
in_chs = out_chs
258298
self.blocks = nn.Sequential(*stage_blocks)
@@ -324,6 +364,8 @@ def __init__(
324364
norm_eps: Optional[float] = None,
325365
drop_rate: float = 0.,
326366
drop_path_rate: float = 0.,
367+
device=None,
368+
dtype=None,
327369
):
328370
"""
329371
Args:
@@ -349,6 +391,7 @@ def __init__(
349391
drop_path_rate: Stochastic depth drop rate.
350392
"""
351393
super().__init__()
394+
dd = {'device': device, 'dtype': dtype}
352395
assert output_stride in (8, 16, 32)
353396
kernel_sizes = to_ntuple(4)(kernel_sizes)
354397
norm_layer, norm_layer_cl = _get_norm_layers(norm_layer, conv_mlp, norm_eps)
@@ -362,17 +405,17 @@ def __init__(
362405
if stem_type == 'patch':
363406
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
364407
self.stem = nn.Sequential(
365-
nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias),
366-
norm_layer(dims[0]),
408+
nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias, **dd),
409+
norm_layer(dims[0], **dd),
367410
)
368411
stem_stride = patch_size
369412
else:
370413
mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0]
371414
self.stem = nn.Sequential(*filter(None, [
372-
nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias),
415+
nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias, **dd),
373416
act_layer() if 'act' in stem_type else None,
374-
nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias),
375-
norm_layer(dims[0]),
417+
nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias, **dd),
418+
norm_layer(dims[0], **dd),
376419
]))
377420
stem_stride = 4
378421

@@ -406,6 +449,7 @@ def __init__(
406449
act_layer=act_layer,
407450
norm_layer=norm_layer,
408451
norm_layer_cl=norm_layer_cl,
452+
**dd,
409453
))
410454
prev_chs = out_chs
411455
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
@@ -417,12 +461,13 @@ def __init__(
417461
# otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
418462
if head_norm_first:
419463
assert not head_hidden_size
420-
self.norm_pre = norm_layer(self.num_features)
464+
self.norm_pre = norm_layer(self.num_features, **dd)
421465
self.head = ClassifierHead(
422466
self.num_features,
423467
num_classes,
424468
pool_type=global_pool,
425469
drop_rate=self.drop_rate,
470+
**dd,
426471
)
427472
else:
428473
self.norm_pre = nn.Identity()
@@ -434,6 +479,7 @@ def __init__(
434479
drop_rate=self.drop_rate,
435480
norm_layer=norm_layer,
436481
act_layer='gelu',
482+
**dd,
437483
)
438484
self.head_hidden_size = self.head.num_features
439485
named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)

0 commit comments

Comments
 (0)