Skip to content

Commit 9cdc814

Browse files
hx89fmassa
authored andcommitted
Quantizable googlenet, inceptionv3 and shufflenetv2 models (pytorch#1503)
* quantizable googlenet * Minor improvements * Rename basic_conv2d with conv_block plus additional fixes * More renamings and fixes * Bugfix * Fix missing import for mypy * Add pretrained weights
1 parent b438d32 commit 9cdc814

File tree

7 files changed

+739
-115
lines changed

7 files changed

+739
-115
lines changed

torchvision/models/googlenet.py

Lines changed: 61 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
import torch.nn as nn
77
import torch.nn.functional as F
8-
from torch.jit.annotations import Optional
8+
from torch.jit.annotations import Optional, Tuple
99
from torch import Tensor
1010
from .utils import load_state_dict_from_url
1111

@@ -63,34 +63,42 @@ def googlenet(pretrained=False, progress=True, **kwargs):
6363
class GoogLeNet(nn.Module):
6464
__constants__ = ['aux_logits', 'transform_input']
6565

66-
def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=True):
66+
def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=True,
67+
blocks=None):
6768
super(GoogLeNet, self).__init__()
69+
if blocks is None:
70+
blocks = [BasicConv2d, Inception, InceptionAux]
71+
assert len(blocks) == 3
72+
conv_block = blocks[0]
73+
inception_block = blocks[1]
74+
inception_aux_block = blocks[2]
75+
6876
self.aux_logits = aux_logits
6977
self.transform_input = transform_input
7078

71-
self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
79+
self.conv1 = conv_block(3, 64, kernel_size=7, stride=2, padding=3)
7280
self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
73-
self.conv2 = BasicConv2d(64, 64, kernel_size=1)
74-
self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
81+
self.conv2 = conv_block(64, 64, kernel_size=1)
82+
self.conv3 = conv_block(64, 192, kernel_size=3, padding=1)
7583
self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
7684

77-
self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
78-
self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
85+
self.inception3a = inception_block(192, 64, 96, 128, 16, 32, 32)
86+
self.inception3b = inception_block(256, 128, 128, 192, 32, 96, 64)
7987
self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
8088

81-
self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
82-
self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
83-
self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
84-
self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
85-
self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
89+
self.inception4a = inception_block(480, 192, 96, 208, 16, 48, 64)
90+
self.inception4b = inception_block(512, 160, 112, 224, 24, 64, 64)
91+
self.inception4c = inception_block(512, 128, 128, 256, 24, 64, 64)
92+
self.inception4d = inception_block(512, 112, 144, 288, 32, 64, 64)
93+
self.inception4e = inception_block(528, 256, 160, 320, 32, 128, 128)
8694
self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
8795

88-
self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
89-
self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
96+
self.inception5a = inception_block(832, 256, 160, 320, 32, 128, 128)
97+
self.inception5b = inception_block(832, 384, 192, 384, 48, 128, 128)
9098

9199
if aux_logits:
92-
self.aux1 = InceptionAux(512, num_classes)
93-
self.aux2 = InceptionAux(528, num_classes)
100+
self.aux1 = inception_aux_block(512, num_classes)
101+
self.aux2 = inception_aux_block(528, num_classes)
94102

95103
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
96104
self.dropout = nn.Dropout(0.2)
@@ -112,14 +120,17 @@ def _initialize_weights(self):
112120
nn.init.constant_(m.weight, 1)
113121
nn.init.constant_(m.bias, 0)
114122

115-
def forward(self, x):
116-
# type: (Tensor) -> GoogLeNetOutputs
123+
def _transform_input(self, x):
124+
# type: (Tensor) -> Tensor
117125
if self.transform_input:
118126
x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
119127
x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
120128
x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
121129
x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
130+
return x
122131

