これにより、CeleBA HQ データセットで DDPM ベースのモデルがトレーニングされます。ダウンロードの説明は、fast.ai のこのディスカッションにあります。data/celebA
画像をフォルダーに保存します。
この論文では、モデルの指数移動平均を減衰させて使用していました。簡略化のため、ここでは省略しています
。20from typing import List
21
22import torch
23import torch.utils.data
24import torchvision
25from PIL import Image
26
27from labml import lab, tracker, experiment, monit
28from labml.configs import BaseConfigs, option
29from labml_helpers.device import DeviceConfigs
30from labml_nn.diffusion.ddpm import DenoiseDiffusion
31from labml_nn.diffusion.ddpm.unet import UNet34class Configs(BaseConfigs):モデルをトレーニングするデバイス。DeviceConfigs
使用可能な CUDA デバイスを選択するか、デフォルトで CPU に設定します
41    device: torch.device = DeviceConfigs()用の U-Net モデル
44    eps_model: UNet46    diffusion: DenoiseDiffusion画像内のチャンネル数。RGB 用です。
49    image_channels: int = 3画像サイズ
51    image_size: int = 32初期機能マップのチャンネル数
53    n_channels: int = 64各解像度のチャンネル番号のリスト。チャンネル数は channel_multipliers[i] * n_channels
56    channel_multipliers: List[int] = [1, 2, 2, 4]各解像度で注意を向けるかどうかを示すブーリアンのリスト
58    is_attention: List[int] = [False, False, False, True]タイムステップ数
61    n_steps: int = 1_000バッチサイズ
63    batch_size: int = 64生成するサンプルの数
65    n_samples: int = 16学習率
67    learning_rate: float = 2e-5トレーニングエポックの数
70    epochs: int = 1_000データセット
73    dataset: torch.utils.data.Datasetデータローダー
75    data_loader: torch.utils.data.DataLoaderアダム・オプティマイザー
78    optimizer: torch.optim.Adam80    def init(self):モデル作成
82        self.eps_model = UNet(
83            image_channels=self.image_channels,
84            n_channels=self.n_channels,
85            ch_mults=self.channel_multipliers,
86            is_attn=self.is_attention,
87        ).to(self.device)90        self.diffusion = DenoiseDiffusion(
91            eps_model=self.eps_model,
92            n_steps=self.n_steps,
93            device=self.device,
94        )データローダーの作成
97        self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)オプティマイザーを作成
99        self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)画像ロギング
102        tracker.set_image("sample", True)104    def sample(self):108        with torch.no_grad():110            x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
111                            device=self.device)ステップのノイズ除去
114            for t_ in monit.iterate('Sample', self.n_steps):116                t = self.n_steps - t_ - 1からのサンプル
118                x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))ログサンプル
121            tracker.save('sample', x)123    def train(self):データセットの反復処理
129        for data in monit.iterate('Train', self.data_loader):グローバルステップをインクリメント
131            tracker.add_global_step()データをデバイスに移動
133            data = data.to(self.device)グラデーションをゼロにする
136            self.optimizer.zero_grad()損失の計算
138            loss = self.diffusion.loss(data)勾配の計算
140            loss.backward()最適化の一歩を踏み出す
142            self.optimizer.step()損失をトラッキング
144            tracker.save('loss', loss)146    def run(self):150        for _ in monit.loop(self.epochs):モデルのトレーニング
152            self.train()いくつかの画像のサンプル
154            self.sample()コンソールの新しい行
156            tracker.new_line()モデルを保存する
158            experiment.save_checkpoint()161class CelebADataset(torch.utils.data.Dataset):166    def __init__(self, image_size: int):
167        super().__init__()セレバ画像フォルダー
170        folder = lab.get_data_path() / 'celebA'ファイルリスト
172        self._files = [p for p in folder.glob(f'**/*.jpg')]画像のサイズを変更してテンソルに変換する変換
175        self._transform = torchvision.transforms.Compose([
176            torchvision.transforms.Resize(image_size),
177            torchvision.transforms.ToTensor(),
178        ])データセットのサイズ
180    def __len__(self):184        return len(self._files)画像を取得
186    def __getitem__(self, index: int):190        img = Image.open(self._files[index])
191        return self._transform(img)CeleBA データセットの作成
194@option(Configs.dataset, 'CelebA')
195def celeb_dataset(c: Configs):199    return CelebADataset(c.image_size)202class MNISTDataset(torchvision.datasets.MNIST):207    def __init__(self, image_size):
208        transform = torchvision.transforms.Compose([
209            torchvision.transforms.Resize(image_size),
210            torchvision.transforms.ToTensor(),
211        ])
212
213        super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)215    def __getitem__(self, item):
216        return super().__getitem__(item)[0]MNIST データセットの作成
219@option(Configs.dataset, 'MNIST')
220def mnist_dataset(c: Configs):224    return MNISTDataset(c.image_size)227def main():実験を作成
229    experiment.create(name='diffuse', writers={'screen', 'labml'})構成の作成
232    configs = Configs()構成を設定します。ディクショナリに値を渡すことでデフォルトをオーバーライドできます。
235    experiment.configs(configs, {
236        'dataset': 'CelebA',  # 'MNIST'
237        'image_channels': 3,  # 1,
238        'epochs': 100,  # 5,
239    })[初期化]
242    configs.init()保存および読み込み用のモデルを設定する
245    experiment.add_pytorch_models({'eps_model': configs.eps_model})トレーニングループを開始して実行する
248    with experiment.start():
249        configs.run()253if __name__ == '__main__':
254    main()