2626parser .add_argument ('--niter' , type = int , default = 25 , help = 'number of epochs to train for' )
2727parser .add_argument ('--lr' , type = float , default = 0.0002 , help = 'learning rate, default=0.0002' )
2828parser .add_argument ('--beta1' , type = float , default = 0.5 , help = 'beta1 for adam. default=0.5' )
29- parser .add_argument ('--cuda'   , action = 'store_true' , help = 'enables cuda' )
30- parser .add_argument ('--ngpu'   , type = int , default = 1 , help = 'number of GPUs to use' )
29+ parser .add_argument ('--cuda' , action = 'store_true' , help = 'enables cuda' )
30+ parser .add_argument ('--ngpu' , type = int , default = 1 , help = 'number of GPUs to use' )
3131parser .add_argument ('--netG' , default = '' , help = "path to netG (to continue training)" )
3232parser .add_argument ('--netD' , default = '' , help = "path to netD (to continue training)" )
3333parser .add_argument ('--outf' , default = '.' , help = 'folder to output images and model checkpoints' )
34+ parser .add_argument ('--manualSeed' , type = int , help = 'manual seed' )
3435
3536opt  =  parser .parse_args ()
3637print (opt )
3940 os .makedirs (opt .outf )
4041except  OSError :
4142 pass 
42- opt .manualSeed  =  random .randint (1 , 10000 ) # fix seed 
43+ 
44+ if  opt .manualSeed  is  None :
45+  opt .manualSeed  =  random .randint (1 , 10000 )
4346print ("Random Seed: " , opt .manualSeed )
4447random .seed (opt .manualSeed )
4548torch .manual_seed (opt .manualSeed )
49+ if  opt .cuda :
50+  torch .cuda .manual_seed_all (opt .manualSeed )
4651
4752cudnn .benchmark  =  True 
4853
8489ndf  =  int (opt .ndf )
8590nc  =  3 
8691
92+ 
8793# custom weights initialization called on netG and netD 
8894def  weights_init (m ):
8995 classname  =  m .__class__ .__name__ 
@@ -93,6 +99,7 @@ def weights_init(m):
9399 m .weight .data .normal_ (1.0 , 0.02 )
94100 m .bias .data .fill_ (0 )
95101
102+ 
96103class  _netG (nn .Module ):
97104 def  __init__ (self , ngpu ):
98105 super (_netG , self ).__init__ ()
@@ -119,18 +126,21 @@ def __init__(self, ngpu):
119126 nn .Tanh ()
120127 # state size. (nc) x 64 x 64 
121128 )
129+ 
122130 def  forward (self , input ):
123131 gpu_ids  =  None 
124132 if  isinstance (input .data , torch .cuda .FloatTensor ) and  self .ngpu  >  1 :
125133 gpu_ids  =  range (self .ngpu )
126134 return  nn .parallel .data_parallel (self .main , input , gpu_ids )
127135
136+ 
128137netG  =  _netG (ngpu )
129138netG .apply (weights_init )
130139if  opt .netG  !=  '' :
131140 netG .load_state_dict (torch .load (opt .netG ))
132141print (netG )
133142
143+ 
134144class  _netD (nn .Module ):
135145 def  __init__ (self , ngpu ):
136146 super (_netD , self ).__init__ ()
@@ -155,13 +165,15 @@ def __init__(self, ngpu):
155165 nn .Conv2d (ndf  *  8 , 1 , 4 , 1 , 0 , bias = False ),
156166 nn .Sigmoid ()
157167 )
168+ 
158169 def  forward (self , input ):
159170 gpu_ids  =  None 
160171 if  isinstance (input .data , torch .cuda .FloatTensor ) and  self .ngpu  >  1 :
161172 gpu_ids  =  range (self .ngpu )
162173 output  =  nn .parallel .data_parallel (self .main , input , gpu_ids )
163174 return  output .view (- 1 , 1 )
164175
176+ 
165177netD  =  _netD (ngpu )
166178netD .apply (weights_init )
167179if  opt .netD  !=  '' :
@@ -190,8 +202,8 @@ def forward(self, input):
190202fixed_noise  =  Variable (fixed_noise )
191203
192204# setup optimizer 
193- optimizerD  =  optim .Adam (netD .parameters (), lr   =   opt .lr , betas   =   (opt .beta1 , 0.999 ))
194- optimizerG  =  optim .Adam (netG .parameters (), lr   =   opt .lr , betas   =   (opt .beta1 , 0.999 ))
205+ optimizerD  =  optim .Adam (netD .parameters (), lr = opt .lr , betas = (opt .beta1 , 0.999 ))
206+ optimizerG  =  optim .Adam (netG .parameters (), lr = opt .lr , betas = (opt .beta1 , 0.999 ))
195207
196208for  epoch  in  range (opt .niter ):
197209 for  i , data  in  enumerate (dataloader , 0 ):
@@ -226,7 +238,7 @@ def forward(self, input):
226238 # (2) Update G network: maximize log(D(G(z))) 
227239 ########################### 
228240 netG .zero_grad ()
229-  label .data .fill_ (real_label ) # fake labels are real for generator cost 
241+  label .data .fill_ (real_label )   # fake labels are real for generator cost 
230242 output  =  netD (fake )
231243 errG  =  criterion (output , label )
232244 errG .backward ()
0 commit comments