View on GitHub

data-science

Notebooks and Python about data science

If you like this project please add your Star

DCGAN on MNIST Digits dataset

Following the original GAN [1], Deep Convolutional Generative Adversarial Network (DCGAN) [2] is replacing some of the layers with convolutional layers.

The result is similar but taking advantage of the properties of the convolutional layers: less parameters to train, space invariance.

Learning goals

  • Starting from the GAN implementation notebook (HTML / Jupyter), use convolutional layers
In [1]:
COLAB = True

if COLAB:
  from google.colab import drive
  drive.mount('/content/drive')
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
In [2]:
!pip install tensorview
Requirement already satisfied: tensorview in /usr/local/lib/python3.6/dist-packages (0.4.1)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from tensorview) (3.2.2)
Requirement already satisfied: pandas>=0.24.1 in /usr/local/lib/python3.6/dist-packages (from tensorview) (1.1.5)
Requirement already satisfied: pyecharts>=1.2.0 in /usr/local/lib/python3.6/dist-packages (from tensorview) (1.9.0)
Requirement already satisfied: pyecharts-snapshot>=0.1.10tensorflow>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from tensorview) (0.2.0)
Requirement already satisfied: linora>=0.9.3 in /usr/local/lib/python3.6/dist-packages (from tensorview) (0.9.3)
Requirement already satisfied: numpy>=1.11 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tensorview) (1.18.5)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tensorview) (2.8.1)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tensorview) (1.3.1)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tensorview) (0.10.0)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tensorview) (2.4.7)
Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.24.1->tensorview) (2018.9)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.6/dist-packages (from pyecharts>=1.2.0->tensorview) (2.11.2)
Requirement already satisfied: prettytable in /usr/local/lib/python3.6/dist-packages (from pyecharts>=1.2.0->tensorview) (2.0.0)
Requirement already satisfied: simplejson in /usr/local/lib/python3.6/dist-packages (from pyecharts>=1.2.0->tensorview) (3.17.2)
Requirement already satisfied: pyppeteer>=0.0.25 in /usr/local/lib/python3.6/dist-packages (from pyecharts-snapshot>=0.1.10tensorflow>=2.0.0->tensorview) (0.2.2)
Requirement already satisfied: pillow in /usr/local/lib/python3.6/dist-packages (from pyecharts-snapshot>=0.1.10tensorflow>=2.0.0->tensorview) (7.0.0)
Requirement already satisfied: xgboost>=0.81 in /usr/local/lib/python3.6/dist-packages (from linora>=0.9.3->tensorview) (0.90)
Requirement already satisfied: tensorflow>=2.0.0rc0 in /usr/local/lib/python3.6/dist-packages (from linora>=0.9.3->tensorview) (2.3.0)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.6/dist-packages (from python-dateutil>=2.1->matplotlib->tensorview) (1.15.0)
Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.6/dist-packages (from jinja2->pyecharts>=1.2.0->tensorview) (1.1.1)
Requirement already satisfied: wcwidth in /usr/local/lib/python3.6/dist-packages (from prettytable->pyecharts>=1.2.0->tensorview) (0.2.5)
Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from prettytable->pyecharts>=1.2.0->tensorview) (50.3.2)
Requirement already satisfied: urllib3<2.0.0,>=1.25.8 in /usr/local/lib/python3.6/dist-packages (from pyppeteer>=0.0.25->pyecharts-snapshot>=0.1.10tensorflow>=2.0.0->tensorview) (1.26.2)
Requirement already satisfied: appdirs<2.0.0,>=1.4.3 in /usr/local/lib/python3.6/dist-packages (from pyppeteer>=0.0.25->pyecharts-snapshot>=0.1.10tensorflow>=2.0.0->tensorview) (1.4.4)
Requirement already satisfied: tqdm<5.0.0,>=4.42.1 in /usr/local/lib/python3.6/dist-packages (from pyppeteer>=0.0.25->pyecharts-snapshot>=0.1.10tensorflow>=2.0.0->tensorview) (4.54.1)
Requirement already satisfied: websockets<9.0,>=8.1 in /usr/local/lib/python3.6/dist-packages (from pyppeteer>=0.0.25->pyecharts-snapshot>=0.1.10tensorflow>=2.0.0->tensorview) (8.1)
Requirement already satisfied: pyee<8.0.0,>=7.0.1 in /usr/local/lib/python3.6/dist-packages (from pyppeteer>=0.0.25->pyecharts-snapshot>=0.1.10tensorflow>=2.0.0->tensorview) (7.0.4)
Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from xgboost>=0.81->linora>=0.9.3->tensorview) (1.4.1)
Requirement already satisfied: protobuf>=3.9.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (3.12.4)
Requirement already satisfied: tensorflow-estimator<2.4.0,>=2.3.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (2.3.0)
Requirement already satisfied: astunparse==1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (1.6.3)
Requirement already satisfied: google-pasta>=0.1.8 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (0.2.0)
Requirement already satisfied: h5py<2.11.0,>=2.10.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (2.10.0)
Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (0.36.1)
Requirement already satisfied: absl-py>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (0.10.0)
Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (3.3.0)
Requirement already satisfied: tensorboard<3,>=2.3.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (2.3.0)
Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (1.34.0)
Requirement already satisfied: gast==0.3.3 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (0.3.3)
Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (1.1.0)
Requirement already satisfied: wrapt>=1.11.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (1.12.1)
Requirement already satisfied: keras-preprocessing<1.2,>=1.1.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (1.1.2)
Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (2.23.0)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (0.4.2)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (1.7.0)
Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (1.0.1)
Requirement already satisfied: google-auth<2,>=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (1.17.2)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (3.3.3)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (2.10)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (2020.12.5)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (3.0.4)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (1.3.0)
Requirement already satisfied: rsa<5,>=3.1.4; python_version >= "3" in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (4.6)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (4.1.1)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (0.2.8)
Requirement already satisfied: importlib-metadata; python_version < "3.8" in /usr/local/lib/python3.6/dist-packages (from markdown>=2.6.8->tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (3.1.1)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (3.1.0)
Requirement already satisfied: pyasn1>=0.1.3 in /usr/local/lib/python3.6/dist-packages (from rsa<5,>=3.1.4; python_version >= "3"->google-auth<2,>=1.6.3->tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (0.4.8)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < "3.8"->markdown>=2.6.8->tensorboard<3,>=2.3.0->tensorflow>=2.0.0rc0->linora>=0.9.3->tensorview) (3.4.0)
In [3]:
import sys
import tensorflow as tf
import numpy as np
from tensorflow.keras import models, layers, losses, optimizers, metrics
import tensorflow_datasets as tf_ds
import tensorview as tv
import matplotlib.pyplot as plt
from pathlib import Path
/usr/local/lib/python3.6/dist-packages/requests/__init__.py:91: RequestsDependencyWarning: urllib3 (1.26.2) or chardet (3.0.4) doesn't match a supported version!
  RequestsDependencyWarning)
