Skip to content

Commit 8e4ac35

Browse files
committed
All ScaledStdConv and StdConv uses default to using F.layernorm so that they work with PyTorch XLA. eps value tweaking is a WIP.
1 parent 54a6cca commit 8e4ac35

File tree

3 files changed

+34
-27
lines changed

3 files changed

+34
-27
lines changed

timm/models/layers/std_conv.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,22 @@ class StdConv2d(nn.Conv2d):
1919
"""
2020
def __init__(
2121
self, in_channel, out_channels, kernel_size, stride=1, padding=None, dilation=1,
22-
groups=1, bias=False, eps=1e-5):
22+
groups=1, bias=False, eps=1e-5, use_layernorm=True):
2323
if padding is None:
2424
padding = get_padding(kernel_size, stride, dilation)
2525
super().__init__(
2626
in_channel, out_channels, kernel_size, stride=stride,
2727
padding=padding, dilation=dilation, groups=groups, bias=bias)
2828
self.eps = eps
29+
self.use_layernorm = use_layernorm
2930

3031
def get_weight(self):
31-
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
32-
weight = (self.weight - mean) / (std + self.eps)
32+
if self.use_layernorm:
33+
# NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op
34+
weight = F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
35+
else:
36+
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
37+
weight = (self.weight - mean) / (std + self.eps)
3338
return weight
3439

3540
def forward(self, x):
@@ -45,17 +50,22 @@ class StdConv2dSame(nn.Conv2d):
4550
"""
4651
def __init__(
4752
self, in_channel, out_channels, kernel_size, stride=1, padding='SAME', dilation=1,
48-
groups=1, bias=False, eps=1e-5):
53+
groups=1, bias=False, eps=1e-5, use_layernorm=True):
4954
padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
5055
super().__init__(
5156
in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
5257
groups=groups, bias=bias)
5358
self.same_pad = is_dynamic
5459
self.eps = eps
60+
self.use_layernorm = use_layernorm
5561

5662
def get_weight(self):
57-
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
58-
weight = (self.weight - mean) / (std + self.eps)
63+
if self.use_layernorm:
64+
# NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op
65+
weight = F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
66+
else:
67+
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
68+
weight = (self.weight - mean) / (std + self.eps)
5969
return weight
6070

6171
def forward(self, x):
@@ -76,24 +86,25 @@ class ScaledStdConv2d(nn.Conv2d):
7686

7787
def __init__(
7888
self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1,
79-
bias=True, gamma=1.0, eps=1e-5, gain_init=1.0, use_layernorm=False):
89+
bias=True, gamma=1.0, eps=1e-5, gain_init=1.0, use_layernorm=True):
8090
if padding is None:
8191
padding = get_padding(kernel_size, stride, dilation)
8292
super().__init__(
8393
in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
8494
groups=groups, bias=bias)
8595
self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init))
8696
self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in)
87-
self.eps = eps ** 2 if use_layernorm else eps
97+
self.eps = eps
8898
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel
8999

90100
def get_weight(self):
91101
if self.use_layernorm:
92-
weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
102+
# NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op
103+
weight = F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
93104
else:
94105
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
95-
weight = self.scale * (self.weight - mean) / (std + self.eps)
96-
return self.gain * weight
106+
weight = (self.weight - mean) / (std + self.eps)
107+
return weight.mul_(self.gain * self.scale)
97108

98109
def forward(self, x):
99110
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
@@ -110,32 +121,25 @@ class ScaledStdConv2dSame(nn.Conv2d):
110121

111122
def __init__(
112123
self, in_channels, out_channels, kernel_size, stride=1, padding='SAME', dilation=1, groups=1,
113-
bias=True, gamma=1.0, eps=1e-5, gain_init=1.0, use_layernorm=False):
124+
bias=True, gamma=1.0, eps=1e-5, gain_init=1.0, use_layernorm=True):
114125
padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
115126
super().__init__(
116127
in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
117128
groups=groups, bias=bias)
118129
self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init))
119130
self.scale = gamma * self.weight[0].numel() ** -0.5
120131
self.same_pad = is_dynamic
121-
self.eps = eps ** 2 if use_layernorm else eps
132+
self.eps = eps
122133
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel
123134

124-
# NOTE an alternate formulation to consider, closer to DeepMind Haiku impl but doesn't seem
125-
# to make much numerical difference (+/- .002 to .004) in top-1 during eval.
126-
# def get_weight(self):
127-
# var, mean = torch.var_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
128-
# scale = torch.rsqrt((self.weight[0].numel() * var).clamp_(self.eps)) * self.gain
129-
# weight = (self.weight - mean) * scale
130-
# return self.gain * weight
131-
132135
def get_weight(self):
133136
if self.use_layernorm:
134-
weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
137+
# NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op
138+
weight = F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
135139
else:
136140
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
137-
weight = self.scale * (self.weight - mean) / (std + self.eps)
138-
return self.gain * weight
141+
weight = (self.weight - mean) / (std + self.eps)
142+
return weight.mul_(self.gain * self.scale)
139143

140144
def forward(self, x):
141145
if self.same_pad:

timm/models/nfnet.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ class NfCfg:
166166
extra_conv: bool = False # extra 3x3 bottleneck convolution for NFNet models
167167
gamma_in_act: bool = False
168168
same_padding: bool = False
169+
std_conv_eps: float = 1e-5
170+
std_conv_ln: bool = True # use layer-norm impl to normalize in std-conv, works in PyTorch XLA, slightly faster
169171
skipinit: bool = False # disabled by default, non-trivial performance impact
170172
zero_init_fc: bool = False
171173
act_layer: str = 'silu'
@@ -482,10 +484,11 @@ def __init__(self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg',
482484
conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d
483485
if cfg.gamma_in_act:
484486
act_layer = act_with_gamma(cfg.act_layer, gamma=_nonlin_gamma[cfg.act_layer])
485-
conv_layer = partial(conv_layer, eps=1e-4) # DM weights better with higher eps
487+
conv_layer = partial(conv_layer, eps=cfg.std_conv_eps, use_layernorm=cfg.std_conv_ln)
486488
else:
487489
act_layer = get_act_layer(cfg.act_layer)
488-
conv_layer = partial(conv_layer, gamma=_nonlin_gamma[cfg.act_layer])
490+
conv_layer = partial(
491+
conv_layer, gamma=_nonlin_gamma[cfg.act_layer], eps=cfg.std_conv_eps, use_layernorm=cfg.std_conv_ln)
489492
attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
490493

491494
stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div)

timm/models/vision_transformer_hybrid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def _resnetv2(layers=(3, 4, 9), **kwargs):
118118
padding_same = kwargs.get('padding_same', True)
119119
if padding_same:
120120
stem_type = 'same'
121-
conv_layer = StdConv2dSame
121+
conv_layer = partial(StdConv2dSame, eps=1e-5)
122122
else:
123123
stem_type = ''
124124
conv_layer = StdConv2d

0 commit comments

Comments
 (0)