Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
7560ca5
Adding dd factory_kwargs to modules in timm/layers, initial model WIP…
rwightman Sep 26, 2025
90a35c8
Add dd factory kwargs to eva, resnet
rwightman Sep 27, 2025
325a6cc
Add dd to other ResNet based models, Res2Net, ResNeSt, SKNet
rwightman Sep 27, 2025
b94c221
Add dd factory kwargs to maxxvit and regnet
rwightman Sep 27, 2025
ee751ef
Add dd factory kwargs to nfnet and resnetv2
rwightman Sep 28, 2025
10e7020
dd factory kwargs for fastvit, convnext, mambaout
rwightman Sep 28, 2025
60db539
Add dd factory kwargs to all EfficientNetBuilder models, MobileNet V1…
rwightman Sep 28, 2025
4d19b34
Fix typo for s2d norm
rwightman Sep 28, 2025
f15f7c9
Add dd factory kwargs to byobnet, cspnet, davit, edgenext
rwightman Sep 29, 2025
4c35b78
Add device/dtype factory kwargs to beit, efficientformer*, efficientv…
rwightman Sep 29, 2025
3a85ed4
avg pool should not have been passed dd
rwightman Sep 29, 2025
8cbbf39
Fix DarkStage device kwargs
rwightman Sep 29, 2025
1e172a0
dd kwargs for naflexvit, needs revisit for nn.Parameters
rwightman Sep 29, 2025
a7dc50f
A whack of classic convnets converted with dd factory kwargs. densene…
rwightman Sep 29, 2025
068e6d4
Remove **dd from two inception reset_classifier calls
rwightman Sep 29, 2025
6a3342c
dd factory kwargs added to a bunch of vit/vit-hybrids. cait, coat, co…
rwightman Sep 30, 2025
c7955eb
Add dd factory kwargs to all swin transformers and volo
rwightman Sep 30, 2025
53caeb0
Add some more dd kwarg updates, crossvit, ghostnet, rdnet, repghost, …
rwightman Oct 1, 2025
21b1ae7
More dd factory kwargs updates. hiera, hieradet_sam2, metaformer, mlp…
rwightman Oct 1, 2025
5cadf13
More dd arg conversions. fasternet, gcvit, hgnet, nextvit, starnet, v…
rwightman Oct 1, 2025
d3fdea8
Typing, super(), buffer dtype fixes for timm/layers and timm/models
rwightman Oct 2, 2025
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
dd factory kwargs added to a bunch of vit/vit-hybrids. cait, coat, co…
…nvit, convmixer, deit, mvitv2, nest, pit, pvt_v2, tiny_vit, tnt, twins, visformer, xcit
  • Loading branch information
rwightman committed Sep 30, 2025
commit 6a3342ca8a4e60fb49b6cd6c4ae271fab7cd35f2
184 changes: 110 additions & 74 deletions timm/models/cait.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
from functools import partial
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union, Type, Any

import torch
import torch.nn as nn
Expand All @@ -29,18 +29,28 @@ class ClassAttn(nn.Module):
# with slight modifications to do CA
fused_attn: torch.jit.Final[bool]

def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.,
device=None,
dtype=None,
):
super().__init__()
dd = {'device': device, 'dtype': dtype}
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.fused_attn = use_fused_attn()

self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.k = nn.Linear(dim, dim, bias=qkv_bias)
self.v = nn.Linear(dim, dim, bias=qkv_bias)
self.q = nn.Linear(dim, dim, bias=qkv_bias, **dd)
self.k = nn.Linear(dim, dim, bias=qkv_bias, **dd)
self.v = nn.Linear(dim, dim, bias=qkv_bias, **dd)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj = nn.Linear(dim, dim, **dd)
self.proj_drop = nn.Dropout(proj_drop)

def forward(self, x):
Expand Down Expand Up @@ -73,39 +83,44 @@ class LayerScaleBlockClassAttn(nn.Module):
# with slight modifications to add CA and LayerScale
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
proj_drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
attn_block=ClassAttn,
mlp_block=Mlp,
init_values=1e-4,
dim: int,
num_heads: int,
mlp_ratio: float = 4.,
qkv_bias: bool = False,
proj_drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.,
act_layer: Type[nn.Module] = nn.GELU,
norm_layer: Type[nn.Module] = nn.LayerNorm,
attn_block: Type[nn.Module] = ClassAttn,
mlp_block: Type[nn.Module] = Mlp,
init_values: float = 1e-4,
device=None,
dtype=None,
):
super().__init__()
self.norm1 = norm_layer(dim)
dd = {'device': device, 'dtype': dtype}
self.norm1 = norm_layer(dim, **dd)
self.attn = attn_block(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
**dd,
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.norm2 = norm_layer(dim, **dd)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = mlp_block(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=proj_drop,
**dd,
)
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim))
self.gamma_2 = nn.Parameter(init_values * torch.ones(dim))
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim, **dd))
self.gamma_2 = nn.Parameter(init_values * torch.ones(dim, **dd))

