Denoising Diffusion Implicit Models (DDIM) Sampling

This implements DDIM sampling from the paper Denoising Diffusion Implicit Models

16from typing import Optional, List 17 18import numpy as np 19import torch 20 21from labml import monit 22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion 23from labml_nn.diffusion.stable_diffusion.sampler import DiffusionSampler

DDIM Sampler

This extends the DiffusionSampler base class.

DDIM samples images by repeatedly removing noise by sampling step by step using,

where is random noise, is a subsequence of of length , and

Note that, in DDIM paper refers to from DDPM.

26class DDIMSampler(DiffusionSampler):
52 model: LatentDiffusion
  • model is the model to predict noise
  • n_steps is the number of DDIM sampling steps,
  • ddim_discretize specifies how to extract from . It can be either uniform or quad .
  • ddim_eta is used to calculate . makes the sampling process deterministic.
54 def __init__(self, model: LatentDiffusion, n_steps: int, ddim_discretize: str = "uniform", ddim_eta: float = 0.):
63 super().__init__(model)

Number of steps,

65 self.n_steps = model.n_steps

Calculate to be uniformly distributed across

68 if ddim_discretize == 'uniform': 69 c = self.n_steps // n_steps 70 self.time_steps = np.asarray(list(range(0, self.n_steps, c))) + 1

Calculate to be quadratically distributed across

72 elif ddim_discretize == 'quad': 73 self.time_steps = ((np.linspace(0, np.sqrt(self.n_steps * .8), n_steps)) ** 2).astype(int) + 1 74 else: 75 raise NotImplementedError(ddim_discretize) 76 77 with torch.no_grad():

Get

79 alpha_bar = self.model.alpha_bar

82 self.ddim_alpha = alpha_bar[self.time_steps].clone().to(torch.float32)

84 self.ddim_alpha_sqrt = torch.sqrt(self.ddim_alpha)

86 self.ddim_alpha_prev = torch.cat([alpha_bar[0:1], alpha_bar[self.time_steps[:-1]]])

91 self.ddim_sigma = (ddim_eta * 92 ((1 - self.ddim_alpha_prev) / (1 - self.ddim_alpha) * 93 (1 - self.ddim_alpha / self.ddim_alpha_prev)) ** .5)

96 self.ddim_sqrt_one_minus_alpha = (1. - self.ddim_alpha) ** .5

Sampling Loop

  • shape is the shape of the generated images in the form [batch_size, channels, height, width]
  • cond is the conditional embeddings
  • temperature is the noise temperature (random noise gets multiplied by this)
  • x_last is . If not provided random noise will be used.
  • uncond_scale is the unconditional guidance scale . This is used for
  • uncond_cond is the conditional embedding for empty prompt
  • skip_steps is the number of time steps to skip . We start sampling from . And x_last is then .
98 @torch.no_grad() 99 def sample(self, 100 shape: List[int], 101 cond: torch.Tensor, 102 repeat_noise: bool = False, 103 temperature: float = 1., 104 x_last: Optional[torch.Tensor] = None, 105 uncond_scale: float = 1., 106 uncond_cond: Optional[torch.Tensor] = None, 107 skip_steps: int = 0, 108 ):

Get device and batch size

125 device = self.model.device 126 bs = shape[0]

Get

129 x = x_last if x_last is not None else torch.randn(shape, device=device)

Time steps to sample at

132 time_steps = np.flip(self.time_steps)[skip_steps:] 133 134 for i, step in monit.enum('Sample', time_steps):

Index in the list

136 index = len(time_steps) - i - 1

Time step

138 ts = x.new_full((bs,), step, dtype=torch.long)

Sample

141 x, pred_x0, e_t = self.p_sample(x, cond, ts, step, index=index, 142 repeat_noise=repeat_noise, 143 temperature=temperature, 144 uncond_scale=uncond_scale, 145 uncond_cond=uncond_cond)

Return

148 return x

