Proposed in 2014, the interesting Generative Adversarial Network (GAN) has now many variants. You might not surprised that the relevant papers are more like statistics research. When a model was proposed, the evaluations would be based on some fundamental probability distributions, where generalized applications start.
GAN Model
There are two networks in the model trained simultaneously. A Generator, G and a Discriminator, D. The framework is shown below.
The blue rectangles are the networks G and D. The input of G is noise variable z. It flows into the network G, then become fake image data G(z). Then the fake and real images are the input of the Discriminator D. D is set to tell if the input images are real or fake.
A probabilistic view is given below[*]. The goal of GAN is to learn new samples from given data set. The common method for generative models is to learn the distribution of the real training set, then sample from it to get new extra data. However, GAN is trying to learn the map between a variable and real training samples
. Normally we let
to be a Gaussian distribution. In the graph below, it shows that the Discriminator takes either real inputs
from the training sample distribution
, or the generated inputs from another distribution
.
During the training stage, we wish the Generator could do its best to “lie” to the Discriminator, making D believes the images generated by G are real images. So then the cost function and optimization goal are given below:
Training
Mini-batch SGD is applied on training GAN. The momentum (read here) on gradient-based optimization method is also used during training.When computing gradient, we split into two subtasks. It is a “minimax two-player game”.
Discriminator: take a number of m real images x and a number of m fake images G(z), where z is the noise input. Then we use gradient ascend to update . This process is repeated k times. But for simplicity, they used k=1 in experiments.
A good discriminator is a guarantee that we could train a good generator.
Generator: take number of m noise inputs z. Compute gradient as below. Gradient descent is operated to update .
During training, we fix either D or G, and update the other, maximize the mistake. At the end, the Generator is able to learn a distribution , which we wish it to converge to the real data distribution
. The optimal solution exists when
-> Nash equilibrium. The discriminator is unable to differentiate between the two distributions, i.e. D(x) = 0.5.
Performance
We could see how GAN works through an example from the original paper. Remember the goal is to learn a map from a random variable z to the training dataset x, . The arrows below are the relations, the x and z lines are the sampling spaces. The dashed black line is the simulation of training sample distribution, the green line is the fake distribution. The blue one is the probability of discriminator gives a true answer.
In the graph, (a) shows an initial state of the networks, where fake and real sample distributions are not the same. The blue line stands for the ability of the discriminator, if it could tell the sample is fake or not. If the value is high, it means it could tell, and that means the two distributions are very different. (b) and (c) shows the states of start training and updating. We could observe that on the left part of the two graphs, the discriminator has a good ability, which means the distributions are no the same. but the right tails of them are likely to be less different, where the values of blue line are smaller. After several rounds of updating, finally we could reach (d), the fake and real distributions are of little difference. The discriminator is unable to differentiate between
the two distributions.
The below shows the graph of written digits samples. Highlighted are created by the generator, at least they are very competitive high-quality samples.
Comments
According to the author, the good points for GAN compared with other generative models are:
- No inference stage during learning;Avoid sampling from Markov chain from one learning step to the next.
- Better samples generated, i.e. higher quality images.
- GAN framework can train any generative networks.
However, there are problems:
- Freedom. GAN is no longer trying to formulate p(x), but to use a distribution to do direct sampling. This brings too much freedom and making the model hard to be controlled.
- Non-convergence problem. While SGD could guarantee Nash equilibrium when the function is convex.
- Collapse problem. It is hard to tell if the training is making progress. If the Generator is giving the same sampling point every time, then the learning process is not going to continue. If the Discriminator is collapsed, it is giving similar results every time.
References
Click to access 1406.2661v1.pdf
Other blog posts 1,2,[*].
Implement GAN in Keras.
One thought on “Deep Learning 13: Understanding Generative Adversarial Network”