""" This module contains the implementation of a Generative Adversarial Network (GAN) using TensorFlow and Keras. The GAN is composed of a generator and a discriminator, both implemented as Keras models. The generator takes a random noise vector as input and produces an image, while the discriminator takes an image as input and classifies it as real or fake. The module provides functions to build the generator, the discriminator, and the complete GAN, as well as to train the GAN on the MNIST dataset and generate and save images. Functions: build_generator(latent_dim: int) -> models.Model build_discriminator(img_shape: tuple) -> models.Model build_gan(generator: models.Model, discriminator: models.Model) -> models. Model train_gan(generator: models.Model, discriminator: models.Model, gan: models.Model, mnist_data: np.array, latent_dim: int, epochs: int, batch_size: int, sample_interval: int) generate_and_save_images(generator: models.Model, epoch: int, test_input: np.array) """ import matplotlib.pyplot as plt import numpy as np import tensorflow as tf from tensorflow.keras import layers, models def build_generator(latent_dim): """ Build the generator part of the GAN. Args: latent_dim (int): The size of the random noise vector used as input for the generator. Returns: model: A Keras model representing the generator. """ model = models.Sequential() model.add(layers.Dense(7 * 7 * 128, input_dim=latent_dim)) model.add(layers.LeakyReLU(alpha=0.2)) model.add(layers.Reshape((7, 7, 128))) model.add(layers.Conv2DTranspose( 128, (4, 4), strides=(2, 2), padding='same')) model.add(layers.LeakyReLU(alpha=0.2)) model.add(layers.Conv2DTranspose( 128, (4, 4), strides=(2, 2), padding='same')) model.add(layers.LeakyReLU(alpha=0.2)) model.add(layers.Conv2D(1, (7, 7), activation='tanh', padding='same')) return model def build_discriminator(img_shape): """ Build the discriminator part of the GAN. Args: img_shape (tuple): The shape of the input images. Returns: model: A Keras model representing the discriminator. """ model = models.Sequential() model.add(layers.Conv2D(64, (3, 3), strides=(2, 2), padding='same', input_shape=img_shape)) model.add(layers.LeakyReLU(alpha=0.2)) model.add(layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same')) model.add(layers.LeakyReLU(alpha=0.2)) model.add(layers.Flatten()) model.add(layers.Dense(1, activation='sigmoid')) return model def build_gan(generator, discriminator): """ Build the complete GAN by stacking the generator and discriminator. Args: generator (Keras model): The generator model. discriminator (Keras model): The discriminator model. Returns: model: A Keras model representing the complete GAN. """ discriminator.trainable = False model = models.Sequential() model.add(generator) model.add(discriminator) return model def load_mnist(): """ Load and preprocess the MNIST dataset. Returns: X_train (numpy array): The preprocessed training images. """ (X_train, _), (_, _) = tf.keras.datasets.mnist.load_data() X_train = (X_train.astype(np.float32) - 127.5) / 127.5 X_train = np.expand_dims(X_train, axis=-1) return X_train def train_gan(generator, discriminator, gan, mnist_data, latent_dim, epochs=10000, batch_size=128, sample_interval=1000): """ Train the GAN on the MNIST data. Args: generator (Keras model): The generator model. discriminator (Keras model): The discriminator model. gan (Keras model): The complete GAN model. mnist_data (numpy array): The preprocessed training images. latent_dim (int): The size of the random noise vector used as input for the generator. epochs (int, optional): The number of epochs to train for. Defaults to 10000. batch_size (int, optional): The batch size. Defaults to 128. sample_interval (int, optional): The interval at which to save generated images. Defaults to 1000. """ half_batch = batch_size // 2 for epoch in range(epochs): # Train Discriminator idx = np.random.randint(0, mnist_data.shape[0], half_batch) imgs = mnist_data[idx] noise = np.random.normal(0, 1, (half_batch, latent_dim)) gen_imgs = generator.predict(noise) d_loss_real = discriminator.train_on_batch( imgs, np.ones((half_batch, 1))) d_loss_fake = discriminator.train_on_batch( gen_imgs, np.zeros((half_batch, 1))) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # Train Generator noise = np.random.normal(0, 1, (batch_size, latent_dim)) valid_labels = np.ones((batch_size, 1)) g_loss = gan.train_on_batch(noise, valid_labels) # Print progress and save generated images at specified intervals if epoch % sample_interval == 0: print(f"Epoch {epoch}, D Loss: {d_loss[0]}, G Loss: {g_loss}") save_generated_images(epoch, generator) def generate_and_save_images(generator, epoch, test_input): """ Generate and save images using the generator and display them using matplotlib. Args: generator (Keras model): The generator model. epoch (int): The current epoch. test_input (numpy array): A random noise vector used as input for the generator. """ predictions = generator(test_input, training=False) fig = plt.figure(figsize=(4, 4)) for i in range(predictions.shape[0]): plt.subplot(4, 4, i + 1) plt.imshow(predictions[i, :, :, 0] * 0.5 + 0.5, cmap='gray') plt.axis('off') plt.savefig(f"gan_generated_image_epoch_{epoch}.png") plt.show() def save_generated_images(epoch, generator, latent_dim=100, examples=100, dim=(10, 10), figsize=(10, 10)): """ Save the images generated by the generator. Args: epoch (int): The current epoch. generator (Keras model): The generator model. latent_dim (int, optional): The size of the random noise vector used as input for the generator. Defaults to 100. examples (int, optional): The number of examples to generate. Defaults to 100. dim (tuple, optional): The dimensions of the grid of images. Defaults to (10, 10). figsize (tuple, optional): The size of the figure. Defaults to (10, 10). """ noise = np.random.normal(0, 1, size=[examples, latent_dim]) generated_images = generator.predict(noise) generated_images = generated_images.reshape(examples, 28, 28) plt.figure(figsize=figsize) for i in range(generated_images.shape[0]): plt.subplot(dim[0], dim[1], i+1) plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r') plt.axis('off') plt.tight_layout() plt.savefig(f'gan_generated_image_epoch_{epoch}.png') if __name__ == "__main__": # Define the size of the random vector used as input for the generator latent_dim = 100 # Build the generator generator = build_generator(latent_dim) # Build the discriminator img_shape = (28, 28, 1)
Photo by Steve Johnson on Unsplash
Top comments (0)