@@ -88,7 +88,7 @@ def block(in_feat, out_feat, normalize=True):
8888 layers = [nn .Linear (in_feat , out_feat )]
8989 if normalize :
9090 layers .append (nn .BatchNorm1d (out_feat , 0.8 ))
91- layers .append (nn .LeakyReLU (0.2 , inplace = True ))
91+ layers .append (nn .LeakyReLU (0.01 , inplace = True ))
9292 return layers
9393
9494 self .model = nn .Sequential (
@@ -193,15 +193,15 @@ def training_step(self, batch):
193193 # log sampled images
194194 sample_imgs = self .generated_imgs [:6 ]
195195 grid = torchvision .utils .make_grid (sample_imgs )
196- self .logger .experiment .add_image ("generated_images" , grid , 0 )
196+ self .logger .experiment .add_image ("train/ generated_images" , grid , self . current_epoch )
197197
198198 # ground truth result (ie: all fake)
199199 # put on GPU because we created this tensor inside training_loop
200200 valid = torch .ones (imgs .size (0 ), 1 )
201201 valid = valid .type_as (imgs )
202202
203203 # adversarial loss is binary cross-entropy
204- g_loss = self .adversarial_loss (self .discriminator (self ( z ) ), valid )
204+ g_loss = self .adversarial_loss (self .discriminator (self . generated_imgs ), valid )
205205 self .log ("g_loss" , g_loss , prog_bar = True )
206206 self .manual_backward (g_loss )
207207 optimizer_g .step ()
@@ -222,7 +222,7 @@ def training_step(self, batch):
222222 fake = torch .zeros (imgs .size (0 ), 1 )
223223 fake = fake .type_as (imgs )
224224
225- fake_loss = self .adversarial_loss (self .discriminator (self ( z ) .detach ()), fake )
225+ fake_loss = self .adversarial_loss (self .discriminator (self . generated_imgs .detach ()), fake )
226226
227227 # discriminator loss is the average of these
228228 d_loss = (real_loss + fake_loss ) / 2
@@ -232,6 +232,9 @@ def training_step(self, batch):
232232 optimizer_d .zero_grad ()
233233 self .untoggle_optimizer (optimizer_d )
234234
235+ def validation_step (self , batch , batch_idx ):
236+ pass
237+
235238 def configure_optimizers (self ):
236239 lr = self .hparams .lr
237240 b1 = self .hparams .b1
@@ -247,7 +250,7 @@ def on_validation_epoch_end(self):
247250 # log sampled images
248251 sample_imgs = self (z )
249252 grid = torchvision .utils .make_grid (sample_imgs )
250- self .logger .experiment .add_image ("generated_images" , grid , self .current_epoch )
253+ self .logger .experiment .add_image ("validation/ generated_images" , grid , self .current_epoch )
251254
252255
253256# %%
@@ -263,4 +266,4 @@ def on_validation_epoch_end(self):
263266# %%
264267# Start tensorboard.
265268# %load_ext tensorboard
266- # %tensorboard --logdir lightning_logs/
269+ # %tensorboard --logdir lightning_logs/ --samples_per_plugin=images=60
0 commit comments