GAN was originally introduced in 2014 by Ian Goodfellow.GAN is basically a model where we have two separate models fighting against each other. We want to be able to reproduce images that are similar to images but it not going to be exactly the same images so we’re building it using generative networks.

If you don’t have enough data then manufacture data by using the data that you already have and this is what GAN does.

Normally in the classifier or something else in deep learning, we’re just trying to predict the classification we don’t really care actual distribution whereas with GAN we are much more interested in being able to reproduce a distribution.

TensorFlow GANs

The generative model is to come up with new versions of images and the discriminator checks images and says this is a real image or this is a fake image.

The whole reason why this works is that as the Generator starts to get better at making the fake image and the discriminator has to get better at detecting it. The whole thing with GAN is we want to balance that we’re basically trying to optimize.

Vanilla GAN

It uses a very simple concept to put some latent noise which we call Z into a generator. We also take some real data and stick that in and then randomly present these to a discriminator who has to decide whether are they real or fake. Then use that to score loss and update weights to get the model better at being able to take that noise and be able to reproduce images.

Keras GAN

We basically take the generative loss and the discriminative loss we add those together that’s our total loss of the network. We flip the loss of the generator because the generator is actually trying to push the loss up and then we get a total loss and we then use that.

GAN loss

TFGAN Library

TFGAN is a lightweight library for GANs in TensorFlow. It has a set of pre-made losses and GAN components with a lot of things. With TFGAN you can basically just take all these off-the-shelf losses and stuff that is built for you and then you can put it into a model it’s a much simpler way to be able to make GAN. You can also make a GAN and GANEstimator.


We trying to do the same sort of thing with  MNIST. First, We do our imports and load our data.

import numpy as np
import tensorflow as tf

import tensorflow_datasets as tfds

from tensorflow.keras.layers import UpSampling2D, Conv2D, BatchNormalization,Reshape, Activation, Dense, Flatten, MaxPooling2D
from tensorflow.keras.models import Sequential

import matplotlib.pyplot as plt

tfgan = tf.contrib.gan

Here’s my input pipeline for pulling the data into them in a format. That is sort of estimator friendly for the model.

def train_input_fn(batch_size, num_epochs, noise_dim):
    def resize_image(features):
        image = tf.image.convert_image_dtype(features["image"], dtype=tf.float32)
        image = (image - 0.5) * 2
        image = tf.image.resize(image, size=(28, 28))
        noise = tf.random_normal([noise_dim], name="train_noise")

        return noise,image

    def _input_fn():
        dataset = tfds.load("mnist",split=tfds.Split.TRAIN)
        dataset =
        dataset = dataset.batch(batch_size, drop_remainder=True).repeat(num_epochs)
        return dataset
    return _input_fn

Generator and Discriminator

We’ve our generator here and you can see the unconditional generator. We’re basically passing in some noise that latency matrix or vector that we had before.

def generator():
    model = Sequential()
    model.add(Dense(input_dim=64, units=512))
    model.add(Reshape((7, 7, 64), input_shape=(64*7*7,))) # 7x7 image
    model.add(UpSampling2D(size=(2, 2))) # 14x14 image
    model.add(Conv2D(64, (5, 5), padding='same'))
    model.add(UpSampling2D(size=(2, 2))) # 28x28 image
    model.add(Conv2D(1, (5, 5), padding='same'))
    return model

def generator_fn(inputs, mode):
    is_training = mode == tf.estimator.ModeKeys.TRAIN
    model = generator()

    return model(inputs, is_training)

Discriminator’s whole job is to detect what’s not a real image and you see that here we’ve our model is much simpler.

def discriminator():
    model = Sequential()
    model.add(Conv2D(32,(5, 5),padding='same',input_shape=(28, 28, 1)))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(64, (5, 5)))
    model.add(MaxPooling2D(pool_size=(2, 2)))


    return model

def discriminator_fn(inputs, conditioning, mode):
    is_training = mode == tf.estimator.ModeKeys.TRAIN
    model = discriminator()
    return model(inputs, is_training)

GANs Loss Function

One of the cool things with TFGAN is it has all the loss functions made for you so you don’t have to go through and encode them it is also optimized.

One of the biggest things that are changed in GAN over time and one of the things that sort of improved GAN is different sort of loss functions and different ways of dealing with these sorts of things and TFGAN has a lot of these built-in.


This is a vanilla GAN but this is basically doing it as GANEstimator. Estimator has a few key functions like the model function, the input functions, and some sort of evaluation function.

def gan():
    # hyper param
    model_dir = "../logs-2/"
    batch_size = 64
    num_epochs = 10
    noise_dim = 64
    # Run Configuration
    run_config = tf.estimator.RunConfig(
        model_dir=model_dir, save_summary_steps=100, save_checkpoints_steps=1000)
    gan_estimator = tfgan.estimator.GANEstimator(
        generator_optimizer=tf.train.AdamOptimizer(0.0002, 0.5),
        discriminator_optimizer=tf.train.AdamOptimizer(0.0002, 0.5),
    input_fn = train_input_fn(batch_size, num_epochs, noise_dim)
    model = gan_estimator.train(input_fn, max_steps=None)

    return model

For the model, you can just take it and built this estimator and just say for the generator function use this, for the discriminative function use this for the generator loss function use this.

gan_model = gan()

Then literally train it just like we do any other estimator. You can run the evaluation and then print some out.

input_fn = predict_input_fn(predict_batch,64)
predict = gan_model.predict(input_fn)
result=[next(predict) for _ in range(predict_batch)]

By here we’re actually able to produce real MNIST digits that are not actually digits that were given to the model.