@@ -214,24 +214,73 @@ def forward(self, x):
214214
215215
216216def rot (x ):
217+ # x: [ x0 x1 x2 x3 x4 x5]
218+ # out: [-x1 x0 -x3 x2 -x5 x4]
217219 return torch .stack ([- x [..., 1 ::2 ], x [..., ::2 ]], - 1 ).reshape (x .shape )
218220
219221
220- def apply_rot_embed (x : torch .Tensor , sin_emb , cos_emb ):
221- if sin_emb .ndim == 3 :
222- return x * cos_emb .unsqueeze (1 ).expand_as (x ) + rot (x ) * sin_emb .unsqueeze (1 ).expand_as (x )
223- return x * cos_emb + rot (x ) * sin_emb
222+ def rope_rotate_half (x : torch .Tensor ) -> torch .Tensor :
223+ # x: [ x0 x1 x2 x3 x4 x5]
224+ # out: [-x3 -x4 -x5 x0 x1 x2]
225+ x1 , x2 = x .chunk (2 , dim = - 1 )
226+ return torch .cat ([- x2 , x1 ], dim = - 1 )
224227
225228
226- def apply_rot_embed_list (x : List [torch .Tensor ], sin_emb , cos_emb ):
229+ def apply_rot_embed (
230+ x : torch .Tensor ,
231+ emb : torch .Tensor ,
232+ half : bool = False ,
233+ ) -> torch .Tensor :
234+ # x: [..., D], eg [x0, x1, x2, x3, x4, x5]
235+ if half :
236+ # sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2]
237+ # cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2
238+ # rope_rotate_half(x): eg [-x3, -x4, -x5, x0, x1, x2]
239+ return x * cos_emb + rope_rotate_half (x ) * sin_emb
240+ else :
241+ # sin: [..., D], eg [sin0, sin0, sin1, sin1, sin2, sin2]
242+ # cos: [..., D], eg [cos0, cos0, cos1, cos1, cos2, cos2]
243+ # rot(x): eg [-x1, x0, -x3, x2, -x5, x4]
244+ return x * cos_emb + rot (x ) * sin_emb
245+
246+
247+ def apply_rot_embed_list (
248+ x : List [torch .Tensor ],
249+ emb : torch .Tensor ,
250+ half : bool = False
251+ ) -> List [torch .Tensor ]:
227252 if isinstance (x , torch .Tensor ):
228253 x = [x ]
229- return [t * cos_emb + rot (t ) * sin_emb for t in x ]
254+ # x: [..., D], eg [x0, x1, x2, x3, x4, x5]
255+ if half :
256+ # sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2]
257+ # cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2
258+ # rope_rotate_half(x): eg [-x3, -x4, -x5, x0, x1, x2]
259+ return [t * cos_emb + rope_rotate_half (t ) * sin_emb for t in x ]
260+ else :
261+ # sin: [..., D], eg [sin0, sin0, sin1, sin1, sin2, sin2]
262+ # cos: [..., D], eg [cos0, cos0, cos1, cos1, cos2, cos2]
263+ # rot(x): eg [-x1, x0, -x3, x2, -x5, x4]
264+ return [t * cos_emb + rot (t ) * sin_emb for t in x ]
230265
231266
232- def apply_rot_embed_cat (x : torch .Tensor , emb ):
267+ def apply_rot_embed_cat (
268+ x : torch .Tensor ,
269+ emb : torch .Tensor ,
270+ half : bool = False
271+ ) -> torch .Tensor :
233272 sin_emb , cos_emb = emb .tensor_split (2 , - 1 )
234- return x * cos_emb + rot (x ) * sin_emb
273+ # x: [..., D], eg [x0, x1, x2, x3, x4, x5]
274+ if half :
275+ # sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2]
276+ # cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2
277+ # rope_rotate_half(x), eg [-x3, -x4, -x5, x0, x1, x2]
278+ return x * cos_emb + rope_rotate_half (x ) * sin_emb
279+ else :
280+ # sin: [..., D], eg [sin0, sin0, sin1, sin1, sin2, sin2]
281+ # cos: [..., D], eg [cos0, cos0, cos1, cos1, cos2, cos2]
282+ # rot(x), eg [-x1, x0, -x3, x2, -x5, x4]
283+ return x * cos_emb + rot (x ) * sin_emb
235284
236285
237286def apply_keep_indices_nlc (
@@ -834,3 +883,253 @@ def forward(self, x):
834883 def no_weight_decay (self ):
835884 """Exclude frequency parameters from weight decay."""
836885 return {'freqs' }
886+
887+
888+ class RotaryEmbeddingDinoV3 (nn .Module ):
889+ """RoPE for timm DinoV3 port, numerically matching original.
890+
891+ Math is aligned to original DinoV3 RopePositionEmbedding at https://github.com/facebookresearch/dinov3:
892+ - 0.5-centered coords normalized by H/W (or min/max), mapped to [-1,1]
893+ - training-time augmentations (shift/jitter/rescale)
894+ - periods schedule equals Rope's temperature (base) or min/max period
895+ """
896+
897+ def __init__ (
898+ self ,
899+ dim : int ,
900+ temperature : Optional [float ] = 100.0 ,
901+ min_period : Optional [float ] = 0.5 ,
902+ max_period : Optional [float ] = 90. ,
903+ feat_shape : Optional [List [int ]] = None ,
904+ ref_feat_shape : Optional [List [int ]] = None ,
905+ normalize_coords : str = "separate" , # 'min', 'max', 'separate'
906+ grid_offset : float = 0.0 ,
907+ grid_indexing : str = "ij" ,
908+ rotate_half : bool = True ,
909+ shift_coords : Optional [float ] = None ,
910+ jitter_coords : Optional [float ] = None , # interpreted as factor J >= 1
911+ rescale_coords : Optional [float ] = None , # interpreted as factor R >= 1
912+ ):
913+ super ().__init__ ()
914+
915+ # Dimensions / output format
916+ self .dim = dim # equal to head_dim for most vit applications
917+ self .rotate_half = rotate_half
918+
919+ # Period schedule parameters
920+ self .temperature = float (temperature )
921+ self .min_period = min_period
922+ self .max_period = max_period
923+
924+ # Coord processing + augs
925+ self .normalize_coords = normalize_coords
926+ self .shift_coords = shift_coords
927+ self .jitter_coords = jitter_coords
928+ self .rescale_coords = rescale_coords
929+ self .aug_active = any ([a is not None for a in [self .shift_coords , self .jitter_coords , self .rescale_coords ]])
930+
931+ # Grid config
932+ self .feat_shape = feat_shape
933+ self .ref_feat_shape = ref_feat_shape
934+ self .grid_offset = grid_offset
935+ self .grid_indexing = grid_indexing
936+
937+ # Precompute periods
938+ periods = self ._compute_periods ()
939+ self .register_buffer ("periods" , periods , persistent = False )
940+
941+ if feat_shape is not None :
942+ self ._cache_embed (feat_shape )
943+ else :
944+ self .register_buffer ("pos_embed_cached" , None , persistent = False )
945+ self .feat_shape = None
946+
947+ def _compute_periods (self , device = 'cpu' , dtype = torch .float32 ) -> torch .Tensor :
948+ """Construct periods from either min/max or temperature."""
949+ dim = self .dim // 4
950+
951+ if self .min_period is not None and self .max_period is not None :
952+ exponents = torch .linspace (0 , 1 , dim , dtype = torch .float32 )
953+ periods = self .min_period * ((self .max_period / self .min_period ) ** exponents )
954+ else :
955+ if self .temperature is None :
956+ raise ValueError ("Provide either min/max periods or `temperature`." )
957+ exponents = 2.0 * torch .arange (dim , device = device , dtype = dtype ) / (self .dim // 2 )
958+ periods = self .temperature ** exponents
959+
960+ # NOTE: original has periods downcast to bfloat16 in persistent buffers, so loaded models
961+ # BTW orignal and timm might differ a bit here
962+
963+ return periods
964+
965+ def _make_coords (
966+ self ,
967+ height : int ,
968+ width : int ,
969+ device : torch .device = 'cpu' ,
970+ dtype : torch .dtype = torch .float32 ,
971+ ) -> torch .Tensor :
972+ """Make coordinate grid matching offset and normalization of original.
973+ Returns: coords with shape (HW, 2) in [-1, 1].
974+ """
975+ # 0.5-centered indices with optional offset
976+ coords_h = torch .arange (0.5 , height , device = device , dtype = dtype ) + self .grid_offset
977+ coords_w = torch .arange (0.5 , width , device = device , dtype = dtype ) + self .grid_offset
978+
979+ # Normalization denominators
980+ if self .normalize_coords == "max" :
981+ denom = float (max (height , width ))
982+ h_denom = denom
983+ w_denom = denom
984+ elif self .normalize_coords == "min" :
985+ denom = float (min (height , width ))
986+ h_denom = denom
987+ w_denom = denom
988+ elif self .normalize_coords == "separate" :
989+ h_denom = float (height )
990+ w_denom = float (width )
991+ else :
992+ raise ValueError (f"Unknown normalize_coords: { self .normalize_coords } " )
993+
994+ # Normalize to [0, 1]
995+ coords_h = coords_h / h_denom
996+ coords_w = coords_w / w_denom
997+
998+ # Create grid then map to [-1, 1]
999+ if self .grid_indexing == "xy" :
1000+ grid_w , grid_h = torch .meshgrid (coords_w , coords_h , indexing = "xy" )
1001+ coords = torch .stack ([grid_h , grid_w ], dim = - 1 ) # (H, W, 2) -> (h, w order)
1002+ else :
1003+ coords = torch .stack (torch .meshgrid (coords_h , coords_w , indexing = "ij" ), dim = - 1 ) # (H, W, 2)
1004+ coords = coords .flatten (0 , 1 ) # (HW, 2)
1005+ coords = 2.0 * coords - 1.0 # (H, W, 2) in [-1, 1]
1006+ return coords
1007+
1008+ def _apply_coord_augs (self , coords : torch .Tensor ) -> torch .Tensor :
1009+ """Apply shift/jitter/rescale train time augmentations."""
1010+ if not self .training or not self .aug_active :
1011+ return coords
1012+
1013+ device = coords .device
1014+ dtype = coords .dtype
1015+
1016+ # Shift per-axis in [-s, +s]
1017+ if self .shift_coords is not None :
1018+ shift = float (self .shift_coords )
1019+ shift_hw = torch .empty (2 , device = device , dtype = dtype ).uniform_ (- shift , + shift )
1020+ coords = coords + shift_hw [None , :]
1021+
1022+ # Jitter: per-axis log-uniform factor in [1/J, J]
1023+ if self .jitter_coords is not None :
1024+ jitter_factor = float (self .jitter_coords )
1025+ if jitter_factor <= 0 :
1026+ raise ValueError ("jitter_coords must be > 0 (interpreted as multiplicative factor)." )
1027+ jitter_max = math .log (jitter_factor )
1028+ jitter_hw = torch .empty (2 , device = device , dtype = dtype ).uniform_ (- jitter_max , + jitter_max ).exp ()
1029+ coords = coords * jitter_hw [None , :]
1030+
1031+ # Rescale: shared scalar log-uniform factor in [1/R, R]
1032+ if self .rescale_coords is not None :
1033+ rescale_factor = float (self .rescale_coords )
1034+ if rescale_factor <= 0 :
1035+ raise ValueError ("rescale_coords must be > 0 (interpreted as multiplicative factor)." )
1036+ rescale_max = math .log (rescale_factor )
1037+ rescale = torch .empty (1 , device = device , dtype = dtype ).uniform_ (- rescale_max , + rescale_max ).exp ()
1038+ coords = coords * rescale
1039+
1040+ return coords
1041+
1042+ def _get_pos_embed_from_coords (self , coords : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
1043+ """Return sin/cos embeddings with either 'half' or 'interleaved' layout."""
1044+ # coords: (HW, 2); periods: (dim)
1045+ dim = self .dim // 4
1046+ device = self .periods .device
1047+ dtype = self .periods .dtype
1048+ assert self .periods .numel () == dim
1049+
1050+ # NOTE this is a slightly later device/dtype switch than original
1051+ coords = coords [:, :, None ].to (device = device , dtype = dtype )
1052+ angles = 2 * math .pi * coords / self .periods [None , None , :]
1053+ angles = angles .flatten (1 ) # (HW, dim // 2)
1054+
1055+ if self .rotate_half :
1056+ # Tile (half layout) (HW, dim // 2) -> (HW, dim)
1057+ angles = angles .tile (2 )
1058+ else :
1059+ # Interleaved layout (HW, dim // 2) -> (HW, dim)
1060+ angles = angles .repeat_interleave (2 , dim = - 1 )
1061+
1062+ sin = torch .sin (angles )
1063+ cos = torch .cos (angles )
1064+ return sin , cos
1065+
1066+ def _create_embed (self , feat_shape : List [int ], no_aug : bool = False ) -> torch .Tensor :
1067+ H , W = feat_shape
1068+ coords = self ._make_coords (H , W ) # (HW, 2)
1069+ if not no_aug :
1070+ coords = self ._apply_coord_augs (coords )
1071+ sin , cos = self ._get_pos_embed_from_coords (coords ) # 2 * (HW, dim)
1072+ rope_embed = torch .cat ([sin , cos ], dim = - 1 ) # (HW, 2*dim)
1073+ return rope_embed
1074+
1075+ def _cache_embed (self , feat_shape : List [int ]):
1076+ rope_embed = self ._create_embed (feat_shape , no_aug = True ) # create non-augmented embeds for cache
1077+ self .register_buffer ("pos_embed_cached" , rope_embed , persistent = False )
1078+ self .feat_shape = feat_shape
1079+
1080+ def get_embed (self , shape : Optional [List [int ]] = None ) -> torch .Tensor :
1081+ """Generate rope_embed matching DINOv3 RopePositionEmbedding numerics.
1082+
1083+ Returns: (HW, num_heads, 2 * head_dim) with last dim = [sin, cos] cat.
1084+ """
1085+ if shape is not None :
1086+ rope_embed = self ._create_embed (shape )
1087+ else :
1088+ need_create = self .pos_embed_cached is None or (self .training and self .aug_active )
1089+ if need_create :
1090+ assert self .feat_shape is not None , 'feature shape must be cached on create'
1091+ rope_embed = self ._create_embed (self .feat_shape )
1092+ else :
1093+ rope_embed = self .pos_embed_cached
1094+
1095+ return rope_embed
1096+
1097+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
1098+ """Get and apply rotary embeddings to x"""
1099+ # assuming channel-first tensor where spatial dim are >= 2
1100+ pos_embed = self .get_embed (x .shape [2 :])
1101+ return apply_rot_embed_cat (x , pos_embed , half = self .rotate_half )
1102+
1103+
1104+ def create_rope_embed (
1105+ rope_type : str = 'cat' ,
1106+ dim : int = 768 ,
1107+ num_heads : int = 12 ,
1108+ ** kwargs
1109+ ) -> nn .Module :
1110+ """Factory function for creating rotary position embeddings.
1111+
1112+ Args:
1113+ rope_type: Type of RoPE to create. Options:
1114+ - 'base': Basic RotaryEmbedding
1115+ - 'cat': RotaryEmbeddingCat (concatenated sin/cos)
1116+ - 'mixed': RotaryEmbeddingMixed (learnable per-depth frequencies)
1117+ - 'dinov3': RotaryEmbeddingDinoV3 (with coordinate transforms)
1118+ dim: Total embedding dimension
1119+ num_heads: Number of attention heads
1120+ **kwargs: Additional arguments passed to the specific RoPE class
1121+
1122+ Returns:
1123+ Rotary embedding module
1124+ """
1125+ if rope_type == 'base' :
1126+ return RotaryEmbedding (dim = dim // num_heads , ** kwargs )
1127+ elif rope_type == 'cat' :
1128+ return RotaryEmbeddingCat (dim = dim // num_heads , ** kwargs )
1129+ elif rope_type == 'mixed' :
1130+ # Mixed requires depth parameter, generates differing embeddings per layer and head
1131+ return RotaryEmbeddingMixed (dim = dim , num_heads = num_heads , ** kwargs )
1132+ elif rope_type == 'dinov3' :
1133+ return RotaryEmbeddingDinoV3 (dim = dim // num_heads , ** kwargs )
1134+ else :
1135+ raise ValueError (f"Unknown RoPE type: { rope_type } " )
0 commit comments