Skip to content

Commit 90a35c8

Browse files
committed
Add dd factory kwargs to eva, resnet
1 parent 7560ca5 commit 90a35c8

File tree

3 files changed

+115
-57
lines changed

3 files changed

+115
-57
lines changed

timm/models/eva.py

Lines changed: 54 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ def __init__(
121121
qk_norm: bool = False,
122122
scale_norm: bool = True,
123123
rotate_half: bool = False,
124+
device=None,
125+
dtype=None,
124126
):
125127
"""
126128
Args:
@@ -139,6 +141,7 @@ def __init__(
139141
scale_norm: Enable normalization (scaling) of attention output with norm_layer
140142
rotate_half: Use half rotation layout instead of interleaved
141143
"""
144+
dd = {'device': device, 'dtype': dtype}
142145
super().__init__()
143146
if scale_norm or qk_norm:
144147
assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'
@@ -154,25 +157,25 @@ def __init__(
154157
self.rotate_half = rotate_half
155158

156159
if qkv_fused:
157-
self.qkv = nn.Linear(dim, attn_dim * 3, bias=False)
160+
self.qkv = nn.Linear(dim, attn_dim * 3, bias=False, **dd)
158161
self.q_proj = self.k_proj = self.v_proj = None
159162
if qkv_bias:
160-
self.q_bias = nn.Parameter(torch.zeros(attn_dim))
161-
self.register_buffer('k_bias', torch.zeros(attn_dim), persistent=False)
162-
self.v_bias = nn.Parameter(torch.zeros(attn_dim))
163+
self.q_bias = nn.Parameter(torch.zeros(attn_dim, **dd))
164+
self.register_buffer('k_bias', torch.zeros(attn_dim, **dd), persistent=False)
165+
self.v_bias = nn.Parameter(torch.zeros(attn_dim, **dd))
163166
else:
164167
self.q_bias = self.k_bias = self.v_bias = None
165168
else:
166-
self.q_proj = nn.Linear(dim, attn_dim, bias=qkv_bias)
167-
self.k_proj = nn.Linear(dim, attn_dim, bias=False)
168-
self.v_proj = nn.Linear(dim, attn_dim, bias=qkv_bias)
169+
self.q_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd)
170+
self.k_proj = nn.Linear(dim, attn_dim, bias=False, **dd)
171+
self.v_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd)
169172
self.qkv = None
170173
self.q_bias = self.k_bias = self.v_bias = None
171-
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
172-
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
174+
self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
175+
self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
173176
self.attn_drop = nn.Dropout(attn_drop)
174-
self.norm = norm_layer(attn_dim) if scale_norm else nn.Identity()
175-
self.proj = nn.Linear(attn_dim, dim)
177+
self.norm = norm_layer(attn_dim, **dd) if scale_norm else nn.Identity()
178+
self.proj = nn.Linear(attn_dim, dim, **dd)
176179
self.proj_drop = nn.Dropout(proj_drop)
177180

178181
def forward(
@@ -263,6 +266,8 @@ def __init__(
263266
act_layer: Callable = nn.GELU,
264267
norm_layer: Callable = LayerNorm,
265268
attn_head_dim: Optional[int] = None,
269+
device=None,
270+
dtype=None,
266271
**kwargs,
267272
):
268273
""" Initialize the EVA transformer block.
@@ -286,8 +291,10 @@ def __init__(
286291
norm_layer: Normalization layer constructor
287292
attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads)
288293
"""
294+
dd = {'device': device, 'dtype': dtype}
289295
super().__init__()
290-
self.norm1 = norm_layer(dim)
296+
297+
self.norm1 = norm_layer(dim, **dd)
291298
attn_cls = AttentionRope if attn_type == 'rope' else EvaAttention
292299
self.attn = attn_cls(
293300
dim,
@@ -301,11 +308,12 @@ def __init__(
301308
norm_layer=norm_layer,
302309
scale_norm=scale_attn_inner,
303310
rotate_half=rotate_half,
311+
**dd,
304312
)
305-
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim)) if init_values is not None else None
313+
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim, **dd)) if init_values is not None else None
306314
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
307315

