src.diffusion.ddpm module#
- class src.diffusion.ddpm.DDPM(args: Namespace, flexibility=0.0)[source]#
Bases:
objectDenoising Diffusion Probabilistic Model (DDPM).
Implements the DDPM forward noising process and reverse denoising process. Supports both linear and cosine beta schedules.
- __init__(args: Namespace, flexibility=0.0)[source]#
Initialize the DDPM noise schedule.
- Parameters:
args (Namespace) –
Configuration object containing: - beta_schedule (str): ‘linear’ or ‘cosine’. - T (int): Total number of diffusion steps. - device (str): Compute device. - antithetic_sampling (bool): Whether to use antithetic
sampling in add_noise to reduce variance.
flexibility (float, optional) – Variance interpolation coefficient controlling stochasticity during generation. 0.0 corresponds to fixed variance and 1.0 to full DDPM variance.
- to(device)[source]#
Move the noise schedule tensors to the target device.
- Parameters:
device (torch.device or str) – Target device.
- Returns:
Returns the instance itself for chaining.
- Return type:
self
- add_noise(x0, denoise_t=None)[source]#
DDPM forward process: add noise to clean data x0 to produce x_t.
q(x_t | x_0) = N(x_t; sqrt(alpha_bar_t) * x_0, (1 - alpha_bar_t) * I)
- Parameters:
x0 (torch.FloatTensor) – Original clean data (t=0). Shape: (batch_size, …) with arbitrary dimensions.
denoise_t (torch.LongTensor, optional) – Specified timestep t. If None, t is sampled randomly according to the args.antithetic_sampling strategy. Shape: (batch_size, )。
- Returns:
xt (torch.FloatTensor): Noisy data. Same shape as x0.
noise (torch.FloatTensor): Added standard Gaussian noise epsilon. Same shape as x0.
denoise_t (torch.LongTensor): Actual timestep t used. Shape: (batch_size, ).
- Return type:
tuple
- denoise(xt, denoise_t, x0=None, noise=None, stride=1)[source]#
DDPM reverse process: sample x_{t-stride} from x_t using the predicted x0 or noise.
p_theta(x_{t-1} | x_t) = N(x_{t-1}; mu_theta(x_t, t), sigma_t^2 * I)
- Parameters:
xt (torch.FloatTensor) – Noisy data at the current timestep t. Shape: (batch_size, …)
denoise_t (torch.LongTensor) – Index of the current timestep t. Shape: (batch_size, )
x0 (torch.FloatTensor, optional) – Model-predicted original data x0.
noise (torch.FloatTensor, optional) – Model-predicted noise epsilon. Note: exactly one of x0 and noise must be provided.
stride (int, optional) – Reverse denoising step size. Default is 1. Used as a skip-step strategy to accelerate sampling.
- Returns:
- Denoised data at the previous step x_{t-stride}.
Same shape as xt.
- Return type:
torch.FloatTensor
- noise_to_x0(xt, denoise_t, noise)[source]#
Infer the original data x0 from the current noisy data xt and predicted noise epsilon.
x_0 = (x_t - sqrt(1 - alpha_bar_t) * epsilon) / sqrt(alpha_bar_t)
- Parameters:
xt (torch.FloatTensor) – Noisy data x_t.
denoise_t (torch.LongTensor or int) – Timestep t.
noise (torch.FloatTensor) – Predicted noise epsilon.
- Returns:
Estimated original data x0.
- Return type:
torch.FloatTensor
- x0_to_noise(xt, denoise_t, x0)[source]#
Infer the latent noise epsilon from the current noisy data xt and estimated x0.
epsilon = (x_t - sqrt(alpha_bar_t) * x_0) / sqrt(1 - alpha_bar_t)
- Parameters:
xt (torch.FloatTensor) – Noisy data x_t.
denoise_t (torch.LongTensor or int) – Timestep t.
x0 (torch.FloatTensor) – Estimated original data x0.
- Returns:
Inferred latent noise epsilon.
- Return type:
torch.FloatTensor
- static cosine_beta_schedule(T, s=0.008)[source]#
Generate a cosine annealing beta schedule.
- Parameters:
T (int) – Total number of timesteps.
s (float, optional) – Offset used to prevent beta from becoming too small at t=0. Default is 0.008.
- Returns:
Sequence of beta values. Shape: (T,)
- Return type:
torch.FloatTensor
- static linear_beta_schedule(T, beta_start=0.0001, beta_end=0.05)[source]#
Generate a linearly increasing beta schedule.
- Parameters:
T (int) – Total number of timesteps.
beta_start (float, optional) – Initial beta value. Default is 1e-4.
beta_end (float, optional) – Final beta value. Default is 0.05.
- Returns:
Sequence of beta values. Shape: (T,)
- Return type:
torch.FloatTensor