def forward(self, x, x_cls):
u = torch.cat((x_cls, x), dim=1)
Expand All @@ -117,22 +132,32 @@ def forward(self, x, x_cls):
class TalkingHeadAttn(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications to add Talking Heads Attention (https://arxiv.org/pdf/2003.02436v1.pdf)
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.,
device=None,
dtype=None,
):
super().__init__()
dd = {'device': device, 'dtype': dtype}

self.num_heads = num_heads

head_dim = dim // num_heads

self.scale = head_dim ** -0.5

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
self.attn_drop = nn.Dropout(attn_drop)

self.proj = nn.Linear(dim, dim)
self.proj = nn.Linear(dim, dim, **dd)

self.proj_l = nn.Linear(num_heads, num_heads)
self.proj_w = nn.Linear(num_heads, num_heads)
self.proj_l = nn.Linear(num_heads, num_heads, **dd)
self.proj_w = nn.Linear(num_heads, num_heads, **dd)

self.proj_drop = nn.Dropout(proj_drop)

Expand Down Expand Up @@ -161,39 +186,44 @@ class LayerScaleBlock(nn.Module):
# with slight modifications to add layerScale
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
proj_drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
attn_block=TalkingHeadAttn,
mlp_block=Mlp,
init_values=1e-4,
dim: int,
num_heads: int,
mlp_ratio: float = 4.,
qkv_bias: bool = False,
proj_drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.,
act_layer: Type[nn.Module] = nn.GELU,
norm_layer: Type[nn.Module] = nn.LayerNorm,
attn_block: Type[nn.Module] = TalkingHeadAttn,
mlp_block: Type[nn.Module] = Mlp,
init_values: float = 1e-4,
device=None,
dtype=None,
):
super().__init__()
self.norm1 = norm_layer(dim)
dd = {'device': device, 'dtype': dtype}
self.norm1 = norm_layer(dim, **dd)
self.attn = attn_block(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
**dd,
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.norm2 = norm_layer(dim, **dd)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = mlp_block(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=proj_drop,
**dd,
)
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim))
self.gamma_2 = nn.Parameter(init_values * torch.ones(dim))
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim, **dd))
self.gamma_2 = nn.Parameter(init_values * torch.ones(dim, **dd))

def forward(self, x):
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
Expand All @@ -206,35 +236,38 @@ class Cait(nn.Module):
# with slight modifications to adapt to our cait models
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
global_pool='token',
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
drop_rate=0.,
pos_drop_rate=0.,
proj_drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
block_layers=LayerScaleBlock,
block_layers_token=LayerScaleBlockClassAttn,
patch_layer=PatchEmbed,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
attn_block=TalkingHeadAttn,
mlp_block=Mlp,
init_values=1e-4,
attn_block_token_only=ClassAttn,
mlp_block_token_only=Mlp,
depth_token_only=2,
mlp_ratio_token_only=4.0
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'token',
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
drop_rate: float = 0.,
pos_drop_rate: float = 0.,
proj_drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
block_layers: Type[nn.Module] = LayerScaleBlock,
block_layers_token: Type[nn.Module] = LayerScaleBlockClassAttn,
patch_layer: Type[nn.Module] = PatchEmbed,
norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
act_layer: Type[nn.Module] = nn.GELU,
attn_block: Type[nn.Module] = TalkingHeadAttn,
mlp_block: Type[nn.Module] = Mlp,
init_values: float = 1e-4,
attn_block_token_only: Type[nn.Module] = ClassAttn,
mlp_block_token_only: Type[nn.Module] = Mlp,
depth_token_only: int = 2,
mlp_ratio_token_only: float = 4.0,
device=None,
dtype=None,
):
super().__init__()
dd = {'device': device, 'dtype': dtype}
assert global_pool in ('', 'token', 'avg')

self.num_classes = num_classes
Expand All @@ -247,12 +280,13 @@ def __init__(
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
**dd,
)
num_patches = self.patch_embed.num_patches
r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size

self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, **dd))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim, **dd))
self.pos_drop = nn.Dropout(p=pos_drop_rate)

dpr = [drop_path_rate for i in range(depth)]
Expand All @@ -269,6 +303,7 @@ def __init__(
attn_block=attn_block,
mlp_block=mlp_block,
init_values=init_values,
**dd,
) for i in range(depth)])
self.feature_info = [dict(num_chs=embed_dim, reduction=r, module=f'blocks.{i}') for i in range(depth)]

Expand All @@ -282,12 +317,13 @@ def __init__(
attn_block=attn_block_token_only,
mlp_block=mlp_block_token_only,
init_values=init_values,
**dd,
) for _ in range(depth_token_only)])

self.norm = norm_layer(embed_dim)
self.norm = norm_layer(embed_dim, **dd)

self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.head = nn.Linear(embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity()

trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
Expand Down
Loading