@@ -107,13 +107,17 @@ def forward(self, x):
107107
108108 out = self .output (out )
109109
110- return out .view (x .shape [0 ], - 1 , 4 )
110+ # out is B x C x W x H, with C = 4*num_anchors
111+ out = out .permute (0 , 2 , 3 , 1 )
112+
113+ return out .contiguous ().view (out .shape [0 ], - 1 , 4 )
111114
112115class ClassificationModel (nn .Module ):
113116 def __init__ (self , num_features_in , num_anchors = 9 , num_classes = 80 , prior = 0.01 , feature_size = 128 ):
114117 super (ClassificationModel , self ).__init__ ()
115118
116119 self .num_classes = num_classes
120+ self .num_anchors = num_anchors
117121
118122 self .conv1 = nn .Conv2d (num_features_in , feature_size , kernel_size = 3 , padding = 1 )
119123 self .act1 = nn .ReLU ()
@@ -147,7 +151,14 @@ def forward(self, x):
147151 out = self .output (out )
148152 out = self .output_act (out )
149153
150- return out .view (x .shape [0 ], - 1 , self .num_classes )
154+ # out is B x C x W x H, with C = n_classes + n_anchors
155+ out1 = out .permute (0 , 2 , 3 , 1 )
156+
157+ batch_size , width , height , channels = out1 .shape
158+
159+ out2 = out1 .view (batch_size , width , height , self .num_anchors , self .num_classes )
160+
161+ return out2 .contiguous ().view (x .shape [0 ], - 1 , self .num_classes )
151162
152163class ResNet (nn .Module ):
153164
@@ -163,10 +174,9 @@ def __init__(self, block, layers, num_classes=1000):
163174 self .layer2 = self ._make_layer (block , 128 , layers [1 ], stride = 2 )
164175 self .layer3 = self ._make_layer (block , 256 , layers [2 ], stride = 2 )
165176 self .layer4 = self ._make_layer (block , 512 , layers [3 ], stride = 2 )
166- self .avgpool = nn .AvgPool2d (7 , stride = 1 )
167- self .fc = nn .Linear (512 * block .expansion , num_classes )
168-
169- self .fpn = PyramidFeatures (128 , 256 , 512 )
177+ #import pdb
178+ #pdb.set_trace()
179+ self .fpn = PyramidFeatures (512 , 1024 , 2048 )
170180
171181 self .regressionModel = RegressionModel (128 )
172182 self .classificationModel = ClassificationModel (128 )
@@ -246,10 +256,12 @@ def forward(self, img_batch):
246256 anchors_nms_idx = nms (torch .cat ([transformed_anchors , scores ], dim = 2 )[0 , :, :], 0.5 )
247257
248258 nms_scores , nms_class = classification [0 , anchors_nms_idx , :].max (dim = 1 )
249- #import pdb
250- #pdb.set_trace()
251259 return [nms_scores , nms_class , transformed_anchors [0 , anchors_nms_idx , :]]
252260
261+ #nms_scores, nms_class = classification[0, :, :].max(dim=1)
262+ #return [nms_scores, nms_class, transformed_anchors[0, :, :]]
263+
264+
253265
254266def resnet18 (pretrained = False , ** kwargs ):
255267 """Constructs a ResNet-18 model.
@@ -269,7 +281,7 @@ def resnet34(pretrained=False, **kwargs):
269281 """
270282 model = ResNet (BasicBlock , [3 , 4 , 6 , 3 ], ** kwargs )
271283 if pretrained :
272- model .load_state_dict (model_zoo .load_url (model_urls ['resnet34' ]) )
284+ model .load_state_dict (model_zoo .load_url (model_urls ['resnet34' ], model_dir = '.' ), strict = False )
273285 return model
274286
275287
@@ -280,5 +292,5 @@ def resnet50(pretrained=False, **kwargs):
280292 """
281293 model = ResNet (Bottleneck , [3 , 4 , 6 , 3 ], ** kwargs )
282294 if pretrained :
283- model .load_state_dict (model_zoo .load_url (model_urls ['resnet50' ]) )
295+ model .load_state_dict (model_zoo .load_url (model_urls ['resnet50' ], model_dir = '.' ), strict = False )
284296 return model
0 commit comments