Skip to content

Commit 8319e0c

Browse files
committed
Add file docstring to std_conv.py
1 parent 0020268 commit 8319e0c

File tree

1 file changed

+22
-10
lines changed

1 file changed

+22
-10
lines changed

timm/models/layers/std_conv.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,28 @@
1+
""" Convolution with Weight Standardization (StdConv and ScaledStdConv)
2+
3+
StdConv:
4+
@article{weightstandardization,
5+
author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Yuille},
6+
title = {Weight Standardization},
7+
journal = {arXiv preprint arXiv:1903.10520},
8+
year = {2019},
9+
}
10+
Code: https://github.com/joe-siyuan-qiao/WeightStandardization
11+
12+
ScaledStdConv:
13+
Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
14+
- https://arxiv.org/abs/2101.08692
15+
Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets
16+
17+
Hacked together by / copyright Ross Wightman, 2021.
18+
"""
119
import torch
220
import torch.nn as nn
321
import torch.nn.functional as F
422

523
from .padding import get_padding, get_padding_value, pad_same
624

725

8-
def get_weight(module):
9-
std, mean = torch.std_mean(module.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
10-
weight = (module.weight - mean) / (std + module.eps)
11-
return weight
12-
13-
1426
class StdConv2d(nn.Conv2d):
1527
"""Conv2d with Weight Standardization. Used for BiT ResNet-V2 models.
1628
@@ -30,7 +42,7 @@ def __init__(
3042
def forward(self, x):
3143
weight = F.batch_norm(
3244
self.weight.view(1, self.out_channels, -1), None, None,
33-
eps=self.eps, training=True, momentum=0.).reshape_as(self.weight)
45+
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
3446
x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
3547
return x
3648

@@ -56,7 +68,7 @@ def forward(self, x):
5668
x = pad_same(x, self.kernel_size, self.stride, self.dilation)
5769
weight = F.batch_norm(
5870
self.weight.view(1, self.out_channels, -1), None, None,
59-
eps=self.eps, training=True, momentum=0.).reshape_as(self.weight)
71+
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
6072
x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
6173
return x
6274

@@ -86,7 +98,7 @@ def forward(self, x):
8698
weight = F.batch_norm(
8799
self.weight.view(1, self.out_channels, -1), None, None,
88100
weight=(self.gain * self.scale).view(-1),
89-
eps=self.eps, training=True, momentum=0.).reshape_as(self.weight)
101+
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
90102
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
91103

92104

@@ -117,5 +129,5 @@ def forward(self, x):
117129
weight = F.batch_norm(
118130
self.weight.view(1, self.out_channels, -1), None, None,
119131
weight=(self.gain * self.scale).view(-1),
120-
eps=self.eps, training=True, momentum=0.).reshape_as(self.weight)
132+
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
121133
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

0 commit comments

Comments
 (0)