Skip to content
This repository was archived by the owner on Aug 19, 2023. It is now read-only.

Commit 37ceef6

Browse files
committed
fix bug in dimension ordering
1 parent b644323 commit 37ceef6

File tree

1 file changed

+22
-10
lines changed

1 file changed

+22
-10
lines changed

model.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,17 @@ def forward(self, x):
107107

108108
out = self.output(out)
109109

110-
return out.view(x.shape[0], -1, 4)
110+
# out is B x C x W x H, with C = 4*num_anchors
111+
out = out.permute(0, 2, 3, 1)
112+
113+
return out.contiguous().view(out.shape[0], -1, 4)
111114

112115
class ClassificationModel(nn.Module):
113116
def __init__(self, num_features_in, num_anchors=9, num_classes=80, prior=0.01, feature_size=128):
114117
super(ClassificationModel, self).__init__()
115118

116119
self.num_classes = num_classes
120+
self.num_anchors = num_anchors
117121

118122
self.conv1 = nn.Conv2d(num_features_in, feature_size, kernel_size=3, padding=1)
119123
self.act1 = nn.ReLU()
@@ -147,7 +151,14 @@ def forward(self, x):
147151
out = self.output(out)
148152
out = self.output_act(out)
149153

150-
return out.view(x.shape[0], -1, self.num_classes)
154+
# out is B x C x W x H, with C = n_classes + n_anchors
155+
out1 = out.permute(0, 2, 3, 1)
156+
157+
batch_size, width, height, channels = out1.shape
158+
159+
out2 = out1.view(batch_size, width, height, self.num_anchors, self.num_classes)
160+
161+
return out2.contiguous().view(x.shape[0], -1, self.num_classes)
151162

152163
class ResNet(nn.Module):
153164

@@ -163,10 +174,9 @@ def __init__(self, block, layers, num_classes=1000):
163174
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
164175
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
165176
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
166-
self.avgpool = nn.AvgPool2d(7, stride=1)
167-
self.fc = nn.Linear(512 * block.expansion, num_classes)
168-
169-
self.fpn = PyramidFeatures(128, 256, 512)
177+
#import pdb
178+
#pdb.set_trace()
179+
self.fpn = PyramidFeatures(512, 1024, 2048)
170180

171181
self.regressionModel = RegressionModel(128)
172182
self.classificationModel = ClassificationModel(128)
@@ -246,10 +256,12 @@ def forward(self, img_batch):
246256
anchors_nms_idx = nms(torch.cat([transformed_anchors, scores], dim=2)[0, :, :], 0.5)
247257

248258
nms_scores, nms_class = classification[0, anchors_nms_idx, :].max(dim=1)
249-
#import pdb
250-
#pdb.set_trace()
251259
return [nms_scores, nms_class, transformed_anchors[0, anchors_nms_idx, :]]
252260

261+
#nms_scores, nms_class = classification[0, :, :].max(dim=1)
262+
#return [nms_scores, nms_class, transformed_anchors[0, :, :]]
263+
264+
253265

254266
def resnet18(pretrained=False, **kwargs):
255267
"""Constructs a ResNet-18 model.
@@ -269,7 +281,7 @@ def resnet34(pretrained=False, **kwargs):
269281
"""
270282
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
271283
if pretrained:
272-
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
284+
model.load_state_dict(model_zoo.load_url(model_urls['resnet34'], model_dir='.'), strict=False)
273285
return model
274286

275287

@@ -280,5 +292,5 @@ def resnet50(pretrained=False, **kwargs):
280292
"""
281293
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
282294
if pretrained:
283-
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
295+
model.load_state_dict(model_zoo.load_url(model_urls['resnet50'], model_dir='.'), strict=False)
284296
return model

0 commit comments

Comments
 (0)