2626parser .add_argument ('--niter' , type = int , default = 25 , help = 'number of epochs to train for' )
2727parser .add_argument ('--lr' , type = float , default = 0.0002 , help = 'learning rate, default=0.0002' )
2828parser .add_argument ('--beta1' , type = float , default = 0.5 , help = 'beta1 for adam. default=0.5' )
29- parser .add_argument ('--cuda' , action = 'store_true' , help = 'enables cuda' )
30- parser .add_argument ('--ngpu' , type = int , default = 1 , help = 'number of GPUs to use' )
29+ parser .add_argument ('--cuda' , action = 'store_true' , help = 'enables cuda' )
30+ parser .add_argument ('--ngpu' , type = int , default = 1 , help = 'number of GPUs to use' )
3131parser .add_argument ('--netG' , default = '' , help = "path to netG (to continue training)" )
3232parser .add_argument ('--netD' , default = '' , help = "path to netD (to continue training)" )
3333parser .add_argument ('--outf' , default = '.' , help = 'folder to output images and model checkpoints' )
34+ parser .add_argument ('--manualSeed' , type = int , help = 'manual seed' )
3435
3536opt = parser .parse_args ()
3637print (opt )
3940 os .makedirs (opt .outf )
4041except OSError :
4142 pass
42- opt .manualSeed = random .randint (1 , 10000 ) # fix seed
43+
44+ if opt .manualSeed is None :
45+ opt .manualSeed = random .randint (1 , 10000 )
4346print ("Random Seed: " , opt .manualSeed )
4447random .seed (opt .manualSeed )
4548torch .manual_seed (opt .manualSeed )
49+ if opt .cuda :
50+ torch .cuda .manual_seed_all (opt .manualSeed )
4651
4752cudnn .benchmark = True
4853
8489ndf = int (opt .ndf )
8590nc = 3
8691
92+
8793# custom weights initialization called on netG and netD
8894def weights_init (m ):
8995 classname = m .__class__ .__name__
@@ -93,6 +99,7 @@ def weights_init(m):
9399 m .weight .data .normal_ (1.0 , 0.02 )
94100 m .bias .data .fill_ (0 )
95101
102+
96103class _netG (nn .Module ):
97104 def __init__ (self , ngpu ):
98105 super (_netG , self ).__init__ ()
@@ -119,18 +126,21 @@ def __init__(self, ngpu):
119126 nn .Tanh ()
120127 # state size. (nc) x 64 x 64
121128 )
129+
122130 def forward (self , input ):
123131 gpu_ids = None
124132 if isinstance (input .data , torch .cuda .FloatTensor ) and self .ngpu > 1 :
125133 gpu_ids = range (self .ngpu )
126134 return nn .parallel .data_parallel (self .main , input , gpu_ids )
127135
136+
128137netG = _netG (ngpu )
129138netG .apply (weights_init )
130139if opt .netG != '' :
131140 netG .load_state_dict (torch .load (opt .netG ))
132141print (netG )
133142
143+
134144class _netD (nn .Module ):
135145 def __init__ (self , ngpu ):
136146 super (_netD , self ).__init__ ()
@@ -155,13 +165,15 @@ def __init__(self, ngpu):
155165 nn .Conv2d (ndf * 8 , 1 , 4 , 1 , 0 , bias = False ),
156166 nn .Sigmoid ()
157167 )
168+
158169 def forward (self , input ):
159170 gpu_ids = None
160171 if isinstance (input .data , torch .cuda .FloatTensor ) and self .ngpu > 1 :
161172 gpu_ids = range (self .ngpu )
162173 output = nn .parallel .data_parallel (self .main , input , gpu_ids )
163174 return output .view (- 1 , 1 )
164175
176+
165177netD = _netD (ngpu )
166178netD .apply (weights_init )
167179if opt .netD != '' :
@@ -190,8 +202,8 @@ def forward(self, input):
190202fixed_noise = Variable (fixed_noise )
191203
192204# setup optimizer
193- optimizerD = optim .Adam (netD .parameters (), lr = opt .lr , betas = (opt .beta1 , 0.999 ))
194- optimizerG = optim .Adam (netG .parameters (), lr = opt .lr , betas = (opt .beta1 , 0.999 ))
205+ optimizerD = optim .Adam (netD .parameters (), lr = opt .lr , betas = (opt .beta1 , 0.999 ))
206+ optimizerG = optim .Adam (netG .parameters (), lr = opt .lr , betas = (opt .beta1 , 0.999 ))
195207
196208for epoch in range (opt .niter ):
197209 for i , data in enumerate (dataloader , 0 ):
@@ -226,7 +238,7 @@ def forward(self, input):
226238 # (2) Update G network: maximize log(D(G(z)))
227239 ###########################
228240 netG .zero_grad ()
229- label .data .fill_ (real_label ) # fake labels are real for generator cost
241+ label .data .fill_ (real_label ) # fake labels are real for generator cost
230242 output = netD (fake )
231243 errG = criterion (output , label )
232244 errG .backward ()
0 commit comments