Latent Diffusion Models

Latent diffusion models use an auto-encoder to map between image space and latent space. The diffusion model works on the latent space, which makes it a lot easier to train. It is based on paper High-Resolution Image Synthesis with Latent Diffusion Models.

They use a pre-trained auto-encoder and train the diffusion U-Net on the latent space of the pre-trained auto-encoder.

For a simpler diffusion implementation refer to our DDPM implementation. We use same notations for , schedules, etc.

24from typing import List 25 26import torch 27import torch.nn as nn 28 29from labml_nn.diffusion.stable_diffusion.model.autoencoder import Autoencoder 30from labml_nn.diffusion.stable_diffusion.model.clip_embedder import CLIPTextEmbedder 31from labml_nn.diffusion.stable_diffusion.model.unet import UNetModel

This is an empty wrapper class around the U-Net. We keep this to have the same model structure as CompVis/stable-diffusion so that we do not have to map the checkpoint weights explicitly.

34class DiffusionWrapper(nn.Module):
42 def __init__(self, diffusion_model: UNetModel): 43 super().__init__() 44 self.diffusion_model = diffusion_model
46 def forward(self, x: torch.Tensor, time_steps: torch.Tensor, context: torch.Tensor): 47 return self.diffusion_model(x, time_steps, context)

Latent diffusion model

This contains following components:

50class LatentDiffusion(nn.Module):
60 model: DiffusionWrapper 61 first_stage_model: Autoencoder 62 cond_stage_model: CLIPTextEmbedder
  • unet_model is the U-Net that predicts noise , in latent space
  • autoencoder is the AutoEncoder
  • clip_embedder is the CLIP embeddings generator
  • latent_scaling_factor is the scaling factor for the latent space. The encodings of the autoencoder are scaled by this before feeding into the U-Net.
  • n_steps is the number of diffusion steps .
  • linear_start is the start of the schedule.
  • linear_end is the end of the schedule.
64 def __init__(self, 65 unet_model: UNetModel, 66 autoencoder: Autoencoder, 67 clip_embedder: CLIPTextEmbedder, 68 latent_scaling_factor: float, 69 n_steps: int, 70 linear_start: float, 71 linear_end: float, 72 ):
84 super().__init__()

Wrap the U-Net to keep the same model structure as CompVis/stable-diffusion.

87 self.model = DiffusionWrapper(unet_model)

Auto-encoder and scaling factor

89 self.first_stage_model = autoencoder 90 self.latent_scaling_factor = latent_scaling_factor
92 self.cond_stage_model = clip_embedder

Number of steps

95 self.n_steps = n_steps

schedule

98 beta = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_steps, dtype=torch.float64) ** 2 99 self.beta = nn.Parameter(beta.to(torch.float32), requires_grad=False)

101 alpha = 1. - beta

103 alpha_bar = torch.cumprod(alpha, dim=0) 104 self.alpha_bar = nn.Parameter(alpha_bar.to(torch.float32), requires_grad=False)

Get model device

106 @property 107 def device(self):
111 return next(iter(self.model.parameters())).device

Get CLIP embeddings for a list of text prompts

113 def get_text_conditioning(self, prompts: List[str]):
117 return self.cond_stage_model(prompts)

Get scaled latent space representation of the image

The encoder output is a distribution. We sample from that and multiply by the scaling factor.

119 def autoencoder_encode(self, image: torch.Tensor):
126 return self.latent_scaling_factor * self.first_stage_model.encode(image).sample()

Get image from the latent representation

We scale down by the scaling factor and then decode.

128 def autoencoder_decode(self, z: torch.Tensor):
134 return self.first_stage_model.decode(z / self.latent_scaling_factor)

Predict noise

Predict noise given the latent representation , time step , and the conditioning context .

136 def forward(self, x: torch.Tensor, t: torch.Tensor, context: torch.Tensor):
145 return self.model(x, t, context)