这是论文《去噪扩散概率模型》的 PyTorc h 实现/教程。
简而言之,我们从数据中获取图像并逐步添加噪点。然后我们训练一个模型来预测每一步的噪声,然后使用该模型生成图像。
以下定义和派生说明了它的工作原理。详情请参阅论文。
对于时间步长,转发过程会给数据增加噪音。
其中是差异计划。
我们可以在任何时间段采样,
在哪里和
相反的过程从开始消除时间步长的噪音。
是我们训练的参数。
我们根据负对数似然优化ELBO(来自简森的不等式)。
损失可以按如下方式重写。
是恒定的,因为我们保持不变。
前进过程的后方条件是,
本文将哪里设置为常量或。
那么,
对于给定的噪音,使用
这给了,
使用模型重新参数化以预测噪声
wh ere 是一个预测给定的学习函数。
这给了,
也就是说,我们正在训练以预测噪音。
这样可以最大限度地减少丢弃权重的时间和用途。丢弃权重会增加赋予较高权重(噪声级更高),从而提高样品质量。
该文件实现了损失计算和我们在训练期间用来生成图像的基本采样方法。
163from typing import Tuple, Optional
164
165import torch
166import torch.nn.functional as F
167import torch.utils.data
168from torch import nn
169
170from labml_nn.diffusion.ddpm.utils import gather173class DenoiseDiffusion:eps_model
是模特n_steps
是device
是用来放置常量的设备178    def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):184        super().__init__()
185        self.eps_model = eps_model创建线性增加的差异计划
188        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)191        self.alpha = 1. - self.beta193        self.alpha_bar = torch.cumprod(self.alpha, dim=0)195        self.n_steps = n_steps197        self.sigma2 = self.beta199    def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:211        var = 1 - gather(self.alpha_bar, t)213        return mean, var215    def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):225        if eps is None:
226            eps = torch.randn_like(x0)得到
229        mean, var = self.q_xt_x0(x0, t)样本来自
231        return mean + (var ** 0.5) * eps233    def p_sample(self, xt: torch.Tensor, t: torch.Tensor):247        eps_theta = self.eps_model(xt, t)251        alpha = gather(self.alpha, t)253        eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5256        mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)258        var = gather(self.sigma2, t)261        eps = torch.randn(xt.shape, device=xt.device)样本
263        return mean + (var ** .5) * eps265    def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):获取批次大小
274        batch_size = x0.shape[0]批次中每个样本的随机获取
276        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)279        if noise is None:
280            noise = torch.randn_like(x0)的样本
283        xt = self.q_sample(x0, t, eps=noise)得到
285        eps_theta = self.eps_model(xt, t)MSE 亏损
288        return F.mse_loss(noise, eps_theta)