55import torch
66import torch .nn as nn
77import torch .nn .functional as F
8- from torch .jit .annotations import Optional
8+ from torch .jit .annotations import Optional , Tuple
99from torch import Tensor
1010from .utils import load_state_dict_from_url
1111
@@ -63,34 +63,42 @@ def googlenet(pretrained=False, progress=True, **kwargs):
6363class GoogLeNet (nn .Module ):
6464 __constants__ = ['aux_logits' , 'transform_input' ]
6565
66- def __init__ (self , num_classes = 1000 , aux_logits = True , transform_input = False , init_weights = True ):
66+ def __init__ (self , num_classes = 1000 , aux_logits = True , transform_input = False , init_weights = True ,
67+ blocks = None ):
6768 super (GoogLeNet , self ).__init__ ()
69+ if blocks is None :
70+ blocks = [BasicConv2d , Inception , InceptionAux ]
71+ assert len (blocks ) == 3
72+ conv_block = blocks [0 ]
73+ inception_block = blocks [1 ]
74+ inception_aux_block = blocks [2 ]
75+
6876 self .aux_logits = aux_logits
6977 self .transform_input = transform_input
7078
71- self .conv1 = BasicConv2d (3 , 64 , kernel_size = 7 , stride = 2 , padding = 3 )
79+ self .conv1 = conv_block (3 , 64 , kernel_size = 7 , stride = 2 , padding = 3 )
7280 self .maxpool1 = nn .MaxPool2d (3 , stride = 2 , ceil_mode = True )
73- self .conv2 = BasicConv2d (64 , 64 , kernel_size = 1 )
74- self .conv3 = BasicConv2d (64 , 192 , kernel_size = 3 , padding = 1 )
81+ self .conv2 = conv_block (64 , 64 , kernel_size = 1 )
82+ self .conv3 = conv_block (64 , 192 , kernel_size = 3 , padding = 1 )
7583 self .maxpool2 = nn .MaxPool2d (3 , stride = 2 , ceil_mode = True )
7684
77- self .inception3a = Inception (192 , 64 , 96 , 128 , 16 , 32 , 32 )
78- self .inception3b = Inception (256 , 128 , 128 , 192 , 32 , 96 , 64 )
85+ self .inception3a = inception_block (192 , 64 , 96 , 128 , 16 , 32 , 32 )
86+ self .inception3b = inception_block (256 , 128 , 128 , 192 , 32 , 96 , 64 )
7987 self .maxpool3 = nn .MaxPool2d (3 , stride = 2 , ceil_mode = True )
8088
81- self .inception4a = Inception (480 , 192 , 96 , 208 , 16 , 48 , 64 )
82- self .inception4b = Inception (512 , 160 , 112 , 224 , 24 , 64 , 64 )
83- self .inception4c = Inception (512 , 128 , 128 , 256 , 24 , 64 , 64 )
84- self .inception4d = Inception (512 , 112 , 144 , 288 , 32 , 64 , 64 )
85- self .inception4e = Inception (528 , 256 , 160 , 320 , 32 , 128 , 128 )
89+ self .inception4a = inception_block (480 , 192 , 96 , 208 , 16 , 48 , 64 )
90+ self .inception4b = inception_block (512 , 160 , 112 , 224 , 24 , 64 , 64 )
91+ self .inception4c = inception_block (512 , 128 , 128 , 256 , 24 , 64 , 64 )
92+ self .inception4d = inception_block (512 , 112 , 144 , 288 , 32 , 64 , 64 )
93+ self .inception4e = inception_block (528 , 256 , 160 , 320 , 32 , 128 , 128 )
8694 self .maxpool4 = nn .MaxPool2d (2 , stride = 2 , ceil_mode = True )
8795
88- self .inception5a = Inception (832 , 256 , 160 , 320 , 32 , 128 , 128 )
89- self .inception5b = Inception (832 , 384 , 192 , 384 , 48 , 128 , 128 )
96+ self .inception5a = inception_block (832 , 256 , 160 , 320 , 32 , 128 , 128 )
97+ self .inception5b = inception_block (832 , 384 , 192 , 384 , 48 , 128 , 128 )
9098
9199 if aux_logits :
92- self .aux1 = InceptionAux (512 , num_classes )
93- self .aux2 = InceptionAux (528 , num_classes )
100+ self .aux1 = inception_aux_block (512 , num_classes )
101+ self .aux2 = inception_aux_block (528 , num_classes )
94102
95103 self .avgpool = nn .AdaptiveAvgPool2d ((1 , 1 ))
96104 self .dropout = nn .Dropout (0.2 )
@@ -112,14 +120,17 @@ def _initialize_weights(self):
112120 nn .init .constant_ (m .weight , 1 )
113121 nn .init .constant_ (m .bias , 0 )
114122
115- def forward (self , x ):
116- # type: (Tensor) -> GoogLeNetOutputs
123+ def _transform_input (self , x ):
124+ # type: (Tensor) -> Tensor
117125 if self .transform_input :
118126 x_ch0 = torch .unsqueeze (x [:, 0 ], 1 ) * (0.229 / 0.5 ) + (0.485 - 0.5 ) / 0.5
119127 x_ch1 = torch .unsqueeze (x [:, 1 ], 1 ) * (0.224 / 0.5 ) + (0.456 - 0.5 ) / 0.5
120128 x_ch2 = torch .unsqueeze (x [:, 2 ], 1 ) * (0.225 / 0.5 ) + (0.406 - 0.5 ) / 0.5
121129 x = torch .cat ((x_ch0 , x_ch1 , x_ch2 ), 1 )
130+ return x
122131
132+ def _forward (self , x ):
133+ # type: (Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]
123134 # N x 3 x 224 x 224
124135 x = self .conv1 (x )
125136 # N x 64 x 112 x 112
@@ -173,12 +184,7 @@ def forward(self, x):
173184 x = self .dropout (x )
174185 x = self .fc (x )
175186 # N x 1000 (num_classes)
176- if torch .jit .is_scripting ():
177- if not aux_defined :
178- warnings .warn ("Scripted GoogleNet always returns GoogleNetOutputs Tuple" )
179- return GoogLeNetOutputs (x , aux2 , aux1 )
180- else :
181- return self .eager_outputs (x , aux2 , aux1 )
187+ return x , aux2 , aux1
182188
183189 @torch .jit .unused
184190 def eager_outputs (self , x , aux2 , aux1 ):
@@ -188,45 +194,65 @@ def eager_outputs(self, x, aux2, aux1):
188194 else :
189195 return x
190196
197+ def forward (self , x ):
198+ # type: (Tensor) -> GoogLeNetOutputs
199+ x = self ._transform_input (x )
200+ x , aux1 , aux2 = self ._forward (x )
201+ aux_defined = self .training and self .aux_logits
202+ if torch .jit .is_scripting ():
203+ if not aux_defined :
204+ warnings .warn ("Scripted GoogleNet always returns GoogleNetOutputs Tuple" )
205+ return GoogLeNetOutputs (x , aux2 , aux1 )
206+ else :
207+ return self .eager_outputs (x , aux2 , aux1 )
208+
191209
192210class Inception (nn .Module ):
193211 __constants__ = ['branch2' , 'branch3' , 'branch4' ]
194212
195- def __init__ (self , in_channels , ch1x1 , ch3x3red , ch3x3 , ch5x5red , ch5x5 , pool_proj ):
213+ def __init__ (self , in_channels , ch1x1 , ch3x3red , ch3x3 , ch5x5red , ch5x5 , pool_proj ,
214+ conv_block = None ):
196215 super (Inception , self ).__init__ ()
197-
198- self .branch1 = BasicConv2d (in_channels , ch1x1 , kernel_size = 1 )
216+ if conv_block is None :
217+ conv_block = BasicConv2d
218+ self .branch1 = conv_block (in_channels , ch1x1 , kernel_size = 1 )
199219
200220 self .branch2 = nn .Sequential (
201- BasicConv2d (in_channels , ch3x3red , kernel_size = 1 ),
202- BasicConv2d (ch3x3red , ch3x3 , kernel_size = 3 , padding = 1 )
221+ conv_block (in_channels , ch3x3red , kernel_size = 1 ),
222+ conv_block (ch3x3red , ch3x3 , kernel_size = 3 , padding = 1 )
203223 )
204224
205225 self .branch3 = nn .Sequential (
206- BasicConv2d (in_channels , ch5x5red , kernel_size = 1 ),
207- BasicConv2d (ch5x5red , ch5x5 , kernel_size = 3 , padding = 1 )
226+ conv_block (in_channels , ch5x5red , kernel_size = 1 ),
227+ conv_block (ch5x5red , ch5x5 , kernel_size = 3 , padding = 1 )
208228 )
209229
210230 self .branch4 = nn .Sequential (
211231 nn .MaxPool2d (kernel_size = 3 , stride = 1 , padding = 1 , ceil_mode = True ),
212- BasicConv2d (in_channels , pool_proj , kernel_size = 1 )
232+ conv_block (in_channels , pool_proj , kernel_size = 1 )
213233 )
214234
215- def forward (self , x ):
235+ def _forward (self , x ):
216236 branch1 = self .branch1 (x )
217237 branch2 = self .branch2 (x )
218238 branch3 = self .branch3 (x )
219239 branch4 = self .branch4 (x )
220240
221241 outputs = [branch1 , branch2 , branch3 , branch4 ]
242+ return outputs
243+
244+ def forward (self , x ):
245+ outputs = self ._forward (x )
222246 return torch .cat (outputs , 1 )
223247
224248
225249class InceptionAux (nn .Module ):
226250
227- def __init__ (self , in_channels , num_classes ):
251+ def __init__ (self , in_channels , num_classes , conv_block = None ):
228252 super (InceptionAux , self ).__init__ ()
229- self .conv = BasicConv2d (in_channels , 128 , kernel_size = 1 )
253+ if conv_block is None :
254+ conv_block = BasicConv2d
255+ self .conv = conv_block (in_channels , 128 , kernel_size = 1 )
230256
231257 self .fc1 = nn .Linear (2048 , 1024 )
232258 self .fc2 = nn .Linear (1024 , num_classes )
0 commit comments