By Akash Goyal –
Technological advancements have put everyone in awe, We have witnessed applications where a text can generate an image, an outline of a bag converted to an actual bag image, collection of seemingly real but fake images, objects removed with precision from the image, colorization of black-white images. Ever wondered, how it works …? Thought it’s AI/ML? Yess. Generative Models are in play behind the curtains.
This Blog will try to give an all round introduction to Generative Adversarial Network.
this post will revolve around:
- Context for GANs, discriminative vs. generative modeling.
- Train a GAN model using both a generative and a discriminative model.
1. What is GAN ?
GANs are one kind of generative models which have the ability to generate realistic examples across a range of problem domains. Popularly in image-to-image translation, generating seemingly realistic photos of people, objects, scenes that even humans cannot tell are fake, text to image generation, etc. Generative Adversarial Networks involve generative modeling using deep learning methods, such as CNNs.
Generative modeling is an unsupervised learning task in machine learning. Automatically discovering and learning the patterns in input data in a way that the model can generate new examples that could have been drawn from the original dataset. The GAN architecture involves two sub-models: a generator model for generating new samples and a discriminator model for classifying whether generated samples are real or fake(generated by the generator model).
A Generative Adversarial Network (GAN) has two parts:
● The generator learns to generate plausible data. The generated samples are the
negative training examples for the discriminator. It learns to make the discriminator
classify its output as real.
● The discriminator learns to distinguish the generator’s fake data from real data. The
discriminator penalizes the generator for producing implausible results.
2. Some popular variations of GANs.
There can be and are many types of GAN. Few are named below :-
a. Deep Convolutional GAN –
The generator uses the transposed convolution to perform up-sampling, instead of simple FCN.
b. Conditional GAN –
Train on a labeled data set and let you specify the label for each generated sample. Ex.- Train
GAN for both cat & dog; then specify the required output with latent(cat or dog).
c. SRGAN – low resolution to high resolution
d. Cycle GAN – Convert image from one domain to another. Pix2Pix model.
e. Information Maximizing GAN –
An information-theoretic extension to GAN. GAN is able to learn disentangled representations in a completely unsupervised manner. The objective is to learn interpretable and meaningful representations. This is done by maximizing the mutual information between a fixed small subset of the GAN’s noise variables and the observations. Entangled representation added to noise input of generator to have some control on output.
f. Style GAN –
StyleGAN uses the progressive Growing Technique, Noise Mapping Network and constant input instead of traditional latent inputs; to generate the simulated image sequentially, originating from a simple resolution and enlarging to a big resolution. Example image – Applying styles.
3. Training a GAN model
Step 1 – Define a problem
Step 2 – Select GAN architecture
Step 3 – Train Discriminator on Real Data
Step 4 – Generate Fake Inputs for Generator (random noise)
Step 5 – Train Discriminator on Fake Data
Step 6 – Train Generator with the output of Discriminator
GAN loss function
There can be multiple loss functions depending on the type of GAN being used.
The most basic one is the minimax loss function.
The generator tries to minimize the following function while the discriminator tries to maximize it:
In this function:
- ⦁ D(x) is the discriminator’s estimate of the probability that real data instance x is real.
⦁ Ex is the expected value over all real data instances.
⦁ G(z) is the generator’s output when given noise z.
⦁ D(G(z)) is the discriminator’s estimate of the probability that a fake instance is real.
⦁ Ez is the expected value over all random inputs to the generator (in effect, the expected value over all generated fake instances G(z)).
⦁ The formula derives from the ⦁ cross-entropy between the real and generated distributions.
The generator can’t directly affect the log(D(x)) term in the function, so, for the generator, minimizing the loss is equivalent to minimizing log(1 – D(G(z))).
Both the generator and the discriminator are neural networks. The generator output is connected directly to the discriminator input. Through backpropagation, the discriminator’s classification provides a signal that the generator uses to update its weights.
During discriminator training:
- The discriminator classifies both real data and fake data from the generator.
- The discriminator loss penalizes the discriminator for misclassifying a real instance as fake or a fake instance as real.
- The discriminator updates its weights through backpropagation from the discriminator loss through the discriminator network.
During generator training:
- The generator generates fake data from sampled random noise.
- The generator loss penalizes the generator for generating a sample not close to real.
- The generator updates its weights through ⦁ backpropagation from the generator loss through the discriminator network. Only the generator’s weights are updated.
4. Coding Space
Let’s build a GAN model using Tensorflow Keras on the Fashion MNIST dataset, to generate new cloth images.
# Builingd a GAN model using Tensorflow Keras on Fashion MNIST dataset
# This code has references from open sources (github,youtube,etc.)
import numpy as np
from matplotlib import pyplot as plt
from tensorflow.keras import Sequential, Model
from tensorflow.keras.layers import Input, Reshape, Flatten, Dense, Activation, LeakyReLU, BatchNormalization, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import datasets
def build_discriminator(input_shape=(28,28,), verbose=True):
model = Sequential()
def build_generator(z_dim=100, output_shape=(28, 28), verbose=True):
model = Sequential()
def train(generator=None,discriminator=None,gan_model=None,epochs=1000, batch_size=128, sample_interval=50,z_dim=100):
(X_train, _), (_, _) = datasets.fashion_mnist.load_data()
X_train = X_train / 127.5 – 1
real_y = np.ones((batch_size, 1))
fake_y = np.zeros((batch_size, 1))
for epoch in range(epochs):
idx = np.random.randint(0, X_train.shape, batch_size)
real_imgs = X_train[idx]
noise = np.random.normal(0, 1, (batch_size, z_dim))
fake_imgs = generator.predict(noise)
disc_loss_real = discriminator.train_on_batch(real_imgs, real_y)
disc_loss_fake = discriminator.train_on_batch(fake_imgs, fake_y)
discriminator_loss = 0.5 * np.add(disc_loss_real, disc_loss_fake)
noise = np.random.normal(0, 1, (batch_size, z_dim))
gen_loss = gan_model.train_on_batch(noise, real_y)
print (“%d [Discriminator loss: %f, acc.: %.2f%%] [Generator loss: %f]” % (epoch, discriminator_loss, 100*discriminator_loss, gen_loss))
if epoch % sample_interval == 0:
discriminator = build_discriminator()
# Noise for generator
z_dim = 300
z = Input(shape=(z_dim,))
img = generator(z)
# Fix the discriminator
discriminator.trainable = False
# Get discriminator output
validity = discriminator(img)
# Stack discriminator on top of generator
gan_model = Model(z, validity)
gan_model.compile(loss=’binary_crossentropy’, optimizer=Adam(0.0001, 0.5))
def sample_images(epoch, generator, z_dim=100, save_output=True, output_dir=”MyDrive”):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, z_dim))
gen_imgs = generator.predict(noise)
gen_imgs = 0.5 * gen_imgs + 0.5
output_shape = len(generator.output_shape)
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
if output_shape == 3:
axs[i, j].imshow(gen_imgs[cnt, :, :], cmap=’gray’)
axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap=’gray’)
cnt += 1
# final step
train(generator,discriminator,gan_model,epochs=30000, batch_size=32, sample_interval=200)
# code last four outputs close to 30000 iterations
I hope this post has given you significant insights on What GAN is? Where can it be used? Its variations and how to build one of your own. So, now when you come across applications where a text would generate an image, an outline of a bag converted to an actual bag image, collection of seemingly real but fake images, objects removed with precision from the image, colorization of black-white images; you can relate it to this concept & understand how the model in action would have been trained.
GANs are one type of Generative Models. Diffusion Models are Also the ones in this category for similar tasks, will cover them in another blog. Stay tuned to this space for more knowledge.