Skip to content

Commit 59b1de7

Browse files
updated progan
1 parent c72d1d6 commit 59b1de7

File tree

5 files changed

+29
-9
lines changed

5 files changed

+29
-9
lines changed

ML/Pytorch/GANs/ProGAN/README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,15 @@ A clean, simple and readable implementation of ProGAN in PyTorch. I've tried to
44
## Results
55
The model was trained on the Celeb-HQ dataset up to 256x256 image size. After that point I felt it was enough as it would take quite a while to train to 1024^2.
66

7-
|First is 64 random examples (not cherry picked) and second is more cherry picked examples. |
7+
|First is some more cherrypicked examples and second is just sampled from random latent vectors|
88
|:---:|
9-
|![](results/64_examples.png)|
109
|![](results/result1.png)|
10+
|![](results/64_examples.png)|
1111

1212

1313
### Celeb-HQ dataset
1414
The dataset can be downloaded from Kaggle: [link](https://www.kaggle.com/lamsimon/celebahq).
1515

16-
1716
### Download pretrained weights
1817
Pretrained weights [here](https://github.com/aladdinpersson/Machine-Learning-Collection/releases/download/1.0/ProGAN_weights.zip).
1918

ML/Pytorch/GANs/ProGAN/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
CHECKPOINT_CRITIC = "critic.pth"
99
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
1010
SAVE_MODEL = True
11-
LOAD_MODEL = True
11+
LOAD_MODEL = False
1212
LEARNING_RATE = 1e-3
1313
BATCH_SIZES = [32, 32, 32, 16, 16, 16, 16, 8, 4]
1414
CHANNELS_IMG = 3

ML/Pytorch/GANs/ProGAN/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def forward(self, x, alpha, steps):
134134

135135

136136
class Discriminator(nn.Module):
137-
def __init__(self, z_dim, in_channels, img_channels=3):
137+
def __init__(self, in_channels, img_channels=3):
138138
super(Discriminator, self).__init__()
139139
self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
140140
self.leaky = nn.LeakyReLU(0.2)

ML/Pytorch/GANs/ProGAN/train.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
plot_to_tensorboard,
1212
save_checkpoint,
1313
load_checkpoint,
14+
generate_examples,
1415
)
1516
from model import Discriminator, Generator
1617
from math import log2
@@ -130,9 +131,8 @@ def train_fn(
130131
def main():
131132
# initialize gen and disc, note: discriminator should be called critic,
132133
# according to WGAN paper (since it no longer outputs between [0, 1])
133-
# but really who cares..
134134
gen = Generator(
135-
config.Z_DIM, config.W_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
135+
config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
136136
).to(config.DEVICE)
137137
critic = Discriminator(
138138
config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
@@ -147,7 +147,7 @@ def main():
147147
scaler_gen = torch.cuda.amp.GradScaler()
148148

149149
# for tensorboard plotting
150-
writer = SummaryWriter(f"logs/gan1")
150+
writer = SummaryWriter(f"logs/gan")
151151

152152
if config.LOAD_MODEL:
153153
load_checkpoint(
@@ -163,6 +163,10 @@ def main():
163163
tensorboard_step = 0
164164
# start at step that corresponds to img size that we set in config
165165
step = int(log2(config.START_TRAIN_AT_IMG_SIZE / 4))
166+
167+
generate_examples(gen, step)
168+
import sys
169+
sys.exit()
166170
for num_epochs in config.PROGRESSIVE_EPOCHS[step:]:
167171
alpha = 1e-5 # start with very low alpha
168172
loader, dataset = get_loader(4 * 2 ** step) # 4->0, 8->1, 16->2, 32->3, 64 -> 4

ML/Pytorch/GANs/ProGAN/utils.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
import os
55
import torchvision
66
import torch.nn as nn
7+
import config
8+
from torchvision.utils import save_image
9+
from scipy.stats import truncnorm
710

811
# Print losses occasionally and print to tensorboard
912
def plot_to_tensorboard(
@@ -12,7 +15,7 @@ def plot_to_tensorboard(
1215
writer.add_scalar("Loss Critic", loss_critic, global_step=tensorboard_step)
1316

1417
with torch.no_grad():
15-
# take out (up to) 32 examples
18+
# take out (up to) 8 examples to plot
1619
img_grid_real = torchvision.utils.make_grid(real[:8], normalize=True)
1720
img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True)
1821
writer.add_image("Real", img_grid_real, global_step=tensorboard_step)
@@ -72,4 +75,18 @@ def seed_everything(seed=42):
7275
torch.backends.cudnn.deterministic = True
7376
torch.backends.cudnn.benchmark = False
7477

78+
def generate_examples(gen, steps, truncation=0.7, n=100):
79+
"""
80+
Tried using truncation trick here but not sure it actually helped anything, you can
81+
remove it if you like and just sample from torch.randn
82+
"""
83+
gen.eval()
84+
alpha = 1.0
85+
for i in range(n):
86+
with torch.no_grad():
87+
noise = torch.tensor(truncnorm.rvs(-truncation, truncation, size=(1, config.Z_DIM, 1, 1)), device=config.DEVICE, dtype=torch.float32)
88+
img = gen(noise, alpha, steps)
89+
save_image(img*0.5+0.5, f"saved_examples/img_{i}.png")
90+
gen.train()
91+
7592

0 commit comments

Comments
 (0)