1+ """ Convolution with Weight Standardization (StdConv and ScaledStdConv)
2+
3+ StdConv:
4+ @article{weightstandardization,
5+ author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Yuille},
6+ title = {Weight Standardization},
7+ journal = {arXiv preprint arXiv:1903.10520},
8+ year = {2019},
9+ }
10+ Code: https://github.com/joe-siyuan-qiao/WeightStandardization
11+
12+ ScaledStdConv:
13+ Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
14+ - https://arxiv.org/abs/2101.08692
15+ Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets
16+
17+ Hacked together by / copyright Ross Wightman, 2021.
18+ """
119import torch
220import torch .nn as nn
321import torch .nn .functional as F
422
523from .padding import get_padding , get_padding_value , pad_same
624
725
8- def get_weight (module ):
9- std , mean = torch .std_mean (module .weight , dim = [1 , 2 , 3 ], keepdim = True , unbiased = False )
10- weight = (module .weight - mean ) / (std + module .eps )
11- return weight
12-
13-
1426class StdConv2d (nn .Conv2d ):
1527 """Conv2d with Weight Standardization. Used for BiT ResNet-V2 models.
1628
@@ -30,7 +42,7 @@ def __init__(
3042 def forward (self , x ):
3143 weight = F .batch_norm (
3244 self .weight .view (1 , self .out_channels , - 1 ), None , None ,
33- eps = self . eps , training = True , momentum = 0. ).reshape_as (self .weight )
45+ training = True , momentum = 0. , eps = self . eps ).reshape_as (self .weight )
3446 x = F .conv2d (x , weight , self .bias , self .stride , self .padding , self .dilation , self .groups )
3547 return x
3648
@@ -56,7 +68,7 @@ def forward(self, x):
5668 x = pad_same (x , self .kernel_size , self .stride , self .dilation )
5769 weight = F .batch_norm (
5870 self .weight .view (1 , self .out_channels , - 1 ), None , None ,
59- eps = self . eps , training = True , momentum = 0. ).reshape_as (self .weight )
71+ training = True , momentum = 0. , eps = self . eps ).reshape_as (self .weight )
6072 x = F .conv2d (x , weight , self .bias , self .stride , self .padding , self .dilation , self .groups )
6173 return x
6274
@@ -86,7 +98,7 @@ def forward(self, x):
8698 weight = F .batch_norm (
8799 self .weight .view (1 , self .out_channels , - 1 ), None , None ,
88100 weight = (self .gain * self .scale ).view (- 1 ),
89- eps = self . eps , training = True , momentum = 0. ).reshape_as (self .weight )
101+ training = True , momentum = 0. , eps = self . eps ).reshape_as (self .weight )
90102 return F .conv2d (x , weight , self .bias , self .stride , self .padding , self .dilation , self .groups )
91103
92104
@@ -117,5 +129,5 @@ def forward(self, x):
117129 weight = F .batch_norm (
118130 self .weight .view (1 , self .out_channels , - 1 ), None , None ,
119131 weight = (self .gain * self .scale ).view (- 1 ),
120- eps = self . eps , training = True , momentum = 0. ).reshape_as (self .weight )
132+ training = True , momentum = 0. , eps = self . eps ).reshape_as (self .weight )
121133 return F .conv2d (x , weight , self .bias , self .stride , self .padding , self .dilation , self .groups )
0 commit comments