Sample

  • x is of shape [batch_size, channels, height, width]
  • c is the conditional embeddings of shape [batch_size, emb_size]
  • t is of shape [batch_size]
  • step is the step as an integer
  • index is index in the list
  • repeat_noise specified whether the noise should be same for all samples in the batch
  • temperature is the noise temperature (random noise gets multiplied by this)
  • uncond_scale is the unconditional guidance scale . This is used for
  • uncond_cond is the conditional embedding for empty prompt
150 @torch.no_grad() 151 def p_sample(self, x: torch.Tensor, c: torch.Tensor, t: torch.Tensor, step: int, index: int, *, 152 repeat_noise: bool = False, 153 temperature: float = 1., 154 uncond_scale: float = 1., 155 uncond_cond: Optional[torch.Tensor] = None):

Get

172 e_t = self.get_eps(x, t, c, 173 uncond_scale=uncond_scale, 174 uncond_cond=uncond_cond)

Calculate and predicted

177 x_prev, pred_x0 = self.get_x_prev_and_pred_x0(e_t, index, x, 178 temperature=temperature, 179 repeat_noise=repeat_noise)

182 return x_prev, pred_x0, e_t

Sample given

184 def get_x_prev_and_pred_x0(self, e_t: torch.Tensor, index: int, x: torch.Tensor, *, 185 temperature: float, 186 repeat_noise: bool):

192 alpha = self.ddim_alpha[index]

194 alpha_prev = self.ddim_alpha_prev[index]

196 sigma = self.ddim_sigma[index]

198 sqrt_one_minus_alpha = self.ddim_sqrt_one_minus_alpha[index]

Current prediction for ,

202 pred_x0 = (x - sqrt_one_minus_alpha * e_t) / (alpha ** 0.5)

Direction pointing to

205 dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * e_t

No noise is added, when

208 if sigma == 0.: 209 noise = 0.

If same noise is used for all samples in the batch

211 elif repeat_noise: 212 noise = torch.randn((1, *x.shape[1:]), device=x.device)

Different noise for each sample

214 else: 215 noise = torch.randn(x.shape, device=x.device)

Multiply noise by the temperature

218 noise = noise * temperature

227 x_prev = (alpha_prev ** 0.5) * pred_x0 + dir_xt + sigma * noise

230 return x_prev, pred_x0

Sample from

  • x0 is of shape [batch_size, channels, height, width]
  • index is the time step index
  • noise is the noise,
232 @torch.no_grad() 233 def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):

Random noise, if noise is not specified

246 if noise is None: 247 noise = torch.randn_like(x0)

Sample from

252 return self.ddim_alpha_sqrt[index] * x0 + self.ddim_sqrt_one_minus_alpha[index] * noise

Painting Loop

  • x is of shape [batch_size, channels, height, width]
  • cond is the conditional embeddings
  • t_start is the sampling step to start from,
  • orig is the original image in latent page which we are in paining. If this is not provided, it'll be an image to image transformation.
  • mask is the mask to keep the original image.
  • orig_noise is fixed noise to be added to the original image.
  • uncond_scale is the unconditional guidance scale . This is used for
  • uncond_cond is the conditional embedding for empty prompt
254 @torch.no_grad() 255 def paint(self, x: torch.Tensor, cond: torch.Tensor, t_start: int, *, 256 orig: Optional[torch.Tensor] = None, 257 mask: Optional[torch.Tensor] = None, orig_noise: Optional[torch.Tensor] = None, 258 uncond_scale: float = 1., 259 uncond_cond: Optional[torch.Tensor] = None, 260 ):

Get batch size

276 bs = x.shape[0]

Time steps to sample at

279 time_steps = np.flip(self.time_steps[:t_start]) 280 281 for i, step in monit.enum('Paint', time_steps):

Index in the list

283 index = len(time_steps) - i - 1

Time step

285 ts = x.new_full((bs,), step, dtype=torch.long)

Sample

288 x, _, _ = self.p_sample(x, cond, ts, step, index=index, 289 uncond_scale=uncond_scale, 290 uncond_cond=uncond_cond)

Replace the masked area with original image

293 if orig is not None:

Get the for original image in latent space

295 orig_t = self.q_sample(orig, index, noise=orig_noise)

Replace the masked area

297 x = orig_t * mask + x * (1 - mask)

300 return x