1818import torch
1919from torch .cuda .amp import autocast
2020import torch .backends .cudnn as cudnn
21+
22+ from image_classification import models
2123import torchvision .transforms as transforms
22- import image_classification .resnet as models
23- from image_classification .dataloaders import load_jpeg_from_file
2424
25+ from image_classification .models import (
26+ resnet50 ,
27+ resnext101_32x4d ,
28+ se_resnext101_32x4d ,
29+ efficientnet_b0 ,
30+ efficientnet_b4 ,
31+ efficientnet_widese_b0 ,
32+ efficientnet_widese_b4 ,
33+ )
34+
35+ def available_models ():
36+ models = {
37+ m .name : m
38+ for m in [
39+ resnet50 ,
40+ resnext101_32x4d ,
41+ se_resnext101_32x4d ,
42+ efficientnet_b0 ,
43+ efficientnet_b4 ,
44+ efficientnet_widese_b0 ,
45+ efficientnet_widese_b4 ,
46+ ]
47+ }
48+ return models
2549
2650def add_parser_arguments (parser ):
27- model_names = models .resnet_versions .keys ()
28- model_configs = models .resnet_configs .keys ()
51+ model_names = available_models ().keys ()
2952 parser .add_argument ("--image-size" , default = "224" , type = int )
3053 parser .add_argument (
3154 "--arch" ,
@@ -35,39 +58,49 @@ def add_parser_arguments(parser):
3558 choices = model_names ,
3659 help = "model architecture: " + " | " .join (model_names ) + " (default: resnet50)" ,
3760 )
38- parser .add_argument (
39- "--model-config" ,
40- "-c" ,
41- metavar = "CONF" ,
42- default = "classic" ,
43- choices = model_configs ,
44- help = "model configs: " + " | " .join (model_configs ) + "(default: classic)" ,
45- )
46- parser .add_argument ("--weights" , metavar = "<path>" , help = "file with model weights" )
4761 parser .add_argument (
4862 "--precision" , metavar = "PREC" , default = "AMP" , choices = ["AMP" , "FP32" ]
4963 )
64+ parser .add_argument ("--cpu" , action = "store_true" , help = "perform inference on CPU" )
5065 parser .add_argument ("--image" , metavar = "<path>" , help = "path to classified image" )
5166
5267
53- def main (args ):
54- imgnet_classes = np .array (json .load (open ("./LOC_synset_mapping.json" , "r" )))
55- model = models .build_resnet (args .arch , args .model_config , 1000 , verbose = False )
68+ def load_jpeg_from_file (path , image_size , cuda = True ):
69+ img_transforms = transforms .Compose (
70+ [
71+ transforms .Resize (image_size + 32 ),
72+ transforms .CenterCrop (image_size ),
73+ transforms .ToTensor (),
74+ ]
75+ )
76+
77+ img = img_transforms (Image .open (path ))
78+ with torch .no_grad ():
79+ # mean and std are not multiplied by 255 as they are in training script
80+ # torch dataloader reads data into bytes whereas loading directly
81+ # through PIL creates a tensor with floats in [0,1] range
82+ mean = torch .tensor ([0.485 , 0.456 , 0.406 ]).view (1 , 3 , 1 , 1 )
83+ std = torch .tensor ([0.229 , 0.224 , 0.225 ]).view (1 , 3 , 1 , 1 )
5684
57- if args .weights is not None :
58- weights = torch .load (args .weights )
59- # Temporary fix to allow NGC checkpoint loading
60- weights = {
61- k .replace ("module." , "" ): v for k , v in weights .items ()
62- }
63- model .load_state_dict (weights )
85+ if cuda :
86+ mean = mean .cuda ()
87+ std = std .cuda ()
88+ img = img .cuda ()
89+ img = img .float ()
6490
65- model = model .cuda ()
91+ input = img .unsqueeze (0 ).sub_ (mean ).div_ (std )
92+
93+ return input
94+
95+
96+ def main (args , model_args ):
97+ imgnet_classes = np .array (json .load (open ("./LOC_synset_mapping.json" , "r" )))
98+ model = available_models ()[args .arch ](** model_args .__dict__ )
99+ if not args .cpu :
100+ model = model .cuda ()
66101 model .eval ()
67102
68- input = load_jpeg_from_file (
69- args .image , cuda = True
70- )
103+ input = load_jpeg_from_file (args .image , args .image_size , cuda = not args .cpu )
71104
72105 with torch .no_grad (), autocast (enabled = args .precision == "AMP" ):
73106 output = torch .nn .functional .softmax (model (input ), dim = 1 )
@@ -81,11 +114,14 @@ def main(args):
81114
82115
83116if __name__ == "__main__" :
84- parser = argparse .ArgumentParser (description = "PyTorch ImageNet Training " )
117+ parser = argparse .ArgumentParser (description = "PyTorch ImageNet Classification " )
85118
86119 add_parser_arguments (parser )
87- args = parser .parse_args ()
120+ args , rest = parser .parse_known_args ()
121+ model_args , rest = available_models ()[args .arch ].parser ().parse_known_args (rest )
122+
123+ assert len (rest ) == 0 , f"Unknown args passed: { rest } "
88124
89125 cudnn .benchmark = True
90126
91- main (args )
127+ main (args , model_args )
0 commit comments