99
1010parser  =  argparse .ArgumentParser (description = 'PyTorch MNIST Example' )
1111parser .add_argument ('--batch-size' , type = int , default = 128 , metavar = 'N' ,
12-  help = 'input batch size for training (default: 64 )' )
12+  help = 'input batch size for training (default: 128 )' )
1313parser .add_argument ('--epochs' , type = int , default = 10 , metavar = 'N' ,
14-  help = 'number of epochs to train (default: 2 )' )
14+  help = 'number of epochs to train (default: 10 )' )
1515parser .add_argument ('--no-cuda' , action = 'store_true' , default = False ,
1616 help = 'enables CUDA training' )
1717parser .add_argument ('--seed' , type = int , default = 1 , metavar = 'S' ,
@@ -56,11 +56,7 @@ def encode(self, x):
5656
5757 def  reparametrize (self , mu , logvar ):
5858 std  =  logvar .mul (0.5 ).exp_ ()
59-  if  args .cuda :
60-  eps  =  torch .cuda .FloatTensor (std .size ()).normal_ ()
61-  else :
62-  eps  =  torch .FloatTensor (std .size ()).normal_ ()
63-  eps  =  Variable (eps )
59+  eps  =  Variable (std .data .new (std .size ()).normal_ ())
6460 return  eps .mul (std ).add_ (mu )
6561
6662 def  decode (self , z ):
@@ -82,7 +78,7 @@ def forward(self, x):
8278
8379
8480def  loss_function (recon_x , x , mu , logvar ):
85-  BCE  =  reconstruction_function (recon_x , x )
81+  BCE  =  reconstruction_function (recon_x , x . view ( - 1 ,  784 ) )
8682
8783 # see Appendix B from VAE paper: 
8884 # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 
0 commit comments