11""" Training of ProGAN using WGAN-GP loss"""
22
33import torch
4+ import torch .nn as nn
45import torch .optim as optim
6+ import torchvision
57import torchvision .datasets as datasets
68import torchvision .transforms as transforms
79from torch .utils .data import DataLoader
810from torch .utils .tensorboard import SummaryWriter
9- from utils import (
10- gradient_penalty ,
11- plot_to_tensorboard ,
12- save_checkpoint ,
13- load_checkpoint ,
14- generate_examples ,
15- )
11+ from utils import gradient_penalty , plot_to_tensorboard , save_checkpoint , load_checkpoint
1612from model import Discriminator , Generator
1713from math import log2
1814from tqdm import tqdm
15+ import time
1916import config
2017
2118torch .backends .cudnn .benchmarks = True
2219
23-
2420def get_loader (image_size ):
2521 transform = transforms .Compose (
2622 [
2723 transforms .Resize ((image_size , image_size )),
2824 transforms .ToTensor (),
29- transforms .RandomHorizontalFlip (p = 0.5 ),
3025 transforms .Normalize (
3126 [0.5 for _ in range (config .CHANNELS_IMG )],
3227 [0.5 for _ in range (config .CHANNELS_IMG )],
3328 ),
3429 ]
3530 )
36- batch_size = config .BATCH_SIZES [int (log2 (image_size / 4 ))]
37- dataset = datasets .ImageFolder (root = config .DATASET , transform = transform )
38- loader = DataLoader (
39- dataset ,
40- batch_size = batch_size ,
41- shuffle = True ,
42- num_workers = config .NUM_WORKERS ,
43- pin_memory = True ,
44- )
31+ batch_size = config .BATCH_SIZES [int (log2 (image_size / 4 ))]
32+ dataset = datasets .ImageFolder (root = "celeb_dataset" , transform = transform )
33+ loader = DataLoader (dataset , batch_size = batch_size , shuffle = True , num_workers = config .NUM_WORKERS , pin_memory = True )
4534 return loader , dataset
4635
47-
4836def train_fn (
4937 critic ,
5038 gen ,
@@ -59,96 +47,91 @@ def train_fn(
5947 scaler_gen ,
6048 scaler_critic ,
6149):
50+ start = time .time ()
51+ total_time = 0
6252 loop = tqdm (loader , leave = True )
63- # critic_losses = []
64- reals = 0
65- fakes = 0
53+ losses_critic = []
54+
6655 for batch_idx , (real , _ ) in enumerate (loop ):
6756 real = real .to (config .DEVICE )
6857 cur_batch_size = real .shape [0 ]
69-
70- # Train Critic: max E[critic(real)] - E[critic(fake)] <-> min -E[critic(real)] + E[critic(fake)]
71- # which is equivalent to minimizing the negative of the expression
72- noise = torch .randn (cur_batch_size , config .Z_DIM , 1 , 1 ).to (config .DEVICE )
73-
74- with torch .cuda .amp .autocast ():
75- fake = gen (noise , alpha , step )
76- critic_real = critic (real , alpha , step )
77- critic_fake = critic (fake .detach (), alpha , step )
78- reals += critic_real .mean ().item ()
79- fakes += critic_fake .mean ().item ()
80- gp = gradient_penalty (critic , real , fake , alpha , step , device = config .DEVICE )
81- loss_critic = (
82- - (torch .mean (critic_real ) - torch .mean (critic_fake ))
83- + config .LAMBDA_GP * gp
84- + (0.001 * torch .mean (critic_real ** 2 ))
85- )
86-
87- opt_critic .zero_grad ()
88- scaler_critic .scale (loss_critic ).backward ()
89- scaler_critic .step (opt_critic )
90- scaler_critic .update ()
91-
92- # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
93- with torch .cuda .amp .autocast ():
94- gen_fake = critic (fake , alpha , step )
95- loss_gen = - torch .mean (gen_fake )
96-
97- opt_gen .zero_grad ()
98- scaler_gen .scale (loss_gen ).backward ()
99- scaler_gen .step (opt_gen )
100- scaler_gen .update ()
58+ model_start = time .time ()
59+
60+ for _ in range (4 ):
61+ # Train Critic: max E[critic(real)] - E[critic(fake)]
62+ # which is equivalent to minimizing the negative of the expression
63+ for _ in range (config .CRITIC_ITERATIONS ):
64+ noise = torch .randn (cur_batch_size , config .Z_DIM , 1 , 1 ).to (config .DEVICE )
65+
66+ with torch .cuda .amp .autocast ():
67+ fake = gen (noise , alpha , step )
68+ critic_real = critic (real , alpha , step ).reshape (- 1 )
69+ critic_fake = critic (fake , alpha , step ).reshape (- 1 )
70+ gp = gradient_penalty (critic , real , fake , alpha , step , device = config .DEVICE )
71+ loss_critic = (
72+ - (torch .mean (critic_real ) - torch .mean (critic_fake ))
73+ + config .LAMBDA_GP * gp
74+ )
75+
76+ losses_critic .append (loss_critic .item ())
77+ opt_critic .zero_grad ()
78+ scaler_critic .scale (loss_critic ).backward ()
79+ scaler_critic .step (opt_critic )
80+ scaler_critic .update ()
81+ #loss_critic.backward(retain_graph=True)
82+ #opt_critic.step()
83+
84+ # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
85+ with torch .cuda .amp .autocast ():
86+ fake = gen (noise , alpha , step )
87+ gen_fake = critic (fake , alpha , step ).reshape (- 1 )
88+ loss_gen = - torch .mean (gen_fake )
89+
90+ opt_gen .zero_grad ()
91+ scaler_gen .scale (loss_gen ).backward ()
92+ scaler_gen .step (opt_gen )
93+ scaler_gen .update ()
94+ #gen.zero_grad()
95+ #loss_gen.backward()
96+ #opt_gen.step()
10197
10298 # Update alpha and ensure less than 1
10399 alpha += cur_batch_size / (
104- (config .PROGRESSIVE_EPOCHS [step ] * 0.5 ) * len (dataset )
100+ (config .PROGRESSIVE_EPOCHS [step ]* 0.5 ) * len (dataset ) # - step
105101 )
106102 alpha = min (alpha , 1 )
103+ total_time += time .time ()- model_start
107104
108- if batch_idx % 500 == 0 :
105+ if batch_idx % 10 == 0 :
106+ print (alpha )
109107 with torch .no_grad ():
110- fixed_fakes = gen (config .FIXED_NOISE , alpha , step ) * 0.5 + 0.5
108+ fixed_fakes = gen (config .FIXED_NOISE , alpha , step )
111109 plot_to_tensorboard (
112- writer ,
113- loss_critic .item (),
114- loss_gen .item (),
115- real .detach (),
116- fixed_fakes .detach (),
117- tensorboard_step ,
110+ writer , loss_critic , loss_gen , real , fixed_fakes , tensorboard_step
118111 )
119112 tensorboard_step += 1
120113
121- loop .set_postfix (
122- reals = reals / (batch_idx + 1 ),
123- fakes = fakes / (batch_idx + 1 ),
124- gp = gp .item (),
125- loss_critic = loss_critic .item (),
126- )
114+ mean_loss = sum (losses_critic ) / len (losses_critic )
115+ loop .set_postfix (loss = mean_loss )
127116
117+ print (f'Fraction spent on model training: { total_time / (time .time ()- start )} ' )
128118 return tensorboard_step , alpha
129119
130120
131121def main ():
132122 # initialize gen and disc, note: discriminator should be called critic,
133123 # according to WGAN paper (since it no longer outputs between [0, 1])
134- # but really who cares..
135- gen = Generator (
136- config .Z_DIM , config .IN_CHANNELS , img_channels = config .CHANNELS_IMG
137- ).to (config .DEVICE )
138- critic = Discriminator (
139- config .Z_DIM , config .IN_CHANNELS , img_channels = config .CHANNELS_IMG
140- ).to (config .DEVICE )
141-
142- # initialize optimizers and scalers for FP16 training
124+ gen = Generator (config .Z_DIM , config .IN_CHANNELS , img_size = config .IMAGE_SIZE , img_channels = config .CHANNELS_IMG ).to (config .DEVICE )
125+ critic = Discriminator (config .IMAGE_SIZE , config .Z_DIM , config .IN_CHANNELS , img_channels = config .CHANNELS_IMG ).to (config .DEVICE )
126+
127+ # initializate optimizer
143128 opt_gen = optim .Adam (gen .parameters (), lr = config .LEARNING_RATE , betas = (0.0 , 0.99 ))
144- opt_critic = optim .Adam (
145- critic .parameters (), lr = config .LEARNING_RATE , betas = (0.0 , 0.99 )
146- )
129+ opt_critic = optim .Adam (critic .parameters (), lr = config .LEARNING_RATE , betas = (0.0 , 0.99 ))
147130 scaler_critic = torch .cuda .amp .GradScaler ()
148131 scaler_gen = torch .cuda .amp .GradScaler ()
149132
150133 # for tensorboard plotting
151- writer = SummaryWriter (f"logs/gan1 " )
134+ writer = SummaryWriter (f"logs/gan " )
152135
153136 if config .LOAD_MODEL :
154137 load_checkpoint (
@@ -162,13 +145,12 @@ def main():
162145 critic .train ()
163146
164147 tensorboard_step = 0
165- # start at step that corresponds to img size that we set in config
166- step = int (log2 (config .START_TRAIN_AT_IMG_SIZE / 4 ))
167- for num_epochs in config .PROGRESSIVE_EPOCHS [step :]:
168- alpha = 1e-5 # start with very low alpha
169- loader , dataset = get_loader (4 * 2 ** step ) # 4->0, 8->1, 16->2, 32->3, 64 -> 4
170- print (f"Current image size: { 4 * 2 ** step } " )
148+ step = int (log2 (config .START_TRAIN_AT_IMG_SIZE / 4 ))
171149
150+ for num_epochs in config .PROGRESSIVE_EPOCHS [step :]:
151+ alpha = 0.01
152+ loader , dataset = get_loader (4 * 2 ** step ) # 4->0, 8->1, 16->2, 32->3
153+ print (f"Current image size: { 4 * 2 ** step } " )
172154 for epoch in range (num_epochs ):
173155 print (f"Epoch [{ epoch + 1 } /{ num_epochs } ]" )
174156 tensorboard_step , alpha = train_fn (
@@ -190,8 +172,7 @@ def main():
190172 save_checkpoint (gen , opt_gen , filename = config .CHECKPOINT_GEN )
191173 save_checkpoint (critic , opt_critic , filename = config .CHECKPOINT_CRITIC )
192174
193- step += 1 # progress to the next img size
194-
175+ step += 1
195176
196177if __name__ == "__main__" :
197- main ()
178+ main ()
0 commit comments