14import numpy as np
15import torch
16from matplotlib import pyplot as plt
17from torchvision.transforms.functional import to_pil_image, resize
18
19from labml import experiment, monit
20from labml_nn.diffusion.ddpm import DenoiseDiffusion, gather
21from labml_nn.diffusion.ddpm.experiment import Configs24class Sampler:diffusion
DenoiseDiffusion
インスタンスですimage_channels
は画像内のチャンネル数image_size
は画像サイズですdevice
モデルのデバイスです29    def __init__(self, diffusion: DenoiseDiffusion, image_channels: int, image_size: int, device: torch.device):36        self.device = device
37        self.image_size = image_size
38        self.image_channels = image_channels
39        self.diffusion = diffusion42        self.n_steps = diffusion.n_steps44        self.eps_model = diffusion.eps_model46        self.beta = diffusion.beta48        self.alpha = diffusion.alpha50        self.alpha_bar = diffusion.alpha_bar52        alpha_bar_tm1 = torch.cat([self.alpha_bar.new_ones((1,)), self.alpha_bar[:-1]])64        self.beta_tilde = self.beta * (1 - alpha_bar_tm1) / (1 - self.alpha_bar)66        self.mu_tilde_coef1 = self.beta * (alpha_bar_tm1 ** 0.5) / (1 - self.alpha_bar)68        self.mu_tilde_coef2 = (self.alpha ** 0.5) * (1 - alpha_bar_tm1) / (1 - self.alpha_bar)70        self.sigma2 = self.beta画像を表示するヘルパー関数
72    def show_image(self, img, title=""):74        img = img.clip(0, 1)
75        img = img.cpu().numpy()
76        plt.imshow(img.transpose(1, 2, 0))
77        plt.title(title)
78        plt.show()動画を作成するためのヘルパー機能
80    def make_video(self, frames, path="video.mp4"):82        import imageio20 秒のビデオ
84        writer = imageio.get_writer(path, fps=len(frames) // 20)各画像を追加
86        for f in frames:
87            f = f.clip(0, 1)
88            f = to_pil_image(resize(f, [368, 368]))
89            writer.append_data(np.array(f))91        writer.close()93    def sample_animation(self, n_frames: int = 1000, create_video: bool = True):104        xt = torch.randn([1, self.image_channels, self.image_size, self.image_size], device=self.device)ログに記録する間隔
107        interval = self.n_steps // n_framesビデオ用フレーム
109        frames = []サンプルステップ
111        for t_inv in monit.iterate('Denoise', self.n_steps):113            t_ = self.n_steps - t_inv - 1テンソルで
115            t = xt.new_full((1,), t_, dtype=torch.long)117            eps_theta = self.eps_model(xt, t)
118            if t_ % interval == 0:フレームの取得と追加
120                x0 = self.p_x0(xt, t, eps_theta)
121                frames.append(x0[0])
122                if not create_video:
123                    self.show_image(x0[0], f"{t_}")からのサンプル
125            xt = self.p_sample(xt, t, eps_theta)動画を作る
128        if create_video:
129            self.make_video(frames)131    def interpolate(self, x1: torch.Tensor, x2: torch.Tensor, lambda_: float, t_: int = 100):サンプル数
150        n_samples = x1.shape[0]テンソル
152        t = torch.full((n_samples,), t_, device=self.device)154        xt = (1 - lambda_) * self.diffusion.q_sample(x1, t) + lambda_ * self.diffusion.q_sample(x2, t)157        return self._sample_x0(xt, t_)159    def interpolate_animate(self, x1: torch.Tensor, x2: torch.Tensor, n_frames: int = 100, t_: int = 100,
160                            create_video=True):元の画像を表示
172        self.show_image(x1, "x1")
173        self.show_image(x2, "x2")バッチディメンションを追加
175        x1 = x1[None, :, :, :]
176        x2 = x2[None, :, :, :]テンソル
178        t = torch.full((1,), t_, device=self.device)180        x1t = self.diffusion.q_sample(x1, t)182        x2t = self.diffusion.q_sample(x2, t)
183
184        frames = []異なるフレームを取得
186        for i in monit.iterate('Interpolate', n_frames + 1, is_children_silent=True):188            lambda_ = i / n_frames190            xt = (1 - lambda_) * x1t + lambda_ * x2t192            x0 = self._sample_x0(xt, t_)フレームに追加
194            frames.append(x0[0])フレームを表示
196            if not create_video:
197                self.show_image(x0[0], f"{lambda_ :.2f}")動画を作る
200        if create_video:
201            self.make_video(frames)203    def _sample_x0(self, xt: torch.Tensor, n_steps: int):サンプル数
212        n_samples = xt.shape[0]ステップまで繰り返す
214        for t_ in monit.iterate('Denoise', n_steps):
215            t = n_steps - t_ - 1からのサンプル
217            xt = self.diffusion.p_sample(xt, xt.new_full((n_samples,), t, dtype=torch.long))リターン
220        return xt222    def sample(self, n_samples: int = 16):227        xt = torch.randn([n_samples, self.image_channels, self.image_size, self.image_size], device=self.device)230        x0 = self._sample_x0(xt, self.n_steps)画像を表示
233        for i in range(n_samples):
234            self.show_image(x0[i])236    def p_sample(self, xt: torch.Tensor, t: torch.Tensor, eps_theta: torch.Tensor):251        alpha = gather(self.alpha, t)253        eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5256        mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)258        var = gather(self.sigma2, t)261        eps = torch.randn(xt.shape, device=xt.device)[サンプル]
263        return mean + (var ** .5) * eps265    def p_x0(self, xt: torch.Tensor, t: torch.Tensor, eps: torch.Tensor):277        return (xt - (1 - alpha_bar) ** 0.5 * eps) / (alpha_bar ** 0.5)サンプルを生成
280def main():トレーニング実験実行 UUID
284    run_uuid = "a44333ea251411ec8007d1a1762ed686"評価を開始する
287    experiment.evaluate()コンフィグの作成
290    configs = Configs()トレーニングランのカスタム構成をロード
292    configs_dict = experiment.load_configs(run_uuid)構成を設定
294    experiment.configs(configs, configs_dict)[初期化]
297    configs.init()保存と読み込み用の PyTorch モジュールの設定
300    experiment.add_pytorch_models({'eps_model': configs.eps_model})負荷訓練実験
303    experiment.load(run_uuid)サンプラーの作成
306    sampler = Sampler(diffusion=configs.diffusion,
307                      image_channels=configs.image_channels,
308                      image_size=configs.image_size,
309                      device=configs.device)評価を開始する
312    with experiment.start():グラデーションなし
314        with torch.no_grad():ノイズ除去アニメーションによる画像のサンプリング
316            sampler.sample_animation()
317
318            if False:データからいくつかの画像を取得
320                data = next(iter(configs.data_loader)).to(configs.device)補間アニメーションの作成
323                sampler.interpolate_animate(data[0], data[1])327if __name__ == '__main__':
328    main()