Faking images using Generative Adversarial Network

In recent days, I have been working quite a bit on training deep models for my personal projects. Most of these have been on tensorflow. However, looking at some reviews, and how pytorch seems to be more pythonic, I started looking at the framework. For no specific reason, I thought about creating a deep convolutional generative adversarial network to create some fake images. It may be because I have been working on something on this topic now or maybe when you search for dcgan, the first page you get is one that links to pytorch.

Overview of GAN

I would not go into too much theory as there are plenty of nice documents out there explaining what GAN is.

From a 10,000 feet view, GAN is machine learning approach that combines a discriminator and generator network to create samples based on real images. It contains of two neural networks working in tandem with each other. A discriminator is trained on real images and can identify between real and fake images. A generator network takes in random noise and generates images. These images are sent to discriminator which reports the variance of images from real ones. Eventually generator tries to minimize loss reported by discriminator. After quite a number of epochs, a generator ‘learns’ what a real image looks like. This is when a generator can create ‘real’ images from noise. An ideal condition would be when discriminator reports all generated images as real.

In the example that follows, we will be using convolutional network. We call this type of approach deep convolutional generative adversarial model.

Working Dataset

I am using a dataset that contains over 6500 images of flowers. There are 102 different varieties. I believe this comes from Oxford 102 flowers dataset. Given below is a sample subset of the flowers that this dataset contains.

Creating the network

We will be using python 3.8 and add on the following libraries.

  • torch
  • torchvision
  • numpy
  • matplotlib

Pytorch was installed with cuda enabled.

conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch -c conda-forge
Load Data

Before we start loading data, we will validate if CUDA is available.

import torch

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Device: %s' % device)
import torch.utils.data as dta
import torchvision.datasets as dset
import torchvision.transforms as T

# Load dataset
data_dir = './datasets/flowers'
workers = 1
image_size = 64
batch_size = 128

dataset = dset.ImageFolder(
    root=data_dir,
    transform=T.Compose([
        T.Resize(size=(image_size, image_size)),
        T.ToTensor(),
        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
)
kwargs = {'pin_memory': True} if device=='cuda' else {}
print('KWargs: %s' % kwargs)
dataloader = dta.DataLoader(
    dataset, batch_size=batch_size, num_workers=workers,
    shuffle=True, **kwargs
)

I created this notebook on Windows PC as I wanted to get CUDA for faster training. However setting the workers to anything over 1, kept on crashing some windows DLL. Finally I had to settle for 1 for this value. We are using torchvision.datasets to load the images. We do three operations here. Firstly we resize the images to 64×64 pixels. Next we convert them to tensors. Finally we normalize the values to be within [-1, 1] range. We have created a dataloader variable in this cell to be used later.

I have the following CUDA version on my PC.

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Mon_Oct_12_20:54:10_Pacific_Daylight_Time_2020
Cuda compilation tools, release 11.1, V11.1.105
Build cuda_11.1.relgpu_drvr455TC455_06.29190527_0

Generator

"""
a. Take a 100 dim noise
b. Convert to 4x4x1024 feature maps
c. Using 4 convolutions progressively to convert to a 64x64 image
d. For all hidden layers we use BatchNorm and relu as activation
e. For output layer we use tanh to keep pixels within [-1, 1] range
"""
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn

zvector = 100
num_channel = 3
hidden_dim = 64

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        print('Generator Parameters: Zvector: %d, Channels: %d, Hidden: %d' % \
              (zvector, num_channel, hidden_dim))
        self.main = nn.Sequential(
            self.create_hidden_layer(zvector, hidden_dim * 8, 4, 1, 0),
            self.create_hidden_layer(hidden_dim * 8, hidden_dim * 4),
            self.create_hidden_layer(hidden_dim * 4, hidden_dim * 2),
            self.create_hidden_layer(hidden_dim * 2, hidden_dim),
            self.create_output_layer(hidden_dim, num_channel),
        )
    
    def create_hidden_layer(
            self, 
            in_channels, 
            out_channels,
            kernel_size=4,
            stride=2,
            padding=1):
        
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True)
        )
    
    def create_output_layer(
            self, 
            in_channels, 
            out_channels,
            kernel_size=4,
            stride=2,
            padding=1):
        
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.Tanh()
        )
    
    def forward(self, input):
        return self.main(input)


netgen = Generator().to(device)
if (device.type == 'cuda'):
    netgen = nn.DataParallel(netgen, list(range(1)))
print(netgen)

Definition of the generator is given below.

Generator Parameters: Zvector: 100, Channels: 3, Hidden: 64
DataParallel(
  (module): Generator(
    (main): Sequential(
      (0): Sequential(
        (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (1): Sequential(
        (0): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (2): Sequential(
        (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (3): Sequential(
        (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (4): Sequential(
        (0): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): Tanh()
      )
    )
  )
)

We are using Relu for activation in all the inner layers. For the final output layer, we use a Tanh.

Discriminator

"""
a. This is the reverse of Generator
b. We progressively decrease the image size
c. For hidden layers, we use LeakyRelu as activation
"""
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        print('Discriminator Parameters: Zvector: %d, Channels: %d, Hidden: %d' % \
              (zvector, num_channel, hidden_dim))
        self.main = nn.Sequential(
            self.create_hidden_layer(num_channel, hidden_dim),
            self.create_hidden_layer(hidden_dim, hidden_dim * 2),
            self.create_hidden_layer(hidden_dim * 2, hidden_dim * 4),
            self.create_hidden_layer(hidden_dim * 4, hidden_dim * 8),
            self.create_output_layer(hidden_dim * 8, 1)
        )
        
    def create_hidden_layer(
            self, 
            in_channels, 
            out_channels,
            kernel_size=4,
            stride=2,
            padding=1):
        
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )
    
    def create_output_layer(
            self, 
            in_channels, 
            out_channels,
            kernel_size=4,
            stride=1,
            padding=0):
        
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, input):
        return self.main(input)
      
netdis = Discriminator().to(device)
if (device.type == 'cuda'):
    netdis = nn.DataParallel(netdis, list(range(1)))
print(netdis)
Discriminator Parameters: Zvector: 100, Channels: 3, Hidden: 64
DataParallel(
  (module): Discriminator(
    (main): Sequential(
      (0): Sequential(
        (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (1): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (2): Sequential(
        (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (3): Sequential(
        (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (4): Sequential(
        (0): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
        (1): Sigmoid()
      )
    )
  )
)

This is the discriminator network. We are using a LeakyRELU activation for all hidden layers. This avoids the problem that we have with relu where negative is always zero. leaky relu has a small gradient also on the negative side. For the final layer we use a Sigmoid.

Optimizer and Loss functions

import torch.optim as optim

torch.manual_seed(322)

# Binary Cross Entropy loss
criterion = nn.BCELoss()
fixed_noise = torch.randn(image_size, zvector, 1, 1, device=device)
lr = 0.0002   # Learning rate
beta1 = 0.5   # Beta1 hyperparam

# Setup Adam optimizers for both Gen and Dis
optimdis = optim.Adam(netdis.parameters(), lr=lr, betas=(beta1, 0.999))
optimgen = optim.Adam(netgen.parameters(), lr=lr, betas=(beta1, 0.999))

def calculate_loss(pdata, plabel):
    # Calculate loss
    err_loss = criterion(pdata, plabel)
    # Calculate gradients in a backward pass
    err_loss.backward()
    err_grad = output.mean().item()
    return err_loss, err_grad

We are using a Binary Cross Entropy function for loss with the rates provided in the original paper. We also use Adam for optimizer.

Training

We have taken the code for training from pytorch tutorial. Here we are first sending the real images through the discriminator. Next we create some noise using the generator and pass it through the discriminator again. After this we calculate the loss and adjust generator accordingly. We keep on repeating this process for defined number of epochs. Intermediate results are captured for reporting.

import torchvision.utils as vutils

# Values that we assign for Real/ Fake
REAL_LABEL = 1.
FAKE_LABEL = 0.

num_epochs = 100
img_list = []
G_losses = []
D_losses = []
iters = 0

print('Start training...')
print_current_time('Training Start Time')
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        netdis.zero_grad()
        netgen.zero_grad()
        
        # Load data
        # Batch for Real images
        bat_set = data[0].to(device)
        bat_size = bat_set.size(0)
        bat_lbl = torch.full((bat_size,), REAL_LABEL, dtype=torch.float, device=device)
        
        # Send through discriminator
        output = netdis(bat_set).view(-1)
        assert (list(output.size())[0] == list(bat_lbl.size())[0]), \
                ('Size Error: Output %d, Label: %d' % (list(output.size())[0], list(bat_lbl.size())[0]))
        
        # Calculate loss on all-real batch
        errD_real, D_x = calculate_loss(output, bat_lbl)
        
        # Fake Batch
        noise = torch.randn(bat_size, zvector, 1, 1, device=device)
        fake_set = netgen(noise)
        bat_lbl.fill_(FAKE_LABEL)
        # Classify all fake batch with D
        output = netdis(fake_set.detach()).view(-1)
        assert (list(output.size())[0] == list(bat_lbl.size())[0]), \
                ('Size Error: Output %d, Label: %d' % (list(output.size())[0], list(label.size())[0]))

        # Calculate D's loss on the all-fake batch
        errD_fake, D_G_z1 = calculate_loss(output, bat_lbl)
        
        # Compute error of D as sum over the fake and the real batches
        errD = errD_real + errD_fake
        # Update D
        optimdis.step()
        
        # Update Generator
        bat_lbl.fill_(REAL_LABEL)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netdis(fake_set).view(-1)
        # Calculate G's loss based on this output
        errG, D_G_z2 = calculate_loss(output, bat_lbl)
        optimgen.step()
        
        # Output training stats
        if i % steps == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake_img = netgen(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake_img, padding=2, normalize=True))

        iters += 1
print("Complete Training...")
print_current_time('Training End Time')

The chart below show the loss for both discriminator and generator over the one hundred epochs we trained for.

After we ran for 100 epochs, there were quite a few nice looking flowers. The accuracy will have increased if ran this network longer for more epochs.

Real vs Fake images

Conclusion

Since this is my early foray into pytorch, I have very little idea of the complete API. However, it did not look too complicated. I still have plenty of learning before I can understand all intricacies of using this library. Ciao for now!