Skip to content

Commit ab7cb38

Browse files
Kaixhinsoumith
authored andcommitted
Balance VAE losses, add reconstruction + sampling
1 parent ddf9e30 commit ab7cb38

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

vae/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.png

vae/main.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
import argparse
33
import torch
44
import torch.utils.data
5-
import torch.nn as nn
6-
import torch.optim as optim
5+
from torch import nn, optim
76
from torch.autograd import Variable
87
from torchvision import datasets, transforms
8+
from torchvision.utils import save_image
99

10-
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
10+
11+
parser = argparse.ArgumentParser(description='VAE MNIST Example')
1112
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
1213
help='input batch size for training (default: 128)')
1314
parser.add_argument('--epochs', type=int, default=10, metavar='N',
@@ -54,18 +55,21 @@ def encode(self, x):
5455
h1 = self.relu(self.fc1(x))
5556
return self.fc21(h1), self.fc22(h1)
5657

57-
def reparametrize(self, mu, logvar):
58-
std = logvar.mul(0.5).exp_()
59-
eps = Variable(std.data.new(std.size()).normal_())
60-
return eps.mul(std).add_(mu)
58+
def reparameterize(self, mu, logvar):
59+
if self.training:
60+
std = logvar.mul(0.5).exp_()
61+
eps = Variable(std.data.new(std.size()).normal_())
62+
return eps.mul(std).add_(mu)
63+
else:
64+
return mu
6165

6266
def decode(self, z):
6367
h3 = self.relu(self.fc3(z))
6468
return self.sigmoid(self.fc4(h3))
6569

6670
def forward(self, x):
6771
mu, logvar = self.encode(x.view(-1, 784))
68-
z = self.reparametrize(mu, logvar)
72+
z = self.reparameterize(mu, logvar)
6973
return self.decode(z), mu, logvar
7074

7175

@@ -74,7 +78,6 @@ def forward(self, x):
7478
model.cuda()
7579

7680
reconstruction_function = nn.BCELoss()
77-
reconstruction_function.size_average = False
7881

7982

8083
def loss_function(recon_x, x, mu, logvar):
@@ -86,6 +89,8 @@ def loss_function(recon_x, x, mu, logvar):
8689
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
8790
KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
8891
KLD = torch.sum(KLD_element).mul_(-0.5)
92+
# Normalise by same number of elements as in reconstruction
93+
KLD /= args.batch_size * 784
8994

9095
return BCE + KLD
9196

@@ -119,12 +124,15 @@ def train(epoch):
119124
def test(epoch):
120125
model.eval()
121126
test_loss = 0
122-
for data, _ in test_loader:
127+
for i, (data, _) in enumerate(test_loader):
123128
if args.cuda:
124129
data = data.cuda()
125130
data = Variable(data, volatile=True)
126131
recon_batch, mu, logvar = model(data)
127132
test_loss += loss_function(recon_batch, data, mu, logvar).data[0]
133+
if i == 0:
134+
save_image(recon_batch.data.view(args.batch_size, 1, 28, 28),
135+
'reconstruction_' + str(epoch) + '.png')
128136

129137
test_loss /= len(test_loader.dataset)
130138
print('====> Test set loss: {:.4f}'.format(test_loss))
@@ -133,3 +141,5 @@ def test(epoch):
133141
for epoch in range(1, args.epochs + 1):
134142
train(epoch)
135143
test(epoch)
144+
sample = model.decode(Variable(torch.randn(64, 20)))
145+
save_image(sample.data.view(64, 1, 28, 28), 'sample_' + str(epoch) + '.png')

0 commit comments

Comments
 (0)