这将在 CeleBA HQ 数据集上训练基于 DDPM 的模型。你可以在这篇关于 fast.ai 的讨论中找到下载说明。将图像保存在data/celebA
文件夹中。
该论文使用了衰减为的模型的指数移动平均线。为了简单起见,我们跳过了这个。
21from typing import List
22
23import torch
24import torch.utils.data
25import torchvision
26from PIL import Image
27
28from labml import lab, tracker, experiment, monit
29from labml.configs import BaseConfigs, option
30from labml_helpers.device import DeviceConfigs
31from labml_nn.diffusion.ddpm import DenoiseDiffusion
32from labml_nn.diffusion.ddpm.unet import UNet35class Configs(BaseConfigs):用于训练模型的设备。DeviceConfigs
选择可用的 CUDA 设备或默认为 CPU。
42 device: torch.device = DeviceConfigs()U-Net 模型用于
45 eps_model: UNet图像中的通道数。对于 RGB。
50 image_channels: int = 3图像大小
52 image_size: int = 32初始特征图中的频道数量
54 n_channels: int = 64每种分辨率下的通道编号列表。频道的数量是channel_multipliers[i] * n_channels
57 channel_multipliers: List[int] = [1, 2, 2, 4]指示是否在每个分辨率下使用注意力的布尔值列表
59 is_attention: List[int] = [False, False, False, True]时间步数
62 n_steps: int = 1_000批量大小
64 batch_size: int = 64要生成的样本数
66 n_samples: int = 16学习率
68 learning_rate: float = 2e-5训练周期的数量
71 epochs: int = 1_000数据集
74 dataset: torch.utils.data.Dataset数据加载器
76 data_loader: torch.utils.data.DataLoaderAdam 优化器
79 optimizer: torch.optim.Adam81 def init(self):创建模型
83 self.eps_model = UNet(
84 image_channels=self.image_channels,
85 n_channels=self.n_channels,
86 ch_mults=self.channel_multipliers,
87 is_attn=self.is_attention,
88 ).to(self.device)91 self.diffusion = DenoiseDiffusion(
92 eps_model=self.eps_model,
93 n_steps=self.n_steps,
94 device=self.device,
95 )创建数据加载器
98 self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)创建优化器
100 self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)图像日志记录
103 tracker.set_image("sample", True)105 def sample(self):109 with torch.no_grad():111 x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
112 device=self.device)消除台阶噪音
115 for t_ in monit.iterate('Sample', self.n_steps):117 t = self.n_steps - t_ - 1样本来自
119 x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))日志样本
122 tracker.save('sample', x)124 def train(self):遍历数据集
130 for data in monit.iterate('Train', self.data_loader):递增全局步长
132 tracker.add_global_step()将数据移动到设备
134 data = data.to(self.device)将渐变设为零
137 self.optimizer.zero_grad()计算损失
139 loss = self.diffusion.loss(data)计算梯度
141 loss.backward()采取优化步骤
143 self.optimizer.step()追踪损失
145 tracker.save('loss', loss)147 def run(self):151 for _ in monit.loop(self.epochs):训练模型
153 self.train()对一些图像进行采样
155 self.sample()控制台中的新行
157 tracker.new_line()保存模型
159 experiment.save_checkpoint()162class CelebADataset(torch.utils.data.Dataset):167 def __init__(self, image_size: int):
168 super().__init__()CeleBA 图片文件夹
171 folder = lab.get_data_path() / 'celebA'文件清单
173 self._files = [p for p in folder.glob(f'**/*.jpg')]用于调整图像大小并转换为张量的转换
176 self._transform = torchvision.transforms.Compose([
177 torchvision.transforms.Resize(image_size),
178 torchvision.transforms.ToTensor(),
179 ])数据集的大小
181 def __len__(self):185 return len(self._files)获取一张图片
187 def __getitem__(self, index: int):191 img = Image.open(self._files[index])
192 return self._transform(img)创建 CeleBA 数据集
195@option(Configs.dataset, 'CelebA')
196def celeb_dataset(c: Configs):200 return CelebADataset(c.image_size)203class MNISTDataset(torchvision.datasets.MNIST):208 def __init__(self, image_size):
209 transform = torchvision.transforms.Compose([
210 torchvision.transforms.Resize(image_size),
211 torchvision.transforms.ToTensor(),
212 ])
213
214 super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)216 def __getitem__(self, item):
217 return super().__getitem__(item)[0]创建 MNIST 数据集
220@option(Configs.dataset, 'MNIST')
221def mnist_dataset(c: Configs):225 return MNISTDataset(c.image_size)228def main():创建实验
230 experiment.create(name='diffuse', writers={'screen', 'comet'})创建配置
233 configs = Configs()设置配置。您可以通过在字典中传递值来覆盖默认值。
236 experiment.configs(configs, {
237 'dataset': 'CelebA', # 'MNIST'
238 'image_channels': 3, # 1,
239 'epochs': 100, # 5,
240 })初始化
243 configs.init()设置用于保存和加载的模型
246 experiment.add_pytorch_models({'eps_model': configs.eps_model})启动并运行训练循环
249 with experiment.start():
250 configs.run()254if __name__ == '__main__':
255 main()