Source code for src.diffusion.ddim
import torch
from argparse import Namespace
from .ddpm import DDPM
[docs]
class DDIM(DDPM):
"""
Denoising Diffusion Implicit Model (DDIM).
Inherits from `DDPM`. DDIM enables faster sampling with skipped steps while
preserving generation quality through a non-Markovian reverse process.
"""
[docs]
def __init__(self, args: Namespace, eta=0.0):
"""
Initialize the DDIM sampler.
Args:
args (Namespace): Configuration argument object.
eta (float, optional): Hyperparameter controlling the stochasticity
of the sampling process.
- eta=0.0: deterministic sampling (Standard DDIM).
- eta=1.0: DDPM-equivalent variance (Standard DDPM).
Default is 0.0.
"""
super().__init__(args)
self.eta = eta
[docs]
def denoise(self, xt, denoise_t, x0=None, noise=None, stride=1):
"""
DDIM reverse process: deterministically or semi-deterministically derive
x_{t-stride} from x_t.
This method overrides the parent `DDPM.denoise` implementation and uses
the DDIM update rule.
x_{t-1} = sqrt(alpha_bar_{t-1}) * "predicted x0" +
sqrt(1 - alpha_bar_{t-1} - sigma_t^2) * "predicted noise" +
sigma_t * epsilon_t
Args:
xt (torch.FloatTensor): Noisy data at the current timestep t.
denoise_t (torch.LongTensor): Index of the current timestep t.
x0 (torch.FloatTensor, optional): Model-predicted original data x0.
noise (torch.FloatTensor, optional): Model-predicted noise epsilon.
Note: exactly one of x0 and noise must be provided.
stride (int, optional): Sampling step size used for acceleration.
Default is 1.
Returns:
torch.FloatTensor: Denoised data at the previous step x_{t-stride}.
"""
if not ((x0 is None) ^ (noise is None)):
raise ValueError("Exactly one of `x0` and `noise` must be provided.")
if denoise_t == 0:
raise ValueError("`denoise_t` cannot be 0.")
if denoise_t - stride < 0:
raise ValueError("`denoise_t - stride` cannot be negative.")
at = self.alpha_bar[denoise_t]
at_next = self.alpha_bar[denoise_t - stride]
var = self.eta**2 * (1 - at_next) / (1 - at) * (1 - at / at_next)
if x0 is None:
coef1 = torch.sqrt(at_next / at)
coef2 = torch.sqrt(1 - at_next - var) - torch.sqrt(at_next * (1 - at) / at)
mean = coef1 * xt + coef2 * noise
else:
coef1 = torch.sqrt(at_next) - torch.sqrt((1 - at_next - var) * at / (1 - at))
coef2 = torch.sqrt((1 - at_next - var) / (1 - at))
mean = coef1 * x0 + coef2 * xt
if denoise_t - stride > 0:
mean = mean + var.sqrt() * torch.randn_like(mean)
return mean