විසරණසම්භාවිතාව ආකෘති නිරූපණය කිරීම (DDPM)

Open In Colab Open In Comet

මෙය PyTorch ක්රියාත්මක කිරීම/නිබන්ධනයකි කඩදාසි Denoising Diffusion සම්භාවිතාව ආකෘති .

සරළවකිවහොත්, අපි දත්ත වලින් රූපයක් ලබාගෙන පියවරෙන් පියවර ශබ්දය එක් කරමු. ඉන්පසු අපි සෑම පියවරකදීම එම ශබ්දය පුරෝකථනය කිරීමට ආකෘතියක් පුහුණු කර රූප ජනනය කිරීමට ආකෘතිය භාවිතා කරමු.

පහතදැක්වෙන අර්ථ දැක්වීම් සහ ව්යුත්පන්නයන් මෙය ක්රියාත්මක වන ආකාරය පෙන්වයි. විස්තර සඳහා කරුණාකර කඩදාසි වෙතයොමු වන්න.

ඉදිරික්රියාවලිය

ඉදිරික්රියාවලිය කාලසටහන සඳහා දත්ත වලට ශබ්දය එක් කරයි.

විචලනයකාලසටහන කොහේද?

අපටඕනෑම වේලාවක නියැදිය හැකිය,

කොහේද සහ

ප්රතිලෝමක්රියාවලිය

ප්රතිලෝමක්රියාවලිය මඟින් කාල පියවර සඳහා ආරම්භ වන ශබ්දය ඉවත් කරයි.

අපි පුහුණු පරාමිතීන් වේ.

පාඩුව

අපිELBO (ජෙන්සන්ගේ අසමානතාවයෙන්) සෘණ ලොග් සම්භාවිතාව මත ප්රශස්තිකරණය කරමු.

අලාභයපහත පරිදි නැවත ලිවිය හැකිය.

අපි නියතව සිටින බැවින් නියත වේ.

පරිගණක

ඉදිරික්රියාවලිය posterior විසින් සමනය වේ,

කඩදාසිනියම කර ඇති තැන සකසයි හෝ .

එවිට,

ලබාදී ඇති ශබ්දය සඳහා

මෙයලබා දෙයි,

ශබ්දයපුරෝකථනය කිරීම සඳහා ආකෘතියක් සමඟ නැවත පරාමිතිකරණය කිරීම

ලබා දී ඇති අනාවැකි පළ කරන උගත් ශ්රිතයක් කොහේද?

මෙයලබා දෙයි,

එනම්, ශබ්දය පුරෝකථනය කිරීමට අපි පුහුණු වෙමු.

සරලඅලාභය

බරඅඩු කිරීමේදී සහ ඉවතලීම සඳහා මෙය අවම කරයි . බර ඉවතලීම ඉහළ (ඉහළ ශබ්ද මට්ටම් ඇති) දක්වා ඇති බර වැඩි කරයි, එබැවින් නියැදි ගුණාත්මකභාවය වැඩි කරයි.

පුහුණුවඅතරතුර රූප ජනනය කිරීම සඳහා අප භාවිතා කරන පාඩු ගණනය කිරීම සහ මූලික නියැදි ක්රමයක් මෙම ගොනුව ක්රියාත්මක කරයි.

කේතය ලබා දෙන සහ පුහුණු කරන UNET ආකෘතිය මෙන්න. මෙම ගොනුවට පුහුණු ආකෘතියකින් සාම්පල සහ අන්තර්නිවේශනයන් ජනනය කළ හැකිය.

163from typing import Tuple, Optional
164
165import torch
166import torch.nn.functional as F
167import torch.utils.data
168from torch import nn
169
170from labml_nn.diffusion.ddpm.utils import gather

ඩෙනොයිස්විසරණය

173class DenoiseDiffusion:
  • eps_model ආකෘතිය වේ
  • n_steps වේ
  • device නියතයන් මත තැබීමට උපාංගය වේ
  • 178    def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
    184        super().__init__()
    185        self.eps_model = eps_model

    රේඛීයව වැඩිවන විචල්යතා කාලසටහනක් සාදන්න

    188        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)

    191        self.alpha = 1. - self.beta

    193        self.alpha_bar = torch.cumprod(self.alpha, dim=0)

    195        self.n_steps = n_steps

    197        self.sigma2 = self.beta

    බෙදා හැරීම ලබා ගන්න

    199    def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

    රැස් කර ගණනය කරන්න

    209        mean = gather(self.alpha_bar, t) ** 0.5 * x0

    211        var = 1 - gather(self.alpha_bar, t)

    213        return mean, var

    වෙතින්නියැදිය

    215    def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):

    225        if eps is None:
    226            eps = torch.randn_like(x0)

    ලබාගන්න

    229        mean, var = self.q_xt_x0(x0, t)

    වෙතින්නියැදිය

    231        return mean + (var ** 0.5) * eps

    වෙතින්නියැදිය

    233    def p_sample(self, xt: torch.Tensor, t: torch.Tensor):

    247        eps_theta = self.eps_model(xt, t)
    249        alpha_bar = gather(self.alpha_bar, t)

    251        alpha = gather(self.alpha, t)

    253        eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5

    256        mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)

    258        var = gather(self.sigma2, t)

    261        eps = torch.randn(xt.shape, device=xt.device)

    නියැදිය

    263        return mean + (var ** .5) * eps

    සරලඅඞු කිරීමට

    265    def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):

    කණ්ඩායම්ප්රමාණය ලබා ගන්න

    274        batch_size = x0.shape[0]

    කණ්ඩායමේඑක් එක් නියැදිය සඳහා අහඹු ලෙස ලබා ගන්න

    276        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)

    279        if noise is None:
    280            noise = torch.randn_like(x0)

    සඳහා නියැදිය

    283        xt = self.q_sample(x0, t, eps=noise)

    ලබාගන්න

    285        eps_theta = self.eps_model(xt, t)

    MSEඅලාභය

    288        return F.mse_loss(noise, eps_theta)