Skip to content

Commit 06bd820

Browse files
update to progan
1 parent dc7f4f4 commit 06bd820

File tree

5 files changed

+169
-96
lines changed

5 files changed

+169
-96
lines changed

ML/Pytorch/GANs/ProGAN/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# ProGAN
2-
A clean, simple and readable implementation of ProGAN in PyTorch. I've tried to replicate the original paper as closely as possible, so if you read the paper the implementation should be pretty much identical. The results from this implementation I would say is on par with the paper, I'll include some examples results below.
2+
A clean, simple and readable implementation of ProGAN in PyTorch. I've tried to replicate the original paper as closely as possible, so if you read the paper the implementation should be pretty much identical. The results from this implementation I would say is pretty close to the original paper (I'll include some examples results below) but because of time limitation I only trained to 256x256 and on lower model size than they did in the paper. Making the number of channels to 512 instead of 256 as I trained it would probably make the results even better :)
33

44
## Results
55
The model was trained on the Celeb-HQ dataset up to 256x256 image size. After that point I felt it was enough as it would take quite a while to train to 1024^2.

ML/Pytorch/GANs/ProGAN/config.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import cv2
2+
import torch
3+
from math import log2
4+
5+
START_TRAIN_AT_IMG_SIZE = 4
6+
DATASET = 'celeb_dataset'
7+
CHECKPOINT_GEN = "generator.pth"
8+
CHECKPOINT_CRITIC = "critic.pth"
9+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10+
SAVE_MODEL = True
11+
LOAD_MODEL = True
12+
LEARNING_RATE = 1e-3
13+
BATCH_SIZES = [32, 32, 32, 16, 16, 16, 16, 8, 4]
14+
CHANNELS_IMG = 3
15+
Z_DIM = 256 # should be 512 in original paper
16+
IN_CHANNELS = 256 # should be 512 in original paper
17+
CRITIC_ITERATIONS = 1
18+
LAMBDA_GP = 10
19+
PROGRESSIVE_EPOCHS = [30] * len(BATCH_SIZES)
20+
FIXED_NOISE = torch.randn(8, Z_DIM, 1, 1).to(DEVICE)
21+
NUM_WORKERS = 4

ML/Pytorch/GANs/ProGAN/test.py

Lines changed: 0 additions & 4 deletions
This file was deleted.

ML/Pytorch/GANs/ProGAN/train.py

Lines changed: 72 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,38 @@
11
""" Training of ProGAN using WGAN-GP loss"""
22

33
import torch
4+
import torch.nn as nn
45
import torch.optim as optim
6+
import torchvision
57
import torchvision.datasets as datasets
68
import torchvision.transforms as transforms
79
from torch.utils.data import DataLoader
810
from torch.utils.tensorboard import SummaryWriter
9-
from utils import (
10-
gradient_penalty,
11-
plot_to_tensorboard,
12-
save_checkpoint,
13-
load_checkpoint,
14-
generate_examples,
15-
)
11+
from utils import gradient_penalty, plot_to_tensorboard, save_checkpoint, load_checkpoint
1612
from model import Discriminator, Generator
1713
from math import log2
1814
from tqdm import tqdm
15+
import time
1916
import config
2017

2118
torch.backends.cudnn.benchmarks = True
2219

23-
2420
def get_loader(image_size):
2521
transform = transforms.Compose(
2622
[
2723
transforms.Resize((image_size, image_size)),
2824
transforms.ToTensor(),
29-
transforms.RandomHorizontalFlip(p=0.5),
3025
transforms.Normalize(
3126
[0.5 for _ in range(config.CHANNELS_IMG)],
3227
[0.5 for _ in range(config.CHANNELS_IMG)],
3328
),
3429
]
3530
)
36-
batch_size = config.BATCH_SIZES[int(log2(image_size / 4))]
37-
dataset = datasets.ImageFolder(root=config.DATASET, transform=transform)
38-
loader = DataLoader(
39-
dataset,
40-
batch_size=batch_size,
41-
shuffle=True,
42-
num_workers=config.NUM_WORKERS,
43-
pin_memory=True,
44-
)
31+
batch_size = config.BATCH_SIZES[int(log2(image_size/4))]
32+
dataset = datasets.ImageFolder(root="celeb_dataset", transform=transform)
33+
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True)
4534
return loader, dataset
4635

