这将基于 CeleBA HQ 数据集训练基于 DDPM 的模型。你可以在 fast.ai 的讨论中找到下载说明。将图像保存在data/celebA
文件夹中。
该论文使用了该模型的指数移动平均线,其衰减量为。为简单起见,我们跳过了这个。
20from typing import List
21
22import torch
23import torch.utils.data
24import torchvision
25from PIL import Image
26
27from labml import lab, tracker, experiment, monit
28from labml.configs import BaseConfigs, option
29from labml_helpers.device import DeviceConfigs
30from labml_nn.diffusion.ddpm import DenoiseDiffusion
31from labml_nn.diffusion.ddpm.unet import UNet34class Configs(BaseConfigs):用于训练模型的设备。DeviceConfigs
选择可用的 CUDA 设备或默认为 CPU。
41 device: torch.device = DeviceConfigs()U-Net 模型用于
44 eps_model: UNet图像中的通道数。对于 RGB。
49 image_channels: int = 3图像大小
51 image_size: int = 32初始特征图中的频道数量
53 n_channels: int = 64每种分辨率下的通道编号列表。频道的数量是channel_multipliers[i] * n_channels
56 channel_multipliers: List[int] = [1, 2, 2, 4]指示是否在每个分辨率下使用注意力的布尔值列表
58 is_attention: List[int] = [False, False, False, True]时间步数
61 n_steps: int = 1_000批量大小
63 batch_size: int = 64要生成的样本数
65 n_samples: int = 16学习率
67 learning_rate: float = 2e-5训练周期的数量
70 epochs: int = 1_000数据集
73 dataset: torch.utils.data.Dataset数据加载器
75 data_loader: torch.utils.data.DataLoaderAdam 优化器
78 optimizer: torch.optim.Adam80 def init(self):创建模型
82 self.eps_model = UNet(
83 image_channels=self.image_channels,
84 n_channels=self.n_channels,
85 ch_mults=self.channel_multipliers,
86 is_attn=self.is_attention,
87 ).to(self.device)90 self.diffusion = DenoiseDiffusion(
91 eps_model=self.eps_model,
92 n_steps=self.n_steps,
93 device=self.device,
94 )创建数据加载器
97 self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)创建优化器
99 self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)图像日志记录
102 tracker.set_image("sample", True)104 def sample(self):108 with torch.no_grad():110 x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
111 device=self.device)消除台阶噪音
114 for t_ in monit.iterate('Sample', self.n_steps):116 t = self.n_steps - t_ - 1样本来自
118 x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))日志样本
121 tracker.save('sample', x)123 def train(self):遍历数据集
129 for data in monit.iterate('Train', self.data_loader):递增全局步长
131 tracker.add_global_step()将数据移动到设备
133 data = data.to(self.device)将渐变设为零
136 self.optimizer.zero_grad()计算损失
138 loss = self.diffusion.loss(data)计算梯度
140 loss.backward()采取优化步骤
142 self.optimizer.step()追踪损失
144 tracker.save('loss', loss)146 def run(self):150 for _ in monit.loop(self.epochs):训练模型
152 self.train()对一些图像进行采样
154 self.sample()控制台中的新行
156 tracker.new_line()保存模型
158 experiment.save_checkpoint()161class CelebADataset(torch.utils.data.Dataset):166 def __init__(self, image_size: int):
167 super().__init__()CeleBA 图片文件夹
170 folder = lab.get_data_path() / 'celebA'文件清单
172 self._files = [p for p in folder.glob(f'**/*.jpg')]用于调整图像大小并转换为张量的转换
175 self._transform = torchvision.transforms.Compose([
176 torchvision.transforms.Resize(image_size),
177 torchvision.transforms.ToTensor(),
178 ])数据集的大小
180 def __len__(self):184 return len(self._files)获取一张图片
186 def __getitem__(self, index: int):190 img = Image.open(self._files[index])
191 return self._transform(img)创建 CeleBA 数据集
194@option(Configs.dataset, 'CelebA')
195def celeb_dataset(c: Configs):199 return CelebADataset(c.image_size)202class MNISTDataset(torchvision.datasets.MNIST):207 def __init__(self, image_size):
208 transform = torchvision.transforms.Compose([
209 torchvision.transforms.Resize(image_size),
210 torchvision.transforms.ToTensor(),
211 ])
212
213 super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)215 def __getitem__(self, item):
216 return super().__getitem__(item)[0]创建 MNIST 数据集
219@option(Configs.dataset, 'MNIST')
220def mnist_dataset(c: Configs):224 return MNISTDataset(c.image_size)227def main():创建实验
229 experiment.create(name='diffuse', writers={'screen', 'labml'})创建配置
232 configs = Configs()设置配置。您可以通过在字典中传递值来覆盖默认值。
235 experiment.configs(configs, {
236 'dataset': 'CelebA', # 'MNIST'
237 'image_channels': 3, # 1,
238 'epochs': 100, # 5,
239 })初始化
242 configs.init()设置用于保存和加载的模型
245 experiment.add_pytorch_models({'eps_model': configs.eps_model})启动并运行训练循环
248 with experiment.start():
249 configs.run()253if __name__ == '__main__':
254 main()