Skip to content

Commit c72d1d6

Browse files
damn, copied over wrong train file for ProGAN (will check this more thoroughly before the video is up too
1 parent bd6db84 commit c72d1d6

File tree

1 file changed

+89
-71
lines changed

1 file changed

+89
-71
lines changed

ML/Pytorch/GANs/ProGAN/train.py

Lines changed: 89 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,49 @@
11
""" Training of ProGAN using WGAN-GP loss"""
22

33
import torch
4-
import torch.nn as nn
54
import torch.optim as optim
6-
import torchvision
75
import torchvision.datasets as datasets
86
import torchvision.transforms as transforms
97
from torch.utils.data import DataLoader
108
from torch.utils.tensorboard import SummaryWriter
11-
from utils import gradient_penalty, plot_to_tensorboard, save_checkpoint, load_checkpoint
9+
from utils import (
10+
gradient_penalty,
11+
plot_to_tensorboard,
12+
save_checkpoint,
13+
load_checkpoint,
14+
)
1215
from model import Discriminator, Generator
1316
from math import log2
1417
from tqdm import tqdm
15-
import time
1618
import config
1719

1820
torch.backends.cudnn.benchmarks = True
1921

22+
2023
def get_loader(image_size):
2124
transform = transforms.Compose(
2225
[
2326
transforms.Resize((image_size, image_size)),
2427
transforms.ToTensor(),
28+
transforms.RandomHorizontalFlip(p=0.5),
2529
transforms.Normalize(
2630
[0.5 for _ in range(config.CHANNELS_IMG)],
2731
[0.5 for _ in range(config.CHANNELS_IMG)],
2832
),
2933
]
3034
)
31-
batch_size = config.BATCH_SIZES[int(log2(image_size/4))]
32-
dataset = datasets.ImageFolder(root="celeb_dataset", transform=transform)
33-
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True)
35+
batch_size = config.BATCH_SIZES[int(log2(image_size / 4))]
36+
dataset = datasets.ImageFolder(root=config.DATASET, transform=transform)
37+
loader = DataLoader(
38+
dataset,
39+
batch_size=batch_size,
40+
shuffle=True,
41+
num_workers=config.NUM_WORKERS,
42+
pin_memory=True,
43+
)
3444
return loader, dataset
3545

