- Notifications
You must be signed in to change notification settings - Fork 0
Update main.py #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,140 +1,281 @@ | ||||||||||||||||||||||||||||||||||||||||||||||
| '''Example of VAE on MNIST dataset using MLP | ||||||||||||||||||||||||||||||||||||||||||||||
| The VAE has a modular design. The encoder, decoder and VAE | ||||||||||||||||||||||||||||||||||||||||||||||
| are 3 models that share weights. After training the VAE model, | ||||||||||||||||||||||||||||||||||||||||||||||
| the encoder can be used to generate latent vectors. | ||||||||||||||||||||||||||||||||||||||||||||||
| The decoder can be used to generate MNIST digits by sampling the | ||||||||||||||||||||||||||||||||||||||||||||||
| latent vector from a Gaussian distribution with mean = 0 and std = 1. | ||||||||||||||||||||||||||||||||||||||||||||||
| # Reference | ||||||||||||||||||||||||||||||||||||||||||||||
| [1] Kingma, Diederik P., and Max Welling. | ||||||||||||||||||||||||||||||||||||||||||||||
| "Auto-Encoding Variational Bayes." | ||||||||||||||||||||||||||||||||||||||||||||||
| https://arxiv.org/abs/1312.6114 | ||||||||||||||||||||||||||||||||||||||||||||||
| ''' | ||||||||||||||||||||||||||||||||||||||||||||||
| from __future__ import print_function | ||||||||||||||||||||||||||||||||||||||||||||||
| import argparse | ||||||||||||||||||||||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||
| import torch.utils.data | ||||||||||||||||||||||||||||||||||||||||||||||
| from torch.utils.data import DataLoader | ||||||||||||||||||||||||||||||||||||||||||||||
| from torch import nn, optim | ||||||||||||||||||||||||||||||||||||||||||||||
| from torch.nn import functional as F | ||||||||||||||||||||||||||||||||||||||||||||||
| from torchvision import datasets, transforms | ||||||||||||||||||||||||||||||||||||||||||||||
| from torchvision.utils import save_image | ||||||||||||||||||||||||||||||||||||||||||||||
| from torchvision.utils import save_image, make_grid | ||||||||||||||||||||||||||||||||||||||||||||||
| import matplotlib.pyplot as plt | ||||||||||||||||||||||||||||||||||||||||||||||
| import matplotlib.animation as animation | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| parser = argparse.ArgumentParser(description='VAE MNIST Example') | ||||||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument('--batch-size', type=int, default=128, metavar='N', | ||||||||||||||||||||||||||||||||||||||||||||||
| help='input batch size for training (default: 128)') | ||||||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument('--epochs', type=int, default=10, metavar='N', | ||||||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument('--epochs', type=int, default=50, metavar='N', | ||||||||||||||||||||||||||||||||||||||||||||||
| help='number of epochs to train (default: 10)') | ||||||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument('--no-cuda', action='store_true', default=False, | ||||||||||||||||||||||||||||||||||||||||||||||
| help='disables CUDA training') | ||||||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument('--no-mps', action='store_true', default=False, | ||||||||||||||||||||||||||||||||||||||||||||||
| help='disables macOS GPU training') | ||||||||||||||||||||||||||||||||||||||||||||||
| help='enables CUDA training') | ||||||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument('--seed', type=int, default=1, metavar='S', | ||||||||||||||||||||||||||||||||||||||||||||||
| help='random seed (default: 1)') | ||||||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument('--log-interval', type=int, default=10, metavar='N', | ||||||||||||||||||||||||||||||||||||||||||||||
| help='how many batches to wait before logging training status') | ||||||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument('--reduction', type=str, default='mean', metavar='N', | ||||||||||||||||||||||||||||||||||||||||||||||
| help='Type of reduction to do [choices: sum, mean] (default: mean)') | ||||||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument('--use-mse', type=bool, default=False, metavar='N', | ||||||||||||||||||||||||||||||||||||||||||||||
| help='Whether to use MSE instead of BCE (default: False)') | ||||||||||||||||||||||||||||||||||||||||||||||
| parser.add_argument('--convert_path', type=str, default='C:/Program Files/ImageMagick/convert.exe', | ||||||||||||||||||||||||||||||||||||||||||||||
| metavar='N', help='Under windows, specify where convert.exe is located') | ||||||||||||||||||||||||||||||||||||||||||||||
| args = parser.parse_args() | ||||||||||||||||||||||||||||||||||||||||||||||
| args.cuda = not args.no_cuda and torch.cuda.is_available() | ||||||||||||||||||||||||||||||||||||||||||||||
| use_mps = not args.no_mps and torch.backends.mps.is_available() | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| torch.manual_seed(args.seed) | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| if args.cuda: | ||||||||||||||||||||||||||||||||||||||||||||||
| device = torch.device("cuda") | ||||||||||||||||||||||||||||||||||||||||||||||
| elif use_mps: | ||||||||||||||||||||||||||||||||||||||||||||||
| device = torch.device("mps") | ||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||
| device = torch.device("cpu") | ||||||||||||||||||||||||||||||||||||||||||||||
| device = torch.device("cuda" if args.cuda else "cpu") | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} | ||||||||||||||||||||||||||||||||||||||||||||||
| train_loader = torch.utils.data.DataLoader( | ||||||||||||||||||||||||||||||||||||||||||||||
| datasets.MNIST('../data', train=True, download=True, | ||||||||||||||||||||||||||||||||||||||||||||||
| transform=transforms.ToTensor()), | ||||||||||||||||||||||||||||||||||||||||||||||
| batch_size=args.batch_size, shuffle=True, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||
| test_loader = torch.utils.data.DataLoader( | ||||||||||||||||||||||||||||||||||||||||||||||
| datasets.MNIST('../data', train=False, transform=transforms.ToTensor()), | ||||||||||||||||||||||||||||||||||||||||||||||
| batch_size=args.batch_size, shuffle=False, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||
| train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transforms.ToTensor()) | ||||||||||||||||||||||||||||||||||||||||||||||
| train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| test_dataset = datasets.MNIST('../data', train=False, transform=transforms.ToTensor()) | ||||||||||||||||||||||||||||||||||||||||||||||
| test_loader = DataLoader(test_dataset , batch_size=args.batch_size, shuffle=True, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If possible, it is better to rely on automatic pinning in PyTorch to avoid undefined behavior and for efficiency | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| class VAE(nn.Module): | ||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self): | ||||||||||||||||||||||||||||||||||||||||||||||
| super(VAE, self).__init__() | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| self.fc1 = nn.Linear(784, 400) | ||||||||||||||||||||||||||||||||||||||||||||||
| self.fc21 = nn.Linear(400, 20) | ||||||||||||||||||||||||||||||||||||||||||||||
| self.fc22 = nn.Linear(400, 20) | ||||||||||||||||||||||||||||||||||||||||||||||
| self.fc3 = nn.Linear(20, 400) | ||||||||||||||||||||||||||||||||||||||||||||||
| self.fc4 = nn.Linear(400, 784) | ||||||||||||||||||||||||||||||||||||||||||||||
| class VAE(nn.Module): | ||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, embedding_size=2): | ||||||||||||||||||||||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| def encode(self, x): | ||||||||||||||||||||||||||||||||||||||||||||||
| h1 = F.relu(self.fc1(x)) | ||||||||||||||||||||||||||||||||||||||||||||||
| return self.fc21(h1), self.fc22(h1) | ||||||||||||||||||||||||||||||||||||||||||||||
| self.embedding_size = embedding_size | ||||||||||||||||||||||||||||||||||||||||||||||
| self.fc1 = nn.Linear(28*28, 512) | ||||||||||||||||||||||||||||||||||||||||||||||
| self.fc1_mu = nn.Linear(512, self.embedding_size) | ||||||||||||||||||||||||||||||||||||||||||||||
| self.fc1_std = nn.Linear(512, self.embedding_size) | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| def reparameterize(self, mu, logvar): | ||||||||||||||||||||||||||||||||||||||||||||||
| self.decoder = nn.Sequential( nn.Linear(self.embedding_size, 512), | ||||||||||||||||||||||||||||||||||||||||||||||
| nn.ReLU(), | ||||||||||||||||||||||||||||||||||||||||||||||
| nn.Linear(512, 28*28), | ||||||||||||||||||||||||||||||||||||||||||||||
| nn.Sigmoid()) | ||||||||||||||||||||||||||||||||||||||||||||||
| # VAEs sample from a random node z. Backprop cannot flow through a random node. | ||||||||||||||||||||||||||||||||||||||||||||||
| # In order to solve this, we randomly sample 'epsilon' from a unit Gaussian, | ||||||||||||||||||||||||||||||||||||||||||||||
| # and then simply shift it by the latent distrubtions mean and scale it by its varinace | ||||||||||||||||||||||||||||||||||||||||||||||
| # This is called "reparameterization trick". | ||||||||||||||||||||||||||||||||||||||||||||||
| # ϵ allows us to reparameterize z in a way that allows backprop to flow through the | ||||||||||||||||||||||||||||||||||||||||||||||
| # deterministic nodes. | ||||||||||||||||||||||||||||||||||||||||||||||
| # With this reparameterization, we can now optimize the parameters of the distribution | ||||||||||||||||||||||||||||||||||||||||||||||
| # while still maintaining the ability to randomly sample from that distribution. | ||||||||||||||||||||||||||||||||||||||||||||||
| # Note: In order to deal with the fact that the network may learn negative values | ||||||||||||||||||||||||||||||||||||||||||||||
| # for σ, we'll typically have the network learn log(σ) and exponentiate(exp)) this value | ||||||||||||||||||||||||||||||||||||||||||||||
| # to get the latent distribution's variance. | ||||||||||||||||||||||||||||||||||||||||||||||
| def reparamtrization_trick(self, mu, logvar): | ||||||||||||||||||||||||||||||||||||||||||||||
| # we divide by two because we are eliminating the negative values | ||||||||||||||||||||||||||||||||||||||||||||||
| # and we only care about the absolute possible deviance from standard. | ||||||||||||||||||||||||||||||||||||||||||||||
| std = torch.exp(0.5*logvar) | ||||||||||||||||||||||||||||||||||||||||||||||
| # epsilon sampled from normal distribution with N(0,1) | ||||||||||||||||||||||||||||||||||||||||||||||
| eps = torch.randn_like(std) | ||||||||||||||||||||||||||||||||||||||||||||||
| # How to sample from a normal distribution with known mean and variance? | ||||||||||||||||||||||||||||||||||||||||||||||
| # https://stats.stackexchange.com/questions/16334/ | ||||||||||||||||||||||||||||||||||||||||||||||
| # (tldr: just add the mu , multiply by the var) . | ||||||||||||||||||||||||||||||||||||||||||||||
| # why we use an epsilon? because without it, backprop wouldnt work. | ||||||||||||||||||||||||||||||||||||||||||||||
| return mu + eps*std | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| def decode(self, z): | ||||||||||||||||||||||||||||||||||||||||||||||
| h3 = F.relu(self.fc3(z)) | ||||||||||||||||||||||||||||||||||||||||||||||
| return torch.sigmoid(self.fc4(h3)) | ||||||||||||||||||||||||||||||||||||||||||||||
| def encode(self, input): | ||||||||||||||||||||||||||||||||||||||||||||||
| input = input.view(input.size(0), -1) | ||||||||||||||||||||||||||||||||||||||||||||||
| output = F.relu(self.fc1(input)) | ||||||||||||||||||||||||||||||||||||||||||||||
| # ref: https://www.jeremyjordan.me/variational-autoencoders/ | ||||||||||||||||||||||||||||||||||||||||||||||
| # Note that we are not using any activation functions here. | ||||||||||||||||||||||||||||||||||||||||||||||
| # in other words our vectors μ and σ are unbounded that is they can take | ||||||||||||||||||||||||||||||||||||||||||||||
| # any values and thus our encoder will be able to learn to generate very | ||||||||||||||||||||||||||||||||||||||||||||||
| # different μ for different classes, clustering them apart, and minimize σ, | ||||||||||||||||||||||||||||||||||||||||||||||
| # making sure the encodings themselves don’t vary much for the same sample | ||||||||||||||||||||||||||||||||||||||||||||||
| # (that is, less uncertainty for the decoder). | ||||||||||||||||||||||||||||||||||||||||||||||
| # This allows the decoder to efficiently reconstruct the training data. | ||||||||||||||||||||||||||||||||||||||||||||||
| mu = self.fc1_mu(output) | ||||||||||||||||||||||||||||||||||||||||||||||
| log_var = self.fc1_std(output) | ||||||||||||||||||||||||||||||||||||||||||||||
| z = self.reparamtrization_trick(mu, log_var) | ||||||||||||||||||||||||||||||||||||||||||||||
| return z, mu, log_var | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| def forward(self, x): | ||||||||||||||||||||||||||||||||||||||||||||||
| mu, logvar = self.encode(x.view(-1, 784)) | ||||||||||||||||||||||||||||||||||||||||||||||
| z = self.reparameterize(mu, logvar) | ||||||||||||||||||||||||||||||||||||||||||||||
| def decode(self, z): | ||||||||||||||||||||||||||||||||||||||||||||||
| output = self.decoder(z).view(z.size(0), 1, 28, 28) | ||||||||||||||||||||||||||||||||||||||||||||||
| return output | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| def forward(self, input): | ||||||||||||||||||||||||||||||||||||||||||||||
| z, mu, logvar = self.encode(input) | ||||||||||||||||||||||||||||||||||||||||||||||
| return self.decode(z), mu, logvar | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| model = VAE().to(device) | ||||||||||||||||||||||||||||||||||||||||||||||
| embedding_size = 2 | ||||||||||||||||||||||||||||||||||||||||||||||
| model = VAE(embedding_size).to(device) | ||||||||||||||||||||||||||||||||||||||||||||||
| optimizer = optim.Adam(model.parameters(), lr=1e-3) | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| # Reconstruction + KL divergence losses summed over all elements and batch | ||||||||||||||||||||||||||||||||||||||||||||||
| def loss_function(recon_x, x, mu, logvar): | ||||||||||||||||||||||||||||||||||||||||||||||
| BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum') | ||||||||||||||||||||||||||||||||||||||||||||||
| def loss_function(outputs, inputs, mu, logvar, reduction ='mean', use_mse = False): | ||||||||||||||||||||||||||||||||||||||||||||||
| if reduction == 'sum': | ||||||||||||||||||||||||||||||||||||||||||||||
| criterion = nn.BCELoss(reduction='sum') | ||||||||||||||||||||||||||||||||||||||||||||||
| reconstruction_loss = criterion(outputs, inputs) | ||||||||||||||||||||||||||||||||||||||||||||||
| # see Appendix B from VAE paper: | ||||||||||||||||||||||||||||||||||||||||||||||
| # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 | ||||||||||||||||||||||||||||||||||||||||||||||
| # https://arxiv.org/abs/1312.6114 | ||||||||||||||||||||||||||||||||||||||||||||||
| # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) | ||||||||||||||||||||||||||||||||||||||||||||||
| KL = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) | ||||||||||||||||||||||||||||||||||||||||||||||
| return reconstruction_loss + KL | ||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||
| if use_mse: | ||||||||||||||||||||||||||||||||||||||||||||||
| criterion = nn.MSELoss() | ||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||
| criterion = nn.BCELoss(reduction='mean') | ||||||||||||||||||||||||||||||||||||||||||||||
| reconstruction_loss = criterion(outputs, inputs) | ||||||||||||||||||||||||||||||||||||||||||||||
| # normalize reconstruction loss | ||||||||||||||||||||||||||||||||||||||||||||||
| reconstruction_loss *= 28*28 | ||||||||||||||||||||||||||||||||||||||||||||||
| KL = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), -1) | ||||||||||||||||||||||||||||||||||||||||||||||
| return torch.mean(reconstruction_loss + KL) | ||||||||||||||||||||||||||||||||||||||||||||||
| Comment on lines +135 to +145 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This rule is about removing unnecessary else statements in your code. An else statement is considered unnecessary when it follows a return statement in the if block. In such cases, the else block can be safely removed without changing the logic of the code. Unnecessary else statements can make your code harder to read and understand. They can also lead to more complex code structures, which can increase the likelihood of introducing bugs. By removing unnecessary else statements, you can make your code simpler and more readable. Suggested change
| ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| def plot_latent_space(): | ||||||||||||||||||||||||||||||||||||||||||||||
| dataloader_test = DataLoader(test_dataset, | ||||||||||||||||||||||||||||||||||||||||||||||
| batch_size = len(test_dataset), | ||||||||||||||||||||||||||||||||||||||||||||||
| num_workers = 2, | ||||||||||||||||||||||||||||||||||||||||||||||
| pin_memory=True) | ||||||||||||||||||||||||||||||||||||||||||||||
| imgs, labels = next(iter(dataloader_test)) | ||||||||||||||||||||||||||||||||||||||||||||||
| imgs = imgs.to(device) | ||||||||||||||||||||||||||||||||||||||||||||||
| z_test,_,_ = model.encode(imgs) | ||||||||||||||||||||||||||||||||||||||||||||||
| z_test = z_test.cpu().detach().numpy() | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| # see Appendix B from VAE paper: | ||||||||||||||||||||||||||||||||||||||||||||||
| # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 | ||||||||||||||||||||||||||||||||||||||||||||||
| # https://arxiv.org/abs/1312.6114 | ||||||||||||||||||||||||||||||||||||||||||||||
| # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) | ||||||||||||||||||||||||||||||||||||||||||||||
| KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) | ||||||||||||||||||||||||||||||||||||||||||||||
| plt.figure(figsize=(12,10)) | ||||||||||||||||||||||||||||||||||||||||||||||
| img = plt.scatter(x=z_test[:,0], | ||||||||||||||||||||||||||||||||||||||||||||||
| y=z_test[:,1], | ||||||||||||||||||||||||||||||||||||||||||||||
| c=labels.numpy(), | ||||||||||||||||||||||||||||||||||||||||||||||
| alpha=.4, | ||||||||||||||||||||||||||||||||||||||||||||||
| s=3**2, | ||||||||||||||||||||||||||||||||||||||||||||||
| cmap='viridis') | ||||||||||||||||||||||||||||||||||||||||||||||
| plt.colorbar() | ||||||||||||||||||||||||||||||||||||||||||||||
| plt.xlabel('Z[0]') | ||||||||||||||||||||||||||||||||||||||||||||||
| plt.ylabel('Z[1]') | ||||||||||||||||||||||||||||||||||||||||||||||
| plt.savefig('vae_latent_space.png') | ||||||||||||||||||||||||||||||||||||||||||||||
| plt.show() | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| return BCE + KLD | ||||||||||||||||||||||||||||||||||||||||||||||
| def display_2d_manifold(model, digit_count=20): | ||||||||||||||||||||||||||||||||||||||||||||||
| # display a 2D manifold of the digits | ||||||||||||||||||||||||||||||||||||||||||||||
| embeddingsize = model.embedding_size | ||||||||||||||||||||||||||||||||||||||||||||||
| # figure with 20x20 digits | ||||||||||||||||||||||||||||||||||||||||||||||
| n = digit_count | ||||||||||||||||||||||||||||||||||||||||||||||
| digit_size = 28 | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| z1 = torch.linspace(-2, 2, n) | ||||||||||||||||||||||||||||||||||||||||||||||
| z2 = torch.linspace(-2, 2, n) | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| def train(epoch): | ||||||||||||||||||||||||||||||||||||||||||||||
| z_grid = np.dstack(np.meshgrid(z1, z2)) | ||||||||||||||||||||||||||||||||||||||||||||||
| z_grid = torch.from_numpy(z_grid).to(device) | ||||||||||||||||||||||||||||||||||||||||||||||
| z_grid = z_grid.reshape(-1, embeddingsize) | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| x_pred_grid = model.decode(z_grid) | ||||||||||||||||||||||||||||||||||||||||||||||
| x_pred_grid= x_pred_grid.cpu().detach() | ||||||||||||||||||||||||||||||||||||||||||||||
| x = make_grid(x_pred_grid, nrow=n).numpy().transpose(1, 2, 0) | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| plt.figure(figsize=(10, 10)) | ||||||||||||||||||||||||||||||||||||||||||||||
| plt.xlabel('Z_1') | ||||||||||||||||||||||||||||||||||||||||||||||
| plt.ylabel('Z_2') | ||||||||||||||||||||||||||||||||||||||||||||||
| plt.imshow(x) | ||||||||||||||||||||||||||||||||||||||||||||||
| plt.savefig('vae_digits_2d_manifiold.png') | ||||||||||||||||||||||||||||||||||||||||||||||
| plt.show() | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| def save_animation(model, sample_count = 30, use_mp4=True): | ||||||||||||||||||||||||||||||||||||||||||||||
| fig = plt.figure() | ||||||||||||||||||||||||||||||||||||||||||||||
| ax = fig.add_subplot(111) | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| if os.name == 'nt': | ||||||||||||||||||||||||||||||||||||||||||||||
| plt.rcParams["animation.convert_path"] = args.convert_path | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| z = torch.randn(size = (sample_count, model.embedding_size)).to(device) | ||||||||||||||||||||||||||||||||||||||||||||||
| model.eval() | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| def animate(i): | ||||||||||||||||||||||||||||||||||||||||||||||
| imgs = model.decode(z * (i * 0.03) + 0.02) | ||||||||||||||||||||||||||||||||||||||||||||||
| img_grid = make_grid(imgs).cpu().detach().numpy().transpose(1, 2, 0) | ||||||||||||||||||||||||||||||||||||||||||||||
| ax.clear() | ||||||||||||||||||||||||||||||||||||||||||||||
| ax.imshow(img_grid) | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| anim = animation.FuncAnimation(fig, animate, frames=100, interval=300, | ||||||||||||||||||||||||||||||||||||||||||||||
| repeat=True, repeat_delay=1000) | ||||||||||||||||||||||||||||||||||||||||||||||
| if use_mp4: | ||||||||||||||||||||||||||||||||||||||||||||||
| anim.save('vae_off.mp4', writer="ffmpeg", fps=20) | ||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||
| anim.save('vae_off.gif', writer="imagemagick", extra_args="convert", fps=20) | ||||||||||||||||||||||||||||||||||||||||||||||
| # plt.show() | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| def train(epoch, reduction='mean', use_mse=False): | ||||||||||||||||||||||||||||||||||||||||||||||
| model.train() | ||||||||||||||||||||||||||||||||||||||||||||||
| train_loss = 0 | ||||||||||||||||||||||||||||||||||||||||||||||
| for batch_idx, (data, _) in enumerate(train_loader): | ||||||||||||||||||||||||||||||||||||||||||||||
| data = data.to(device) | ||||||||||||||||||||||||||||||||||||||||||||||
| for batch_idx, (imgs, _) in enumerate(train_loader): | ||||||||||||||||||||||||||||||||||||||||||||||
| imgs = imgs.to(device) | ||||||||||||||||||||||||||||||||||||||||||||||
| optimizer.zero_grad() | ||||||||||||||||||||||||||||||||||||||||||||||
| recon_batch, mu, logvar = model(data) | ||||||||||||||||||||||||||||||||||||||||||||||
| loss = loss_function(recon_batch, data, mu, logvar) | ||||||||||||||||||||||||||||||||||||||||||||||
| recons, mu, logvar = model(imgs) | ||||||||||||||||||||||||||||||||||||||||||||||
| loss = loss_function(recons, | ||||||||||||||||||||||||||||||||||||||||||||||
| imgs, | ||||||||||||||||||||||||||||||||||||||||||||||
| mu, | ||||||||||||||||||||||||||||||||||||||||||||||
| logvar, | ||||||||||||||||||||||||||||||||||||||||||||||
| reduction, | ||||||||||||||||||||||||||||||||||||||||||||||
| use_mse) | ||||||||||||||||||||||||||||||||||||||||||||||
| loss.backward() | ||||||||||||||||||||||||||||||||||||||||||||||
| train_loss += loss.item() | ||||||||||||||||||||||||||||||||||||||||||||||
| optimizer.step() | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| if batch_idx % args.log_interval == 0: | ||||||||||||||||||||||||||||||||||||||||||||||
| loss = loss/len(imgs) if (reduction == 'sum') else loss | ||||||||||||||||||||||||||||||||||||||||||||||
| print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( | ||||||||||||||||||||||||||||||||||||||||||||||
| epoch, batch_idx * len(data), len(train_loader.dataset), | ||||||||||||||||||||||||||||||||||||||||||||||
| epoch, batch_idx * len(imgs), len(train_loader.dataset), | ||||||||||||||||||||||||||||||||||||||||||||||
| 100. * batch_idx / len(train_loader), | ||||||||||||||||||||||||||||||||||||||||||||||
| loss.item() / len(data))) | ||||||||||||||||||||||||||||||||||||||||||||||
| loss.item())) | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| print('====> Epoch: {} Average loss: {:.4f}'.format( | ||||||||||||||||||||||||||||||||||||||||||||||
| epoch, train_loss / len(train_loader.dataset))) | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| def test(epoch): | ||||||||||||||||||||||||||||||||||||||||||||||
| def test(epoch, reduction='mean', use_mse=False): | ||||||||||||||||||||||||||||||||||||||||||||||
| model.eval() | ||||||||||||||||||||||||||||||||||||||||||||||
| test_loss = 0 | ||||||||||||||||||||||||||||||||||||||||||||||
| with torch.no_grad(): | ||||||||||||||||||||||||||||||||||||||||||||||
| for i, (data, _) in enumerate(test_loader): | ||||||||||||||||||||||||||||||||||||||||||||||
| data = data.to(device) | ||||||||||||||||||||||||||||||||||||||||||||||
| recon_batch, mu, logvar = model(data) | ||||||||||||||||||||||||||||||||||||||||||||||
| test_loss += loss_function(recon_batch, data, mu, logvar).item() | ||||||||||||||||||||||||||||||||||||||||||||||
| for i, (imgs, _) in enumerate(test_loader): | ||||||||||||||||||||||||||||||||||||||||||||||
| imgs = imgs.to(device) | ||||||||||||||||||||||||||||||||||||||||||||||
| recons, mu, logvar = model(imgs) | ||||||||||||||||||||||||||||||||||||||||||||||
| test_loss += loss_function(recons, | ||||||||||||||||||||||||||||||||||||||||||||||
| imgs, | ||||||||||||||||||||||||||||||||||||||||||||||
| mu, | ||||||||||||||||||||||||||||||||||||||||||||||
| logvar, | ||||||||||||||||||||||||||||||||||||||||||||||
| reduction, | ||||||||||||||||||||||||||||||||||||||||||||||
| use_mse).item() | ||||||||||||||||||||||||||||||||||||||||||||||
| if i == 0: | ||||||||||||||||||||||||||||||||||||||||||||||
| n = min(data.size(0), 8) | ||||||||||||||||||||||||||||||||||||||||||||||
| comparison = torch.cat([data[:n], | ||||||||||||||||||||||||||||||||||||||||||||||
| recon_batch.view(args.batch_size, 1, 28, 28)[:n]]) | ||||||||||||||||||||||||||||||||||||||||||||||
| save_image(comparison.cpu(), | ||||||||||||||||||||||||||||||||||||||||||||||
| 'results/reconstruction_' + str(epoch) + '.png', nrow=n) | ||||||||||||||||||||||||||||||||||||||||||||||
| n = min(imgs.size(0), 8) | ||||||||||||||||||||||||||||||||||||||||||||||
| comparison = torch.cat([imgs[:n], recons[:n]]) | ||||||||||||||||||||||||||||||||||||||||||||||
| save_image(comparison.cpu(), 'results/reconstruction_' + str(epoch) + '.png', nrow=n) | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| test_loss /= len(test_loader.dataset) | ||||||||||||||||||||||||||||||||||||||||||||||
| test_loss = test_loss/len(test_loader.dataset) if (reduction == 'sum') else test_loss/len(test_loader) | ||||||||||||||||||||||||||||||||||||||||||||||
| print('====> Test set loss: {:.4f}'.format(test_loss)) | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| if __name__ == "__main__": | ||||||||||||||||||||||||||||||||||||||||||||||
| if args.reduction =='sum' and args.use_mse: | ||||||||||||||||||||||||||||||||||||||||||||||
| print('Warning: reduction=sum will only use BCE. use_mse is ignored!') | ||||||||||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||||||||||
| for epoch in range(1, args.epochs + 1): | ||||||||||||||||||||||||||||||||||||||||||||||
| train(epoch) | ||||||||||||||||||||||||||||||||||||||||||||||
| test(epoch) | ||||||||||||||||||||||||||||||||||||||||||||||
| train(epoch, args.reduction, args.use_mse) | ||||||||||||||||||||||||||||||||||||||||||||||
| test(epoch, args.reduction, args.use_mse) | ||||||||||||||||||||||||||||||||||||||||||||||
| with torch.no_grad(): | ||||||||||||||||||||||||||||||||||||||||||||||
| sample = torch.randn(64, 20).to(device) | ||||||||||||||||||||||||||||||||||||||||||||||
| sample = torch.randn(64, model.embedding_size).to(device) | ||||||||||||||||||||||||||||||||||||||||||||||
| sample = model.decode(sample).cpu() | ||||||||||||||||||||||||||||||||||||||||||||||
| save_image(sample.view(64, 1, 28, 28), | ||||||||||||||||||||||||||||||||||||||||||||||
| 'results/sample_' + str(epoch) + '.png') | ||||||||||||||||||||||||||||||||||||||||||||||
| save_image(sample,'results/sample_' + str(epoch) + '.png') | ||||||||||||||||||||||||||||||||||||||||||||||
| save_animation(model, sample_count=30, use_mp4=False) | ||||||||||||||||||||||||||||||||||||||||||||||
| plot_latent_space() | ||||||||||||||||||||||||||||||||||||||||||||||
| display_2d_manifold(model, digit_count=20) | ||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If possible, it is better to rely on automatic pinning in PyTorch to avoid undefined behavior and for efficiency