Skip to content

Commit 74597aa

Browse files
updated progan
1 parent 59b1de7 commit 74597aa

File tree

5 files changed

+15
-26
lines changed

5 files changed

+15
-26
lines changed

ML/Pytorch/GANs/ProGAN/README.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
# ProGAN
2-
A clean, simple and readable implementation of ProGAN in PyTorch. I've tried to replicate the original paper as closely as possible, so if you read the paper the implementation should be pretty much identical. The results from this implementation I would say is pretty close to the original paper (I'll include some examples results below) but because of time limitation I only trained to 256x256 and on lower model size than they did in the paper. Making the number of channels to 512 instead of 256 as I trained it would probably make the results even better :)
2+
A clean, simple and readable implementation of ProGAN in PyTorch. I've tried to replicate the original paper as closely as possible, so if you read the paper the implementation should be pretty much identical. The results from this implementation I would say is on par with the paper, I'll include some examples results below.
33

44
## Results
5-
The model was trained on the Celeb-HQ dataset up to 256x256 image size. After that point I felt it was enough as it would take quite a while to train to 1024^2.
5+
The model was trained on the Maps dataset and for fun I also tried using it to colorize anime.
66

7-
|First is some more cherrypicked examples and second is just sampled from random latent vectors|
7+
||
88
|:---:|
9-
|![](results/result1.png)|
109
|![](results/64_examples.png)|
10+
|![](results/result1.png)|
1111

1212

1313
### Celeb-HQ dataset
1414
The dataset can be downloaded from Kaggle: [link](https://www.kaggle.com/lamsimon/celebahq).
1515

16+
1617
### Download pretrained weights
17-
Pretrained weights [here](https://github.com/aladdinpersson/Machine-Learning-Collection/releases/download/1.0/ProGAN_weights.zip).
18+
Pretrained weights [here]().
1819

1920
Extract the zip file and put the pth.tar files in the directory with all the python files. Make sure you put LOAD_MODEL=True in the config.py file.
2021

ML/Pytorch/GANs/ProGAN/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
from math import log2
44

5-
START_TRAIN_AT_IMG_SIZE = 4
5+
START_TRAIN_AT_IMG_SIZE = 128
66
DATASET = 'celeb_dataset'
77
CHECKPOINT_GEN = "generator.pth"
88
CHECKPOINT_CRITIC = "critic.pth"

ML/Pytorch/GANs/ProGAN/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def forward(self, x, alpha, steps):
134134

135135

136136
class Discriminator(nn.Module):
137-
def __init__(self, in_channels, img_channels=3):
137+
def __init__(self, z_dim, in_channels, img_channels=3):
138138
super(Discriminator, self).__init__()
139139
self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
140140
self.leaky = nn.LeakyReLU(0.2)

ML/Pytorch/GANs/ProGAN/train.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -60,24 +60,19 @@ def train_fn(
6060
scaler_critic,
6161
):
6262
loop = tqdm(loader, leave=True)
63-
# critic_losses = []
64-
reals = 0
65-
fakes = 0
6663
for batch_idx, (real, _) in enumerate(loop):
6764
real = real.to(config.DEVICE)
6865
cur_batch_size = real.shape[0]
6966

7067
# Train Critic: max E[critic(real)] - E[critic(fake)] <-> min -E[critic(real)] + E[critic(fake)]
7168
# which is equivalent to minimizing the negative of the expression
72-
noise = torch.randn(cur_batch_size, config.Z_DIM).to(config.DEVICE)
69+
noise = torch.randn(cur_batch_size, config.Z_DIM, 1, 1).to(config.DEVICE)
7370

7471
with torch.cuda.amp.autocast():
7572
fake = gen(noise, alpha, step)
7673
critic_real = critic(real, alpha, step)
7774
critic_fake = critic(fake.detach(), alpha, step)
78-
reals += critic_real.mean().item()
79-
fakes += critic_fake.mean().item()
80-
gp = gradient_penalty(critic, real, fake, device=config.DEVICE)
75+
gp = gradient_penalty(critic, real, fake, alpha, step, device=config.DEVICE)
8176
loss_critic = (
8277
-(torch.mean(critic_real) - torch.mean(critic_fake))
8378
+ config.LAMBDA_GP * gp
@@ -119,8 +114,6 @@ def train_fn(
119114
tensorboard_step += 1
120115

121116
loop.set_postfix(
122-
reals=reals / (batch_idx + 1),
123-
fakes=fakes / (batch_idx + 1),
124117
gp=gp.item(),
125118
loss_critic=loss_critic.item(),
126119
)
@@ -131,11 +124,12 @@ def train_fn(
131124
def main():
132125
# initialize gen and disc, note: discriminator should be called critic,
133126
# according to WGAN paper (since it no longer outputs between [0, 1])
127+
# but really who cares..
134128
gen = Generator(
135129
config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
136130
).to(config.DEVICE)
137131
critic = Discriminator(
138-
config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
132+
config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
139133
).to(config.DEVICE)
140134

141135
# initialize optimizers and scalers for FP16 training
@@ -147,7 +141,7 @@ def main():
147141
scaler_gen = torch.cuda.amp.GradScaler()
148142

149143
# for tensorboard plotting
150-
writer = SummaryWriter(f"logs/gan")
144+
writer = SummaryWriter(f"logs/gan1")
151145

152146
if config.LOAD_MODEL:
153147
load_checkpoint(
@@ -163,10 +157,6 @@ def main():
163157
tensorboard_step = 0
164158
# start at step that corresponds to img size that we set in config
165159
step = int(log2(config.START_TRAIN_AT_IMG_SIZE / 4))
166-
167-
generate_examples(gen, step)
168-
import sys
169-
sys.exit()
170160
for num_epochs in config.PROGRESSIVE_EPOCHS[step:]:
171161
alpha = 1e-5 # start with very low alpha
172162
loader, dataset = get_loader(4 * 2 ** step) # 4->0, 8->1, 16->2, 32->3, 64 -> 4
@@ -197,4 +187,4 @@ def main():
197187

198188

199189
if __name__ == "__main__":
200-
main()
190+
main()

ML/Pytorch/GANs/ProGAN/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,4 @@ def generate_examples(gen, steps, truncation=0.7, n=100):
8787
noise = torch.tensor(truncnorm.rvs(-truncation, truncation, size=(1, config.Z_DIM, 1, 1)), device=config.DEVICE, dtype=torch.float32)
8888
img = gen(noise, alpha, steps)
8989
save_image(img*0.5+0.5, f"saved_examples/img_{i}.png")
90-
gen.train()
91-
92-
90+
gen.train()

0 commit comments

Comments
 (0)