132+
def _forward(self, x):
133+
# type: (Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]
123134
# N x 3 x 224 x 224
124135
x = self.conv1(x)
125136
# N x 64 x 112 x 112
@@ -173,12 +184,7 @@ def forward(self, x):
173184
x = self.dropout(x)
174185
x = self.fc(x)
175186
# N x 1000 (num_classes)
176-
if torch.jit.is_scripting():
177-
if not aux_defined:
178-
warnings.warn("Scripted GoogleNet always returns GoogleNetOutputs Tuple")
179-
return GoogLeNetOutputs(x, aux2, aux1)
180-
else:
181-
return self.eager_outputs(x, aux2, aux1)
187+
return x, aux2, aux1
182188

183189
@torch.jit.unused
184190
def eager_outputs(self, x, aux2, aux1):
@@ -188,45 +194,65 @@ def eager_outputs(self, x, aux2, aux1):
188194
else:
189195
return x
190196

197+
def forward(self, x):
198+
# type: (Tensor) -> GoogLeNetOutputs
199+
x = self._transform_input(x)
200+
x, aux1, aux2 = self._forward(x)
201+
aux_defined = self.training and self.aux_logits
202+
if torch.jit.is_scripting():
203+
if not aux_defined:
204+
warnings.warn("Scripted GoogleNet always returns GoogleNetOutputs Tuple")
205+
return GoogLeNetOutputs(x, aux2, aux1)
206+
else:
207+
return self.eager_outputs(x, aux2, aux1)
208+
191209

192210
class Inception(nn.Module):
193211
__constants__ = ['branch2', 'branch3', 'branch4']
194212

195-
def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
213+
def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj,
214+
conv_block=None):
196215
super(Inception, self).__init__()
197-
198-
self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)
216+
if conv_block is None:
217+
conv_block = BasicConv2d
218+
self.branch1 = conv_block(in_channels, ch1x1, kernel_size=1)
199219

200220
self.branch2 = nn.Sequential(
201-
BasicConv2d(in_channels, ch3x3red, kernel_size=1),
202-
BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)
221+
conv_block(in_channels, ch3x3red, kernel_size=1),
222+
conv_block(ch3x3red, ch3x3, kernel_size=3, padding=1)
203223
)
204224

205225
self.branch3 = nn.Sequential(
206-
BasicConv2d(in_channels, ch5x5red, kernel_size=1),
207-
BasicConv2d(ch5x5red, ch5x5, kernel_size=3, padding=1)
226+
conv_block(in_channels, ch5x5red, kernel_size=1),
227+
conv_block(ch5x5red, ch5x5, kernel_size=3, padding=1)
208228
)
209229

210230
self.branch4 = nn.Sequential(
211231
nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
212-
BasicConv2d(in_channels, pool_proj, kernel_size=1)
232+
conv_block(in_channels, pool_proj, kernel_size=1)
213233
)
214234

215-
def forward(self, x):
235+
def _forward(self, x):
216236
branch1 = self.branch1(x)
217237
branch2 = self.branch2(x)
218238
branch3 = self.branch3(x)
219239
branch4 = self.branch4(x)
220240

221241
outputs = [branch1, branch2, branch3, branch4]
242+
return outputs
243+
244+
def forward(self, x):
245+
outputs = self._forward(x)
222246
return torch.cat(outputs, 1)
223247

224248

225249
class InceptionAux(nn.Module):
226250

227-
def __init__(self, in_channels, num_classes):
251+
def __init__(self, in_channels, num_classes, conv_block=None):
228252
super(InceptionAux, self).__init__()
229-
self.conv = BasicConv2d(in_channels, 128, kernel_size=1)
253+
if conv_block is None:
254+
conv_block = BasicConv2d
255+
self.conv = conv_block(in_channels, 128, kernel_size=1)
230256

231257
self.fc1 = nn.Linear(2048, 1024)
232258
self.fc2 = nn.Linear(1024, num_classes)

0 commit comments

Comments
 (0)