@@ -121,6 +121,8 @@ def __init__(
121
121
qk_norm : bool = False ,
122
122
scale_norm : bool = True ,
123
123
rotate_half : bool = False ,
124
+ device = None ,
125
+ dtype = None ,
124
126
):
125
127
"""
126
128
Args:
@@ -139,6 +141,7 @@ def __init__(
139
141
scale_norm: Enable normalization (scaling) of attention output with norm_layer
140
142
rotate_half: Use half rotation layout instead of interleaved
141
143
"""
144
+ dd = {'device' : device , 'dtype' : dtype }
142
145
super ().__init__ ()
143
146
if scale_norm or qk_norm :
144
147
assert norm_layer is not None , 'norm_layer must be provided if qk_norm or scale_norm is True'
@@ -154,25 +157,25 @@ def __init__(
154
157
self .rotate_half = rotate_half
155
158
156
159
if qkv_fused :
157
- self .qkv = nn .Linear (dim , attn_dim * 3 , bias = False )
160
+ self .qkv = nn .Linear (dim , attn_dim * 3 , bias = False , ** dd )
158
161
self .q_proj = self .k_proj = self .v_proj = None
159
162
if qkv_bias :
160
- self .q_bias = nn .Parameter (torch .zeros (attn_dim ))
161
- self .register_buffer ('k_bias' , torch .zeros (attn_dim ), persistent = False )
162
- self .v_bias = nn .Parameter (torch .zeros (attn_dim ))
163
+ self .q_bias = nn .Parameter (torch .zeros (attn_dim , ** dd ))
164
+ self .register_buffer ('k_bias' , torch .zeros (attn_dim , ** dd ), persistent = False )
165
+ self .v_bias = nn .Parameter (torch .zeros (attn_dim , ** dd ))
163
166
else :
164
167
self .q_bias = self .k_bias = self .v_bias = None
165
168
else :
166
- self .q_proj = nn .Linear (dim , attn_dim , bias = qkv_bias )
167
- self .k_proj = nn .Linear (dim , attn_dim , bias = False )
168
- self .v_proj = nn .Linear (dim , attn_dim , bias = qkv_bias )
169
+ self .q_proj = nn .Linear (dim , attn_dim , bias = qkv_bias , ** dd )
170
+ self .k_proj = nn .Linear (dim , attn_dim , bias = False , ** dd )
171
+ self .v_proj = nn .Linear (dim , attn_dim , bias = qkv_bias , ** dd )
169
172
self .qkv = None
170
173
self .q_bias = self .k_bias = self .v_bias = None
171
- self .q_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
172
- self .k_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
174
+ self .q_norm = norm_layer (self .head_dim , ** dd ) if qk_norm else nn .Identity ()
175
+ self .k_norm = norm_layer (self .head_dim , ** dd ) if qk_norm else nn .Identity ()
173
176
self .attn_drop = nn .Dropout (attn_drop )
174
- self .norm = norm_layer (attn_dim ) if scale_norm else nn .Identity ()
175
- self .proj = nn .Linear (attn_dim , dim )
177
+ self .norm = norm_layer (attn_dim , ** dd ) if scale_norm else nn .Identity ()
178
+ self .proj = nn .Linear (attn_dim , dim , ** dd )
176
179
self .proj_drop = nn .Dropout (proj_drop )
177
180
178
181
def forward (
@@ -263,6 +266,8 @@ def __init__(
263
266
act_layer : Callable = nn .GELU ,
264
267
norm_layer : Callable = LayerNorm ,
265
268
attn_head_dim : Optional [int ] = None ,
269
+ device = None ,
270
+ dtype = None ,
266
271
** kwargs ,
267
272
):
268
273
""" Initialize the EVA transformer block.
@@ -286,8 +291,10 @@ def __init__(
286
291
norm_layer: Normalization layer constructor
287
292
attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads)
288
293
"""
294
+ dd = {'device' : device , 'dtype' : dtype }
289
295
super ().__init__ ()
290
- self .norm1 = norm_layer (dim )
296
+
297
+ self .norm1 = norm_layer (dim , ** dd )
291
298
attn_cls = AttentionRope if attn_type == 'rope' else EvaAttention
292
299
self .attn = attn_cls (
293
300
dim ,
@@ -301,11 +308,12 @@ def __init__(
301
308
norm_layer = norm_layer ,
302
309
scale_norm = scale_attn_inner ,
303
310
rotate_half = rotate_half ,
311
+ ** dd ,
304
312
)
305
- self .gamma_1 = nn .Parameter (init_values * torch .ones (dim )) if init_values is not None else None
313
+ self .gamma_1 = nn .Parameter (init_values * torch .ones (dim , ** dd )) if init_values is not None else None
306
314
self .drop_path1 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
307
315
308
- self .norm2 = norm_layer (dim )
316
+ self .norm2 = norm_layer (dim , ** dd )
309
317
hidden_features = int (dim * mlp_ratio )
310
318
if swiglu_mlp :
311
319
if scale_mlp or swiglu_align_to :
@@ -316,6 +324,7 @@ def __init__(
316
324
norm_layer = norm_layer if scale_mlp else None ,
317
325
drop = proj_drop ,
318
326
align_to = swiglu_align_to ,
327
+ ** dd ,
319
328
)
320
329
else :
321
330
# w/o any extra norm, an impl with packed weights is used
@@ -326,6 +335,7 @@ def __init__(
326
335
act_layer = nn .SiLU ,
327
336
gate_last = False ,
328
337
drop = proj_drop ,
338
+ ** dd ,
329
339
)
330
340
else :
331
341
self .mlp = Mlp (
@@ -334,8 +344,9 @@ def __init__(
334
344
act_layer = act_layer ,
335
345
norm_layer = norm_layer if scale_mlp else None ,
336
346
drop = proj_drop ,
347
+ ** dd ,
337
348
)
338
- self .gamma_2 = nn .Parameter (init_values * torch .ones (dim )) if init_values is not None else None
349
+ self .gamma_2 = nn .Parameter (init_values * torch .ones (dim , ** dd )) if init_values is not None else None
339
350
self .drop_path2 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
340
351
341
352
def forward (
@@ -376,6 +387,8 @@ def __init__(
376
387
act_layer : Callable = nn .GELU ,
377
388
norm_layer : Callable = nn .LayerNorm ,
378
389
attn_head_dim : Optional [int ] = None ,
390
+ device = None ,
391
+ dtype = None ,
379
392
):
380
393
""" Initialize the post-norm EVA transformer block.
381
394
@@ -398,7 +411,9 @@ def __init__(
398
411
norm_layer: Normalization layer constructor
399
412
attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads)
400
413
"""
414
+ dd = {'device' : device , 'dtype' : dtype }
401
415
super ().__init__ ()
416
+
402
417
attn_cls = AttentionRope if attn_type == 'rope' else EvaAttention
403
418
self .attn = attn_cls (
404
419
dim ,
@@ -412,8 +427,9 @@ def __init__(
412
427
norm_layer = norm_layer ,
413
428
scale_norm = scale_attn_inner ,
414
429
rotate_half = rotate_half ,
430
+ ** dd ,
415
431
)
416
- self .norm1 = norm_layer (dim )
432
+ self .norm1 = norm_layer (dim , ** dd )
417
433
self .drop_path1 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
418
434
419
435
hidden_features = int (dim * mlp_ratio )
@@ -426,6 +442,7 @@ def __init__(
426
442
norm_layer = norm_layer if scale_mlp else None ,
427
443
drop = proj_drop ,
428
444
align_to = swiglu_align_to ,
445
+ ** dd ,
429
446
)
430
447
else :
431
448
# w/o any extra norm, an impl with packed fc1 weights is used, matches existing GluMLP
@@ -436,6 +453,7 @@ def __init__(
436
453
act_layer = nn .SiLU ,
437
454
gate_last = False ,
438
455
drop = proj_drop ,
456
+ ** dd ,
439
457
)
440
458
else :
441
459
self .mlp = Mlp (
@@ -444,8 +462,9 @@ def __init__(
444
462
act_layer = act_layer ,
445
463
norm_layer = norm_layer if scale_mlp else None ,
446
464
drop = proj_drop ,
465
+ ** dd ,
447
466
)
448
- self .norm2 = norm_layer (dim )
467
+ self .norm2 = norm_layer (dim , ** dd )
449
468
self .drop_path2 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
450
469
451
470
def forward (
@@ -513,6 +532,8 @@ def __init__(
513
532
dynamic_img_pad : bool = False ,
514
533
ref_feat_shape : Optional [Union [Tuple [int , int ], int ]] = None ,
515
534
head_init_scale : float = 0.001 ,
535
+ device = None ,
536
+ dtype = None ,
516
537
):
517
538
"""Initialize the EVA Vision Transformer model.
518
539
@@ -562,6 +583,7 @@ def __init__(
562
583
head_init_scale: Initialization scale for classification head weights
563
584
"""
564
585
super ().__init__ ()
586
+ dd = {'device' : device , 'dtype' : dtype }
565
587
assert global_pool in ('' , 'avg' , 'avgmax' , 'max' , 'token' , 'map' )
566
588
self .num_classes = num_classes
567
589
self .global_pool = global_pool
@@ -594,16 +616,17 @@ def __init__(
594
616
dynamic_img_pad = dynamic_img_pad ,
595
617
bias = not use_pre_transformer_norm ,
596
618
** embed_args ,
619
+ ** dd ,
597
620
)
598
621
num_patches = self .patch_embed .num_patches
599
622
r = self .patch_embed .feat_ratio () if hasattr (self .patch_embed , 'feat_ratio' ) else patch_size
600
623
601
- self .cls_token = nn .Parameter (torch .zeros (1 , 1 , embed_dim )) if class_token else None
602
- self .reg_token = nn .Parameter (torch .zeros (1 , num_reg_tokens , embed_dim )) if num_reg_tokens else None
624
+ self .cls_token = nn .Parameter (torch .empty (1 , 1 , embed_dim , ** dd )) if class_token else None
625
+ self .reg_token = nn .Parameter (torch .empty (1 , num_reg_tokens , embed_dim , ** dd )) if num_reg_tokens else None
603
626
self .cls_embed = class_token and self .reg_token is None
604
627
605
628
num_pos_tokens = num_patches if no_embed_class else num_patches + self .num_prefix_tokens
606
- self .pos_embed = nn .Parameter (torch .zeros (1 , num_pos_tokens , embed_dim )) if use_abs_pos_emb else None
629
+ self .pos_embed = nn .Parameter (torch .empty (1 , num_pos_tokens , embed_dim , ** dd )) if use_abs_pos_emb else None
607
630
self .pos_drop = nn .Dropout (p = pos_drop_rate )
608
631
if patch_drop_rate > 0 :
609
632
self .patch_drop = PatchDropoutWithIndices (patch_drop_rate , num_prefix_tokens = self .num_prefix_tokens )
@@ -621,6 +644,7 @@ def __init__(
621
644
feat_shape = None if dynamic_img_size else self .patch_embed .grid_size ,
622
645
temperature = rope_temperature ,
623
646
grid_indexing = rope_grid_indexing ,
647
+ ** dd ,
624
648
)
625
649
if rope_type == 'mixed' :
626
650
rope_kwargs .update (dict (depth = depth ))
@@ -636,7 +660,7 @@ def __init__(
636
660
else :
637
661
self .rope = None
638
662
639
- self .norm_pre = norm_layer (embed_dim ) if activate_pre_norm else nn .Identity ()
663
+ self .norm_pre = norm_layer (embed_dim , ** dd ) if activate_pre_norm else nn .Identity ()
640
664
641
665
dpr = calculate_drop_path_rates (drop_path_rate , depth ) # stochastic depth decay rule
642
666
block_fn = EvaBlockPostNorm if use_post_norm else EvaBlock
@@ -659,12 +683,13 @@ def __init__(
659
683
drop_path = dpr [i ],
660
684
norm_layer = norm_layer ,
661
685
init_values = init_values ,
686
+ ** dd ,
662
687
)
663
688
for i in range (depth )])
664
689
self .feature_info = [
665
690
dict (module = f'blocks.{ i } ' , num_chs = embed_dim , reduction = r ) for i in range (depth )]
666
691
667
- self .norm = norm_layer (embed_dim ) if activate_post_norm else nn .Identity ()
692
+ self .norm = norm_layer (embed_dim , ** dd ) if activate_post_norm else nn .Identity ()
668
693
669
694
if global_pool == 'map' :
670
695
self .attn_pool = AttentionPoolLatent (
@@ -673,23 +698,26 @@ def __init__(
673
698
mlp_ratio = attn_pool_mlp_ratio or mlp_ratio ,
674
699
norm_layer = norm_layer ,
675
700
act_layer = nn .GELU ,
701
+ ** dd ,
676
702
)
677
703
else :
678
704
self .attn_pool = None
679
- self .fc_norm = norm_layer (embed_dim ) if activate_fc_norm else nn .Identity ()
705
+ self .fc_norm = norm_layer (embed_dim , ** dd ) if activate_fc_norm else nn .Identity ()
680
706
self .head_drop = nn .Dropout (drop_rate )
681
- self .head = nn .Linear (embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
707
+ self .head = nn .Linear (embed_dim , num_classes , ** dd ) if num_classes > 0 else nn .Identity ()
682
708
709
+ self .init_weights (head_init_scale = head_init_scale )
710
+
711
+ def init_weights (self , head_init_scale = None ):
683
712
self .apply (self ._init_weights )
684
713
if self .pos_embed is not None :
685
714
trunc_normal_ (self .pos_embed , std = .02 )
686
715
if self .cls_token is not None :
687
716
trunc_normal_ (self .cls_token , std = .02 )
688
717
if self .reg_token is not None :
689
718
trunc_normal_ (self .reg_token , std = .02 )
690
-
691
719
self .fix_init_weight ()
692
- if isinstance (self .head , nn .Linear ):
720
+ if head_init_scale and isinstance (self .head , nn .Linear ):
693
721
trunc_normal_ (self .head .weight , std = .02 )
694
722
self .head .weight .data .mul_ (head_init_scale )
695
723
self .head .bias .data .mul_ (head_init_scale )
0 commit comments