46+
3647
def train_fn(
3748
critic,
3849
gen,
@@ -47,91 +58,96 @@ def train_fn(
4758
scaler_gen,
4859
scaler_critic,
4960
):
50-
start = time.time()
51-
total_time = 0
5261
loop = tqdm(loader, leave=True)
53-
losses_critic = []
54-
62+
# critic_losses = []
63+
reals = 0
64+
fakes = 0
5565
for batch_idx, (real, _) in enumerate(loop):
5666
real = real.to(config.DEVICE)
5767
cur_batch_size = real.shape[0]
58-
model_start = time.time()
59-
60-
for _ in range(4):
61-
# Train Critic: max E[critic(real)] - E[critic(fake)]
62-
# which is equivalent to minimizing the negative of the expression
63-
for _ in range(config.CRITIC_ITERATIONS):
64-
noise = torch.randn(cur_batch_size, config.Z_DIM, 1, 1).to(config.DEVICE)
65-
66-
with torch.cuda.amp.autocast():
67-
fake = gen(noise, alpha, step)
68-
critic_real = critic(real, alpha, step).reshape(-1)
69-
critic_fake = critic(fake, alpha, step).reshape(-1)
70-
gp = gradient_penalty(critic, real, fake, alpha, step, device=config.DEVICE)
71-
loss_critic = (
72-
-(torch.mean(critic_real) - torch.mean(critic_fake))
73-
+ config.LAMBDA_GP * gp
74-
)
75-
76-
losses_critic.append(loss_critic.item())
77-
opt_critic.zero_grad()
78-
scaler_critic.scale(loss_critic).backward()
79-
scaler_critic.step(opt_critic)
80-
scaler_critic.update()
81-
#loss_critic.backward(retain_graph=True)
82-
#opt_critic.step()
83-
84-
# Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
85-
with torch.cuda.amp.autocast():
86-
fake = gen(noise, alpha, step)
87-
gen_fake = critic(fake, alpha, step).reshape(-1)
88-
loss_gen = -torch.mean(gen_fake)
89-
90-
opt_gen.zero_grad()
91-
scaler_gen.scale(loss_gen).backward()
92-
scaler_gen.step(opt_gen)
93-
scaler_gen.update()
94-
#gen.zero_grad()
95-
#loss_gen.backward()
96-
#opt_gen.step()
68+
69+
# Train Critic: max E[critic(real)] - E[critic(fake)] <-> min -E[critic(real)] + E[critic(fake)]
70+
# which is equivalent to minimizing the negative of the expression
71+
noise = torch.randn(cur_batch_size, config.Z_DIM).to(config.DEVICE)
72+
73+
with torch.cuda.amp.autocast():
74+
fake = gen(noise, alpha, step)
75+
critic_real = critic(real, alpha, step)
76+
critic_fake = critic(fake.detach(), alpha, step)
77+
reals += critic_real.mean().item()
78+
fakes += critic_fake.mean().item()
79+
gp = gradient_penalty(critic, real, fake, device=config.DEVICE)
80+
loss_critic = (
81+
-(torch.mean(critic_real) - torch.mean(critic_fake))
82+
+ config.LAMBDA_GP * gp
83+
+ (0.001 * torch.mean(critic_real ** 2))
84+
)
85+
86+
opt_critic.zero_grad()
87+
scaler_critic.scale(loss_critic).backward()
88+
scaler_critic.step(opt_critic)
89+
scaler_critic.update()
90+
91+
# Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
92+
with torch.cuda.amp.autocast():
93+
gen_fake = critic(fake, alpha, step)
94+
loss_gen = -torch.mean(gen_fake)
95+
96+
opt_gen.zero_grad()
97+
scaler_gen.scale(loss_gen).backward()
98+
scaler_gen.step(opt_gen)
99+
scaler_gen.update()
97100

98101
# Update alpha and ensure less than 1
99102
alpha += cur_batch_size / (
100-
(config.PROGRESSIVE_EPOCHS[step]*0.5) * len(dataset) # - step
103+
(config.PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
101104
)
102105
alpha = min(alpha, 1)
103-
total_time += time.time()-model_start
104106

105-
if batch_idx % 10 == 0:
106-
print(alpha)
107+
if batch_idx % 500 == 0:
107108
with torch.no_grad():
108-
fixed_fakes = gen(config.FIXED_NOISE, alpha, step)
109+
fixed_fakes = gen(config.FIXED_NOISE, alpha, step) * 0.5 + 0.5
109110
plot_to_tensorboard(
110-
writer, loss_critic, loss_gen, real, fixed_fakes, tensorboard_step
111+
writer,
112+
loss_critic.item(),
113+
loss_gen.item(),
114+
real.detach(),
115+
fixed_fakes.detach(),
116+
tensorboard_step,
111117
)
112118
tensorboard_step += 1
113119

114-
mean_loss = sum(losses_critic) / len(losses_critic)
115-
loop.set_postfix(loss=mean_loss)
120+
loop.set_postfix(
121+
reals=reals / (batch_idx + 1),
122+
fakes=fakes / (batch_idx + 1),
123+
gp=gp.item(),
124+
loss_critic=loss_critic.item(),
125+
)
116126

117-
print(f'Fraction spent on model training: {total_time/(time.time()-start)}')
118127
return tensorboard_step, alpha
119128

120129

121130
def main():
122131
# initialize gen and disc, note: discriminator should be called critic,
123132
# according to WGAN paper (since it no longer outputs between [0, 1])
124-
gen = Generator(config.Z_DIM, config.IN_CHANNELS, img_size=config.IMAGE_SIZE, img_channels=config.CHANNELS_IMG).to(config.DEVICE)
125-
critic = Discriminator(config.IMAGE_SIZE, config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG).to(config.DEVICE)
126-
127-
# initializate optimizer
133+
# but really who cares..
134+
gen = Generator(
135+
config.Z_DIM, config.W_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
136+
).to(config.DEVICE)
137+
critic = Discriminator(
138+
config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
139+
).to(config.DEVICE)
140+
141+
# initialize optimizers and scalers for FP16 training
128142
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99))
129-
opt_critic = optim.Adam(critic.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99))
143+
opt_critic = optim.Adam(
144+
critic.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99)
145+
)
130146
scaler_critic = torch.cuda.amp.GradScaler()
131147
scaler_gen = torch.cuda.amp.GradScaler()
132148

133149
# for tensorboard plotting
134-
writer = SummaryWriter(f"logs/gan")
150+
writer = SummaryWriter(f"logs/gan1")
135151

136152
if config.LOAD_MODEL:
137153
load_checkpoint(
@@ -145,12 +161,13 @@ def main():
145161
critic.train()
146162

147163
tensorboard_step = 0
148-
step = int(log2(config.START_TRAIN_AT_IMG_SIZE/4))
149-
164+
# start at step that corresponds to img size that we set in config
165+
step = int(log2(config.START_TRAIN_AT_IMG_SIZE / 4))
150166
for num_epochs in config.PROGRESSIVE_EPOCHS[step:]:
151-
alpha = 0.01
152-
loader, dataset = get_loader(4 * 2 ** step) # 4->0, 8->1, 16->2, 32->3
153-
print(f"Current image size: {4*2**step}")
167+
alpha = 1e-5 # start with very low alpha
168+
loader, dataset = get_loader(4 * 2 ** step) # 4->0, 8->1, 16->2, 32->3, 64 -> 4
169+
print(f"Current image size: {4 * 2 ** step}")
170+
154171
for epoch in range(num_epochs):
155172
print(f"Epoch [{epoch+1}/{num_epochs}]")
156173
tensorboard_step, alpha = train_fn(
@@ -172,7 +189,8 @@ def main():
172189
save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
173190
save_checkpoint(critic, opt_critic, filename=config.CHECKPOINT_CRITIC)
174191

175-
step += 1
192+
step += 1 # progress to the next img size
193+
176194

177195
if __name__ == "__main__":
178196
main()

0 commit comments

Comments
 (0)