これにより、CeleBA HQ データセットで DDPM ベースのモデルがトレーニングされます。ダウンロードの説明は、fast.ai のこのディスカッションにあります。data/celebA
画像をフォルダーに保存します。
この論文では、モデルの指数移動平均を減衰させて使用していました。簡略化のため、ここでは省略しています
。20from typing import List
21
22import torch
23import torch.utils.data
24import torchvision
25from PIL import Image
26
27from labml import lab, tracker, experiment, monit
28from labml.configs import BaseConfigs, option
29from labml_helpers.device import DeviceConfigs
30from labml_nn.diffusion.ddpm import DenoiseDiffusion
31from labml_nn.diffusion.ddpm.unet import UNet34class Configs(BaseConfigs):モデルをトレーニングするデバイス。DeviceConfigs
使用可能な CUDA デバイスを選択するか、デフォルトで CPU に設定します
41 device: torch.device = DeviceConfigs()用の U-Net モデル
44 eps_model: UNet46 diffusion: DenoiseDiffusion画像内のチャンネル数。RGB 用です。
49 image_channels: int = 3画像サイズ
51 image_size: int = 32初期機能マップのチャンネル数
53 n_channels: int = 64各解像度のチャンネル番号のリスト。チャンネル数は channel_multipliers[i] * n_channels
56 channel_multipliers: List[int] = [1, 2, 2, 4]各解像度で注意を向けるかどうかを示すブーリアンのリスト
58 is_attention: List[int] = [False, False, False, True]タイムステップ数
61 n_steps: int = 1_000バッチサイズ
63 batch_size: int = 64生成するサンプルの数
65 n_samples: int = 16学習率
67 learning_rate: float = 2e-5トレーニングエポックの数
70 epochs: int = 1_000データセット
73 dataset: torch.utils.data.Datasetデータローダー
75 data_loader: torch.utils.data.DataLoaderアダム・オプティマイザー
78 optimizer: torch.optim.Adam80 def init(self):モデル作成
82 self.eps_model = UNet(
83 image_channels=self.image_channels,
84 n_channels=self.n_channels,
85 ch_mults=self.channel_multipliers,
86 is_attn=self.is_attention,
87 ).to(self.device)90 self.diffusion = DenoiseDiffusion(
91 eps_model=self.eps_model,
92 n_steps=self.n_steps,
93 device=self.device,
94 )データローダーの作成
97 self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)オプティマイザーを作成
99 self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)画像ロギング
102 tracker.set_image("sample", True)104 def sample(self):108 with torch.no_grad():110 x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
111 device=self.device)ステップのノイズ除去
114 for t_ in monit.iterate('Sample', self.n_steps):116 t = self.n_steps - t_ - 1からのサンプル
118 x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))ログサンプル
121 tracker.save('sample', x)123 def train(self):データセットの反復処理
129 for data in monit.iterate('Train', self.data_loader):グローバルステップをインクリメント
131 tracker.add_global_step()データをデバイスに移動
133 data = data.to(self.device)グラデーションをゼロにする
136 self.optimizer.zero_grad()損失の計算
138 loss = self.diffusion.loss(data)勾配の計算
140 loss.backward()最適化の一歩を踏み出す
142 self.optimizer.step()損失をトラッキング
144 tracker.save('loss', loss)146 def run(self):150 for _ in monit.loop(self.epochs):モデルのトレーニング
152 self.train()いくつかの画像のサンプル
154 self.sample()コンソールの新しい行
156 tracker.new_line()モデルを保存する
158 experiment.save_checkpoint()161class CelebADataset(torch.utils.data.Dataset):166 def __init__(self, image_size: int):
167 super().__init__()セレバ画像フォルダー
170 folder = lab.get_data_path() / 'celebA'ファイルリスト
172 self._files = [p for p in folder.glob(f'**/*.jpg')]画像のサイズを変更してテンソルに変換する変換
175 self._transform = torchvision.transforms.Compose([
176 torchvision.transforms.Resize(image_size),
177 torchvision.transforms.ToTensor(),
178 ])データセットのサイズ
180 def __len__(self):184 return len(self._files)画像を取得
186 def __getitem__(self, index: int):190 img = Image.open(self._files[index])
191 return self._transform(img)CeleBA データセットの作成
194@option(Configs.dataset, 'CelebA')
195def celeb_dataset(c: Configs):199 return CelebADataset(c.image_size)202class MNISTDataset(torchvision.datasets.MNIST):207 def __init__(self, image_size):
208 transform = torchvision.transforms.Compose([
209 torchvision.transforms.Resize(image_size),
210 torchvision.transforms.ToTensor(),
211 ])
212
213 super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)215 def __getitem__(self, item):
216 return super().__getitem__(item)[0]MNIST データセットの作成
219@option(Configs.dataset, 'MNIST')
220def mnist_dataset(c: Configs):224 return MNISTDataset(c.image_size)227def main():実験を作成
229 experiment.create(name='diffuse', writers={'screen', 'labml'})構成の作成
232 configs = Configs()構成を設定します。ディクショナリに値を渡すことでデフォルトをオーバーライドできます。
235 experiment.configs(configs, {
236 'dataset': 'CelebA', # 'MNIST'
237 'image_channels': 3, # 1,
238 'epochs': 100, # 5,
239 })[初期化]
242 configs.init()保存および読み込み用のモデルを設定する
245 experiment.add_pytorch_models({'eps_model': configs.eps_model})トレーニングループを開始して実行する
248 with experiment.start():
249 configs.run()253if __name__ == '__main__':
254 main()