47-
4836
def train_fn(
4937
critic,
5038
gen,
@@ -59,96 +47,91 @@ def train_fn(
5947
scaler_gen,
6048
scaler_critic,
6149
):
50+
start = time.time()
51+
total_time = 0
6252
loop = tqdm(loader, leave=True)
63-
# critic_losses = []
64-
reals = 0
65-
fakes = 0
53+
losses_critic = []
54+
6655
for batch_idx, (real, _) in enumerate(loop):
6756
real = real.to(config.DEVICE)
6857
cur_batch_size = real.shape[0]
69-
70-
# Train Critic: max E[critic(real)] - E[critic(fake)] <-> min -E[critic(real)] + E[critic(fake)]
71-
# which is equivalent to minimizing the negative of the expression
72-
noise = torch.randn(cur_batch_size, config.Z_DIM, 1, 1).to(config.DEVICE)
73-
74-
with torch.cuda.amp.autocast():
75-
fake = gen(noise, alpha, step)
76-
critic_real = critic(real, alpha, step)
77-
critic_fake = critic(fake.detach(), alpha, step)
78-
reals += critic_real.mean().item()
79-
fakes += critic_fake.mean().item()
80-
gp = gradient_penalty(critic, real, fake, alpha, step, device=config.DEVICE)
81-
loss_critic = (
82-
-(torch.mean(critic_real) - torch.mean(critic_fake))
83-
+ config.LAMBDA_GP * gp
84-
+ (0.001 * torch.mean(critic_real ** 2))
85-
)
86-
87-
opt_critic.zero_grad()
88-
scaler_critic.scale(loss_critic).backward()
89-
scaler_critic.step(opt_critic)
90-
scaler_critic.update()
91-
92-
# Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
93-
with torch.cuda.amp.autocast():
94-
gen_fake = critic(fake, alpha, step)
95-
loss_gen = -torch.mean(gen_fake)
96-
97-
opt_gen.zero_grad()
98-
scaler_gen.scale(loss_gen).backward()
99-
scaler_gen.step(opt_gen)
100-
scaler_gen.update()
58+
model_start = time.time()
59+
60+
for _ in range(4):
61+
# Train Critic: max E[critic(real)] - E[critic(fake)]
62+
# which is equivalent to minimizing the negative of the expression
63+
for _ in range(config.CRITIC_ITERATIONS):
64+
noise = torch.randn(cur_batch_size, config.Z_DIM, 1, 1).to(config.DEVICE)
65+
66+
with torch.cuda.amp.autocast():
67+
fake = gen(noise, alpha, step)
68+
critic_real = critic(real, alpha, step).reshape(-1)
69+
critic_fake = critic(fake, alpha, step).reshape(-1)
70+
gp = gradient_penalty(critic, real, fake, alpha, step, device=config.DEVICE)
71+
loss_critic = (
72+
-(torch.mean(critic_real) - torch.mean(critic_fake))
73+
+ config.LAMBDA_GP * gp
74+
)
75+
76+
losses_critic.append(loss_critic.item())
77+
opt_critic.zero_grad()
78+
scaler_critic.scale(loss_critic).backward()
79+
scaler_critic.step(opt_critic)
80+
scaler_critic.update()
81+
#loss_critic.backward(retain_graph=True)
82+
#opt_critic.step()
83+
84+
# Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
85+
with torch.cuda.amp.autocast():
86+
fake = gen(noise, alpha, step)
87+
gen_fake = critic(fake, alpha, step).reshape(-1)
88+
loss_gen = -torch.mean(gen_fake)
89+
90+
opt_gen.zero_grad()
91+
scaler_gen.scale(loss_gen).backward()
92+
scaler_gen.step(opt_gen)
93+
scaler_gen.update()
94+
#gen.zero_grad()
95+
#loss_gen.backward()
96+
#opt_gen.step()
10197

10298
# Update alpha and ensure less than 1
10399
alpha += cur_batch_size / (
104-
(config.PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
100+
(config.PROGRESSIVE_EPOCHS[step]*0.5) * len(dataset) # - step
105101
)
106102
alpha = min(alpha, 1)
103+
total_time += time.time()-model_start
107104

108-
if batch_idx % 500 == 0:
105+
if batch_idx % 10 == 0:
106+
print(alpha)
109107
with torch.no_grad():
110-
fixed_fakes = gen(config.FIXED_NOISE, alpha, step) * 0.5 + 0.5
108+
fixed_fakes = gen(config.FIXED_NOISE, alpha, step)
111109
plot_to_tensorboard(
112-
writer,
113-
loss_critic.item(),
114-
loss_gen.item(),
115-
real.detach(),
116-
fixed_fakes.detach(),
117-
tensorboard_step,
110+
writer, loss_critic, loss_gen, real, fixed_fakes, tensorboard_step
118111
)
119112
tensorboard_step += 1
120113

121-
loop.set_postfix(
122-
reals=reals / (batch_idx + 1),
123-
fakes=fakes / (batch_idx + 1),
124-
gp=gp.item(),
125-
loss_critic=loss_critic.item(),
126-
)
114+
mean_loss = sum(losses_critic) / len(losses_critic)
115+
loop.set_postfix(loss=mean_loss)
127116

117+
print(f'Fraction spent on model training: {total_time/(time.time()-start)}')
128118
return tensorboard_step, alpha
129119

130120

131121
def main():
132122
# initialize gen and disc, note: discriminator should be called critic,
133123
# according to WGAN paper (since it no longer outputs between [0, 1])
134-
# but really who cares..
135-
gen = Generator(
136-
config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
137-
).to(config.DEVICE)
138-
critic = Discriminator(
139-
config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
140-
).to(config.DEVICE)
141-
142-
# initialize optimizers and scalers for FP16 training
124+
gen = Generator(config.Z_DIM, config.IN_CHANNELS, img_size=config.IMAGE_SIZE, img_channels=config.CHANNELS_IMG).to(config.DEVICE)
125+
critic = Discriminator(config.IMAGE_SIZE, config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG).to(config.DEVICE)
126+
127+
# initializate optimizer
143128
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99))
144-
opt_critic = optim.Adam(
145-
critic.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99)
146-
)
129+
opt_critic = optim.Adam(critic.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99))
147130
scaler_critic = torch.cuda.amp.GradScaler()
148131
scaler_gen = torch.cuda.amp.GradScaler()
149132

