去噪扩散概率模型 (DDPM)

Open In ColabOpen In Comet

这是论文《去噪扩散概率模型》的 PyTorc h 实现/教程。

简而言之,我们从数据中获取图像并逐步添加噪点。然后我们训练一个模型来预测每一步的噪声,然后使用该模型生成图像。

以下定义和派生说明了它的工作原理。详情请参阅论文

转发流程

对于时间步长,转发过程会给数据增加噪音。

其中是差异计划。

我们可以在任何时间段采样

在哪里

逆向流程

相反的过程从开始消除时间步长的噪音。

是我们训练的参数。

亏损

我们根据负对数似然优化ELBO(来自简森的不等式)。

损失可以按如下方式重写。

是恒定的,因为我们保持不变。

计算

前进过程的后方条件是,

本文将哪里设置为常量

那么,

对于给定的噪音,使用

这给了,

使用模型重新参数化以预测噪声

wh ere 是一个预测给定的学习函数

这给了,

也就是说,我们正在训练以预测噪音。

简化损失

这样可以最大限减少丢弃权重的时间和用途。丢弃权重会增加赋予较高权重(噪声级更高),从而提高样品质量。

该文件实现了损失计算和我们在训练期间用来生成图像的基本采样方法。

这里是提供训练代码unET 模型该文件可以根据训练后的模型生成样本和插值。

163from typing import Tuple, Optional
164
165import torch
166import torch.nn.functional as F
167import torch.utils.data
168from torch import nn
169
170from labml_nn.diffusion.ddpm.utils import gather

降噪扩散

173class DenoiseDiffusion:
  • eps_model模特
  • n_steps
  • device 是用来放置常量的设备
178    def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
184        super().__init__()
185        self.eps_model = eps_model

创建线性增加的差异计划

188        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)

191        self.alpha = 1. - self.beta

193        self.alpha_bar = torch.cumprod(self.alpha, dim=0)

195        self.n_steps = n_steps

197        self.sigma2 = self.beta

获取分发

199    def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

收集和计算

209        mean = gather(self.alpha_bar, t) ** 0.5 * x0

211        var = 1 - gather(self.alpha_bar, t)

213        return mean, var

样本来自

215    def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):

225        if eps is None:
226            eps = torch.randn_like(x0)

得到

229        mean, var = self.q_xt_x0(x0, t)

样本来自

231        return mean + (var ** 0.5) * eps

样本来自

233    def p_sample(self, xt: torch.Tensor, t: torch.Tensor):

247        eps_theta = self.eps_model(xt, t)

收集

249        alpha_bar = gather(self.alpha_bar, t)

251        alpha = gather(self.alpha, t)

253        eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5

256        mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)

258        var = gather(self.sigma2, t)

261        eps = torch.randn(xt.shape, device=xt.device)

样本

263        return mean + (var ** .5) * eps

简化损失

265    def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):

获取批次大小

274        batch_size = x0.shape[0]

批次中每个样本的随机获取

276        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)

279        if noise is None:
280            noise = torch.randn_like(x0)

的样本

283        xt = self.q_sample(x0, t, eps=noise)

得到

285        eps_theta = self.eps_model(xt, t)

MSE 亏损

288        return F.mse_loss(noise, eps_theta)