import torch
import logging
import numpy as np
import torch.nn as nn
import torch.functional as F
from typing import List
from argparse import Namespace
_logger = logging.getLogger(__name__)
[docs]
class DDPM:
"""
Denoising Diffusion Probabilistic Model (DDPM).
Implements the DDPM forward noising process and reverse denoising process.
Supports both linear and cosine beta schedules.
"""
[docs]
def __init__(self, args: Namespace, flexibility=0.0):
"""
Initialize the DDPM noise schedule.
Args:
args (Namespace): Configuration object containing:
- beta_schedule (str): `'linear'` or `'cosine'`.
- T (int): Total number of diffusion steps.
- device (str): Compute device.
- antithetic_sampling (bool): Whether to use antithetic
sampling in `add_noise` to reduce variance.
flexibility (float, optional): Variance interpolation coefficient
controlling stochasticity during generation. `0.0` corresponds
to fixed variance and `1.0` to full DDPM variance.
"""
self.args = args
if args.beta_schedule == "cosine":
beta = self.cosine_beta_schedule(args.T)
elif args.beta_schedule == "linear":
beta = self.linear_beta_schedule(args.T)
else:
raise ValueError(f"Unknown beta_schedule: {args.beta_schedule}")
self.beta = torch.concatenate([torch.tensor([0.0], device=args.device), beta.to(args.device)]) # (T+1,)
self.alpha = 1 - self.beta
self.alpha_bar = self.alpha.cumprod(dim=0)
self.flexibility = flexibility
[docs]
def to(self, device):
"""
Move the noise schedule tensors to the target device.
Args:
device (torch.device or str): Target device.
Returns:
self: Returns the instance itself for chaining.
"""
self.beta = self.beta.to(device)
self.alpha = self.alpha.to(device)
self.alpha_bar = self.alpha_bar.to(device)
return self
[docs]
def add_noise(self, x0, denoise_t=None):
"""
DDPM forward process: add noise to clean data `x0` to produce `x_t`.
q(x_t | x_0) = N(x_t; sqrt(alpha_bar_t) * x_0, (1 - alpha_bar_t) * I)
Args:
x0 (torch.FloatTensor): Original clean data (t=0).
Shape: (batch_size, ...) with arbitrary dimensions.
denoise_t (torch.LongTensor, optional): Specified timestep t.
If None, t is sampled randomly according to the
`args.antithetic_sampling` strategy.
Shape: (batch_size, )。
Returns:
tuple:
- xt (torch.FloatTensor): Noisy data. Same shape as x0.
- noise (torch.FloatTensor): Added standard Gaussian noise epsilon. Same shape as x0.
- denoise_t (torch.LongTensor): Actual timestep t used. Shape: (batch_size, ).
"""
if denoise_t is not None:
raise NotImplementedError("Specifying `denoise_t` explicitly is not implemented yet.")
if (denoise_t == 0).any():
raise ValueError("`denoise_t` cannot be 0.")
batch_size = x0.shape[0]
if self.args.antithetic_sampling:
denoise_t_half = torch.randint(1, self.args.T+1, (batch_size // 2,), device=self.args.device)
denoise_t = torch.cat([denoise_t_half, self.args.T + 1 - denoise_t_half], dim=0) # (batch_size,)
if batch_size % 2 == 1:
t = torch.randint(1, self.args.T+1, (1,), device=self.args.device)
denoise_t = torch.cat([denoise_t, t], dim=0)
else:
denoise_t = torch.randint(1, self.args.T+1, (batch_size,), device=self.args.device) # (batch_size,)
denoise_t = denoise_t.long()
at = self.alpha_bar[denoise_t].view(batch_size, 1, 1, 1)
noise = torch.randn_like(x0, device=self.args.device)
xt = torch.sqrt(at) * x0 + torch.sqrt(1 - at) * noise
return xt, noise, denoise_t
[docs]
def denoise(self, xt, denoise_t, x0=None, noise=None, stride=1):
"""
DDPM reverse process: sample x_{t-stride} from x_t using the predicted x0 or noise.
p_theta(x_{t-1} | x_t) = N(x_{t-1}; mu_theta(x_t, t), sigma_t^2 * I)
Args:
xt (torch.FloatTensor): Noisy data at the current timestep t.
Shape: (batch_size, ...)
denoise_t (torch.LongTensor): Index of the current timestep t.
Shape: (batch_size, )
x0 (torch.FloatTensor, optional): Model-predicted original data x0.
noise (torch.FloatTensor, optional): Model-predicted noise epsilon.
Note: exactly one of x0 and noise must be provided.
stride (int, optional): Reverse denoising step size. Default is 1.
Used as a skip-step strategy to accelerate sampling.
Returns:
torch.FloatTensor: Denoised data at the previous step x_{t-stride}.
Same shape as xt.
"""
if not ((x0 is None) ^ (noise is None)):
raise ValueError("Exactly one of `x0` and `noise` must be provided.")
if denoise_t == 0:
raise ValueError("`denoise_t` cannot be 0.")
if denoise_t - stride < 0:
raise ValueError("`denoise_t - stride` cannot be negative.")
at = self.alpha_bar[denoise_t]
at_next = self.alpha_bar[denoise_t - stride]
if x0 is None:
coef1 = (at_next / at).sqrt()
coef2 = - (1 - at / at_next) / ((1 - at) * at / at_next).sqrt()
mean = coef1 * xt + coef2 * noise
else:
coef1 = (1 - at / at_next) * torch.sqrt(at_next) / (1 - at)
coef2 = (1 - at_next) * torch.sqrt(at / at_next) / (1 - at)
mean = coef1 * x0 + coef2 * xt
if denoise_t - stride > 0:
var_upper = (1 - at / at_next)
var_lower = (1 - at / at_next) * (1 - at_next) / (1 - at)
var = (1 - self.flexibility) * var_upper + self.flexibility * var_lower
mean = mean + var.sqrt() * torch.randn_like(xt)
return mean
[docs]
def noise_to_x0(self, xt, denoise_t, noise):
"""
Infer the original data x0 from the current noisy data xt and predicted noise epsilon.
x_0 = (x_t - sqrt(1 - alpha_bar_t) * epsilon) / sqrt(alpha_bar_t)
Args:
xt (torch.FloatTensor): Noisy data x_t.
denoise_t (torch.LongTensor or int): Timestep t.
noise (torch.FloatTensor): Predicted noise epsilon.
Returns:
torch.FloatTensor: Estimated original data x0.
"""
if (
(isinstance(denoise_t, int) and (denoise_t == 0)) or
(isinstance(denoise_t, torch.Tensor) and (denoise_t == 0).any())
):
raise ValueError("`denoise_t` cannot be 0.")
at = self.alpha_bar[denoise_t]
if isinstance(denoise_t, torch.Tensor) and denoise_t.numel() > 1:
at = at.view(xt.shape[0], *[1] * (xt.ndim - 1))
x0 = (xt - torch.sqrt(1 - at) * noise) / torch.sqrt(at)
return x0
[docs]
def x0_to_noise(self, xt, denoise_t, x0):
"""
Infer the latent noise epsilon from the current noisy data xt and estimated x0.
epsilon = (x_t - sqrt(alpha_bar_t) * x_0) / sqrt(1 - alpha_bar_t)
Args:
xt (torch.FloatTensor): Noisy data x_t.
denoise_t (torch.LongTensor or int): Timestep t.
x0 (torch.FloatTensor): Estimated original data x0.
Returns:
torch.FloatTensor: Inferred latent noise epsilon.
"""
if (
(isinstance(denoise_t, int) and (denoise_t == 0)) or
(isinstance(denoise_t, torch.Tensor) and (denoise_t == 0).any())
):
raise ValueError("`denoise_t` cannot be 0.")
at = self.alpha_bar[denoise_t]
if isinstance(denoise_t, torch.Tensor) and denoise_t.numel() > 1:
at = at.view(xt.shape[0], *[1] * (xt.ndim - 1))
noise = (xt - torch.sqrt(at) * x0) / torch.sqrt(1 - at)
return noise
[docs]
@staticmethod
def cosine_beta_schedule(T, s=0.008):
"""
Generate a cosine annealing beta schedule.
Args:
T (int): Total number of timesteps.
s (float, optional): Offset used to prevent beta from becoming too
small at t=0. Default is 0.008.
Returns:
torch.FloatTensor: Sequence of beta values. Shape: (T,)
"""
# steps = T + 1
# x = torch.linspace(0, T, steps)
# alphas_cumprod = torch.cos(((x / T) + s) / (1 + s) * np.pi / 2) ** 2
# alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
# betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
# return torch.clamp(betas, 0.0001, 0.9999)
timesteps = torch.arange(T + 1) / T + s
alpha_bar = timesteps / (1 + s) * np.pi / 2
alpha_bar = torch.cos(alpha_bar).pow(2)
alpha_bar = alpha_bar / alpha_bar[0]
alpha = alpha_bar[1:] / alpha_bar[:-1]
beta = 1 - alpha
beta = beta.clamp(max=0.999)
return beta
[docs]
@staticmethod
def linear_beta_schedule(T, beta_start=0.0001, beta_end=0.05):
"""
Generate a linearly increasing beta schedule.
Args:
T (int): Total number of timesteps.
beta_start (float, optional): Initial beta value. Default is 1e-4.
beta_end (float, optional): Final beta value. Default is 0.05.
Returns:
torch.FloatTensor: Sequence of beta values. Shape: (T,)
"""
return torch.linspace(beta_start, beta_end, T)