Skip to content
95 changes: 95 additions & 0 deletions templates/gan/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
[dataset.selectbox]
label = 'Dataset to use (dataset)'
options = ["cifar10", "lsun", "imagenet", "folder", "lfw", "fake", "mnist"]

[data_path.text_input]
label = 'Dataset path (data_path)'
value = './'

[filepath.text_input]
label = 'Logging file path (filepath)'
value = './logs'

[saved_G.text_input]
label = 'Path to saved generator (saved_G)'
value = '.'

[saved_D.text_input]
label = 'Path to saved discriminator (saved_D)'
value = '.'

[batch_size.number_input]
label = 'Train batch size (batch_size)'
min_value = 0
value = 4

[num_workers.number_input]
label = 'Number of workers (num_workers)'
min_value = 0
value = 2

[max_epochs.number_input]
label = 'Maximum epochs to train (max_epochs)'
min_value = 1
value = 2

[lr.number_input]
label = 'Learning rate used by torch.optim.* (lr)'
min_value = 0.0
value = 1e-3
format = '%e'

[log_train.number_input]
label = 'Logging interval of training iterations (log_train)'
min_value = 0
value = 50

[seed.number_input]
label = 'Seed used in ignite.utils.manual_seed() (seed)'
min_value = 0
value = 666

[nproc_per_node.number_input]
label = 'Number of processes to launch on each node (nproc_per_node)'
min_value = 1

[nnodes.number_input]
label = 'Number of nodes to use for distributed training (nnodes)'
min_value = 1

[node_rank.number_input]
label = 'Rank of the node for multi-node distributed training (node_rank)'
min_value = 0

[master_addr.text_input]
label = 'Master node TCP/IP address for torch native backends (master_addr)'
value = "'127.0.0.1'"

[master_port.number_input]
label = 'Master node port for torch native backends (master_port)'
value = 8080

[n_saved.number_input]
label = 'Number of best models to store (n_saved)'
min_value = 1
value = 2

[z_dim.number_input]
label = 'Size of the latent z vector (z_dim)'
value = 100

[alpha.number_input]
label = 'Running average decay factor (alpha)'
value = 0.98

[g_filters.number_input]
label = 'Number of filters in the second-to-last generator deconv layer (g_filters)'
value = 64

[d_filters.number_input]
label = 'Number of filters in first discriminator conv layer (d_filters)'
value = 64

[beta_1.number_input]
label = 'beta_1 for Adam optimizer (beta_1)'
value = 0.5
45 changes: 45 additions & 0 deletions templates/gan/datasets.py.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from torchvision import transforms as T
from torchvision import datasets as dset


def get_datasets(dataset, dataroot):
"""

Args:
dataset (str): Name of the dataset to use. See CLI help for details
dataroot (str): root directory where the dataset will be stored.

Returns:
dataset, num_channels
"""
resize = T.Resize(64)
crop = T.CenterCrop(64)
to_tensor = T.ToTensor()
normalize = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

if dataset in {"imagenet", "folder", "lfw"}:
dataset = dset.ImageFolder(root=dataroot, transform=T.Compose([resize, crop, to_tensor, normalize]))
nc = 3

elif dataset == "lsun":
dataset = dset.LSUN(
root=dataroot, classes=["bedroom_train"], transform=T.Compose([resize, crop, to_tensor, normalize])
)
nc = 3

elif dataset == "cifar10":
dataset = dset.CIFAR10(root=dataroot, download=True, transform=T.Compose([resize, to_tensor, normalize]))
nc = 3

elif dataset == "mnist":
dataset = dset.MNIST(root=dataroot, download=True, transform=T.Compose([resize, to_tensor, normalize]))
nc = 1

elif dataset == "fake":
dataset = dset.FakeData(size=256, image_size=(3, 64, 64), transform=to_tensor)
nc = 3

else:
raise RuntimeError(f"Invalid dataset name: {dataset}")

return dataset, nc
55 changes: 55 additions & 0 deletions templates/gan/fn.py.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch


def update(netD, netG, device, optimizerD, optimizerG, loss_fn, config, real_labels, fake_labels):

# The main function, processing a batch of examples
def step(engine, batch):

# unpack the batch. It comes from a dataset, so we have <images, labels> pairs. Discard labels.
real, _ = batch
real = real.to(device)

# -----------------------------------------------------------
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
netD.zero_grad()

# train with real
output = netD(real)
errD_real = loss_fn(output, real_labels)
D_x = output.mean().item()

errD_real.backward()

# get fake image from generator
noise = torch.randn(config.batch_size, config.z_dim, 1, 1, device=device)
fake = netG(noise)

# train with fake
output = netD(fake.detach())
errD_fake = loss_fn(output, fake_labels)
D_G_z1 = output.mean().item()

errD_fake.backward()

# gradient update
errD = errD_real + errD_fake
optimizerD.step()

# -----------------------------------------------------------
# (2) Update G network: maximize log(D(G(z)))
netG.zero_grad()

# Update generator. We want to make a step that will make it more likely that discriminator outputs "real"
output = netD(fake)
errG = loss_fn(output, real_labels)
D_G_z2 = output.mean().item()

errG.backward()

# gradient update
optimizerG.step()

return {"errD": errD.item(), "errG": errG.item(), "D_x": D_x, "D_G_z1": D_G_z1, "D_G_z2": D_G_z2}

return step
Loading