Skip to content

Commit 5f52459

Browse files
committed
added standalone schedulers
1 parent 5e86c32 commit 5f52459

File tree

6 files changed

+920
-3
lines changed

6 files changed

+920
-3
lines changed

backends/stable_diffusion/schedulers/__init__.py

Whitespace-only changes.
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# source : https://github.com/huggingface/diffusers/
2+
3+
from typing import Union
4+
import numpy as np
5+
6+
7+
class SchedulerMixin:
8+
"""
9+
Mixin containing common functions for the schedulers.
10+
"""
11+
12+
ignore_for_config = ["tensor_format"]
13+
14+
def set_format(self, tensor_format="pt"):
15+
self.tensor_format = tensor_format
16+
if tensor_format == "pt":
17+
for key, value in vars(self).items():
18+
if isinstance(value, np.ndarray):
19+
setattr(self, key, torch.from_numpy(value))
20+
21+
return self
22+
23+
def clip(self, tensor, min_value=None, max_value=None):
24+
tensor_format = getattr(self, "tensor_format", "pt")
25+
26+
if tensor_format == "np":
27+
return np.clip(tensor, min_value, max_value)
28+
elif tensor_format == "pt":
29+
return torch.clamp(tensor, min_value, max_value)
30+
31+
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
32+
33+
def log(self, tensor):
34+
tensor_format = getattr(self, "tensor_format", "pt")
35+
36+
if tensor_format == "np":
37+
return np.log(tensor)
38+
elif tensor_format == "pt":
39+
return torch.log(tensor)
40+
41+
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
42+
43+
def match_shape(self, values: Union[np.ndarray], broadcast_array: Union[np.ndarray]):
44+
"""
45+
Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims.
46+
Args:
47+
values: an array or tensor of values to extract.
48+
broadcast_array: an array with a larger shape of K dimensions with the batch
49+
dimension equal to the length of timesteps.
50+
Returns:
51+
a tensor of shape [batch_size, 1, ...] where the shape has K dims.
52+
"""
53+
54+
tensor_format = getattr(self, "tensor_format", "pt")
55+
values = values.flatten()
56+
57+
while len(values.shape) < len(broadcast_array.shape):
58+
values = values[..., None]
59+
if tensor_format == "pt":
60+
values = values.to(broadcast_array.device)
61+
62+
return values
63+
64+
def norm(self, tensor):
65+
tensor_format = getattr(self, "tensor_format", "pt")
66+
if tensor_format == "np":
67+
return np.linalg.norm(tensor)
68+
elif tensor_format == "pt":
69+
return torch.norm(tensor.reshape(tensor.shape[0], -1), dim=-1).mean()
70+
71+
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
72+
73+
def randn_like(self, tensor, generator=None):
74+
tensor_format = getattr(self, "tensor_format", "pt")
75+
if tensor_format == "np":
76+
return np.random.randn(*np.shape(tensor))
77+
elif tensor_format == "pt":
78+
# return torch.randn_like(tensor)
79+
return torch.randn(tensor.shape, layout=tensor.layout, generator=generator).to(tensor.device)
80+
81+
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
82+
83+
def zeros_like(self, tensor):
84+
tensor_format = getattr(self, "tensor_format", "pt")
85+
if tensor_format == "np":
86+
return np.zeros_like(tensor)
87+
elif tensor_format == "pt":
88+
return torch.zeros_like(tensor)
89+
90+
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
# source : https://github.com/huggingface/diffusers/
2+
3+
4+
# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
19+
# and https://github.com/hojonathanho/diffusion
20+
21+
import math
22+
from typing import Optional, Tuple, Union
23+
24+
import numpy as np
25+
26+
27+
from .scheduler_mixin import SchedulerMixin
28+
29+
30+
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
31+
"""
32+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
33+
(1-beta) over time from t = [0,1].
34+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
35+
to that part of the diffusion process.
36+
Args:
37+
num_diffusion_timesteps (`int`): the number of betas to produce.
38+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
39+
prevent singularities.
40+
Returns:
41+
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
42+
"""
43+
44+
def alpha_bar(time_step):
45+
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
46+
47+
betas = []
48+
for i in range(num_diffusion_timesteps):
49+
t1 = i / num_diffusion_timesteps
50+
t2 = (i + 1) / num_diffusion_timesteps
51+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
52+
return np.array(betas, dtype=np.float32)
53+
54+
55+
class DDIMScheduler(SchedulerMixin):
56+
"""
57+
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
58+
diffusion probabilistic models (DDPMs) with non-Markovian guidance.
59+
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
60+
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
61+
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
62+
[`~ConfigMixin.from_config`] functios.
63+
For more details, see the original paper: https://arxiv.org/abs/2010.02502
64+
Args:
65+
num_train_timesteps (`int`): number of diffusion steps used to train the model.
66+
beta_start (`float`): the starting `beta` value of inference.
67+
beta_end (`float`): the final `beta` value.
68+
beta_schedule (`str`):
69+
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
70+
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
71+
trained_betas (`np.ndarray`, optional): TODO
72+
timestep_values (`np.ndarray`, optional): TODO
73+
clip_sample (`bool`, default `True`):
74+
option to clip predicted sample between -1 and 1 for numerical stability.
75+
set_alpha_to_one (`bool`, default `True`):
76+
if alpha for final step is 1 or the final alpha of the "non-previous" one.
77+
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
78+
"""
79+
80+
def __init__(
81+
self,
82+
num_train_timesteps: int = 1000,
83+
beta_start: float = 0.0001,
84+
beta_end: float = 0.02,
85+
beta_schedule: str = "linear",
86+
trained_betas: Optional[np.ndarray] = None,
87+
timestep_values: Optional[np.ndarray] = None,
88+
clip_sample: bool = True,
89+
set_alpha_to_one: bool = True,
90+
tensor_format: str = "pt",
91+
):
92+
if trained_betas is not None:
93+
self.betas = np.asarray(trained_betas)
94+
if beta_schedule == "linear":
95+
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
96+
elif beta_schedule == "scaled_linear":
97+
# this schedule is very specific to the latent diffusion model.
98+
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
99+
elif beta_schedule == "squaredcos_cap_v2":
100+
# Glide cosine schedule
101+
self.betas = betas_for_alpha_bar(num_train_timesteps)
102+
else:
103+
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
104+
105+
self.alphas = 1.0 - self.betas
106+
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
107+
108+
self.num_train_timesteps = num_train_timesteps
109+
110+
self.beta_start = beta_start
111+
self.beta_end = beta_end
112+
self.beta_schedule = beta_schedule
113+
self.trained_betas = trained_betas
114+
self.timestep_values = timestep_values
115+
self.clip_sample = clip_sample
116+
self.set_alpha_to_one = set_alpha_to_one
117+
self.tensor_format = tensor_format
118+
119+
# At every step in ddim, we are looking into the previous alphas_cumprod
120+
# For the final step, there is no previous alphas_cumprod because we are already at 0
121+
# `set_alpha_to_one` decides whether we set this paratemer simply to one or
122+
# whether we use the final alpha of the "non-previous" one.
123+
self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
124+
125+
# setable values
126+
self.num_inference_steps = None
127+
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
128+
129+
self.tensor_format = tensor_format
130+
self.config = self
131+
self.set_format(tensor_format=tensor_format)
132+
133+
def _get_variance(self, timestep, prev_timestep):
134+
alpha_prod_t = self.alphas_cumprod[timestep]
135+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
136+
beta_prod_t = 1 - alpha_prod_t
137+
beta_prod_t_prev = 1 - alpha_prod_t_prev
138+
139+
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
140+
141+
return variance
142+
143+
def set_timesteps(self, num_inference_steps: int, offset: int = 0):
144+
"""
145+
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
146+
Args:
147+
num_inference_steps (`int`):
148+
the number of diffusion steps used when generating samples with a pre-trained model.
149+
offset (`int`): TODO
150+
"""
151+
self.num_inference_steps = num_inference_steps
152+
self.timesteps = np.arange(
153+
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
154+
)[::-1].copy()
155+
self.timesteps += offset
156+
self.set_format(tensor_format=self.tensor_format)
157+
158+
def step(
159+
self,
160+
model_output: Union[ np.ndarray],
161+
timestep: int,
162+
sample: Union[ np.ndarray],
163+
eta: float = 0.0,
164+
use_clipped_model_output: bool = False,
165+
generator=None,
166+
return_dict: bool = True,
167+
) -> Union[ Tuple]:
168+
"""
169+
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
170+
process from the learned model outputs (most often the predicted noise).
171+
Args:
172+
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
173+
timestep (`int`): current discrete timestep in the diffusion chain.
174+
sample (`torch.FloatTensor` or `np.ndarray`):
175+
current instance of sample being created by diffusion process.
176+
eta (`float`): weight of noise for added noise in diffusion step.
177+
use_clipped_model_output (`bool`): TODO
178+
generator: random number generator.
179+
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
180+
Returns:
181+
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
182+
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
183+
returning a tuple, the first element is the sample tensor.
184+
"""
185+
if self.num_inference_steps is None:
186+
raise ValueError(
187+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
188+
)
189+
190+
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
191+
# Ideally, read DDIM paper in-detail understanding
192+
193+
# Notation (<variable name> -> <name in paper>
194+
# - pred_noise_t -> e_theta(x_t, t)
195+
# - pred_original_sample -> f_theta(x_t, t) or x_0
196+
# - std_dev_t -> sigma_t
197+
# - eta -> η
198+
# - pred_sample_direction -> "direction pointingc to x_t"
199+
# - pred_prev_sample -> "x_t-1"
200+
201+
# 1. get previous step value (=t-1)
202+
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
203+
204+
# 2. compute alphas, betas
205+
alpha_prod_t = self.alphas_cumprod[timestep]
206+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
207+
beta_prod_t = 1 - alpha_prod_t
208+
209+
# 3. compute predicted original sample from predicted noise also called
210+
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
211+
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
212+
213+
# 4. Clip "predicted x_0"
214+
if self.config.clip_sample:
215+
pred_original_sample = self.clip(pred_original_sample, -1, 1)
216+
217+
# 5. compute variance: "sigma_t(η)" -> see formula (16)
218+
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
219+
variance = self._get_variance(timestep, prev_timestep)
220+
std_dev_t = eta * variance ** (0.5)
221+
222+
if use_clipped_model_output:
223+
# the model_output is always re-derived from the clipped x_0 in Glide
224+
model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
225+
226+
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
227+
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
228+
229+
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
230+
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
231+
232+
if eta > 0:
233+
noise = torch.randn(model_output.shape, generator=generator).to(device)
234+
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
235+
236+
if not torch.is_tensor(model_output):
237+
variance = variance.numpy()
238+
239+
prev_sample = prev_sample + variance
240+
241+
if not return_dict:
242+
return (prev_sample,)
243+
244+
return {"prev_sample":prev_sample} # SchedulerOutput(prev_sample=prev_sample)
245+
246+
def add_noise(
247+
self,
248+
original_samples: Union[ np.ndarray],
249+
noise: Union[ np.ndarray],
250+
timesteps: Union[np.ndarray],
251+
) -> Union[ np.ndarray]:
252+
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
253+
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
254+
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
255+
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
256+
257+
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
258+
return noisy_samples
259+
260+
def __len__(self):
261+
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)