308-
self.norm2 = norm_layer(dim)
316+
self.norm2 = norm_layer(dim, **dd)
309317
hidden_features = int(dim * mlp_ratio)
310318
if swiglu_mlp:
311319
if scale_mlp or swiglu_align_to:
@@ -316,6 +324,7 @@ def __init__(
316324
norm_layer=norm_layer if scale_mlp else None,
317325
drop=proj_drop,
318326
align_to=swiglu_align_to,
327+
**dd,
319328
)
320329
else:
321330
# w/o any extra norm, an impl with packed weights is used
@@ -326,6 +335,7 @@ def __init__(
326335
act_layer=nn.SiLU,
327336
gate_last=False,
328337
drop=proj_drop,
338+
**dd,
329339
)
330340
else:
331341
self.mlp = Mlp(
@@ -334,8 +344,9 @@ def __init__(
334344
act_layer=act_layer,
335345
norm_layer=norm_layer if scale_mlp else None,
336346
drop=proj_drop,
347+
**dd,
337348
)
338-
self.gamma_2 = nn.Parameter(init_values * torch.ones(dim)) if init_values is not None else None
349+
self.gamma_2 = nn.Parameter(init_values * torch.ones(dim, **dd)) if init_values is not None else None
339350
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
340351

341352
def forward(
@@ -376,6 +387,8 @@ def __init__(
376387
act_layer: Callable = nn.GELU,
377388
norm_layer: Callable = nn.LayerNorm,
378389
attn_head_dim: Optional[int] = None,
390+
device=None,
391+
dtype=None,
379392
):
380393
""" Initialize the post-norm EVA transformer block.
381394
@@ -398,7 +411,9 @@ def __init__(
398411
norm_layer: Normalization layer constructor
399412
attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads)
400413
"""
414+
dd = {'device': device, 'dtype': dtype}
401415
super().__init__()
416+
402417
attn_cls = AttentionRope if attn_type == 'rope' else EvaAttention
403418
self.attn = attn_cls(
404419
dim,
@@ -412,8 +427,9 @@ def __init__(
412427
norm_layer=norm_layer,
413428
scale_norm=scale_attn_inner,
414429
rotate_half=rotate_half,
430+
**dd,
415431
)
416-
self.norm1 = norm_layer(dim)
432+
self.norm1 = norm_layer(dim, **dd)
417433
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
418434

419435
hidden_features = int(dim * mlp_ratio)
@@ -426,6 +442,7 @@ def __init__(
426442
norm_layer=norm_layer if scale_mlp else None,
427443
drop=proj_drop,
428444
align_to=swiglu_align_to,
445+
**dd,
429446
)
430447
else:
431448
# w/o any extra norm, an impl with packed fc1 weights is used, matches existing GluMLP
@@ -436,6 +453,7 @@ def __init__(
436453
act_layer=nn.SiLU,
437454
gate_last=False,
438455
drop=proj_drop,
456+
**dd,
439457
)
440458
else:
441459
self.mlp = Mlp(
@@ -444,8 +462,9 @@ def __init__(
444462
act_layer=act_layer,
445463
norm_layer=norm_layer if scale_mlp else None,
446464
drop=proj_drop,
465+
**dd,
447466
)
448-
self.norm2 = norm_layer(dim)
467+
self.norm2 = norm_layer(dim, **dd)
449468
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
450469

451470
def forward(
@@ -513,6 +532,8 @@ def __init__(
513532
dynamic_img_pad: bool = False,
514533
ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None,
515534
head_init_scale: float = 0.001,
535+
device=None,
536+
dtype=None,
516537
):
517538
"""Initialize the EVA Vision Transformer model.
518539
@@ -562,6 +583,7 @@ def __init__(
562583
head_init_scale: Initialization scale for classification head weights
563584
"""
564585
super().__init__()
586+
dd = {'device': device, 'dtype': dtype}
565587
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
566588
self.num_classes = num_classes
567589
self.global_pool = global_pool
@@ -594,16 +616,17 @@ def __init__(
594616
dynamic_img_pad=dynamic_img_pad,
595617
bias=not use_pre_transformer_norm,
596618
**embed_args,
619+
**dd,
597620
)
598621
num_patches = self.patch_embed.num_patches
599622
r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
600623

601-
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
602-
self.reg_token = nn.Parameter(torch.zeros(1, num_reg_tokens, embed_dim)) if num_reg_tokens else None
624+
self.cls_token = nn.Parameter(torch.empty(1, 1, embed_dim, **dd)) if class_token else None
625+
self.reg_token = nn.Parameter(torch.empty(1, num_reg_tokens, embed_dim, **dd)) if num_reg_tokens else None
603626
self.cls_embed = class_token and self.reg_token is None
604627

605628
num_pos_tokens = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
606-
self.pos_embed = nn.Parameter(torch.zeros(1, num_pos_tokens, embed_dim)) if use_abs_pos_emb else None
629+
self.pos_embed = nn.Parameter(torch.empty(1, num_pos_tokens, embed_dim, **dd)) if use_abs_pos_emb else None
607630
self.pos_drop = nn.Dropout(p=pos_drop_rate)
608631
if patch_drop_rate > 0:
609632
self.patch_drop = PatchDropoutWithIndices(patch_drop_rate, num_prefix_tokens=self.num_prefix_tokens)
@@ -621,6 +644,7 @@ def __init__(
621644
feat_shape=None if dynamic_img_size else self.patch_embed.grid_size,
622645
temperature=rope_temperature,
623646
grid_indexing=rope_grid_indexing,
647+
**dd,
624648
)
625649
if rope_type == 'mixed':
626650
rope_kwargs.update(dict(depth=depth))
@@ -636,7 +660,7 @@ def __init__(
636660
else:
637661
self.rope = None
638662

639-
self.norm_pre = norm_layer(embed_dim) if activate_pre_norm else nn.Identity()
663+
self.norm_pre = norm_layer(embed_dim, **dd) if activate_pre_norm else nn.Identity()
640664

641665
dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule
642666
block_fn = EvaBlockPostNorm if use_post_norm else EvaBlock
@@ -659,12 +683,13 @@ def __init__(
659683
drop_path=dpr[i],
660684
norm_layer=norm_layer,
661685
init_values=init_values,
686+
**dd,
662687
)
663688
for i in range(depth)])
664689
self.feature_info = [
665690
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
666691

667-
self.norm = norm_layer(embed_dim) if activate_post_norm else nn.Identity()
692+
self.norm = norm_layer(embed_dim, **dd) if activate_post_norm else nn.Identity()
668693

669694
if global_pool == 'map':
670695
self.attn_pool = AttentionPoolLatent(
@@ -673,23 +698,26 @@ def __init__(
673698
mlp_ratio=attn_pool_mlp_ratio or mlp_ratio,
674699
norm_layer=norm_layer,
675700
act_layer=nn.GELU,
701+
**dd,
676702
)
677703
else:
678704
self.attn_pool = None
679-
self.fc_norm = norm_layer(embed_dim) if activate_fc_norm else nn.Identity()
705+
self.fc_norm = norm_layer(embed_dim, **dd) if activate_fc_norm else nn.Identity()
680706
self.head_drop = nn.Dropout(drop_rate)
681-
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
707+
self.head = nn.Linear(embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
682708

709+
self.init_weights(head_init_scale=head_init_scale)
710+
711+
def init_weights(self, head_init_scale=None):
683712
self.apply(self._init_weights)
684713
if self.pos_embed is not None:
685714
trunc_normal_(self.pos_embed, std=.02)
686715
if self.cls_token is not None:
687716
trunc_normal_(self.cls_token, std=.02)
688717
if self.reg_token is not None:
689718
trunc_normal_(self.reg_token, std=.02)
690-
691719
self.fix_init_weight()
692-
if isinstance(self.head, nn.Linear):
720+
if head_init_scale and isinstance(self.head, nn.Linear):
693721
trunc_normal_(self.head.weight, std=.02)
694722
self.head.weight.data.mul_(head_init_scale)
695723
self.head.bias.data.mul_(head_init_scale)

0 commit comments

Comments
 (0)