150133
# for tensorboard plotting
151-
writer = SummaryWriter(f"logs/gan1")
134+
writer = SummaryWriter(f"logs/gan")
152135

153136
if config.LOAD_MODEL:
154137
load_checkpoint(
@@ -162,13 +145,12 @@ def main():
162145
critic.train()
163146

164147
tensorboard_step = 0
165-
# start at step that corresponds to img size that we set in config
166-
step = int(log2(config.START_TRAIN_AT_IMG_SIZE / 4))
167-
for num_epochs in config.PROGRESSIVE_EPOCHS[step:]:
168-
alpha = 1e-5 # start with very low alpha
169-
loader, dataset = get_loader(4 * 2 ** step) # 4->0, 8->1, 16->2, 32->3, 64 -> 4
170-
print(f"Current image size: {4 * 2 ** step}")
148+
step = int(log2(config.START_TRAIN_AT_IMG_SIZE/4))
171149

150+
for num_epochs in config.PROGRESSIVE_EPOCHS[step:]:
151+
alpha = 0.01
152+
loader, dataset = get_loader(4 * 2 ** step) # 4->0, 8->1, 16->2, 32->3
153+
print(f"Current image size: {4*2**step}")
172154
for epoch in range(num_epochs):
173155
print(f"Epoch [{epoch+1}/{num_epochs}]")
174156
tensorboard_step, alpha = train_fn(
@@ -190,8 +172,7 @@ def main():
190172
save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
191173
save_checkpoint(critic, opt_critic, filename=config.CHECKPOINT_CRITIC)
192174

193-
step += 1 # progress to the next img size
194-
175+
step += 1
195176

196177
if __name__ == "__main__":
197-
main()
178+
main()

ML/Pytorch/GANs/ProGAN/utils.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import torch
2+
import random
3+
import numpy as np
4+
import os
5+
import torchvision
6+
import torch.nn as nn
7+
8+
# Print losses occasionally and print to tensorboard
9+
def plot_to_tensorboard(
10+
writer, loss_critic, loss_gen, real, fake, tensorboard_step
11+
):
12+
writer.add_scalar("Loss Critic", loss_critic, global_step=tensorboard_step)
13+
14+
with torch.no_grad():
15+
# take out (up to) 32 examples
16+
img_grid_real = torchvision.utils.make_grid(real[:8], normalize=True)
17+
img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True)
18+
writer.add_image("Real", img_grid_real, global_step=tensorboard_step)
19+
writer.add_image("Fake", img_grid_fake, global_step=tensorboard_step)
20+
21+
22+
def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
23+
BATCH_SIZE, C, H, W = real.shape
24+
beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
25+
interpolated_images = real * beta + fake.detach() * (1 - beta)
26+
interpolated_images.requires_grad_(True)
27+
28+
# Calculate critic scores
29+
mixed_scores = critic(interpolated_images, alpha, train_step)
30+
31+
# Take the gradient of the scores with respect to the images
32+
gradient = torch.autograd.grad(
33+
inputs=interpolated_images,
34+
outputs=mixed_scores,
35+
grad_outputs=torch.ones_like(mixed_scores),
36+
create_graph=True,
37+
retain_graph=True,
38+
)[0]
39+
gradient = gradient.view(gradient.shape[0], -1)
40+
gradient_norm = gradient.norm(2, dim=1)
41+
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
42+
return gradient_penalty
43+
44+
45+
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
46+
print("=> Saving checkpoint")
47+
checkpoint = {
48+
"state_dict": model.state_dict(),
49+
"optimizer": optimizer.state_dict(),
50+
}
51+
torch.save(checkpoint, filename)
52+
53+
54+
def load_checkpoint(checkpoint_file, model, optimizer, lr):
55+
print("=> Loading checkpoint")
56+
checkpoint = torch.load(checkpoint_file, map_location="cuda")
57+
model.load_state_dict(checkpoint["state_dict"])
58+
optimizer.load_state_dict(checkpoint["optimizer"])
59+
60+
# If we don't do this then it will just have learning rate of old checkpoint
61+
# and it will lead to many hours of debugging \:
62+
for param_group in optimizer.param_groups:
63+
param_group["lr"] = lr
64+
65+
def seed_everything(seed=42):
66+
os.environ['PYTHONHASHSEED'] = str(seed)
67+
random.seed(seed)
68+
np.random.seed(seed)
69+
torch.manual_seed(seed)
70+
torch.cuda.manual_seed(seed)
71+
torch.cuda.manual_seed_all(seed)
72+
torch.backends.cudnn.deterministic = True
73+
torch.backends.cudnn.benchmark = False
74+
75+

0 commit comments

Comments
 (0)