这是论文《去噪扩散概率模型》的 PyTorc h 实现/教程。
简而言之,我们从数据中获取图像并逐步添加噪点。然后我们训练一个模型来预测每一步的噪声,然后使用该模型生成图像。
以下定义和派生说明了它的工作原理。详情请参阅论文。
对于时间步长,转发过程会给数据增加噪音。
其中是差异计划。
我们可以在任何时间段采样,
在哪里和
相反的过程从开始消除时间步长的噪音。
是我们训练的参数。
我们根据负对数似然优化ELBO(来自简森的不等式)。
损失可以按如下方式重写。
是恒定的,因为我们保持不变。
前进过程的后方条件是,
本文将哪里设置为常量或。
那么,
对于给定的噪音,使用
这给了,
使用模型重新参数化以预测噪声
wh ere 是一个预测给定的学习函数。
这给了,
也就是说,我们正在训练以预测噪音。
这样可以最大限度地减少丢弃权重的时间和用途。丢弃权重会增加赋予较高权重(噪声级更高),从而提高样品质量。
该文件实现了损失计算和我们在训练期间用来生成图像的基本采样方法。
163from typing import Tuple, Optional
164
165import torch
166import torch.nn.functional as F
167import torch.utils.data
168from torch import nn
169
170from labml_nn.diffusion.ddpm.utils import gather173class DenoiseDiffusion:eps_model
是模特n_steps
是device
是用来放置常量的设备178 def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):184 super().__init__()
185 self.eps_model = eps_model创建线性增加的差异计划
188 self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)191 self.alpha = 1. - self.beta193 self.alpha_bar = torch.cumprod(self.alpha, dim=0)195 self.n_steps = n_steps197 self.sigma2 = self.beta199 def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:211 var = 1 - gather(self.alpha_bar, t)213 return mean, var215 def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):225 if eps is None:
226 eps = torch.randn_like(x0)得到
229 mean, var = self.q_xt_x0(x0, t)样本来自
231 return mean + (var ** 0.5) * eps233 def p_sample(self, xt: torch.Tensor, t: torch.Tensor):247 eps_theta = self.eps_model(xt, t)251 alpha = gather(self.alpha, t)253 eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5256 mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)258 var = gather(self.sigma2, t)261 eps = torch.randn(xt.shape, device=xt.device)样本
263 return mean + (var ** .5) * eps265 def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):获取批次大小
274 batch_size = x0.shape[0]批次中每个样本的随机获取
276 t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)279 if noise is None:
280 noise = torch.randn_like(x0)的样本
283 xt = self.q_sample(x0, t, eps=noise)得到
285 eps_theta = self.eps_model(xt, t)MSE 亏损
288 return F.mse_loss(noise, eps_theta)