Deep Learning 13: Understanding Generative Adversarial Network

Proposed in 2014, the interesting Generative Adversarial Network (GAN) has now many variants. You might not surprised that the relevant papers are more like statistics research. When a model was proposed, the evaluations would be based on some fundamental probability distributions, where generalized applications start.

GAN Model

There are two networks in the model trained simultaneously. A Generator, G and a Discriminator, D. The framework is shown below.

15128517_907134486088787_794468940_n

The blue rectangles are the networks G and D. The input of G is noise variable z. It flows into the network G, then become fake image data G(z). Then the fake and real images are the input of the Discriminator D. D is set to tell if the input images are real or fake.

A probabilistic view is given below[*]. The goal of GAN is to learn new samples from given data set. The common method for generative models is to learn the distribution of the real training set, then sample from it to get new extra data. However, GAN is trying to learn the map between a variable z and real training samples x. Normally we let z to be a Gaussian distribution. In the graph below, it shows that the Discriminator takes either real inputs x from the training sample distribution q(x), or the generated inputs from another distribution p(z).

Generative adversarial networks
During the training stage, we wish the Generator could do its best to “lie” to the Discriminator, making D believes the images generated by G are real images. So then the cost function and optimization goal are given below:
15139806_907140552754847_517075619_n

Training

Mini-batch SGD is applied on training GAN. The momentum (read here)  on gradient-based optimization method is also used during training.When computing gradient, we split into two subtasks. It is a “minimax two-player game”.
Discriminator: take a number of m real images x and a number of m fake images G(z), where z is the noise input. Then we use gradient ascend to update  \theta_d. This process is repeated k times. But for simplicity, they used k=1 in experiments.D

A good discriminator is a guarantee that we could train a good generator.
Generator: take number of m noise inputs z. Compute gradient as below. Gradient descent is operated to update \theta_g.

15139735_907160372752865_137712290_n
During training, we fix either D or G, and update the other, maximize the mistake. At the end, the Generator is able to learn a distribution P_g, which we wish it to converge to the real data distribution P_{data}. The optimal solution exists when P_g = P_{data}-> Nash equilibrium. The discriminator is unable to differentiate between the two distributions, i.e. D(x) = 0.5.

Performance

We could see how GAN works through an example from the original paper. Remember the goal is to learn a map from a random variable z to the training dataset x, x=G(z). The arrows below are the relations, the x and z lines are the sampling spaces. The dashed black line is the simulation of training sample distribution, the green line is the fake distribution. The blue one is the probability of discriminator  gives a true answer.

In the graph, (a) shows an initial state of the networks, where fake and real sample distributions are not the same. The blue line stands for the ability of the discriminator, if it could tell the sample is fake or not. If the value is high, it means it could tell, and that means the two distributions are very different.  (b) and (c) shows the states of start training and updating. We could observe that on the left part of the two graphs, the discriminator has a good ability, which means the distributions are no the same. but the right tails of them are likely to be less different, where the values of blue line are smaller. After several rounds of updating, finally we could reach (d), the fake and real distributions are of little difference.  The discriminator is unable to differentiate between
the two distributions.

gan5

The below shows the graph of written digits samples. Highlighted are created by the generator, at least they are very competitive high-quality samples.

15058669_907193882749514_621428372_n

Comments

According to the author, the good points for GAN compared with other generative models are:

  • No inference stage during learning;Avoid sampling from Markov chain from one learning step to the next.
  • Better samples generated, i.e. higher quality images.
  • GAN framework can train any generative networks.

However, there are problems:

  • Freedom. GAN is no longer trying to formulate p(x), but to use a distribution to do direct sampling. This brings too much freedom and making the model hard to be controlled.
  • Non-convergence problem. While SGD could guarantee Nash equilibrium when the function is convex.
  • Collapse problem. It is hard to tell if the training is making progress. If the Generator is giving the same sampling point every time, then the learning process is not going to continue. If the Discriminator is collapsed, it is giving similar results every time.

References

Click to access 1406.2661v1.pdf

Other blog posts 1,2,[*].
Implement GAN in Keras.

Published by Irene

Keep calm and update blog.

One thought on “Deep Learning 13: Understanding Generative Adversarial Network

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

%d bloggers like this: