Generative Modelling Background
At a high level, generative modelling describes the branch of AI dealing with creating new data, rather than drawing conclusions from existing data.
For example, a generative modelling task may be that given a limited dataset of dog images, train a model to generate new images of dogs. This could help provide more training data to a different model attempting to classify images as cats or dogs. A few general uses of generative modelling include Image-to-Image Translation, Text-to-Image Translation, and Semantic-Image-to-Photo Translation.
Traditionally, these tasks have been handled by a GAN based model. As a quick background, GAN models consist of a generator and discriminator, both of which are just neural networks. They are connected through an adversarial network.
The generator’s purpose is to create “fake” training data from noise. After training, the generator should be able to output data that could have plausibly existed in the training dataset.
The discriminator’s job is to learn to judge whether a given input sample is from the actual dataset or the generator. It then penalizes/rewards the generator, or adjusts its weights, according to this.
This is a unique training style called “adversarial” because of how the generator and discriminator are locked in a zero-sum game, where one improves only due to a fault of the other.
GANs have been a focus of generative modelling research since their release, and new papers are frequently released featuring a GAN optimized for a specific usecase — some popular ones include SRGAN, for image super-resolution, and cGAN, which has the ability to generate data based on an input label.
Diffusion Models Intro
GANs, while groundbreaking and performing extremely well in certain use cases, have 3 key flaws stemming from the adversarial training that hinder their ability to perform in many scenarios.
Consider that we train using the MNIST dataset, and that during training the generator loss for the images of “2” is much lower than for other digits — this would incentivize the generator to only produce images of “2”. This is called mode collapse.
Next, consider that the discriminator performs correct classification of samples to a degree where it is constantly rejecting the generator’s images. This means the generator would have a vanishing gradient, and never learn. Similarly, GANs are often non-convergent because of how the encoder’s and decoder’s parameters can keep oscillating as they try to maximize and and minimize loss respectively.
Enter Diffusion models.
Diffusion models were named after the thermodynamic concept where gas molecules diffuse from high to low density regions. These models became popular after the release of Denoising Diffusion Probabilistic Models, which specifically addressed the task of image synthesis.
Diffusion models aim to generate valid data from pure noise. They learn this by taking data samples (image, audio, etc.) and progressively destroying them by adding Gaussian noise until they are literal garbage. Then, the model attempts to recover the input by reversing this process.
A Deeper Look
To understand the potential and need for diffusion models, we have to first understand their architecture and training process. I’ll first cover this at a high level, and then go through the procedure again while delving a bit deeper into the details.
Lets start with an arbitrary training sample — such as this picture of Bob the cheetah.
The training process first involves forward diffusion, where we apply Gaussian noise to the image over T timesteps. At the end of this process, the image should ideally be complete noise.
Now, we use a model to learn the reverse process of getting the image back from pure noise. Note the flexibility afforded by diffusion models — we can use any model we want in this step as long as the input and output to the model have the same dimensionality. Traditionally, this is usually done using a U-Net.
When sampling after training, the model is fed Gaussian noise and should generate outputs similar to the images in the training dataset by running through the reverse process.
Thats it. Now let’s do this again, but with a bit more math.
Starting with a training sample x_0, we progressively apply Gaussian noise till x_T in the form of a Markov chain q, where x_t only depends on x_(t-1).
This means that when working with images, the noise added to an image only depends on the previous image. Noise is sampled using the following formula, a conditional Gaussian distribution:
The important terms in this equation are x_t, the output, sqrt(1-β_t)x_(t-1), the mean of the previous image, β_t, the variance scheduler (set manually), and β_t(I), the fixed variance.
For example, when working with images, pixels have form (r, g, b), corresponding to the integer values of red, green, and blue respectively in the interval [0, 255] (note these values are usually scaled to [-1, 1] for training). So to generate a new value for each pixel at time t, we take each of r, g, and b at t-1 as x_(t-1) and calculate x_t by applying q.
The main takeaway from this forward process is that β_t impacts how rapidly the image decays — higher values mean the image becomes pure noise faster, and lower values take longer. The best value for this is often determined experimentally.
During training, closed form sampling can also be done to obtain an image at timestep t directly.
We start at T, where our data sample has become pure noise. Note the mean of 0 and unit variance.
Next, we use the following formula to step backwards and learn the probability distribution of an earlier timestamp from the current timestamp (x_(t-1) from x_t).
When training, there is usually just random sampling of timesteps done (t to t-1) and the whole sequence from T to 0 is not completed. However, when doing inference with a trained model we must iterate the full sequence to generate valid output.
When working with image data, the outputs should be in the range [-1, 1] for each pixel, assuming that before training pixels were integer values in interval [0, 255]. To make this possible, a discrete decoder is used that provides log likelihoods for each pixel value as the last term in the reverse process.
Since the model’s parameters are global across time steps, the authors of the DDPM paper also added a positional embedding to encode the timestep, similar to transformers. These are often calculated using sin and cosine functions.
For a more detailed look into the math involved behind diffusion models, check out this great explanation: https://lilianweng.github.io/posts/2021-07-11-diffusion-models/
Why use Diffusion Models?
The main advantage diffusion models provide is a more robust training technique as compared to GANs. Since they don’t use adversarial training, they are able to avoid some of the common pitfalls associated with the technique, such as mode collapse. As stated in this paper by OpenAI, “GANs are able to trade off diversity for fidelity, producing high-quality samples but not covering the whole distribution”. Diffusion models have a better ability to produce images that better represent the depth and bredth of a training set.
Additionally, training is much easier to manage as these models do not require as much handholding as GANs, which can require extensive tweaking of hyper parameters to converge and produce valid results.
However, these advantages also come at a cost — the slow, iterative approach diffusion models take for sampling can make them less feasible for situations where speed is desired, as noted in this technical walkthrough by Nvidia.
There is research currently being done on improving the scalability of these models, which would be a vital characteristic for choosing them over GANs.
Cutting Edge Applications
There have been quite a few recent applications of Diffusion Models that are pushing the limit for generative modelling including Stable Diffusion by StabilityAI and Glide by OpenAI. Recent research has also been done to suggest that in the future, Diffusion models could replace GANs for generative modelling, such as this recent paper by OpenAI: Diffusion Models beat GANs on Image Synthesis
Sample Code to Get Started on Training
To see a full implementation from scratch, check out the great guide over here: https://huggingface.co/blog/annotated-diffusion
However, if you are just looking to get hands on quickly, I’ve provided some starter code below to begin training on your own dataset using the denoising_diffusion_models package which already contains basic implementations of U-Net and Gaussian Diffusion.
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
model = Unet(
dim = 64,
dim_mults = (1, 2, 4, 8)
diffusion = GaussianDiffusion(
image_size = 128,
timesteps = 1000,
sampling_timesteps = 250,
loss_type = 'l1'
trainer = Trainer(
train_batch_size = 8,
train_lr = 8e-5,
train_num_steps = 10000, # total training steps
gradient_accumulate_every = 2,
ema_decay = 0.995,
amp = False # turn on for mixed precision, I've found it better to keep it off
If you found this article interesting, please let me know on LinkedIn — I’d love to hear your thoughts.