Skip to content

Commit 8104d3c

Browse files
committed
Further 7b test filtering. Remove ref_feat_shape from DINOv3 RoPE as it's normalized. Fix torchscript issue with + unary op
1 parent d5547b4 commit 8104d3c

File tree

3 files changed

+43
-62
lines changed

3 files changed

+43
-62
lines changed

tests/test_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@
7878
'*enormous*', 'maxvit_xlarge*', 'regnet*1280', 'regnet*2560', '*_1b_*', '*_3b_*']
7979
NON_STD_EXCLUDE_FILTERS = ['*huge*', '*giant*', '*gigantic*', '*enormous*', '*_1b_*', '*_3b_*', '*_7b_*']
8080
else:
81-
EXCLUDE_FILTERS = ['*enormous*']
81+
EXCLUDE_FILTERS = ['*enormous*', '*_7b_*']
8282
NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*', '*_3b_*', '*_7b_*']
8383

84-
EXCLUDE_JIT_FILTERS = ['hiera_*', '*naflex*']
84+
EXCLUDE_JIT_FILTERS = ['hiera_*', '*naflex*', '*_7b_*']
8585

8686
TARGET_FWD_SIZE = MAX_FWD_SIZE = 384
8787
TARGET_BWD_SIZE = 128

timm/layers/pos_embed_sincos.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -901,7 +901,6 @@ def __init__(
901901
min_period: Optional[float] = None,
902902
max_period: Optional[float] = None,
903903
feat_shape: Optional[List[int]] = None,
904-
ref_feat_shape: Optional[List[int]] = None,
905904
normalize_coords: str = "separate", # 'min', 'max', 'separate'
906905
grid_offset: float = 0.0,
907906
grid_indexing: str = "ij",
@@ -930,7 +929,6 @@ def __init__(
930929

931930
# Grid config
932931
self.feat_shape = feat_shape
933-
self.ref_feat_shape = ref_feat_shape
934932
self.grid_offset = grid_offset
935933
self.grid_indexing = grid_indexing
936934

@@ -944,7 +942,7 @@ def __init__(
944942
self.register_buffer("pos_embed_cached", None, persistent=False)
945943
self.feat_shape = None
946944

947-
def _compute_periods(self, device='cpu', dtype=torch.float32) -> torch.Tensor:
945+
def _compute_periods(self, device: torch.device = 'cpu', dtype: torch.dtype = torch.float32) -> torch.Tensor:
948946
"""Construct periods from either min/max or temperature."""
949947
dim = self.dim // 4
950948

@@ -1016,7 +1014,7 @@ def _apply_coord_augs(self, coords: torch.Tensor) -> torch.Tensor:
10161014
# Shift per-axis in [-s, +s]
10171015
if self.shift_coords is not None:
10181016
shift = float(self.shift_coords)
1019-
shift_hw = torch.empty(2, device=device, dtype=dtype).uniform_(-shift, +shift)
1017+
shift_hw = torch.empty(2, device=device, dtype=dtype).uniform_(-shift, shift)
10201018
coords = coords + shift_hw[None, :]
10211019

10221020
# Jitter: per-axis log-uniform factor in [1/J, J]
@@ -1025,7 +1023,7 @@ def _apply_coord_augs(self, coords: torch.Tensor) -> torch.Tensor:
10251023
if jitter_factor <= 0:
10261024
raise ValueError("jitter_coords must be > 0 (interpreted as multiplicative factor).")
10271025
jitter_max = math.log(jitter_factor)
1028-
jitter_hw = torch.empty(2, device=device, dtype=dtype).uniform_(-jitter_max, +jitter_max).exp()
1026+
jitter_hw = torch.empty(2, device=device, dtype=dtype).uniform_(-jitter_max, jitter_max).exp()
10291027
coords = coords * jitter_hw[None, :]
10301028

10311029
# Rescale: shared scalar log-uniform factor in [1/R, R]
@@ -1034,7 +1032,7 @@ def _apply_coord_augs(self, coords: torch.Tensor) -> torch.Tensor:
10341032
if rescale_factor <= 0:
10351033
raise ValueError("rescale_coords must be > 0 (interpreted as multiplicative factor).")
10361034
rescale_max = math.log(rescale_factor)
1037-
rescale = torch.empty(1, device=device, dtype=dtype).uniform_(-rescale_max, +rescale_max).exp()
1035+
rescale = torch.empty(1, device=device, dtype=dtype).uniform_(-rescale_max, rescale_max).exp()
10381036
coords = coords * rescale
10391037

10401038
return coords

timm/models/eva.py

Lines changed: 37 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -624,12 +624,7 @@ def __init__(
624624
if rope_type == 'mixed':
625625
rope_kwargs.update(dict(depth=depth))
626626
self.rope_mixed = True
627-
elif rope_type == 'dinov3':
628-
rope_kwargs.update(dict(
629-
grid_offset=rope_grid_offset,
630-
ref_feat_shape=ref_feat_shape,
631-
))
632-
else: # 'cat' or 'base'
627+
elif rope_type == 'cat':
633628
rope_kwargs.update(dict(
634629
in_pixels=False,
635630
grid_offset=rope_grid_offset,
@@ -1558,160 +1553,148 @@ def _pe_cfg(url: str = '', **kwargs) -> Dict[str, Any]:
15581553
# RoPE-ViT models from Naver
15591554
'vit_small_patch16_rope_224.naver_in1k': _cfg(
15601555
hf_hub_id='timm/',
1561-
mean=IMAGENET_DEFAULT_MEAN,
1562-
std=IMAGENET_DEFAULT_STD,
1556+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
15631557
license='apache-2.0',
15641558
),
15651559
'vit_base_patch16_rope_224.naver_in1k': _cfg(
15661560
hf_hub_id='timm/',
1567-
mean=IMAGENET_DEFAULT_MEAN,
1568-
std=IMAGENET_DEFAULT_STD,
1561+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
15691562
license='apache-2.0',
15701563
),
15711564
'vit_large_patch16_rope_224.naver_in1k': _cfg(
15721565
hf_hub_id='timm/',
1573-
mean=IMAGENET_DEFAULT_MEAN,
1574-
std=IMAGENET_DEFAULT_STD,
1566+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
15751567
license='apache-2.0',
15761568
),
15771569
'vit_small_patch16_rope_mixed_224.naver_in1k': _cfg(
15781570
hf_hub_id='timm/',
1579-
mean=IMAGENET_DEFAULT_MEAN,
1580-
std=IMAGENET_DEFAULT_STD,
1571+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
15811572
license='apache-2.0',
15821573
),
15831574
'vit_base_patch16_rope_mixed_224.naver_in1k': _cfg(
15841575
hf_hub_id='timm/',
1585-
mean=IMAGENET_DEFAULT_MEAN,
1586-
std=IMAGENET_DEFAULT_STD,
1576+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
15871577
license='apache-2.0',
15881578
),
15891579
'vit_large_patch16_rope_mixed_224.naver_in1k': _cfg(
15901580
hf_hub_id='timm/',
1591-
mean=IMAGENET_DEFAULT_MEAN,
1592-
std=IMAGENET_DEFAULT_STD,
1581+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
15931582
license='apache-2.0',
15941583
),
15951584
'vit_small_patch16_rope_ape_224.naver_in1k': _cfg(
15961585
hf_hub_id='timm/',
1597-
mean=IMAGENET_DEFAULT_MEAN,
1598-
std=IMAGENET_DEFAULT_STD,
1586+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
15991587
license='apache-2.0',
16001588
),
16011589
'vit_base_patch16_rope_ape_224.naver_in1k': _cfg(
16021590
hf_hub_id='timm/',
1603-
mean=IMAGENET_DEFAULT_MEAN,
1604-
std=IMAGENET_DEFAULT_STD,
1591+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
16051592
license='apache-2.0',
16061593
),
16071594
'vit_large_patch16_rope_ape_224.naver_in1k': _cfg(
16081595
hf_hub_id='timm/',
1609-
mean=IMAGENET_DEFAULT_MEAN,
1610-
std=IMAGENET_DEFAULT_STD,
1596+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
16111597
license='apache-2.0',
16121598
),
16131599
'vit_small_patch16_rope_mixed_ape_224.naver_in1k': _cfg(
16141600
hf_hub_id='timm/',
1615-
mean=IMAGENET_DEFAULT_MEAN,
1616-
std=IMAGENET_DEFAULT_STD,
1601+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
16171602
license='apache-2.0',
16181603
),
16191604
'vit_base_patch16_rope_mixed_ape_224.naver_in1k': _cfg(
16201605
hf_hub_id='timm/',
1621-
mean=IMAGENET_DEFAULT_MEAN,
1622-
std=IMAGENET_DEFAULT_STD,
1606+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
16231607
license='apache-2.0',
16241608
),
16251609
'vit_large_patch16_rope_mixed_ape_224.naver_in1k': _cfg(
16261610
hf_hub_id='timm/',
1627-
mean=IMAGENET_DEFAULT_MEAN,
1628-
std=IMAGENET_DEFAULT_STD,
1611+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
16291612
license='apache-2.0',
16301613
),
16311614

16321615
# DINOv3 weights are under a specific license with redistribution terms, please see
16331616
# https://github.com/facebookresearch/dinov3/blob/main/LICENSE.md
16341617
'vit_small_patch16_dinov3_224.lvdm_1689m': _cfg(
16351618
# hf_hub_id='timm/',
1636-
mean=IMAGENET_DEFAULT_MEAN,
1637-
std=IMAGENET_DEFAULT_STD,
1619+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
1620+
crop_pct=1.0,
16381621
num_classes=0,
16391622
license='dinov3',
16401623
),
16411624
'vit_small_patch16_dinov3_qkvb_224.lvdm_1689m': _cfg(
16421625
# hf_hub_id='timm/',
1643-
mean=IMAGENET_DEFAULT_MEAN,
1644-
std=IMAGENET_DEFAULT_STD,
1626+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
1627+
crop_pct=1.0,
16451628
num_classes=0,
16461629
license='dinov3',
16471630
),
16481631
'vit_small_plus_patch16_dinov3_224.lvdm_1689m': _cfg(
16491632
# hf_hub_id='timm/',
1650-
mean=IMAGENET_DEFAULT_MEAN,
1651-
std=IMAGENET_DEFAULT_STD,
1633+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
1634+
crop_pct=1.0,
16521635
num_classes=0,
16531636
license='dinov3',
16541637
),
16551638
'vit_small_plus_patch16_dinov3_qkvb_224.lvdm_1689m': _cfg(
16561639
# hf_hub_id='timm/',
1657-
mean=IMAGENET_DEFAULT_MEAN,
1658-
std=IMAGENET_DEFAULT_STD,
1640+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
1641+
crop_pct=1.0,
16591642
num_classes=0,
16601643
license='dinov3',
16611644
),
16621645
'vit_base_patch16_dinov3_224.lvdm_1689m': _cfg(
16631646
#hf_hub_id='timm/',
1664-
mean=IMAGENET_DEFAULT_MEAN,
1665-
std=IMAGENET_DEFAULT_STD,
1647+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
1648+
crop_pct=1.0,
16661649
num_classes=0,
16671650
license='dinov3',
16681651
),
16691652
'vit_base_patch16_dinov3_qkvb_224.lvdm_1689m': _cfg(
16701653
#hf_hub_id='timm/',
1671-
mean=IMAGENET_DEFAULT_MEAN,
1672-
std=IMAGENET_DEFAULT_STD,
1654+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
1655+
crop_pct=1.0,
16731656
num_classes=0,
16741657
license='dinov3',
16751658
),
16761659
'vit_large_patch16_dinov3_224.lvdm_1689m': _cfg(
16771660
# hf_hub_id='timm/',
1678-
mean=IMAGENET_DEFAULT_MEAN,
1679-
std=IMAGENET_DEFAULT_STD,
1661+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
1662+
crop_pct=1.0,
16801663
num_classes=0,
16811664
license='dinov3',
16821665
),
16831666
'vit_large_patch16_dinov3_qkvb_224.lvdm_1689m': _cfg(
16841667
# hf_hub_id='timm/',
1685-
mean=IMAGENET_DEFAULT_MEAN,
1686-
std=IMAGENET_DEFAULT_STD,
1668+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
1669+
crop_pct=1.0,
16871670
num_classes=0,
16881671
license='dinov3',
16891672
),
16901673
'vit_large_patch16_dinov3_224.sat_493m': _cfg(
16911674
# hf_hub_id='timm/',
1692-
mean=(0.430, 0.411, 0.296),
1693-
std=(0.213, 0.156, 0.143),
1675+
mean=(0.430, 0.411, 0.296), std=(0.213, 0.156, 0.143),
1676+
crop_pct=1.0,
16941677
num_classes=0,
16951678
license='dinov3',
16961679
),
16971680
'vit_huge_plus_patch16_dinov3_224.lvdm_1689m': _cfg(
16981681
# hf_hub_id='timm/',
1699-
mean=IMAGENET_DEFAULT_MEAN,
1700-
std=IMAGENET_DEFAULT_STD,
1682+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
1683+
crop_pct=1.0,
17011684
num_classes=0,
17021685
license='dinov3',
17031686
),
17041687
'vit_7b_patch16_dinov3_224.lvdm_1689m': _cfg(
17051688
# hf_hub_id='timm/',
1706-
mean=IMAGENET_DEFAULT_MEAN,
1707-
std=IMAGENET_DEFAULT_STD,
1689+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
1690+
crop_pct=1.0,
17081691
num_classes=0,
17091692
license='dinov3',
17101693
),
17111694
'vit_7b_patch16_dinov3_224.sat_493m': _cfg(
17121695
# hf_hub_id='timm/',
1713-
mean=(0.430, 0.411, 0.296),
1714-
std=(0.213, 0.156, 0.143),
1696+
mean=(0.430, 0.411, 0.296), std=(0.213, 0.156, 0.143),
1697+
crop_pct=1.0,
17151698
num_classes=0,
17161699
license='dinov3',
17171700
),

0 commit comments

Comments
 (0)