@@ -19,17 +19,22 @@ class StdConv2d(nn.Conv2d):
1919 """
2020 def __init__ (
2121 self , in_channel , out_channels , kernel_size , stride = 1 , padding = None , dilation = 1 ,
22- groups = 1 , bias = False , eps = 1e-5 ):
22+ groups = 1 , bias = False , eps = 1e-5 , use_layernorm = True ):
2323 if padding is None :
2424 padding = get_padding (kernel_size , stride , dilation )
2525 super ().__init__ (
2626 in_channel , out_channels , kernel_size , stride = stride ,
2727 padding = padding , dilation = dilation , groups = groups , bias = bias )
2828 self .eps = eps
29+ self .use_layernorm = use_layernorm
2930
3031 def get_weight (self ):
31- std , mean = torch .std_mean (self .weight , dim = [1 , 2 , 3 ], keepdim = True , unbiased = False )
32- weight = (self .weight - mean ) / (std + self .eps )
32+ if self .use_layernorm :
33+ # NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op
34+ weight = F .layer_norm (self .weight , self .weight .shape [1 :], eps = self .eps )
35+ else :
36+ std , mean = torch .std_mean (self .weight , dim = [1 , 2 , 3 ], keepdim = True , unbiased = False )
37+ weight = (self .weight - mean ) / (std + self .eps )
3338 return weight
3439
3540 def forward (self , x ):
@@ -45,17 +50,22 @@ class StdConv2dSame(nn.Conv2d):
4550 """
4651 def __init__ (
4752 self , in_channel , out_channels , kernel_size , stride = 1 , padding = 'SAME' , dilation = 1 ,
48- groups = 1 , bias = False , eps = 1e-5 ):
53+ groups = 1 , bias = False , eps = 1e-5 , use_layernorm = True ):
4954 padding , is_dynamic = get_padding_value (padding , kernel_size , stride = stride , dilation = dilation )
5055 super ().__init__ (
5156 in_channel , out_channels , kernel_size , stride = stride , padding = padding , dilation = dilation ,
5257 groups = groups , bias = bias )
5358 self .same_pad = is_dynamic
5459 self .eps = eps
60+ self .use_layernorm = use_layernorm
5561
5662 def get_weight (self ):
57- std , mean = torch .std_mean (self .weight , dim = [1 , 2 , 3 ], keepdim = True , unbiased = False )
58- weight = (self .weight - mean ) / (std + self .eps )
63+ if self .use_layernorm :
64+ # NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op
65+ weight = F .layer_norm (self .weight , self .weight .shape [1 :], eps = self .eps )
66+ else :
67+ std , mean = torch .std_mean (self .weight , dim = [1 , 2 , 3 ], keepdim = True , unbiased = False )
68+ weight = (self .weight - mean ) / (std + self .eps )
5969 return weight
6070
6171 def forward (self , x ):
@@ -76,24 +86,25 @@ class ScaledStdConv2d(nn.Conv2d):
7686
7787 def __init__ (
7888 self , in_channels , out_channels , kernel_size , stride = 1 , padding = None , dilation = 1 , groups = 1 ,
79- bias = True , gamma = 1.0 , eps = 1e-5 , gain_init = 1.0 , use_layernorm = False ):
89+ bias = True , gamma = 1.0 , eps = 1e-5 , gain_init = 1.0 , use_layernorm = True ):
8090 if padding is None :
8191 padding = get_padding (kernel_size , stride , dilation )
8292 super ().__init__ (
8393 in_channels , out_channels , kernel_size , stride = stride , padding = padding , dilation = dilation ,
8494 groups = groups , bias = bias )
8595 self .gain = nn .Parameter (torch .full ((self .out_channels , 1 , 1 , 1 ), gain_init ))
8696 self .scale = gamma * self .weight [0 ].numel () ** - 0.5 # gamma * 1 / sqrt(fan-in)
87- self .eps = eps ** 2 if use_layernorm else eps
97+ self .eps = eps
8898 self .use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel
8999
90100 def get_weight (self ):
91101 if self .use_layernorm :
92- weight = self .scale * F .layer_norm (self .weight , self .weight .shape [1 :], eps = self .eps )
102+ # NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op
103+ weight = F .layer_norm (self .weight , self .weight .shape [1 :], eps = self .eps )
93104 else :
94105 std , mean = torch .std_mean (self .weight , dim = [1 , 2 , 3 ], keepdim = True , unbiased = False )
95- weight = self . scale * (self .weight - mean ) / (std + self .eps )
96- return self .gain * weight
106+ weight = (self .weight - mean ) / (std + self .eps )
107+ return weight . mul_ ( self .gain * self . scale )
97108
98109 def forward (self , x ):
99110 return F .conv2d (x , self .get_weight (), self .bias , self .stride , self .padding , self .dilation , self .groups )
@@ -110,32 +121,25 @@ class ScaledStdConv2dSame(nn.Conv2d):
110121
111122 def __init__ (
112123 self , in_channels , out_channels , kernel_size , stride = 1 , padding = 'SAME' , dilation = 1 , groups = 1 ,
113- bias = True , gamma = 1.0 , eps = 1e-5 , gain_init = 1.0 , use_layernorm = False ):
124+ bias = True , gamma = 1.0 , eps = 1e-5 , gain_init = 1.0 , use_layernorm = True ):
114125 padding , is_dynamic = get_padding_value (padding , kernel_size , stride = stride , dilation = dilation )
115126 super ().__init__ (
116127 in_channels , out_channels , kernel_size , stride = stride , padding = padding , dilation = dilation ,
117128 groups = groups , bias = bias )
118129 self .gain = nn .Parameter (torch .full ((self .out_channels , 1 , 1 , 1 ), gain_init ))
119130 self .scale = gamma * self .weight [0 ].numel () ** - 0.5
120131 self .same_pad = is_dynamic
121- self .eps = eps ** 2 if use_layernorm else eps
132+ self .eps = eps
122133 self .use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel
123134
124- # NOTE an alternate formulation to consider, closer to DeepMind Haiku impl but doesn't seem
125- # to make much numerical difference (+/- .002 to .004) in top-1 during eval.
126- # def get_weight(self):
127- # var, mean = torch.var_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
128- # scale = torch.rsqrt((self.weight[0].numel() * var).clamp_(self.eps)) * self.gain
129- # weight = (self.weight - mean) * scale
130- # return self.gain * weight
131-
132135 def get_weight (self ):
133136 if self .use_layernorm :
134- weight = self .scale * F .layer_norm (self .weight , self .weight .shape [1 :], eps = self .eps )
137+ # NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op
138+ weight = F .layer_norm (self .weight , self .weight .shape [1 :], eps = self .eps )
135139 else :
136140 std , mean = torch .std_mean (self .weight , dim = [1 , 2 , 3 ], keepdim = True , unbiased = False )
137- weight = self . scale * (self .weight - mean ) / (std + self .eps )
138- return self .gain * weight
141+ weight = (self .weight - mean ) / (std + self .eps )
142+ return weight . mul_ ( self .gain * self . scale )
139143
140144 def forward (self , x ):
141145 if self .same_pad :
0 commit comments