විසරණ සම්භාවිතා ආකෘති නිරූපණය කිරීම (DDPM)

Open In Colab

මෙය PyTorch ක්රියාත්මක කිරීම/නිබන්ධනයකි කඩදාසි Denoising Diffusion Probilistic ආකෘති.

සරළව කිවහොත්, අපි දත්ත වලින් රූපයක් ලබාගෙන පියවරෙන් පියවර ශබ්දය එක් කරමු. ඉන්පසු අපි සෑම පියවරකදීම එම ශබ්දය පුරෝකථනය කිරීමට ආකෘතියක් පුහුණු කර රූප ජනනය කිරීමට ආකෘතිය භාවිතා කරමු.

පහත දැක්වෙන අර්ථ දැක්වීම් සහ ව්යුත්පන්නයන් මෙය ක්රියාත්මක වන ආකාරය පෙන්වයි. විස්තර සඳහා කරුණාකර කඩදාසි වෙත යොමු වන්න.

ඉදිරි ක්රියාවලිය

ඉදිරි ක්රියාවලිය කාලසටහන සඳහා දත්ත වලට ශබ්දය එක් කරයි.

විචලනය කාලසටහන කොහේද?

අපට ඕනෑම වේලාවක නියැදිය හැකිය,

කොහේද සහ

ප්රතිලෝම ක්රියාවලිය

ප්රතිලෝම ක්රියාවලිය මඟින් කාල පියවර සඳහා ආරම්භ වන ශබ්දය ඉවත් කරයි.

අපි පුහුණු පරාමිතීන් වේ.

පාඩුව

ELBO (ජෙන්සන්ගේ අසමානතාවයෙන්) සෘණ ලොග් සම්භාවිතාව මත අපි ප්රශස්තිකරණය කරමු.

අලාභය පහත පරිදි නැවත ලිවිය හැකිය.

අපි නියතව සිටින බැවින් නියත වේ.

පරිගණකකරණය

ඉදිරි ක්රියාවලිය posterior විසින් සමනය වේ,

කඩදාසි නියතයන්ට සකසා ඇති තැන සකසයි හෝ.

එවිට,

ලබා දී ඇති ශබ්දය සඳහා

මෙය ලබා දෙයි,

ශබ්දය පුරෝකථනය කිරීම සඳහා ආකෘතියක් සමඟ නැවත පරාමිතිකරණය කිරීම

ලබා දී ඇති අනාවැකි පළ කරන උගත් ශ්රිතයක් කොහේද?

මෙය ලබා දෙයි,

එනම්, අපි ශබ්දය පුරෝකථනය කිරීමට පුහුණු වෙමු.

සරල කළ අලාභය

බර කිරන විට සහ ඉවතලීම සඳහා මෙය අවම කරයි. බර ඉවතලීම ඉහළ (ඉහළ ශබ්ද මට්ටම් ඇති) ලබා දෙන බර වැඩි කරයි, එබැවින් නියැදියේ ගුණාත්මකභාවය වැඩි කරයි.

මෙම ගොනුව මඟින් පාඩු ගණනය කිරීම සහ පුහුණුව අතරතුර රූප ජනනය කිරීම සඳහා අප භාවිතා කරන මූලික නියැදි ක්රමයක් ක්රියාත්මක කරයි.

කේතය ලබා දෙන සහ පුහුණු කරන UNET ආකෘතිය මෙන්න. මෙම ගොනුවට පුහුණු ආකෘතියකින් සාම්පල සහ අන්තර්නිවේශනයන් ජනනය කළ හැකිය.

162from typing import Tuple, Optional
163
164import torch
165import torch.nn.functional as F
166import torch.utils.data
167from torch import nn
168
169from labml_nn.diffusion.ddpm.utils import gather

ඩෙනොයිස්විසරණය

172class DenoiseDiffusion:
  • eps_model ආකෘතිය වේ
  • n_steps වේ
  • device නියතයන් මත තැබීමට උපාංගය වේ
  • 177    def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
    183        super().__init__()
    184        self.eps_model = eps_model

    රේඛීයව වැඩිවන විචල්යතා කාලසටහනක් සාදන්න

    187        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)

    190        self.alpha = 1. - self.beta

    192        self.alpha_bar = torch.cumprod(self.alpha, dim=0)

    194        self.n_steps = n_steps

    196        self.sigma2 = self.beta

    බෙදා හැරීම ලබා ගන්න

    198    def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

    රැස් කර ගණනය කරන්න

    208        mean = gather(self.alpha_bar, t) ** 0.5 * x0

    210        var = 1 - gather(self.alpha_bar, t)

    212        return mean, var

    වෙතින්නියැදිය

    214    def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):

    224        if eps is None:
    225            eps = torch.randn_like(x0)

    ලබාගන්න

    228        mean, var = self.q_xt_x0(x0, t)

    වෙතින්නියැදිය

    230        return mean + (var ** 0.5) * eps

    වෙතින්නියැදිය

    232    def p_sample(self, xt: torch.Tensor, t: torch.Tensor):

    246        eps_theta = self.eps_model(xt, t)
    248        alpha_bar = gather(self.alpha_bar, t)

    250        alpha = gather(self.alpha, t)

    252        eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5

    255        mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)

    257        var = gather(self.sigma2, t)

    260        eps = torch.randn(xt.shape, device=xt.device)

    නියැදිය

    262        return mean + (var ** .5) * eps

    සරලඅඞු කිරීමට

    264    def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):

    කණ්ඩායම්ප්රමාණය ලබා ගන්න

    273        batch_size = x0.shape[0]

    කණ්ඩායමේඑක් එක් නියැදිය සඳහා අහඹු ලෙස ලබා ගන්න

    275        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)

    278        if noise is None:
    279            noise = torch.randn_like(x0)

    සඳහා නියැදිය

    282        xt = self.q_sample(x0, t, eps=noise)

    ලබාගන්න

    284        eps_theta = self.eps_model(xt, t)

    MSEඅලාභය

    287        return F.mse_loss(noise, eps_theta)