විසරණ සම්භාවිතාව ආකෘති (DDPM) ඇගයීම/නියැදීම

රූපඋත්පාදනය කිරීමට සහ ලබා දී ඇති රූප අතර අන්තර්ක්රියාකාරිත්වයන් නිර්මාණය කිරීමේ කේතය මෙයයි.

14import numpy as np
15import torch
16from matplotlib import pyplot as plt
17from torchvision.transforms.functional import to_pil_image, resize
18
19from labml import experiment, monit
20from labml_nn.diffusion.ddpm import DenoiseDiffusion, gather
21from labml_nn.diffusion.ddpm.experiment import Configs

නියැදිපන්තිය

24class Sampler:
  • diffusion DenoiseDiffusion උදාහරණයක් වේ
  • image_channels යනු රූපයේ ඇති නාලිකා ගණන
  • image_size යනු රූපයේ ප්රමාණයයි
  • device ආකෘතියේ උපාංගය වේ
  • 29    def __init__(self, diffusion: DenoiseDiffusion, image_channels: int, image_size: int, device: torch.device):
    36        self.device = device
    37        self.image_size = image_size
    38        self.image_channels = image_channels
    39        self.diffusion = diffusion

    42        self.n_steps = diffusion.n_steps

    44        self.eps_model = diffusion.eps_model

    46        self.beta = diffusion.beta

    48        self.alpha = diffusion.alpha

    50        self.alpha_bar = diffusion.alpha_bar

    52        alpha_bar_tm1 = torch.cat([self.alpha_bar.new_ones((1,)), self.alpha_bar[:-1]])

    ගණනයකිරීමට

    64        self.beta_tilde = self.beta * (1 - alpha_bar_tm1) / (1 - self.alpha_bar)

    66        self.mu_tilde_coef1 = self.beta * (alpha_bar_tm1 ** 0.5) / (1 - self.alpha_bar)

    68        self.mu_tilde_coef2 = (self.alpha ** 0.5) * (1 - alpha_bar_tm1) / (1 - self.alpha_bar)

    70        self.sigma2 = self.beta

    රූපයක්ප්රදර්ශනය කිරීම සඳහා උපකාරක කාර්යය

    72    def show_image(self, img, title=""):
    74        img = img.clip(0, 1)
    75        img = img.cpu().numpy()
    76        plt.imshow(img.transpose(1, 2, 0))
    77        plt.title(title)
    78        plt.show()

    වීඩියෝවක්නිර්මාණය කිරීම සඳහා උපකාරක කාර්යය

    80    def make_video(self, frames, path="video.mp4"):
    82        import imageio

    20දෙවන වීඩියෝව

    84        writer = imageio.get_writer(path, fps=len(frames) // 20)

    එක්එක් රූපය එකතු කරන්න

    86        for f in frames:
    87            f = f.clip(0, 1)
    88            f = to_pil_image(resize(f, [368, 368]))
    89            writer.append_data(np.array(f))

    91        writer.close()

    පියවරෙන්පියවර භාවිතා කරමින් රූපයක් සාම්පල කරන්න

    අපිපියවරෙන් පියවර රූපයක් සාම්පල කර ඇති අතර සෑම පියවරකදීම ඇස්තමේන්තුව පෙන්වයි

    93    def sample_animation(self, n_frames: int = 1000, create_video: bool = True):

    104        xt = torch.randn([1, self.image_channels, self.image_size, self.image_size], device=self.device)

    ලොග්වීමට පරතරය

    107        interval = self.n_steps // n_frames

    වීඩියෝසඳහා රාමු

    109        frames = []

    නියැදි පියවර

    111        for t_inv in monit.iterate('Denoise', self.n_steps):

    113            t_ = self.n_steps - t_inv - 1

    එය tensor දී

    115            t = xt.new_full((1,), t_, dtype=torch.long)

    117            eps_theta = self.eps_model(xt, t)
    118            if t_ % interval == 0:

    රාමුලබා ගන්න සහ එකතු කරන්න

    120                x0 = self.p_x0(xt, t, eps_theta)
    121                frames.append(x0[0])
    122                if not create_video:
    123                    self.show_image(x0[0], f"{t_}")

    වෙතින්නියැදිය

    125            xt = self.p_sample(xt, t, eps_theta)

    වීඩියෝසාදන්න

    128        if create_video:
    129            self.make_video(frames)

    රූපදෙකක් අන්තර්ග්රහණය කරන්න සහ

    අපිලබා ගන්න සහ .

    ඉන්පසුඅන්තර්පොලේට් කරන්න

    ඉන්පසුලබා ගන්න

    • x1 වේ
    • x2 වේ
    • lambda_ වේ
  • t_ වේ
  • 131    def interpolate(self, x1: torch.Tensor, x2: torch.Tensor, lambda_: float, t_: int = 100):

    සාම්පලගණන

    150        n_samples = x1.shape[0]

    ටෙන්සර්

    152        t = torch.full((n_samples,), t_, device=self.device)

    154        xt = (1 - lambda_) * self.diffusion.q_sample(x1, t) + lambda_ * self.diffusion.q_sample(x2, t)

    157        return self._sample_x0(xt, t_)

    රූපදෙකක් අන්තර්ග්රහණය කර වීඩියෝවක් සාදන්න

    • x1 වේ
    • x2 වේ
    • n_frames යනු රූපය සඳහා රාමු ගණන
    • t_ වේ
    • create_video වීඩියෝවක් සෑදිය යුතුද නැතහොත් එක් එක් රාමුව පෙන්විය යුතුද යන්න නියම කරයි
    159    def interpolate_animate(self, x1: torch.Tensor, x2: torch.Tensor, n_frames: int = 100, t_: int = 100,
    160                            create_video=True):

    මුල්රූප පෙන්වන්න

    172        self.show_image(x1, "x1")
    173        self.show_image(x2, "x2")

    කණ්ඩායම්මානය එක් කරන්න

    175        x1 = x1[None, :, :, :]
    176        x2 = x2[None, :, :, :]

    ටෙන්සර්

    178        t = torch.full((1,), t_, device=self.device)

    180        x1t = self.diffusion.q_sample(x1, t)

    182        x2t = self.diffusion.q_sample(x2, t)
    183
    184        frames = []

    විවිධරාමු ලබා ගන්න

    186        for i in monit.iterate('Interpolate', n_frames + 1, is_children_silent=True):

    188            lambda_ = i / n_frames

    190            xt = (1 - lambda_) * x1t + lambda_ * x2t

    192            x0 = self._sample_x0(xt, t_)

    රාමුවලට එකතු කරන්න

    194            frames.append(x0[0])

    රාමුවපෙන්වන්න

    196            if not create_video:
    197                self.show_image(x0[0], f"{lambda_ :.2f}")

    වීඩියෝසාදන්න

    200        if create_video:
    201            self.make_video(frames)

    භාවිතාකරමින් රූපයක් සාම්පල කරන්න

    • xt වේ
  • n_steps වේ
  • 203    def _sample_x0(self, xt: torch.Tensor, n_steps: int):

    නියැදිගණන

    212        n_samples = xt.shape[0]

    පියවර ගන්නා තෙක් නැවත ක්රියාත්මක වන්න

    214        for t_ in monit.iterate('Denoise', n_steps):
    215            t = n_steps - t_ - 1

    වෙතින්නියැදිය

    217            xt = self.diffusion.p_sample(xt, xt.new_full((n_samples,), t, dtype=torch.long))

    ආපසු

    220        return xt

    රූපජනනය කරන්න

    222    def sample(self, n_samples: int = 16):

    227        xt = torch.randn([n_samples, self.image_channels, self.image_size, self.image_size], device=self.device)

    230        x0 = self._sample_x0(xt, self.n_steps)

    පින්තූරපෙන්වන්න

    233        for i in range(n_samples):
    234            self.show_image(x0[i])

    වෙතින්නියැදිය

    236    def p_sample(self, xt: torch.Tensor, t: torch.Tensor, eps_theta: torch.Tensor):
    249        alpha_bar = gather(self.alpha_bar, t)

    251        alpha = gather(self.alpha, t)

    253        eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5

    256        mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)

    258        var = gather(self.sigma2, t)

    261        eps = torch.randn(xt.shape, device=xt.device)

    නියැදිය

    263        return mean + (var ** .5) * eps

    ඇස්තමේන්තුව

    265    def p_x0(self, xt: torch.Tensor, t: torch.Tensor, eps: torch.Tensor):
    273        alpha_bar = gather(self.alpha_bar, t)

    277        return (xt - (1 - alpha_bar) ** 0.5 * eps) / (alpha_bar ** 0.5)

    සාම්පලජනනය කරන්න

    280def main():

    පුහුණුඅත්හදා බැලීම UUID

    284    run_uuid = "a44333ea251411ec8007d1a1762ed686"

    ඇගයීමක්ආරම්භ කරන්න

    287    experiment.evaluate()

    වින්යාසසාදන්න

    290    configs = Configs()

    පුහුණුධාවනයේ අභිරුචි වින්යාසය පූරණය කරන්න

    292    configs_dict = experiment.load_configs(run_uuid)

    වින්යාසයන්සකසන්න

    294    experiment.configs(configs, configs_dict)

    ආරම්භකරන්න

    297    configs.init()

    ඉතිරිකිරීම සහ පැටවීම සඳහා පයිටෝච් මොඩියුල සකසන්න

    300    experiment.add_pytorch_models({'eps_model': configs.eps_model})

    පුහුණුඅත්හදා බැලීම

    303    experiment.load(run_uuid)

    නියැදියසාදන්න

    306    sampler = Sampler(diffusion=configs.diffusion,
    307                      image_channels=configs.image_channels,
    308                      image_size=configs.image_size,
    309                      device=configs.device)

    ඇගයීමආරම්භ කරන්න

    312    with experiment.start():

    කිසිදුඅනුක්රමික

    314        with torch.no_grad():

    නිරූපණයකරන සජීවිකරණයක් සහිත රූපයක් සාම්පල කරන්න

    316            sampler.sample_animation()
    317
    318            if False:

    දත්තවලින් පින්තූර කිහිපයක් ලබා ගන්න

    320                data = next(iter(configs.data_loader)).to(configs.device)

    අන්තර්නිවේෂණයසජීවිකරණයක් සාදන්න

    323                sampler.interpolate_animate(data[0], data[1])

    327if __name__ == '__main__':
    328    main()