Skip to content
This repository was archived by the owner on Aug 19, 2023. It is now read-only.

Commit e760c25

Browse files
committed
command line configuration
1 parent 7927038 commit e760c25

File tree

6 files changed

+288
-162
lines changed

6 files changed

+288
-162
lines changed

README.md

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
Pytorch implementation of RetinaNet object detection as described in [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) by Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He and Piotr Dollár.
77

8+
This implementation is primarily designed to be easy to read and simple to modify.
89

910
## Results
1011
Currently, this repo achieves 33.7% mAP at 600px resolution with a Resnet-50 backbone. The published result is 34.0% mAP. The difference is likely due to the use of Adam optimizer instead of SGD with weight decay.
@@ -51,18 +52,33 @@ cd ../
5152
The network can be trained using the `train.py` script. Currently, two dataloaders are available: COCO and CSV. For training on coco, use
5253

5354
```
54-
python train.py coco <path/to/coco>
55+
python train.py --dataset coco --coco_path ../coco --depth 50
5556
```
5657

5758
For training using a custom dataset, with annotations in CSV format (see below), use
5859

5960
```
60-
python train.py csv <path/to/annotations.csv> <path/to/classes.csv>
61+
python train.py --dataset csv --csv_train <path/to/train_annots.csv> --csv_classes <path/to/train/class_list.csv> --csv_val <path/to/val_annots.csv>
6162
```
6263

64+
Note that the --csv_val argument is optional, in which case no validation will be performed.
65+
6366
## Visualization
6467

65-
To visualize the network detection, use `test.py`.
68+
To visualize the network detection, use `visualize.py`:
69+
70+
```
71+
python visualize.py --dataset coco --coco_path ../coco --model <path/to/model.pt>
72+
```
73+
This will visualize bounding boxes on the validation set. To visualise with a CSV dataset, use:
74+
75+
```
76+
python visualize.py --dataset csv --csv_classes <path/to/train/class_list.csv> --csv_val <path/to/val_annots.csv> --model <path/to/model.pt>
77+
```
78+
79+
## Model
80+
81+
The retinanet model uses a resnet backbone. You can set the depth of the resnet model using the --depth argument. Depth must be one of 18, 34, 50, 101 or 152. Note that deeper models are more accurate but are slower and use more memory.
6682

6783
## CSV datasets
6884
The `CSVGenerator` provides an easy way to define your own datasets.

dataloader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def load_annotations(self, image_index):
9393
# parse annotations
9494
coco_annotations = self.coco.loadAnns(annotations_ids)
9595
for idx, a in enumerate(coco_annotations):
96+
9697
# some annotations have basically no width / height, skip them
9798
if a['bbox'][2] < 1 or a['bbox'][3] < 1:
9899
continue

model.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,11 @@
88

99
from lib.nms.pth_nms import pth_nms
1010

11-
import pdb
12-
1311
def nms(dets, thresh):
1412
"Dispatch to either CPU or GPU NMS implementations.\
1513
Accept dets as tensor"""
1614
return pth_nms(dets, thresh)
1715

18-
1916
model_urls = {
2017
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
2118
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
@@ -310,3 +307,24 @@ def resnet50(num_classes, pretrained=False, **kwargs):
310307
if pretrained:
311308
model.load_state_dict(model_zoo.load_url(model_urls['resnet50'], model_dir='.'), strict=False)
312309
return model
310+
311+
def resnet101(num_classes, pretrained=False, **kwargs):
312+
"""Constructs a ResNet-101 model.
313+
Args:
314+
pretrained (bool): If True, returns a model pre-trained on ImageNet
315+
"""
316+
model = ResNet(num_classes, Bottleneck, [3, 4, 23, 3], **kwargs)
317+
if pretrained:
318+
model.load_state_dict(model_zoo.load_url(model_urls['resnet101'], model_dir='.'), strict=False)
319+
return model
320+
321+
322+
def resnet152(num_classes, pretrained=False, **kwargs):
323+
"""Constructs a ResNet-152 model.
324+
Args:
325+
pretrained (bool): If True, returns a model pre-trained on ImageNet
326+
"""
327+
model = ResNet(num_classes, Bottleneck, [3, 8, 36, 3], **kwargs)
328+
if pretrained:
329+
model.load_state_dict(model_zoo.load_url(model_urls['resnet152'], model_dir='.'), strict=False)
330+
return model

test.py

Lines changed: 0 additions & 82 deletions
This file was deleted.

0 commit comments

Comments
 (0)