In [4]:
if COLAB:
  model_path = Path('/content/drive/My Drive/Colab Notebooks/DsStepByStep')
else:
  model_path = Path('model')
In [5]:
batch_size = 200
latent_dim = 100
# Padding to 32x32 for better filtering by convo
image_width, image_height, image_channels = 32, 32, 1 
mnist_dim = image_width * image_height * image_channels
In [6]:
disc_learning_rate = 2e-4
gen_learning_rate  = 2e-4

relu_alpha = 0.001

Data

MNIST dataset is optimized to be stored efficiently: images are closely cropped at 28x28 pixels and stored as 1 byte per pixel (uint8 format). However, to get proper performance we need to modify the input data to insert some padding around and convert the pixel format to float on 32 bits.

In [7]:
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32` and pad to 32 x 32"""
  image_float = tf.cast(image, tf.float32)/ 128. - 1.
  image_padded = tf.pad(image_float, [[0, 0], [2, 2], [2, 2], [0, 0]])
  return image_padded, label

(ds_train, ds_test) = tf_ds.load('mnist', split=['train', 'test'], batch_size=batch_size, as_supervised=True)
ds_train = ds_train.map(normalize_img)
#ds_train.batch(batch_size)
ds_train = ds_train.cache()

ds_test = ds_test.map(normalize_img)
#ds_test = ds_test.batch(batch_size)
ds_test = ds_test.cache()

ds_train, ds_test
Out[7]:
(<CacheDataset shapes: ((None, 32, 32, 1), (None,)), types: (tf.float32, tf.int64)>,
 <CacheDataset shapes: ((None, 32, 32, 1), (None,)), types: (tf.float32, tf.int64)>)

Models

GAN model is built out of a generator and a discriminator:

  • The generator gets as input some random noise on space of 100 dimensions, and issues an image (32x32 pixel raster)
  • The discriminator is trained to distinguish generated images by the generator (i.e. fakes), and reference images from the MNIST

The generator and discriminator architecture are more or less symmetrical. The generator is increasing the output space dimension step by step using upsampling layers after convolutional layers. The discriminator is similar to other classification networks reducing the input space dimensions down to the binary classification layer using strides on the convolutional layer. An alternative implementation of the generator is based on the transpose convolution whose stride is able to upsample.

The "game" is to jointly train the generator and discriminator in order to have the best generator but still being able to detect generated images.

In [8]:
generator = models.Sequential([
    layers.Dense(128 * 8 * 8, input_dim=latent_dim, name='g_1'),
    layers.LeakyReLU(relu_alpha),
    layers.Reshape((8, 8, 128)),
    layers.BatchNormalization(),
    layers.UpSampling2D((2, 2)),
    layers.Dropout(0.3),
    layers.Conv2D(64, kernel_size=(5, 5), padding="same", name='g_c1'),
    layers.LeakyReLU(relu_alpha),
    layers.BatchNormalization(),
    layers.UpSampling2D((2, 2)),
    layers.Conv2D(1, kernel_size=(5, 5), padding="same", activation='tanh', name='g_c2')
], name='generator')

generator.compile()
generator.summary()
Model: "generator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
g_1 (Dense)                  (None, 8192)              827392    
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 8192)              0         
_________________________________________________________________
reshape (Reshape)            (None, 8, 8, 128)         0         
_________________________________________________________________
batch_normalization (BatchNo (None, 8, 8, 128)         512       
_________________________________________________________________
up_sampling2d (UpSampling2D) (None, 16, 16, 128)       0         
_________________________________________________________________
dropout (Dropout)            (None, 16, 16, 128)       0         
_________________________________________________________________
g_c1 (Conv2D)                (None, 16, 16, 64)        204864    
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 16, 16, 64)        0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 16, 16, 64)        256       
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 32, 32, 64)        0         
_________________________________________________________________
g_c2 (Conv2D)                (None, 32, 32, 1)         1601      
=================================================================
Total params: 1,034,625
Trainable params: 1,034,241
Non-trainable params: 384
_________________________________________________________________
In [9]:
discriminator = models.Sequential([
    layers.Conv2D(8, kernel_size=(5, 5), strides=(2, 2), padding="same", 
                  name='d_c1', input_shape=[32, 32, 1]),
    layers.LeakyReLU(relu_alpha),
    layers.Dropout(0.3),
    layers.Conv2D(64, kernel_size=(5, 5), strides=(2, 2), padding="same", name='d_c2'),
    layers.LeakyReLU(relu_alpha),
    layers.Flatten(),
    layers.Dense(1, name='d_1') # activation='sigmoid',
], name='discriminator')

discriminator.compile()
discriminator.summary()
Model: "discriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
d_c1 (Conv2D)                (None, 16, 16, 8)         208       
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 16, 16, 8)         0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 16, 16, 8)         0         
_________________________________________________________________
d_c2 (Conv2D)                (None, 8, 8, 64)          12864     
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 8, 8, 64)          0         
_________________________________________________________________
flatten (Flatten)            (None, 4096)              0         
_________________________________________________________________
d_1 (Dense)                  (None, 1)                 4097      
=================================================================
Total params: 17,169
Trainable params: 17,169
Non-trainable params: 0
_________________________________________________________________

Training

Training is alternatively on the distriminator and generator.

The discriminator is trained on a batch made of half genuine images and half trained images.

The generator is trained with its output fed into the discriminator (whose wheights are frozen in this phase).

GAN reputation as difficult to be trained is well deserved and originates in the joint optimization which is similar to a minimax problem (min discrination error, max fidelity of the fakes). As seen below, the noise on the losses and accuracies is high. The main facilitators helping this training are:

  • Use of leaky ReLU activations to avoid gradient vanishing
  • Small learning rate to decrease the noise and instability
  • Batch normalization layers to reduce variance at layer inputs

These are actually the recommendations of the DCGAN paper.

Due to the use of convolutional layers, and the small images at input, the number of trainable parameters of the discriminator is low. The number of parameters ot the generator is not that low since the dense layer is connected to all 100 inputs.

In [10]:
epochs = 60
batch_per_epoch = 60000/batch_size
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
In [11]:
def generator_loss(disc_generated_output):
    return loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
In [12]:
def discriminator_loss(disc_real_output, disc_generated_output):

    real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
    generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

    return real_loss + generated_loss
In [13]:
@tf.function
def train_step(generator, discriminator, 
               generator_optimizer, discriminator_optimizer, 
               generator_latent, batch, 
               epoch):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        
        gen_latent = generator_latent()
        
        gen_output = generator(gen_latent, training=True)

        disc_real_output = discriminator(batch, training=True)
        disc_generated_output = discriminator(gen_output, training=True)

        gen_loss = generator_loss(disc_generated_output)
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

        generator_gradients = gen_tape.gradient(gen_loss, generator.trainable_variables)
        discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))

    return gen_loss, disc_loss
In [14]:
generator_optimizer = tf.keras.optimizers.Adam(gen_learning_rate, beta_1=0.05)
discriminator_optimizer = tf.keras.optimizers.Adam(disc_learning_rate, beta_1=0.05)
In [15]:
tv_plot = tv.train.PlotMetrics(wait_num=200, columns=2, iter_num=epochs * batch_per_epoch)

def generator_latent():
    return tf.random.normal((batch_size, latent_dim), 0, 1)

for epoch in range(epochs):

    for train_batch in iter(ds_train):
        
        g_loss, d_loss = train_step(generator, discriminator, 
                                    generator_optimizer, discriminator_optimizer, 
                                    generator_latent, train_batch[0], 
                                    epoch)
        # Plot
        tv_plot.update({ 'discriminator_loss': d_loss,# 'discriminator_acc': d_acc,
                        'generator_loss': g_loss, # 'generator_acc': g_acc
                       })
        tv_plot.draw()
In [16]:
gen_latent = generator_latent()
gen_imgs = generator(gen_latent, training=True)
        
fig, axes = plt.subplots(8, 8, sharex=True, sharey=True, figsize=(10, 10))
for img, ax in zip(gen_imgs, axes.ravel()): # imgs.numpy()
    ax.imshow(img.numpy().reshape(image_width, image_height), interpolation='nearest', cmap='gray')
    ax.axis('off')
fig.tight_layout()
In [17]:
discriminator.save(model_path / 'mnist_dcgan_discriminator.h5')
generator.save(model_path / 'mnist_dcgan_generator.h5')

Conclusion

Compared to the original GAN implementation, the disriminator accuracy is not sticking above 90%. The generator seems more able to create convincing fakes. But also the variance on the metrics is high. Visually, there is less noise on the background, digits are rounder. The straight digits like 7 and 1 seems harder to generate to this mode. And there are still some "ghosts" around the main shape.

Where to go from here

  • The original GAN based on dense layers (HTML / Jupyter)
  • Revisit the fundamentals about deep neural networks in the CNN versus Dense classification (HTML / Jupyter )

References

  1. "Generative adversarial nets", I. Goodfellow, J. Pouget-Abadie, M. Mirza, B. Xu, D. Warde-Farley, S. Ozair, A. Courville, Y. Bengio, NIPS 2014
  2. "Unsupervised representation learning with deep convolutional generative adversarial networks", A. Radford, L. Metz, S. Chintala, ICLR 2016
In [17]: