Skip to content

Commit 91e8dc5

Browse files
committed
DINOv3 rotary position embedding impl
1 parent 5d5707a commit 91e8dc5

File tree

2 files changed

+309
-8
lines changed

2 files changed

+309
-8
lines changed

timm/layers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,9 @@
133133
RotaryEmbedding,
134134
RotaryEmbeddingCat,
135135
RotaryEmbeddingMixed,
136+
RotaryEmbeddingDinoV3,
136137
get_mixed_freqs,
138+
create_rope_embed,
137139
)
138140
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
139141
from .selective_kernel import SelectiveKernel

timm/layers/pos_embed_sincos.py

Lines changed: 307 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -214,24 +214,73 @@ def forward(self, x):
214214

215215

216216
def rot(x):
217+
# x: [ x0 x1 x2 x3 x4 x5]
218+
# out: [-x1 x0 -x3 x2 -x5 x4]
217219
return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
218220

219221

220-
def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb):
221-
if sin_emb.ndim == 3:
222-
return x * cos_emb.unsqueeze(1).expand_as(x) + rot(x) * sin_emb.unsqueeze(1).expand_as(x)
223-
return x * cos_emb + rot(x) * sin_emb
222+
def rope_rotate_half(x: torch.Tensor) -> torch.Tensor:
223+
# x: [ x0 x1 x2 x3 x4 x5]
224+
# out: [-x3 -x4 -x5 x0 x1 x2]
225+
x1, x2 = x.chunk(2, dim=-1)
226+
return torch.cat([-x2, x1], dim=-1)
224227

225228

226-
def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):
229+
def apply_rot_embed(
230+
x: torch.Tensor,
231+
emb: torch.Tensor,
232+
half: bool = False,
233+
) -> torch.Tensor:
234+
# x: [..., D], eg [x0, x1, x2, x3, x4, x5]
235+
if half:
236+
# sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2]
237+
# cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2
238+
# rope_rotate_half(x): eg [-x3, -x4, -x5, x0, x1, x2]
239+
return x * cos_emb + rope_rotate_half(x) * sin_emb
240+
else:
241+
# sin: [..., D], eg [sin0, sin0, sin1, sin1, sin2, sin2]
242+
# cos: [..., D], eg [cos0, cos0, cos1, cos1, cos2, cos2]
243+
# rot(x): eg [-x1, x0, -x3, x2, -x5, x4]
244+
return x * cos_emb + rot(x) * sin_emb
245+
246+
247+
def apply_rot_embed_list(
248+
x: List[torch.Tensor],
249+
emb: torch.Tensor,
250+
half: bool = False
251+
) -> List[torch.Tensor]:
227252
if isinstance(x, torch.Tensor):
228253
x = [x]
229-
return [t * cos_emb + rot(t) * sin_emb for t in x]
254+
# x: [..., D], eg [x0, x1, x2, x3, x4, x5]
255+
if half:
256+
# sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2]
257+
# cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2
258+
# rope_rotate_half(x): eg [-x3, -x4, -x5, x0, x1, x2]
259+
return [t * cos_emb + rope_rotate_half(t) * sin_emb for t in x]
260+
else:
261+
# sin: [..., D], eg [sin0, sin0, sin1, sin1, sin2, sin2]
262+
# cos: [..., D], eg [cos0, cos0, cos1, cos1, cos2, cos2]
263+
# rot(x): eg [-x1, x0, -x3, x2, -x5, x4]
264+
return [t * cos_emb + rot(t) * sin_emb for t in x]
230265

231266

232-
def apply_rot_embed_cat(x: torch.Tensor, emb):
267+
def apply_rot_embed_cat(
268+
x: torch.Tensor,
269+
emb: torch.Tensor,
270+
half: bool = False
271+
) -> torch.Tensor:
233272
sin_emb, cos_emb = emb.tensor_split(2, -1)
234-
return x * cos_emb + rot(x) * sin_emb
273+
# x: [..., D], eg [x0, x1, x2, x3, x4, x5]
274+
if half:
275+
# sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2]
276+
# cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2
277+
# rope_rotate_half(x), eg [-x3, -x4, -x5, x0, x1, x2]
278+
return x * cos_emb + rope_rotate_half(x) * sin_emb
279+
else:
280+
# sin: [..., D], eg [sin0, sin0, sin1, sin1, sin2, sin2]
281+
# cos: [..., D], eg [cos0, cos0, cos1, cos1, cos2, cos2]
282+
# rot(x), eg [-x1, x0, -x3, x2, -x5, x4]
283+
return x * cos_emb + rot(x) * sin_emb
235284

236285

237286
def apply_keep_indices_nlc(
@@ -834,3 +883,253 @@ def forward(self, x):
834883
def no_weight_decay(self):
835884
"""Exclude frequency parameters from weight decay."""
836885
return {'freqs'}
886+
887+
888+
class RotaryEmbeddingDinoV3(nn.Module):
889+
"""RoPE for timm DinoV3 port, numerically matching original.
890+
891+
Math is aligned to original DinoV3 RopePositionEmbedding at https://github.com/facebookresearch/dinov3:
892+
- 0.5-centered coords normalized by H/W (or min/max), mapped to [-1,1]
893+
- training-time augmentations (shift/jitter/rescale)
894+
- periods schedule equals Rope's temperature (base) or min/max period
895+
"""
896+
897+
def __init__(
898+
self,
899+
dim: int,
900+
temperature: Optional[float] = 100.0,
901+
min_period: Optional[float] = 0.5,
902+
max_period: Optional[float] = 90.,
903+
feat_shape: Optional[List[int]] = None,
904+
ref_feat_shape: Optional[List[int]] = None,
905+
normalize_coords: str = "separate", # 'min', 'max', 'separate'
906+
grid_offset: float = 0.0,
907+
grid_indexing: str = "ij",
908+
rotate_half: bool = True,
909+
shift_coords: Optional[float] = None,
910+
jitter_coords: Optional[float] = None, # interpreted as factor J >= 1
911+
rescale_coords: Optional[float] = None, # interpreted as factor R >= 1
912+
):
913+
super().__init__()
914+
915+
# Dimensions / output format
916+
self.dim = dim # equal to head_dim for most vit applications
917+
self.rotate_half = rotate_half
918+
919+
# Period schedule parameters
920+
self.temperature = float(temperature)
921+
self.min_period = min_period
922+
self.max_period = max_period
923+
924+
# Coord processing + augs
925+
self.normalize_coords = normalize_coords
926+
self.shift_coords = shift_coords
927+
self.jitter_coords = jitter_coords
928+
self.rescale_coords = rescale_coords
929+
self.aug_active = any([a is not None for a in [self.shift_coords, self.jitter_coords, self.rescale_coords]])
930+
931+
# Grid config
932+
self.feat_shape = feat_shape
933+
self.ref_feat_shape = ref_feat_shape
934+
self.grid_offset = grid_offset
935+
self.grid_indexing = grid_indexing
936+
937+
# Precompute periods
938+
periods = self._compute_periods()
939+
self.register_buffer("periods", periods, persistent=False)
940+
941+
if feat_shape is not None:
942+
self._cache_embed(feat_shape)
943+
else:
944+
self.register_buffer("pos_embed_cached", None, persistent=False)
945+
self.feat_shape = None
946+
947+
def _compute_periods(self, device='cpu', dtype=torch.float32) -> torch.Tensor:
948+
"""Construct periods from either min/max or temperature."""
949+
dim = self.dim // 4
950+
951+
if self.min_period is not None and self.max_period is not None:
952+
exponents = torch.linspace(0, 1, dim, dtype=torch.float32)
953+
periods = self.min_period * ((self.max_period / self.min_period) ** exponents)
954+
else:
955+
if self.temperature is None:
956+
raise ValueError("Provide either min/max periods or `temperature`.")
957+
exponents = 2.0 * torch.arange(dim, device=device, dtype=dtype) / (self.dim // 2)
958+
periods = self.temperature ** exponents
959+
960+
# NOTE: original has periods downcast to bfloat16 in persistent buffers, so loaded models
961+
# BTW orignal and timm might differ a bit here
962+
963+
return periods
964+
965+
def _make_coords(
966+
self,
967+
height: int,
968+
width: int,
969+
device: torch.device = 'cpu',
970+
dtype: torch.dtype = torch.float32,
971+
) -> torch.Tensor:
972+
"""Make coordinate grid matching offset and normalization of original.
973+
Returns: coords with shape (HW, 2) in [-1, 1].
974+
"""
975+
# 0.5-centered indices with optional offset
976+
coords_h = torch.arange(0.5, height, device=device, dtype=dtype) + self.grid_offset
977+
coords_w = torch.arange(0.5, width, device=device, dtype=dtype) + self.grid_offset
978+
979+
# Normalization denominators
980+
if self.normalize_coords == "max":
981+
denom = float(max(height, width))
982+
h_denom = denom
983+
w_denom = denom
984+
elif self.normalize_coords == "min":
985+
denom = float(min(height, width))
986+
h_denom = denom
987+
w_denom = denom
988+
elif self.normalize_coords == "separate":
989+
h_denom = float(height)
990+
w_denom = float(width)
991+
else:
992+
raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}")
993+
994+
# Normalize to [0, 1]
995+
coords_h = coords_h / h_denom
996+
coords_w = coords_w / w_denom
997+
998+
# Create grid then map to [-1, 1]
999+
if self.grid_indexing == "xy":
1000+
grid_w, grid_h = torch.meshgrid(coords_w, coords_h, indexing="xy")
1001+
coords = torch.stack([grid_h, grid_w], dim=-1) # (H, W, 2) -> (h, w order)
1002+
else:
1003+
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) # (H, W, 2)
1004+
coords = coords.flatten(0, 1) # (HW, 2)
1005+
coords = 2.0 * coords - 1.0 # (H, W, 2) in [-1, 1]
1006+
return coords
1007+
1008+
def _apply_coord_augs(self, coords: torch.Tensor) -> torch.Tensor:
1009+
"""Apply shift/jitter/rescale train time augmentations."""
1010+
if not self.training or not self.aug_active:
1011+
return coords
1012+
1013+
device = coords.device
1014+
dtype = coords.dtype
1015+
1016+
# Shift per-axis in [-s, +s]
1017+
if self.shift_coords is not None:
1018+
shift = float(self.shift_coords)
1019+
shift_hw = torch.empty(2, device=device, dtype=dtype).uniform_(-shift, +shift)
1020+
coords = coords + shift_hw[None, :]
1021+
1022+
# Jitter: per-axis log-uniform factor in [1/J, J]
1023+
if self.jitter_coords is not None:
1024+
jitter_factor = float(self.jitter_coords)
1025+
if jitter_factor <= 0:
1026+
raise ValueError("jitter_coords must be > 0 (interpreted as multiplicative factor).")
1027+
jitter_max = math.log(jitter_factor)
1028+
jitter_hw = torch.empty(2, device=device, dtype=dtype).uniform_(-jitter_max, +jitter_max).exp()
1029+
coords = coords * jitter_hw[None, :]
1030+
1031+
# Rescale: shared scalar log-uniform factor in [1/R, R]
1032+
if self.rescale_coords is not None:
1033+
rescale_factor = float(self.rescale_coords)
1034+
if rescale_factor <= 0:
1035+
raise ValueError("rescale_coords must be > 0 (interpreted as multiplicative factor).")
1036+
rescale_max = math.log(rescale_factor)
1037+
rescale = torch.empty(1, device=device, dtype=dtype).uniform_(-rescale_max, +rescale_max).exp()
1038+
coords = coords * rescale
1039+
1040+
return coords
1041+
1042+
def _get_pos_embed_from_coords(self, coords: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
1043+
"""Return sin/cos embeddings with either 'half' or 'interleaved' layout."""
1044+
# coords: (HW, 2); periods: (dim)
1045+
dim = self.dim // 4
1046+
device = self.periods.device
1047+
dtype = self.periods.dtype
1048+
assert self.periods.numel() == dim
1049+
1050+
# NOTE this is a slightly later device/dtype switch than original
1051+
coords = coords[:, :, None].to(device=device, dtype=dtype)
1052+
angles = 2 * math.pi * coords / self.periods[None, None, :]
1053+
angles = angles.flatten(1) # (HW, dim // 2)
1054+
1055+
if self.rotate_half:
1056+
# Tile (half layout) (HW, dim // 2) -> (HW, dim)
1057+
angles = angles.tile(2)
1058+
else:
1059+
# Interleaved layout (HW, dim // 2) -> (HW, dim)
1060+
angles = angles.repeat_interleave(2, dim=-1)
1061+
1062+
sin = torch.sin(angles)
1063+
cos = torch.cos(angles)
1064+
return sin, cos
1065+
1066+
def _create_embed(self, feat_shape: List[int], no_aug: bool = False) -> torch.Tensor:
1067+
H, W = feat_shape
1068+
coords = self._make_coords(H, W) # (HW, 2)
1069+
if not no_aug:
1070+
coords = self._apply_coord_augs(coords)
1071+
sin, cos = self._get_pos_embed_from_coords(coords) # 2 * (HW, dim)
1072+
rope_embed = torch.cat([sin, cos], dim=-1) # (HW, 2*dim)
1073+
return rope_embed
1074+
1075+
def _cache_embed(self, feat_shape: List[int]):
1076+
rope_embed = self._create_embed(feat_shape, no_aug=True) # create non-augmented embeds for cache
1077+
self.register_buffer("pos_embed_cached", rope_embed, persistent=False)
1078+
self.feat_shape = feat_shape
1079+
1080+
def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor:
1081+
"""Generate rope_embed matching DINOv3 RopePositionEmbedding numerics.
1082+
1083+
Returns: (HW, num_heads, 2 * head_dim) with last dim = [sin, cos] cat.
1084+
"""
1085+
if shape is not None:
1086+
rope_embed = self._create_embed(shape)
1087+
else:
1088+
need_create = self.pos_embed_cached is None or (self.training and self.aug_active)
1089+
if need_create:
1090+
assert self.feat_shape is not None, 'feature shape must be cached on create'
1091+
rope_embed = self._create_embed(self.feat_shape)
1092+
else:
1093+
rope_embed = self.pos_embed_cached
1094+
1095+
return rope_embed
1096+
1097+
def forward(self, x: torch.Tensor) -> torch.Tensor:
1098+
"""Get and apply rotary embeddings to x"""
1099+
# assuming channel-first tensor where spatial dim are >= 2
1100+
pos_embed = self.get_embed(x.shape[2:])
1101+
return apply_rot_embed_cat(x, pos_embed, half=self.rotate_half)
1102+
1103+
1104+
def create_rope_embed(
1105+
rope_type: str = 'cat',
1106+
dim: int = 768,
1107+
num_heads: int = 12,
1108+
**kwargs
1109+
) -> nn.Module:
1110+
"""Factory function for creating rotary position embeddings.
1111+
1112+
Args:
1113+
rope_type: Type of RoPE to create. Options:
1114+
- 'base': Basic RotaryEmbedding
1115+
- 'cat': RotaryEmbeddingCat (concatenated sin/cos)
1116+
- 'mixed': RotaryEmbeddingMixed (learnable per-depth frequencies)
1117+
- 'dinov3': RotaryEmbeddingDinoV3 (with coordinate transforms)
1118+
dim: Total embedding dimension
1119+
num_heads: Number of attention heads
1120+
**kwargs: Additional arguments passed to the specific RoPE class
1121+
1122+
Returns:
1123+
Rotary embedding module
1124+
"""
1125+
if rope_type == 'base':
1126+
return RotaryEmbedding(dim=dim // num_heads, **kwargs)
1127+
elif rope_type == 'cat':
1128+
return RotaryEmbeddingCat(dim=dim // num_heads, **kwargs)
1129+
elif rope_type == 'mixed':
1130+
# Mixed requires depth parameter, generates differing embeddings per layer and head
1131+
return RotaryEmbeddingMixed(dim=dim, num_heads=num_heads, **kwargs)
1132+
elif rope_type == 'dinov3':
1133+
return RotaryEmbeddingDinoV3(dim=dim // num_heads, **kwargs)
1134+
else:
1135+
raise ValueError(f"Unknown RoPE type: {rope_type}")

0 commit comments

Comments
 (0)