这是《去噪扩散概率模型》论文的 PyTorch 实现/教程。
简而言之,我们从数据中获取图像并逐步添加噪点。然后,我们训练一个模型来预测每个步骤的噪声,并使用该模型生成图像。
以下定义和派生说明了其工作原理。详情请参阅论文。
在时间步长内,转发过程会给数据增加噪音。
方差计划在哪里。
我们可以随时采样,
在哪里和
相反的过程会从四个时间步长开始消除噪音。
是我们训练的参数。
我们根据负对数概率优化 ELBO(来自简森不等式)。
损失可以改写如下。
是恒定的,因为我们保持不变。
后验的前向过程是,
论文将其中设置为常量或.
然后,
对于给定的噪音,使用
这给了,
使用模型重新参数化以预测噪声
其中是预测给定值的学习函数。
这给了,
也就是说,我们正在训练预测噪音。
这样可以最大限度地减少放弃权重的时间和时间。丢弃权重会增加给出更高的权重(噪声等级更高),从而提高样本质量。
该文件实现了损失计算和基本采样方法,我们在训练期间使用该方法生成图像。
162from typing import Tuple, Optional
163
164import torch
165import torch.nn.functional as F
166import torch.utils.data
167from torch import nn
168
169from labml_nn.diffusion.ddpm.utils import gather172class DenoiseDiffusion:eps_model
是模特n_steps
是device
是用来放置常量的设备177    def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):183        super().__init__()
184        self.eps_model = eps_model创建线性增加的差异计划
187        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)190        self.alpha = 1. - self.beta192        self.alpha_bar = torch.cumprod(self.alpha, dim=0)194        self.n_steps = n_steps196        self.sigma2 = self.beta198    def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:210        var = 1 - gather(self.alpha_bar, t)212        return mean, var214    def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):224        if eps is None:
225            eps = torch.randn_like(x0)得到
228        mean, var = self.q_xt_x0(x0, t)样本来自
230        return mean + (var ** 0.5) * eps232    def p_sample(self, xt: torch.Tensor, t: torch.Tensor):246        eps_theta = self.eps_model(xt, t)250        alpha = gather(self.alpha, t)252        eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5255        mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)257        var = gather(self.sigma2, t)260        eps = torch.randn(xt.shape, device=xt.device)样本
262        return mean + (var ** .5) * eps264    def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):获取批次大小
273        batch_size = x0.shape[0]批次中每个样本的随机获取
275        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)278        if noise is None:
279            noise = torch.randn_like(x0)的样本
282        xt = self.q_sample(x0, t, eps=noise)得到
284        eps_theta = self.eps_model(xt, t)MSE 亏损
287        return F.mse_loss(noise, eps_theta)