11""" Training of ProGAN using WGAN-GP loss"""
22
33import torch
4- import torch .nn as nn
54import torch .optim as optim
6- import torchvision
75import torchvision .datasets as datasets
86import torchvision .transforms as transforms
97from torch .utils .data import DataLoader
108from torch .utils .tensorboard import SummaryWriter
11- from utils import gradient_penalty , plot_to_tensorboard , save_checkpoint , load_checkpoint
9+ from utils import (
10+ gradient_penalty ,
11+ plot_to_tensorboard ,
12+ save_checkpoint ,
13+ load_checkpoint ,
14+ )
1215from model import Discriminator , Generator
1316from math import log2
1417from tqdm import tqdm
15- import time
1618import config
1719
1820torch .backends .cudnn .benchmarks = True
1921
22+
2023def get_loader (image_size ):
2124 transform = transforms .Compose (
2225 [
2326 transforms .Resize ((image_size , image_size )),
2427 transforms .ToTensor (),
28+ transforms .RandomHorizontalFlip (p = 0.5 ),
2529 transforms .Normalize (
2630 [0.5 for _ in range (config .CHANNELS_IMG )],
2731 [0.5 for _ in range (config .CHANNELS_IMG )],
2832 ),
2933 ]
3034 )
31- batch_size = config .BATCH_SIZES [int (log2 (image_size / 4 ))]
32- dataset = datasets .ImageFolder (root = "celeb_dataset" , transform = transform )
33- loader = DataLoader (dataset , batch_size = batch_size , shuffle = True , num_workers = config .NUM_WORKERS , pin_memory = True )
35+ batch_size = config .BATCH_SIZES [int (log2 (image_size / 4 ))]
36+ dataset = datasets .ImageFolder (root = config .DATASET , transform = transform )
37+ loader = DataLoader (
38+ dataset ,
39+ batch_size = batch_size ,
40+ shuffle = True ,
41+ num_workers = config .NUM_WORKERS ,
42+ pin_memory = True ,
43+ )
3444 return loader , dataset
3545
46+
3647def train_fn (
3748 critic ,
3849 gen ,
@@ -47,91 +58,96 @@ def train_fn(
4758 scaler_gen ,
4859 scaler_critic ,
4960):
50- start = time .time ()
51- total_time = 0
5261 loop = tqdm (loader , leave = True )
53- losses_critic = []
54-
62+ # critic_losses = []
63+ reals = 0
64+ fakes = 0
5565 for batch_idx , (real , _ ) in enumerate (loop ):
5666 real = real .to (config .DEVICE )
5767 cur_batch_size = real .shape [0 ]
58- model_start = time .time ()
59-
60- for _ in range (4 ):
61- # Train Critic: max E[critic(real)] - E[critic(fake)]
62- # which is equivalent to minimizing the negative of the expression
63- for _ in range (config .CRITIC_ITERATIONS ):
64- noise = torch .randn (cur_batch_size , config .Z_DIM , 1 , 1 ).to (config .DEVICE )
65-
66- with torch .cuda .amp .autocast ():
67- fake = gen (noise , alpha , step )
68- critic_real = critic (real , alpha , step ).reshape (- 1 )
69- critic_fake = critic (fake , alpha , step ).reshape (- 1 )
70- gp = gradient_penalty (critic , real , fake , alpha , step , device = config .DEVICE )
71- loss_critic = (
72- - (torch .mean (critic_real ) - torch .mean (critic_fake ))
73- + config .LAMBDA_GP * gp
74- )
75-
76- losses_critic .append (loss_critic .item ())
77- opt_critic .zero_grad ()
78- scaler_critic .scale (loss_critic ).backward ()
79- scaler_critic .step (opt_critic )
80- scaler_critic .update ()
81- #loss_critic.backward(retain_graph=True)
82- #opt_critic.step()
83-
84- # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
85- with torch .cuda .amp .autocast ():
86- fake = gen (noise , alpha , step )
87- gen_fake = critic (fake , alpha , step ).reshape (- 1 )
88- loss_gen = - torch .mean (gen_fake )
89-
90- opt_gen .zero_grad ()
91- scaler_gen .scale (loss_gen ).backward ()
92- scaler_gen .step (opt_gen )
93- scaler_gen .update ()
94- #gen.zero_grad()
95- #loss_gen.backward()
96- #opt_gen.step()
68+
69+ # Train Critic: max E[critic(real)] - E[critic(fake)] <-> min -E[critic(real)] + E[critic(fake)]
70+ # which is equivalent to minimizing the negative of the expression
71+ noise = torch .randn (cur_batch_size , config .Z_DIM ).to (config .DEVICE )
72+
73+ with torch .cuda .amp .autocast ():
74+ fake = gen (noise , alpha , step )
75+ critic_real = critic (real , alpha , step )
76+ critic_fake = critic (fake .detach (), alpha , step )
77+ reals += critic_real .mean ().item ()
78+ fakes += critic_fake .mean ().item ()
79+ gp = gradient_penalty (critic , real , fake , device = config .DEVICE )
80+ loss_critic = (
81+ - (torch .mean (critic_real ) - torch .mean (critic_fake ))
82+ + config .LAMBDA_GP * gp
83+ + (0.001 * torch .mean (critic_real ** 2 ))
84+ )
85+
86+ opt_critic .zero_grad ()
87+ scaler_critic .scale (loss_critic ).backward ()
88+ scaler_critic .step (opt_critic )
89+ scaler_critic .update ()
90+
91+ # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
92+ with torch .cuda .amp .autocast ():
93+ gen_fake = critic (fake , alpha , step )
94+ loss_gen = - torch .mean (gen_fake )
95+
96+ opt_gen .zero_grad ()
97+ scaler_gen .scale (loss_gen ).backward ()
98+ scaler_gen .step (opt_gen )
99+ scaler_gen .update ()
97100
98101 # Update alpha and ensure less than 1
99102 alpha += cur_batch_size / (
100- (config .PROGRESSIVE_EPOCHS [step ]* 0.5 ) * len (dataset ) # - step
103+ (config .PROGRESSIVE_EPOCHS [step ] * 0.5 ) * len (dataset )
101104 )
102105 alpha = min (alpha , 1 )
103- total_time += time .time ()- model_start
104106
105- if batch_idx % 10 == 0 :
106- print (alpha )
107+ if batch_idx % 500 == 0 :
107108 with torch .no_grad ():
108- fixed_fakes = gen (config .FIXED_NOISE , alpha , step )
109+ fixed_fakes = gen (config .FIXED_NOISE , alpha , step ) * 0.5 + 0.5
109110 plot_to_tensorboard (
110- writer , loss_critic , loss_gen , real , fixed_fakes , tensorboard_step
111+ writer ,
112+ loss_critic .item (),
113+ loss_gen .item (),
114+ real .detach (),
115+ fixed_fakes .detach (),
116+ tensorboard_step ,
111117 )
112118 tensorboard_step += 1
113119
114- mean_loss = sum (losses_critic ) / len (losses_critic )
115- loop .set_postfix (loss = mean_loss )
120+ loop .set_postfix (
121+ reals = reals / (batch_idx + 1 ),
122+ fakes = fakes / (batch_idx + 1 ),
123+ gp = gp .item (),
124+ loss_critic = loss_critic .item (),
125+ )
116126
117- print (f'Fraction spent on model training: { total_time / (time .time ()- start )} ' )
118127 return tensorboard_step , alpha
119128
120129
121130def main ():
122131 # initialize gen and disc, note: discriminator should be called critic,
123132 # according to WGAN paper (since it no longer outputs between [0, 1])
124- gen = Generator (config .Z_DIM , config .IN_CHANNELS , img_size = config .IMAGE_SIZE , img_channels = config .CHANNELS_IMG ).to (config .DEVICE )
125- critic = Discriminator (config .IMAGE_SIZE , config .Z_DIM , config .IN_CHANNELS , img_channels = config .CHANNELS_IMG ).to (config .DEVICE )
126-
127- # initializate optimizer
133+ # but really who cares..
134+ gen = Generator (
135+ config .Z_DIM , config .W_DIM , config .IN_CHANNELS , img_channels = config .CHANNELS_IMG
136+ ).to (config .DEVICE )
137+ critic = Discriminator (
138+ config .IN_CHANNELS , img_channels = config .CHANNELS_IMG
139+ ).to (config .DEVICE )
140+
141+ # initialize optimizers and scalers for FP16 training
128142 opt_gen = optim .Adam (gen .parameters (), lr = config .LEARNING_RATE , betas = (0.0 , 0.99 ))
129- opt_critic = optim .Adam (critic .parameters (), lr = config .LEARNING_RATE , betas = (0.0 , 0.99 ))
143+ opt_critic = optim .Adam (
144+ critic .parameters (), lr = config .LEARNING_RATE , betas = (0.0 , 0.99 )
145+ )
130146 scaler_critic = torch .cuda .amp .GradScaler ()
131147 scaler_gen = torch .cuda .amp .GradScaler ()
132148
133149 # for tensorboard plotting
134- writer = SummaryWriter (f"logs/gan " )
150+ writer = SummaryWriter (f"logs/gan1 " )
135151
136152 if config .LOAD_MODEL :
137153 load_checkpoint (
@@ -145,12 +161,13 @@ def main():
145161 critic .train ()
146162
147163 tensorboard_step = 0
148- step = int ( log2 ( config . START_TRAIN_AT_IMG_SIZE / 4 ))
149-
164+ # start at step that corresponds to img size that we set in config
165+ step = int ( log2 ( config . START_TRAIN_AT_IMG_SIZE / 4 ))
150166 for num_epochs in config .PROGRESSIVE_EPOCHS [step :]:
151- alpha = 0.01
152- loader , dataset = get_loader (4 * 2 ** step ) # 4->0, 8->1, 16->2, 32->3
153- print (f"Current image size: { 4 * 2 ** step } " )
167+ alpha = 1e-5 # start with very low alpha
168+ loader , dataset = get_loader (4 * 2 ** step ) # 4->0, 8->1, 16->2, 32->3, 64 -> 4
169+ print (f"Current image size: { 4 * 2 ** step } " )
170+
154171 for epoch in range (num_epochs ):
155172 print (f"Epoch [{ epoch + 1 } /{ num_epochs } ]" )
156173 tensorboard_step , alpha = train_fn (
@@ -172,7 +189,8 @@ def main():
172189 save_checkpoint (gen , opt_gen , filename = config .CHECKPOINT_GEN )
173190 save_checkpoint (critic , opt_critic , filename = config .CHECKPOINT_CRITIC )
174191
175- step += 1
192+ step += 1 # progress to the next img size
193+
176194
177195if __name__ == "__main__" :
178196 main ()
0 commit comments