September 19, 2022 0

 

Generative Adversarial Networks (GAN)

By Akash Goyal

Technological advancements have put everyone in awe, We have witnessed applications where a text can generate an image, an outline of a bag converted to an actual bag image, collection of seemingly real but fake images, objects removed with precision from the image, colorization of black-white images. Ever wondered, how it works …? Thought it’s AI/ML?  Yess. Generative Models are in play behind the curtains. 

Ref: Image generated from model inference. Model taken from `https://huggingface.co/CompVis/stable-diffusion-v1-4
Ref: Example Image taken from https://saic-mdal.github.io/lama-project. Paper: https://arxiv.org/pdf/2109.07161.pdf
Ref: Example Image taken from Paper: https://arxiv.org/pdf/1611.07004.pdf
Ref: Example Image taken from Paper: https://arxiv.org/pdf/1611.07004.pdf
Ref: Google Search `Anime GAN results` – Image from https://github.com/ANIME305/Anime-GAN-tensorflow
Ref: Google Search `GAN person faces results` Image from – https://paperswithcode.com/task/face-generation

This Blog will try to give an all round introduction to Generative Adversarial Network.
this post will revolve around:

  •  Context for GANs, discriminative vs. generative modeling.
  • Train a GAN model using both a generative and a discriminative model.

1. What is GAN ?

GANs are one kind of generative models which have the ability to generate realistic examples across a range of problem domains. Popularly in image-to-image translation, generating seemingly realistic photos of people, objects, scenes that even humans cannot tell are fake, text to image generation, etc. Generative Adversarial Networks involve generative modeling using deep learning methods, such as CNNs.

Generative modeling is an unsupervised learning task in machine learning. Automatically discovering and learning the patterns in input data in a way that the model can generate new examples that could have been drawn from the original dataset. The GAN architecture involves two sub-models: a generator model for generating new samples and a discriminator model for classifying whether generated samples are real or fake(generated by the generator model).

A Generative Adversarial Network (GAN) has two parts:

●  The generator learns to generate plausible data. The generated samples are the
negative training examples for the discriminator. It learns to make the discriminator
classify its output as real.

●  The discriminator learns to distinguish the generator’s fake data from real data. The
discriminator penalizes the generator for producing implausible results.

2. Some popular variations of GANs.

There can be and are many types of GAN. Few are named below :-

a. Deep Convolutional GAN

The generator uses the transposed convolution to perform up-sampling, instead of simple FCN.

b. Conditional GAN

Train on a labeled data set and let you specify the label for each generated sample. Ex.- Train
GAN for both cat & dog; then specify the required output with latent(cat or dog).

c. SRGAN – low resolution to high resolution

Ref: Image taken from https://modelzoo.co/model/srgan More examples in paper: https://arxiv.org/pdf/1609.04802.pdf

d. Cycle GAN – Convert image from one domain to another. Pix2Pix model.

Ref: Image taken from https://arxiv.org/pdf/1703.10593.pdf

e. Information Maximizing GAN

An information-theoretic extension to GAN. GAN is able to learn disentangled representations in a completely unsupervised manner. The objective is to learn interpretable and meaningful representations. This is done by maximizing the mutual information between a fixed small subset of the GAN’s noise variables and the observations. Entangled representation added to noise input of generator to have some control on output.

Ref: Image taken from https://arxiv.org/pdf/1606.03657.pdf

f. Style GAN

StyleGAN uses the progressive Growing Technique, Noise Mapping Network and constant input instead of traditional latent inputs; to generate the simulated image sequentially, originating from a simple resolution and enlarging to a big resolution. Example image – Applying styles.

Ref: Image taken from StyleGAN paper https://arxiv.org/pdf/1812.04948.pdf

3. Training a GAN model

Step 1 – Define a problem

Step 2 – Select GAN architecture

Step 3 – Train Discriminator on Real Data

Step 4 – Generate Fake Inputs for Generator (random noise)

Step 5 – Train Discriminator on Fake Data

Step 6 – Train Generator with the output of Discriminator 

GAN loss function

There can be multiple loss functions depending on the type of GAN being used.
The most basic one is the minimax loss function.
The generator tries to minimize the following function while the discriminator tries to maximize it:

Ref: Image taken from https://theaisummer.com/static/16451c8109d2540415babbfd245fe0cc/d80c4/gan_training.jpg

In this function:

  • ⦁ D(x) is the discriminator’s estimate of the probability that real data instance x is real.
    ⦁ Ex is the expected value over all real data instances.
    ⦁ G(z) is the generator’s output when given noise z.
    ⦁ D(G(z)) is the discriminator’s estimate of the probability that a fake instance is real.
    ⦁ Ez is the expected value over all random inputs to the generator (in effect, the expected value over all generated fake instances G(z)).
    ⦁ The formula derives from the ⦁ cross-entropy between the real and generated distributions.

The generator can’t directly affect the log(D(x)) term in the function, so, for the generator, minimizing the loss is equivalent to minimizing log(1 – D(G(z))).

GAN Network 

Both the generator and the discriminator are neural networks. The generator output is connected directly to the discriminator input. Through backpropagation, the discriminator’s classification provides a signal that the generator uses to update its weights.

During discriminator training:

  •  The discriminator classifies both real data and fake data from the generator.
  • The discriminator loss penalizes the discriminator for misclassifying a real instance as fake or a fake instance as real.
  • The discriminator updates its weights through backpropagation from the discriminator loss through the discriminator network.

During generator training:

  • The generator generates fake data from sampled random noise.
  • The generator loss penalizes the generator for generating a sample not close to real.
  • The generator updates its weights through ⦁ backpropagation from the generator loss through the discriminator network. Only the generator’s weights are updated.

====== 

4. Coding Space

Let’s build a GAN model using Tensorflow Keras on the Fashion MNIST dataset, to generate new cloth images.

# Builingd a GAN model using Tensorflow Keras on Fashion MNIST dataset

# This code has references from open sources (github,youtube,etc.) 

import numpy as np

from matplotlib import pyplot as plt

from tensorflow.keras import Sequential, Model

from tensorflow.keras.layers import Input, Reshape, Flatten, Dense, Activation, LeakyReLU, BatchNormalization, Dropout

from tensorflow.keras.optimizers import Adam

from tensorflow.keras import datasets

def build_discriminator(input_shape=(28,28,), verbose=True):

    model = Sequential()

    model.add(Input(shape=input_shape))

    model.add(Flatten())

    model.add(Dense(512))

    model.add(LeakyReLU(alpha=0.2))

    model.add(Dense(256))

    model.add(LeakyReLU(alpha=0.2))

    model.add(Dense(1, activation=’sigmoid’))

    if verbose:

        model.summary

    return model

def build_generator(z_dim=100, output_shape=(28, 28), verbose=True):

    model = Sequential()

    model.add(Input(shape=(z_dim,)))

    model.add(Dense(256, input_dim=z_dim))

    model.add(LeakyReLU(alpha=0.2))

    model.add(Dropout(0.3))

    model.add(Dense(512))

    model.add(LeakyReLU(alpha=0.2))

    model.add(Dropout(0.3)) 

    model.add(Dense(1024))

    model.add(LeakyReLU(alpha=0.2))

    model.add(Dropout(0.3))

    model.add(Dense(np.prod(output_shape), activation=’tanh’))

    model.add(Reshape(output_shape))

    if verbose:

        model.summary()

    return model

def train(generator=None,discriminator=None,gan_model=None,epochs=1000, batch_size=128, sample_interval=50,z_dim=100):

    (X_train, _), (_, _) = datasets.fashion_mnist.load_data()

    X_train = X_train / 127.5 – 1

    real_y = np.ones((batch_size, 1))

    fake_y = np.zeros((batch_size, 1))

    for epoch in range(epochs):

        idx = np.random.randint(0, X_train.shape[0], batch_size)

        real_imgs = X_train[idx]

        noise = np.random.normal(0, 1, (batch_size, z_dim))

        fake_imgs = generator.predict(noise)

        disc_loss_real = discriminator.train_on_batch(real_imgs, real_y)

        disc_loss_fake = discriminator.train_on_batch(fake_imgs, fake_y)

        discriminator_loss = 0.5 * np.add(disc_loss_real, disc_loss_fake)

        noise = np.random.normal(0, 1, (batch_size, z_dim))

        gen_loss = gan_model.train_on_batch(noise, real_y)

        print (“%d [Discriminator loss: %f, acc.: %.2f%%] [Generator loss: %f]” % (epoch, discriminator_loss[0], 100*discriminator_loss[1], gen_loss))

        if epoch % sample_interval == 0:

            sample_images(epoch,generator)

discriminator = build_discriminator()

discriminator.compile(loss=’binary_crossentropy’,optimizer=Adam(0.0002,0.5),metrics=[‘accuracy’])

generator=build_generator()

# Noise for generator

z_dim = 300

z = Input(shape=(z_dim,))

img = generator(z)

# Fix the discriminator

discriminator.trainable = False

# Get discriminator output

validity = discriminator(img)

# Stack discriminator on top of generator

gan_model = Model(z, validity)

gan_model.compile(loss=’binary_crossentropy’, optimizer=Adam(0.0001, 0.5))

gan_model.summary()

def sample_images(epoch, generator, z_dim=100, save_output=True, output_dir=”MyDrive”):

    r, c = 5, 5

    noise = np.random.normal(0, 1, (r * c, z_dim))

    gen_imgs = generator.predict(noise)

    gen_imgs = 0.5 * gen_imgs + 0.5

    output_shape = len(generator.output_shape)

    fig, axs = plt.subplots(r, c)

    cnt = 0

    for i in range(r):

        for j in range(c):

            if output_shape == 3:

                axs[i, j].imshow(gen_imgs[cnt, :, :], cmap=’gray’)

            else:

                axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap=’gray’)

            axs[i, j].axis(‘off’)

            cnt += 1

    plt.show()

    if save_output:

        fig.savefig(“{}/{}.png”.format(output_dir, epoch))

    plt.close()

# final step

train(generator,discriminator,gan_model,epochs=30000, batch_size=32, sample_interval=200)

# code last four outputs close to 30000 iterations 

1.  2.

3.  4.

I hope this post has given you significant insights on What GAN is? Where can it be used? Its variations and how to build one of your own. So, now when you come across applications where a text would generate an image, an outline of a bag converted to an actual bag image, collection of seemingly real but fake images, objects removed with precision from the image, colorization of black-white images; you can relate it to this concept & understand how the model in action would have been trained.

GANs are one type of Generative Models. Diffusion Models are Also the ones in this category for similar tasks, will cover them in another blog. Stay tuned to this space for more knowledge.

Leave a comment