去噪扩散概率模型 (DDPM) 训练

Open In ColabOpen In Comet

这将在 CeleBA HQ 数据集上训练基于 DDPM 的模型。你可以在这篇关于 fast.ai 的讨论中找到下载说明。将图像保存在data/celebA 文件夹中

该论文使用了衰减为的模型的指数移动平均线。为了简单起见,我们跳过了这个。

21from typing import List
22
23import torch
24import torch.utils.data
25import torchvision
26from PIL import Image
27
28from labml import lab, tracker, experiment, monit
29from labml.configs import BaseConfigs, option
30from labml_helpers.device import DeviceConfigs
31from labml_nn.diffusion.ddpm import DenoiseDiffusion
32from labml_nn.diffusion.ddpm.unet import UNet

配置

35class Configs(BaseConfigs):

用于训练模型的设备。DeviceConfigs 选择可用的 CUDA 设备或默认为 CPU。

42    device: torch.device = DeviceConfigs()

U-Net 模型用于

45    eps_model: UNet
47    diffusion: DenoiseDiffusion

图像中的通道数。对于 RGB。

50    image_channels: int = 3

图像大小

52    image_size: int = 32

初始特征图中的频道数量

54    n_channels: int = 64

每种分辨率下的通道编号列表。频道的数量是channel_multipliers[i] * n_channels

57    channel_multipliers: List[int] = [1, 2, 2, 4]

指示是否在每个分辨率下使用注意力的布尔值列表

59    is_attention: List[int] = [False, False, False, True]

时间步数

62    n_steps: int = 1_000

批量大小

64    batch_size: int = 64

要生成的样本数

66    n_samples: int = 16

学习率

68    learning_rate: float = 2e-5

训练周期的数量

71    epochs: int = 1_000

数据集

74    dataset: torch.utils.data.Dataset

数据加载器

76    data_loader: torch.utils.data.DataLoader

Adam 优化器

79    optimizer: torch.optim.Adam
81    def init(self):

创建模型

83        self.eps_model = UNet(
84            image_channels=self.image_channels,
85            n_channels=self.n_channels,
86            ch_mults=self.channel_multipliers,
87            is_attn=self.is_attention,
88        ).to(self.device)

创建 DDPM 类

91        self.diffusion = DenoiseDiffusion(
92            eps_model=self.eps_model,
93            n_steps=self.n_steps,
94            device=self.device,
95        )

创建数据加载器

98        self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)

创建优化器

100        self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)

图像日志记录

103        tracker.set_image("sample", True)

样本图片

105    def sample(self):
109        with torch.no_grad():

111            x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
112                            device=self.device)

消除台阶噪音

115            for t_ in monit.iterate('Sample', self.n_steps):

117                t = self.n_steps - t_ - 1

样本来自

119                x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))

日志样本

122            tracker.save('sample', x)

火车

124    def train(self):

遍历数据集

130        for data in monit.iterate('Train', self.data_loader):

递增全局步长

132            tracker.add_global_step()

将数据移动到设备

134            data = data.to(self.device)

将渐变设为零

137            self.optimizer.zero_grad()

计算损失

139            loss = self.diffusion.loss(data)

计算梯度

141            loss.backward()

采取优化步骤

143            self.optimizer.step()

追踪损失

145            tracker.save('loss', loss)

训练循环

147    def run(self):
151        for _ in monit.loop(self.epochs):

训练模型

153            self.train()

对一些图像进行采样

155            self.sample()

控制台中的新行

157            tracker.new_line()

保存模型

159            experiment.save_checkpoint()

CeleBA HQ 数据集

162class CelebADataset(torch.utils.data.Dataset):
167    def __init__(self, image_size: int):
168        super().__init__()

CeleBA 图片文件夹

171        folder = lab.get_data_path() / 'celebA'

文件清单

173        self._files = [p for p in folder.glob(f'**/*.jpg')]

用于调整图像大小并转换为张量的转换

176        self._transform = torchvision.transforms.Compose([
177            torchvision.transforms.Resize(image_size),
178            torchvision.transforms.ToTensor(),
179        ])

数据集的大小

181    def __len__(self):
185        return len(self._files)

获取一张图片

187    def __getitem__(self, index: int):
191        img = Image.open(self._files[index])
192        return self._transform(img)

创建 CeleBA 数据集

195@option(Configs.dataset, 'CelebA')
196def celeb_dataset(c: Configs):
200    return CelebADataset(c.image_size)

MNIST 数据集

203class MNISTDataset(torchvision.datasets.MNIST):
208    def __init__(self, image_size):
209        transform = torchvision.transforms.Compose([
210            torchvision.transforms.Resize(image_size),
211            torchvision.transforms.ToTensor(),
212        ])
213
214        super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)
216    def __getitem__(self, item):
217        return super().__getitem__(item)[0]

创建 MNIST 数据集

220@option(Configs.dataset, 'MNIST')
221def mnist_dataset(c: Configs):
225    return MNISTDataset(c.image_size)
228def main():

创建实验

230    experiment.create(name='diffuse', writers={'screen', 'comet'})

创建配置

233    configs = Configs()

设置配置。您可以通过在字典中传递值来覆盖默认值。

236    experiment.configs(configs, {
237        'dataset': 'CelebA',  # 'MNIST'
238        'image_channels': 3,  # 1,
239        'epochs': 100,  # 5,
240    })

初始化

243    configs.init()

设置用于保存和加载的模型

246    experiment.add_pytorch_models({'eps_model': configs.eps_model})

启动并运行训练循环

249    with experiment.start():
250        configs.run()

254if __name__ == '__main__':
255    main()