22import argparse
33import torch
44import torch .utils .data
5- import torch .nn as nn
6- import torch .optim as optim
5+ from torch import nn , optim
76from torch .autograd import Variable
87from torchvision import datasets , transforms
8+ from torchvision .utils import save_image
99
10- parser = argparse .ArgumentParser (description = 'PyTorch MNIST Example' )
10+
11+ parser = argparse .ArgumentParser (description = 'VAE MNIST Example' )
1112parser .add_argument ('--batch-size' , type = int , default = 128 , metavar = 'N' ,
1213 help = 'input batch size for training (default: 128)' )
1314parser .add_argument ('--epochs' , type = int , default = 10 , metavar = 'N' ,
@@ -54,18 +55,21 @@ def encode(self, x):
5455 h1 = self .relu (self .fc1 (x ))
5556 return self .fc21 (h1 ), self .fc22 (h1 )
5657
57- def reparametrize (self , mu , logvar ):
58- std = logvar .mul (0.5 ).exp_ ()
59- eps = Variable (std .data .new (std .size ()).normal_ ())
60- return eps .mul (std ).add_ (mu )
58+ def reparameterize (self , mu , logvar ):
59+ if self .training :
60+ std = logvar .mul (0.5 ).exp_ ()
61+ eps = Variable (std .data .new (std .size ()).normal_ ())
62+ return eps .mul (std ).add_ (mu )
63+ else :
64+ return mu
6165
6266 def decode (self , z ):
6367 h3 = self .relu (self .fc3 (z ))
6468 return self .sigmoid (self .fc4 (h3 ))
6569
6670 def forward (self , x ):
6771 mu , logvar = self .encode (x .view (- 1 , 784 ))
68- z = self .reparametrize (mu , logvar )
72+ z = self .reparameterize (mu , logvar )
6973 return self .decode (z ), mu , logvar
7074
7175
@@ -74,7 +78,6 @@ def forward(self, x):
7478 model .cuda ()
7579
7680reconstruction_function = nn .BCELoss ()
77- reconstruction_function .size_average = False
7881
7982
8083def loss_function (recon_x , x , mu , logvar ):
@@ -86,6 +89,8 @@ def loss_function(recon_x, x, mu, logvar):
8689 # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
8790 KLD_element = mu .pow (2 ).add_ (logvar .exp ()).mul_ (- 1 ).add_ (1 ).add_ (logvar )
8891 KLD = torch .sum (KLD_element ).mul_ (- 0.5 )
92+ # Normalise by same number of elements as in reconstruction
93+ KLD /= args .batch_size * 784
8994
9095 return BCE + KLD
9196
@@ -119,12 +124,15 @@ def train(epoch):
119124def test (epoch ):
120125 model .eval ()
121126 test_loss = 0
122- for data , _ in test_loader :
127+ for i , ( data , _ ) in enumerate ( test_loader ) :
123128 if args .cuda :
124129 data = data .cuda ()
125130 data = Variable (data , volatile = True )
126131 recon_batch , mu , logvar = model (data )
127132 test_loss += loss_function (recon_batch , data , mu , logvar ).data [0 ]
133+ if i == 0 :
134+ save_image (recon_batch .data .view (args .batch_size , 1 , 28 , 28 ),
135+ 'reconstruction_' + str (epoch ) + '.png' )
128136
129137 test_loss /= len (test_loader .dataset )
130138 print ('====> Test set loss: {:.4f}' .format (test_loss ))
@@ -133,3 +141,5 @@ def test(epoch):
133141for epoch in range (1 , args .epochs + 1 ):
134142 train (epoch )
135143 test (epoch )
144+ sample = model .decode (Variable (torch .randn (64 , 20 )))
145+ save_image (sample .data .view (64 , 1 , 28 , 28 ), 'sample_' + str (epoch ) + '.png' )
0 commit comments