In this short post, I will share a very brief GAN (Generative Adversarial Network) model and in practice, how do we train it using PyTorch. Also, I will include some tips about training as I myself found it is hard to train, especially when working with my own data and model.
Training GAN models
I wrote a blog about how to understand GAN models before, check it out. You can also find PyTorch official tutorial here . We will be focusing on the official tutorial and I will try to provide my understanding and tips of the main steps.
As the figure shows, black arrows show the feedforward path, where the Discriminator D is going to predict labels for both fake and real data. The Generator G and D can be any type of models, let us assume they are neural networks, and we are interested in how to do back
propagation. Mainly, we first update D, then G. In the above figure, I highlight the steps 1-4 using red color. To update D, we need steps 1-3, and step 4 is to update G.
The following pieces of code is from the official tutorial , the Part 2 – Train the Generator section.
Step 1: calculate gradient of the real data of D
# first clean gradient of Discriminator netD.zero_grad() ... # Calculate loss on all-real batch errD_real = criterion(output, label) # Calculate gradients for D in backward pass errD_real.backward()
Note that when we call
backward() we only calculate the gradient, but did not do the update operation. In other words, we do not update the weights of D.
Step 2: calculate gradient of the fake data of D
# Calculate D's loss on the all-fake batch errD_fake = criterion(output, label) # Calculate the gradients for this batch errD_fake.backward()
This part is much similar with step 1, but we do of the other path (from D to G)
Step 3: Update parameters of D
# Add the gradients from the all-real and all-fake batches errD = errD_real + errD_fake # Update the weights of D optimizerD.step()
We then gather all the gradient and do
step() to actually make change to D’s parameters.
Step 4: Update G
Recall that for G, the job is to generate ‘real’ data, so we measure the error using the G output and real labels. That is, all output should be predicted as TRUE data.
# Calculate G's loss based on this output errG = criterion(output, label) # Calculate gradients for G errG.backward() # Update G optimizerG.step()
So then we calculate gradients and update G, and this is much easier compared with that of D.
Tips with PyTorch
- Normally, in each batch, we call
backward()only once. But if for some reason, you need to call
backward()more than once, you need to set the first one to be
backward(retain_graph=True)in order to keep the gradients, otherwise they will be cleaned.
- If we have a personalized loss function, we can customize by
backward(grad_variable=[..])to indicate which variables we want to update using back propagation.
- Training GAN is challenging . I had a situation that the accuracy of discriminator became 0.5 and stayed. In this case, it looks like everything kept unchanged. That means our discriminator performs really bad, or, the generator is doing a good job since it can almost fool the discriminator. So then, in our loop, we can try to update discriminator only and fix the generator.