Skip to content

Commit 3648cbc

Browse files
r9y9soumith
authored andcommitted
vae: Fix UserWarning (#220)
1 parent 30b9c0e commit 3648cbc

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

vae/main.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99

1010
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
1111
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
12-
help='input batch size for training (default: 64)')
12+
help='input batch size for training (default: 128)')
1313
parser.add_argument('--epochs', type=int, default=10, metavar='N',
14-
help='number of epochs to train (default: 2)')
14+
help='number of epochs to train (default: 10)')
1515
parser.add_argument('--no-cuda', action='store_true', default=False,
1616
help='enables CUDA training')
1717
parser.add_argument('--seed', type=int, default=1, metavar='S',
@@ -56,11 +56,7 @@ def encode(self, x):
5656

5757
def reparametrize(self, mu, logvar):
5858
std = logvar.mul(0.5).exp_()
59-
if args.cuda:
60-
eps = torch.cuda.FloatTensor(std.size()).normal_()
61-
else:
62-
eps = torch.FloatTensor(std.size()).normal_()
63-
eps = Variable(eps)
59+
eps = Variable(std.data.new(std.size()).normal_())
6460
return eps.mul(std).add_(mu)
6561

6662
def decode(self, z):
@@ -82,7 +78,7 @@ def forward(self, x):
8278

8379

8480
def loss_function(recon_x, x, mu, logvar):
85-
BCE = reconstruction_function(recon_x, x)
81+
BCE = reconstruction_function(recon_x, x.view(-1, 784))
8682

8783
# see Appendix B from VAE paper:
8884
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014

0 commit comments

Comments
 (0)