Some of the generative work done in the past year or two using generative adversarial networks (GANs) has been pretty exciting and demonstrated some very impressive results. The general idea is that you train two models, one (G) to generate some sort of output example given random noise as input, and one (A) to discern generated model examples from real examples. Then, by training A to be an effective discriminator, we can stack G and A to form our GAN, freeze the weights in the adversarial part of the network, and train the generative network weights to push random noisy inputs towards the “real” example class output of the adversarial half.
Building this style of network in the latest versions of Keras is actually quite straightforward and easy to do, I’ve wanted to try this out on a number of things so I put together a relatively simple version using the classic MNIST dataset to use a GAN approach to generating random handwritten digits.
Before going further I should mention all of this code is available on github here.
We set up a relatively straightforward generative model in keras using the functional API, taking 100 random inputs, and eventually mapping them down to a [1,28,28] pixel to match the MNIST data shape. Be begin by generating a dense 14×14 set of values, and then run through a handful of filters of varying sizes and numbers of channels and ultimately train using and Adam optimizer for binary cross-entropy (although we really only use the generator model in the forwards direction, we don’t train directly on this model itself). We use a sigmiod on the output layer to help saturate pixels into 0 or 1 states rather than a range of grays in between, and use batch normalization to help accelerate training and ensure that a wide range of activations are used within each layer.
# Build Generative model ... nch = 200 g_input = Input(shape=) H = Dense(nch*14*14, init='glorot_normal')(g_input) H = BatchNormalization(mode=2)(H) H = Activation('relu')(H) H = Reshape( [nch, 14, 14] )(H) H = UpSampling2D(size=(2, 2))(H) H = Convolution2D(nch/2, 3, 3, border_mode='same', init='glorot_uniform')(H) H = BatchNormalization(mode=2)(H) H = Activation('relu')(H) H = Convolution2D(nch/4, 3, 3, border_mode='same', init='glorot_uniform')(H) H = BatchNormalization(mode=2)(H) H = Activation('relu')(H) H = Convolution2D(1, 1, 1, border_mode='same', init='glorot_uniform')(H) g_V = Activation('sigmoid')(H) generator = Model(g_input,g_V) generator.compile(loss='binary_crossentropy', optimizer=opt) generator.summary()
We now have a network which could in theory take in 100 random inputs and output digits, although the current weights are all random and this clearly isn’t happening just yet.
We build an adversarial discriminator network to take in [1,28,28] image vectors and decide if they are real or fake by using several convolutional layers, a dense layer, lots of dropout, and a two element softmax output layer encoding: [0,1] = fake, and [1,0] = real. This is a relatively simple network, but the goal here is largely to get something that works passably and trains relatively quickly for experimentation.
# Build Discriminative model ... d_input = Input(shape=shp) H = Convolution2D(256, 5, 5, subsample=(2, 2), border_mode = 'same', activation='relu')(d_input) H = LeakyReLU(0.2)(H) H = Dropout(dropout_rate)(H) H = Convolution2D(512, 5, 5, subsample=(2, 2), border_mode = 'same', activation='relu')(H) H = LeakyReLU(0.2)(H) H = Dropout(dropout_rate)(H) H = Flatten()(H) H = Dense(256)(H) H = LeakyReLU(0.2)(H) H = Dropout(dropout_rate)(H) d_V = Dense(2,activation='softmax')(H) discriminator = Model(d_input,d_V) discriminator.compile(loss='categorical_crossentropy', optimizer=dopt) discriminator.summary()
We pre-train the discriminative model by generating a handful of random images using the untrained generative model, concatenating them with an equal number of real images of digits, labeling them appropriately, and then fitting until we reach a relatively stable loss value which takes 1 epoch over 20,000 examples. This is an important step which should not be skipped — pre-training accelerates the GAN massively and I was not able to achieve convergence without it (possibly due to impatience).
Generative Adversarial Model
Now that we have both the generative and adversarial models, we can combine them to make a GAN quite easily in Keras. Using the functional API, we can simply re-use the same network objects we have already instantiated and they will conveniently maintain the same shared weights with the previously compiled models. Since we want to freeze the weights in the adversarial half of the network during back-propagation of the joint model, we first run through and set the keras trainable flag to False for each element in this part of the network. For now, this seems to need to be applied at the primitive layer level rather than on the high level network so we introduce a simple function to do this.
# Freeze weights in the discriminator for stacked training def make_trainable(net, val): net.trainable = val for l in net.layers: l.trainable = val make_trainable(discriminator, False) # Build stacked GAN model gan_input = Input(shape=) H = generator(gan_input) gan_V = discriminator(H) GAN = Model(gan_input, gan_V) GAN.compile(loss='categorical_crossentropy', optimizer=opt) GAN.summary()
At this point, we now have a randomly initialized generator, a (poorly) trained discriminator, and a GAN which can be trained across the stacked model of both networks. The core of training routine for a GAN looks something like this.
- Generate images using G and random noise (forward pass only).
- Perform a Batch update of weights in A given generated images, real images, and labels.
- Perform a Batch update of weights in G given noise and forced “real” labels in the full GAN.
Running this process for a number of epochs, we can plot the loss of the GAN and Adversarial loss functions over time to get our GAN loss plots during training.
And finally, we can plot some samples from the trained generative model which look relatively like the original MNIST digits, and some examples from the original dataset for comparison.