Skip to content

Commit 8242235

Browse files
Merge pull request NVIDIA#908 from NVIDIA/gh/release
EfficientNets
2 parents 402fcbc + 4a66a00 commit 8242235

File tree

113 files changed

+10260
-686
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

113 files changed

+10260
-686
lines changed

PyTorch/Classification/ConvNets/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ The following table provides links to where you can find additional information
2525
| resnet50 | [README](./resnet50v1.5/README.md) |
2626
| resnext101-32x4d | [README](./resnext101-32x4d/README.md) |
2727
| se-resnext101-32x4d | [README](./se-resnext101-32x4d/README.md) |
28+
| EfficientNet-B0 | [README](./efficientnet/README.md) |
2829

2930
## Validation accuracy results
3031

PyTorch/Classification/ConvNets/classify.py

Lines changed: 66 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,37 @@
1818
import torch
1919
from torch.cuda.amp import autocast
2020
import torch.backends.cudnn as cudnn
21+
22+
from image_classification import models
2123
import torchvision.transforms as transforms
22-
import image_classification.resnet as models
23-
from image_classification.dataloaders import load_jpeg_from_file
2424

25+
from image_classification.models import (
26+
resnet50,
27+
resnext101_32x4d,
28+
se_resnext101_32x4d,
29+
efficientnet_b0,
30+
efficientnet_b4,
31+
efficientnet_widese_b0,
32+
efficientnet_widese_b4,
33+
)
34+
35+
def available_models():
36+
models = {
37+
m.name: m
38+
for m in [
39+
resnet50,
40+
resnext101_32x4d,
41+
se_resnext101_32x4d,
42+
efficientnet_b0,
43+
efficientnet_b4,
44+
efficientnet_widese_b0,
45+
efficientnet_widese_b4,
46+
]
47+
}
48+
return models
2549

2650
def add_parser_arguments(parser):
27-
model_names = models.resnet_versions.keys()
28-
model_configs = models.resnet_configs.keys()
51+
model_names = available_models().keys()
2952
parser.add_argument("--image-size", default="224", type=int)
3053
parser.add_argument(
3154
"--arch",
@@ -35,39 +58,49 @@ def add_parser_arguments(parser):
3558
choices=model_names,
3659
help="model architecture: " + " | ".join(model_names) + " (default: resnet50)",
3760
)
38-
parser.add_argument(
39-
"--model-config",
40-
"-c",
41-
metavar="CONF",
42-
default="classic",
43-
choices=model_configs,
44-
help="model configs: " + " | ".join(model_configs) + "(default: classic)",
45-
)
46-
parser.add_argument("--weights", metavar="<path>", help="file with model weights")
4761
parser.add_argument(
4862
"--precision", metavar="PREC", default="AMP", choices=["AMP", "FP32"]
4963
)
64+
parser.add_argument("--cpu", action="store_true", help="perform inference on CPU")
5065
parser.add_argument("--image", metavar="<path>", help="path to classified image")
5166

5267

53-
def main(args):
54-
imgnet_classes = np.array(json.load(open("./LOC_synset_mapping.json", "r")))
55-
model = models.build_resnet(args.arch, args.model_config, 1000, verbose=False)
68+
def load_jpeg_from_file(path, image_size, cuda=True):
69+
img_transforms = transforms.Compose(
70+
[
71+
transforms.Resize(image_size + 32),
72+
transforms.CenterCrop(image_size),
73+
transforms.ToTensor(),
74+
]
75+
)
76+
77+
img = img_transforms(Image.open(path))
78+
with torch.no_grad():
79+
# mean and std are not multiplied by 255 as they are in training script
80+
# torch dataloader reads data into bytes whereas loading directly
81+
# through PIL creates a tensor with floats in [0,1] range
82+
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
83+
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
5684

57-
if args.weights is not None:
58-
weights = torch.load(args.weights)
59-
# Temporary fix to allow NGC checkpoint loading
60-
weights = {
61-
k.replace("module.", ""): v for k, v in weights.items()
62-
}
63-
model.load_state_dict(weights)
85+
if cuda:
86+
mean = mean.cuda()
87+
std = std.cuda()
88+
img = img.cuda()
89+
img = img.float()
6490

65-
model = model.cuda()
91+
input = img.unsqueeze(0).sub_(mean).div_(std)
92+
93+
return input
94+
95+
96+
def main(args, model_args):
97+
imgnet_classes = np.array(json.load(open("./LOC_synset_mapping.json", "r")))
98+
model = available_models()[args.arch](**model_args.__dict__)
99+
if not args.cpu:
100+
model = model.cuda()
66101
model.eval()
67102

68-
input = load_jpeg_from_file(
69-
args.image, cuda=True
70-
)
103+
input = load_jpeg_from_file(args.image, args.image_size, cuda=not args.cpu)
71104

72105
with torch.no_grad(), autocast(enabled = args.precision == "AMP"):
73106
output = torch.nn.functional.softmax(model(input), dim=1)
@@ -81,11 +114,14 @@ def main(args):
81114

82115

83116
if __name__ == "__main__":
84-
parser = argparse.ArgumentParser(description="PyTorch ImageNet Training")
117+
parser = argparse.ArgumentParser(description="PyTorch ImageNet Classification")
85118

86119
add_parser_arguments(parser)
87-
args = parser.parse_args()
120+
args, rest = parser.parse_known_args()
121+
model_args, rest = available_models()[args.arch].parser().parse_known_args(rest)
122+
123+
assert len(rest) == 0, f"Unknown args passed: {rest}"
88124

89125
cudnn.benchmark = True
90126

91-
main(args)
127+
main(args, model_args)

0 commit comments

Comments
 (0)