Deep Learning 18: GANs with PyTorch

In this short post, I will share a very brief GAN (Generative Adversarial Network) model and in practice, how do we train it using PyTorch. Also, I will include some tips about training as I myself found it is hard to train, especially when working with my own data and model.

Training GAN models

I wrote a blog about how to understand GAN models before, check it out. You can also find PyTorch official tutorial here . We will be focusing on the official tutorial and I will try to provide my understanding and tips of the main steps.


As the figure shows, black arrows show the feedforward path, where the Discriminator D is going to predict labels for both fake and real data. The Generator G and D can be any type of models, let us assume they are neural networks, and we are interested in how to do back
propagation. Mainly, we first update D, then G. In the above figure, I highlight the steps 1-4 using red color. To update D, we need steps 1-3, and step 4 is to update G.

The following pieces of code is from the official tutorial , the Part 2 – Train the Generator section.

Step 1: calculate gradient of the real data of D

# first clean gradient of Discriminator
# Calculate loss on all-real batch
errD_real = criterion(output, label)
# Calculate gradients for D in backward pass

Note that when we call backward() we only calculate the gradient, but did not do the update operation. In other words, we do not update the weights of D.

Step 2: calculate gradient of the fake data of D

# Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)
# Calculate the gradients for this batch

This part is much similar with step 1, but we do of the other path (from D to G)

Step 3: Update parameters of D

# Add the gradients from the all-real and all-fake batches
errD = errD_real + errD_fake
# Update the weights of D

We then gather all the gradient and do step() to actually make change to D’s parameters.

Step 4: Update G

Recall that for G, the job is to generate ‘real’ data, so we measure the error using the G output and real labels. That is, all output should be predicted as TRUE data.

# Calculate G's loss based on this output
errG = criterion(output, label)
# Calculate gradients for G
# Update G

So then we calculate gradients and update G, and this is much easier compared with that of D.

Tips with PyTorch

  1. Normally, in each batch, we call forward() and then backward() only once. But if for some reason, you need to call backward() more than once, you need to set the first one to be backward(retain_graph=True) in order to keep the gradients, otherwise they will be cleaned.
  2. If we have a personalized loss function, we can customize by backward(grad_variable=[..]) to indicate which variables we want to update using back propagation.
  3. Training GAN is challenging . I had a situation that the accuracy of discriminator became 0.5 and stayed. In this case, it looks like everything kept unchanged. That means our discriminator performs really bad, or, the generator is doing a good job since it can almost fool the discriminator. So then, in our loop, we can try to update discriminator only and fix the generator.

Published by Irene

Keep calm and update blog.

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

%d bloggers like this: