diff --git a/docs/diffusion/ddpm/evaluate.html b/docs/diffusion/ddpm/evaluate.html new file mode 100644 index 00000000..ad988108 --- /dev/null +++ b/docs/diffusion/ddpm/evaluate.html @@ -0,0 +1,1198 @@ + + + + + + + + + + + + + + + + + + + + + + + Denoising Diffusion Probabilistic Models (DDPM) evaluation/sampling + + + + + + + + +
+
+
+
+

+ home + diffusion + ddpm +

+

+ + + Github + + Twitter +

+
+
+
+
+ +

Denoising Diffusion Probabilistic Models (DDPM) evaluation/sampling

+

This is the code to generate images and create interpolations between given images.

+
+
+
14import numpy as np
+15import torch
+16from matplotlib import pyplot as plt
+17from torchvision.transforms.functional import to_pil_image, resize
+18
+19from labml import experiment, monit
+20from labml_nn.diffusion.ddpm import DenoiseDiffusion, gather
+21from labml_nn.diffusion.ddpm.experiment import Configs
+
+
+
+
+ +

Sampler class

+
+
+
24class Sampler:
+
+
+
+
+ +
    +
  • diffusion is the DenoiseDiffusion instance
  • +
  • image_channels is the number of channels in the image
  • +
  • image_size is the image size
  • +
  • device is the device of the model
  • +
+
+
+
29    def __init__(self, diffusion: DenoiseDiffusion, image_channels: int, image_size: int, device: torch.device):
+
+
+
+
+ + +
+
+
36        self.device = device
+37        self.image_size = image_size
+38        self.image_channels = image_channels
+39        self.diffusion = diffusion
+
+
+
+
+ +

$T$

+
+
+
42        self.n_steps = diffusion.n_steps
+
+
+
+
+ +

$\color{cyan}{\epsilon_\theta}(x_t, t)$

+
+
+
44        self.eps_model = diffusion.eps_model
+
+
+
+
+ +

$\beta_t$

+
+
+
46        self.beta = diffusion.beta
+
+
+
+
+ +

$\alpha_t$

+
+
+
48        self.alpha = diffusion.alpha
+
+
+
+
+ +

$\bar\alpha_t$

+
+
+
50        self.alpha_bar = diffusion.alpha_bar
+
+
+
+
+ +

$\bar\alpha_{t-1}$

+
+
+
52        alpha_bar_tm1 = torch.cat([self.alpha_bar.new_ones((1,)), self.alpha_bar[:-1]])
+
+
+
+
+ +

To calculate + +

+
+
+
+
+
+
+
+ +

$\tilde\beta_t$

+
+
+
63        self.beta_tilde = self.beta * (1 - alpha_bar_tm1) / (1 - self.alpha_bar)
+
+
+
+
+ +

+ +

+
+
+
65        self.mu_tilde_coef1 = self.beta * (alpha_bar_tm1 ** 0.5) / (1 - self.alpha_bar)
+
+
+
+
+ +

+ +

+
+
+
67        self.mu_tilde_coef2 = (self.alpha ** 0.5) * (1 - alpha_bar_tm1) / (1 - self.alpha_bar)
+
+
+
+
+ +

$\sigma^2 = \beta$

+
+
+
69        self.sigma2 = self.beta
+
+
+
+
+ +

Helper function to display an image

+
+
+
71    def show_image(self, img, title=""):
+
+
+
+
+ + +
+
+
73        img = img.clip(0, 1)
+74        img = img.cpu().numpy()
+75        plt.imshow(img.transpose(1, 2, 0))
+76        plt.title(title)
+77        plt.show()
+
+
+
+
+ +

Helper function to create a video

+
+
+
79    def make_video(self, frames, path="video.mp4"):
+
+
+
+
+ + +
+
+
81        import imageio
+
+
+
+
+ +

20 second video

+
+
+
83        writer = imageio.get_writer(path, fps=len(frames) // 20)
+
+
+
+
+ +

Add each image

+
+
+
85        for f in frames:
+86            f = f.clip(0, 1)
+87            f = to_pil_image(resize(f, [368, 368]))
+88            writer.append_data(np.array(f))
+
+
+
+
+ + +
+
+
90        writer.close()
+
+
+
+
+ +

Sample an image step-by-step using $\color{cyan}{p_\theta}(x_{t-1}|x_t)$

+

We sample an image step-by-step using $\color{cyan}{p_\theta}(x_{t-1}|x_t)$ and at each step +show the estimate + +

+
+
+
92    def sample_animation(self, n_frames: int = 1000, create_video: bool = True):
+
+
+
+
+ +

$x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$

+
+
+
103        xt = torch.randn([1, self.image_channels, self.image_size, self.image_size], device=self.device)
+
+
+
+
+ +

Interval to log $\hat{x}_0$

+
+
+
106        interval = self.n_steps // n_frames
+
+
+
+
+ +

Frames for video

+
+
+
108        frames = []
+
+
+
+
+ +

Sample $T$ steps

+
+
+
110        for t_inv in monit.iterate('Denoise', self.n_steps):
+
+
+
+
+ +

$t$

+
+
+
112            t_ = self.n_steps - t_inv - 1
+
+
+
+
+ +

$t$ in a tensor

+
+
+
114            t = xt.new_full((1,), t_, dtype=torch.long)
+
+
+
+
+ +

$\color{cyan}{\epsilon_\theta}(x_t, t)$

+
+
+
116            eps_theta = self.eps_model(xt, t)
+117            if t_ % interval == 0:
+
+
+
+
+ +

Get $\hat{x}_0$ and add to frames

+
+
+
119                x0 = self.p_x0(xt, t, eps_theta)
+120                frames.append(x0[0])
+121                if not create_video:
+122                    self.show_image(x0[0], f"{t_}")
+
+
+
+
+ +

Sample from $\color{cyan}{p_\theta}(x_{t-1}|x_t)$

+
+
+
124            xt = self.p_sample(xt, t, eps_theta)
+
+
+
+
+ +

Make video

+
+
+
127        if create_video:
+128            self.make_video(frames)
+
+
+
+
+ +

Interpolate two images $x_0$ and $x’_0$

+

We get $x_t \sim q(x_t|x_0)$ and $x’_t \sim q(x’_t|x_0)$.

+

Then interpolate to + +

+

Then get + +

+
    +
  • x1 is $x_0$
  • +
  • x2 is $x’_0$
  • +
  • lambda_ is $\lambda$
  • +
  • t_ is $t$
  • +
+
+
+
130    def interpolate(self, x1: torch.Tensor, x2: torch.Tensor, lambda_: float, t_: int = 100):
+
+
+
+
+ +

Number of samples

+
+
+
149        n_samples = x1.shape[0]
+
+
+
+
+ +

$t$ tensor

+
+
+
151        t = torch.full((n_samples,), t_, device=self.device)
+
+
+
+
+ +

+ +

+
+
+
153        xt = (1 - lambda_) * self.diffusion.q_sample(x1, t) + lambda_ * self.diffusion.q_sample(x2, t)
+
+
+
+
+ +

+ +

+
+
+
156        return self._sample_x0(xt, t_)
+
+
+
+
+ +

Interpolate two images $x_0$ and $x’_0$ and make a video

+
    +
  • x1 is $x_0$
  • +
  • x2 is $x’_0$
  • +
  • n_frames is the number of frames for the image
  • +
  • t_ is $t$
  • +
  • create_video specifies whether to make a video or to show each frame
  • +
+
+
+
158    def interpolate_animate(self, x1: torch.Tensor, x2: torch.Tensor, n_frames: int = 100, t_: int = 100,
+159                            create_video=True):
+
+
+
+
+ +

Show original images

+
+
+
171        self.show_image(x1, "x1")
+172        self.show_image(x2, "x2")
+
+
+
+
+ +

Add batch dimension

+
+
+
174        x1 = x1[None, :, :, :]
+175        x2 = x2[None, :, :, :]
+
+
+
+
+ +

$t$ tensor

+
+
+
177        t = torch.full((1,), t_, device=self.device)
+
+
+
+
+ +

$x_t \sim q(x_t|x_0)$

+
+
+
179        x1t = self.diffusion.q_sample(x1, t)
+
+
+
+
+ +

$x’_t \sim q(x’_t|x_0)$

+
+
+
181        x2t = self.diffusion.q_sample(x2, t)
+182
+183        frames = []
+
+
+
+
+ +

Get frames with different $\lambda$

+
+
+
185        for i in monit.iterate('Interpolate', n_frames + 1, is_children_silent=True):
+
+
+
+
+ +

$\lambda$

+
+
+
187            lambda_ = i / n_frames
+
+
+
+
+ +

+ +

+
+
+
189            xt = (1 - lambda_) * x1t + lambda_ * x2t
+
+
+
+
+ +

+ +

+
+
+
191            x0 = self._sample_x0(xt, t_)
+
+
+
+
+ +

Add to frames

+
+
+
193            frames.append(x0[0])
+
+
+
+
+ +

Show frame

+
+
+
195            if not create_video:
+196                self.show_image(x0[0], f"{lambda_ :.2f}")
+
+
+
+
+ +

Make video

+
+
+
199        if create_video:
+200            self.make_video(frames)
+
+
+
+
+ +

Sample an image using $\color{cyan}{p_\theta}(x_{t-1}|x_t)$

+
    +
  • xt is $x_t$
  • +
  • n_steps is $t$
  • +
+
+
+
202    def _sample_x0(self, xt: torch.Tensor, n_steps: int):
+
+
+
+
+ +

Number of sampels

+
+
+
211        n_samples = xt.shape[0]
+
+
+
+
+ +

Iterate until $t$ steps

+
+
+
213        for t_ in monit.iterate('Denoise', n_steps):
+214            t = n_steps - t_ - 1
+
+
+
+
+ +

Sample from $\color{cyan}{p_\theta}(x_{t-1}|x_t)$

+
+
+
216            xt = self.diffusion.p_sample(xt, xt.new_full((n_samples,), t, dtype=torch.long))
+
+
+
+
+ +

Return $x_0$

+
+
+
219        return xt
+
+
+
+
+ +

Generate images

+
+
+
221    def sample(self, n_samples: int = 16):
+
+
+
+
+ +

$x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$

+
+
+
226        xt = torch.randn([n_samples, self.image_channels, self.image_size, self.image_size], device=self.device)
+
+
+
+
+ +

+ +

+
+
+
229        x0 = self._sample_x0(xt, self.n_steps)
+
+
+
+
+ +

Show images

+
+
+
232        for i in range(n_samples):
+233            self.show_image(x0[i])
+
+
+
+
+ +

Sample from $\color{cyan}{p_\theta}(x_{t-1}|x_t)$

+

+ +

+
+
+
235    def p_sample(self, xt: torch.Tensor, t: torch.Tensor, eps_theta: torch.Tensor):
+
+
+
+
+ +

gather $\bar\alpha_t$

+
+
+
248        alpha_bar = gather(self.alpha_bar, t)
+
+
+
+
+ +

$\alpha_t$

+
+
+
250        alpha = gather(self.alpha, t)
+
+
+
+
+ +

$\frac{\beta}{\sqrt{1-\bar\alpha_t}}$

+
+
+
252        eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
+
+
+
+
+ +

+ +

+
+
+
255        mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
+
+
+
+
+ +

$\sigma^2$

+
+
+
257        var = gather(self.sigma2, t)
+
+
+
+
+ +

$\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$

+
+
+
260        eps = torch.randn(xt.shape, device=xt.device)
+
+
+
+
+ +

Sample

+
+
+
262        return mean + (var ** .5) * eps
+
+
+
+
+ +

Estimate $x_0$

+

+ +

+
+
+
264    def p_x0(self, xt: torch.Tensor, t: torch.Tensor, eps: torch.Tensor):
+
+
+
+
+ +

gather $\bar\alpha_t$

+
+
+
272        alpha_bar = gather(self.alpha_bar, t)
+
+
+
+
+ +

+ +

+
+
+
276        return (xt - (1 - alpha_bar) ** 0.5 * eps) / (alpha_bar ** 0.5)
+
+
+
+
+ +

Generate samples

+
+
+
279def main():
+
+
+
+
+ +

Training experiment run UUID

+
+
+
283    run_uuid = "a44333ea251411ec8007d1a1762ed686"
+
+
+
+
+ +

Start an evaluation

+
+
+
286    experiment.evaluate()
+
+
+
+
+ +

Create configs

+
+
+
289    configs = Configs()
+
+
+
+
+ +

Load custom configuration of the training run

+
+
+
291    configs_dict = experiment.load_configs(run_uuid)
+
+
+
+
+ +

Set configurations

+
+
+
293    experiment.configs(configs, configs_dict)
+
+
+
+
+ +

Initialize

+
+
+
296    configs.init()
+
+
+
+
+ +

Set PyTorch modules for saving and loading

+
+
+
299    experiment.add_pytorch_models({'eps_model': configs.eps_model})
+
+
+
+
+ +

Load training experiment

+
+
+
302    experiment.load(run_uuid)
+
+
+
+
+ +

Create sampler

+
+
+
305    sampler = Sampler(diffusion=configs.diffusion,
+306                      image_channels=configs.image_channels,
+307                      image_size=configs.image_size,
+308                      device=configs.device)
+
+
+
+
+ +

Start evaluation

+
+
+
311    with experiment.start():
+
+
+
+
+ +

No gradients

+
+
+
313        with torch.no_grad():
+
+
+
+
+ +

Sample an image with an denoising animation

+
+
+
315            sampler.sample_animation()
+316
+317            if False:
+
+
+
+
+ +

Get some images fro data

+
+
+
319                data = next(iter(configs.data_loader)).to(configs.device)
+
+
+
+
+ +

Create an interpolation animation

+
+
+
322                sampler.interpolate_animate(data[0], data[1])
+
+
+
+
+ + +
+
+
326if __name__ == '__main__':
+327    main()
+
+
+ +
+ + + + + + \ No newline at end of file diff --git a/docs/diffusion/ddpm/experiment.html b/docs/diffusion/ddpm/experiment.html new file mode 100644 index 00000000..e1aeb082 --- /dev/null +++ b/docs/diffusion/ddpm/experiment.html @@ -0,0 +1,945 @@ + + + + + + + + + + + + + + + + + + + + + + + Denoising Diffusion Probabilistic Models (DDPM) training + + + + + + + + +
+
+
+
+

+ home + diffusion + ddpm +

+

+ + + Github + + Twitter +

+
+
+
+
+ +

Denoising Diffusion Probabilistic Models (DDPM) training

+

This trains a DDPM based model on CelebA HQ dataset. You can find the download instruction in this +discussion on fast.ai. +Save the images inside data/celebA folder.

+

The paper had used a exponential moving average of the model with a decay of $0.9999$. We have skipped this for +simplicity.

+
+
+
18from typing import List
+19
+20import torch
+21import torch.utils.data
+22import torchvision
+23from PIL import Image
+24
+25from labml import lab, tracker, experiment, monit
+26from labml.configs import BaseConfigs, option
+27from labml_helpers.device import DeviceConfigs
+28from labml_nn.diffusion.ddpm import DenoiseDiffusion
+29from labml_nn.diffusion.ddpm.unet import UNet
+
+
+
+
+ +

Configurations

+
+
+
32class Configs(BaseConfigs):
+
+
+
+
+ +

Device to train the model on. +DeviceConfigs + picks up an available CUDA device or defaults to CPU.

+
+
+
39    device: torch.device = DeviceConfigs()
+
+
+
+
+ +

U-Net model for $\color{cyan}{\epsilon_\theta}(x_t, t)$

+
+
+
42    eps_model: UNet
+
+
+
+
+ +

DDPM algorithm

+
+
+
44    diffusion: DenoiseDiffusion
+
+
+
+
+ +

Number of channels in the image. $3$ for RGB.

+
+
+
47    image_channels: int = 3
+
+
+
+
+ +

Image size

+
+
+
49    image_size: int = 32
+
+
+
+
+ +

Number of channels in the initial feature map

+
+
+
51    n_channels: int = 64
+
+
+
+
+ +

The list of channel numbers at each resolution. +The number of channels is channel_multipliers[i] * n_channels

+
+
+
54    channel_multipliers: List[int] = [1, 2, 2, 4]
+
+
+
+
+ +

The list of booleans that indicate whether to use attention at each resolution

+
+
+
56    is_attention: List[int] = [False, False, False, True]
+
+
+
+
+ +

Number of time steps $T$

+
+
+
59    n_steps: int = 1_000
+
+
+
+
+ +

Batch size

+
+
+
61    batch_size: int = 64
+
+
+
+
+ +

Number of samples to generate

+
+
+
63    n_samples: int = 16
+
+
+
+
+ +

Learning rate

+
+
+
65    learning_rate: float = 2e-5
+
+
+
+
+ +

Number of training epochs

+
+
+
68    epochs: int = 1_000
+
+
+
+
+ +

Dataset

+
+
+
71    dataset: torch.utils.data.Dataset
+
+
+
+
+ +

Dataloader

+
+
+
73    data_loader: torch.utils.data.DataLoader
+
+
+
+
+ +

Adam optimizer

+
+
+
76    optimizer: torch.optim.Adam
+
+
+
+
+ + +
+
+
78    def init(self):
+
+
+
+
+ +

Create $\color{cyan}{\epsilon_\theta}(x_t, t)$ model

+
+
+
80        self.eps_model = UNet(
+81            image_channels=self.image_channels,
+82            n_channels=self.n_channels,
+83            ch_mults=self.channel_multipliers,
+84            is_attn=self.is_attention,
+85        ).to(self.device)
+
+
+
+
+ +

Create DDPM class

+
+
+
88        self.diffusion = DenoiseDiffusion(
+89            eps_model=self.eps_model,
+90            n_steps=self.n_steps,
+91            device=self.device,
+92        )
+
+
+
+
+ +

Create dataloader

+
+
+
95        self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)
+
+
+
+
+ +

Create optimizer

+
+
+
97        self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)
+
+
+
+
+ +

Image logging

+
+
+
100        tracker.set_image("sample", True)
+
+
+
+
+ +

Sample images

+
+
+
102    def sample(self):
+
+
+
+
+ + +
+
+
106        with torch.no_grad():
+
+
+
+
+ +

$x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$

+
+
+
108            x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
+109                            device=self.device)
+
+
+
+
+ +

Remove noise for $T$ steps

+
+
+
112            for t_ in monit.iterate('Sample', self.n_steps):
+
+
+
+
+ +

$t$

+
+
+
114                t = self.n_steps - t_ - 1
+
+
+
+
+ +

Sample from $\color{cyan}{p_\theta}(x_{t-1}|x_t)$

+
+
+
116                x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))
+
+
+
+
+ +

Log samples

+
+
+
119            tracker.save('sample', x)
+
+
+
+
+ +

Train

+
+
+
121    def train(self):
+
+
+
+
+ +

Iterate through the dataset

+
+
+
127        for data in monit.iterate('Train', self.data_loader):
+
+
+
+
+ +

Increment global step

+
+
+
129            tracker.add_global_step()
+
+
+
+
+ +

Move data to device

+
+
+
131            data = data.to(self.device)
+
+
+
+
+ +

Make the gradients zero

+
+
+
134            self.optimizer.zero_grad()
+
+
+
+
+ +

Calculate loss

+
+
+
136            loss = self.diffusion.loss(data)
+
+
+
+
+ +

Compute gradients

+
+
+
138            loss.backward()
+
+
+
+
+ +

Take an optimization step

+
+
+
140            self.optimizer.step()
+
+
+
+
+ +

Track the loss

+
+
+
142            tracker.save('loss', loss)
+
+
+
+
+ +

Training loop

+
+
+
144    def run(self):
+
+
+
+
+ + +
+
+
148        for _ in monit.loop(self.epochs):
+
+
+
+
+ +

Train the model

+
+
+
150            self.train()
+
+
+
+
+ +

Sample some images

+
+
+
152            self.sample()
+
+
+
+
+ +

New line in the console

+
+
+
154            tracker.new_line()
+
+
+
+
+ +

Save the model

+
+
+
156            experiment.save_checkpoint()
+
+
+
+
+ +

CelebA HQ dataset

+
+
+
159class CelebADataset(torch.utils.data.Dataset):
+
+
+
+
+ + +
+
+
164    def __init__(self, image_size: int):
+165        super().__init__()
+
+
+
+
+ +

CelebA images folder

+
+
+
168        folder = lab.get_data_path() / 'celebA'
+
+
+
+
+ +

List of files

+
+
+
170        self._files = [p for p in folder.glob(f'**/*.jpg')]
+
+
+
+
+ +

Transformations to resize the image and convert to tensor

+
+
+
173        self._transform = torchvision.transforms.Compose([
+174            torchvision.transforms.Resize(image_size),
+175            torchvision.transforms.ToTensor(),
+176        ])
+
+
+
+
+ +

Size of the dataset

+
+
+
178    def __len__(self):
+
+
+
+
+ + +
+
+
182        return len(self._files)
+
+
+
+
+ +

Get an image

+
+
+
184    def __getitem__(self, index: int):
+
+
+
+
+ + +
+
+
188        img = Image.open(self._files[index])
+189        return self._transform(img)
+
+
+
+
+ +

Create CelebA dataset

+
+
+
192@option(Configs.dataset, 'CelebA')
+193def celeb_dataset(c: Configs):
+
+
+
+
+ + +
+
+
197    return CelebADataset(c.image_size)
+
+
+
+
+ +

MNIST dataset

+
+
+
200class MNISTDataset(torchvision.datasets.MNIST):
+
+
+
+
+ + +
+
+
205    def __init__(self, image_size):
+206        transform = torchvision.transforms.Compose([
+207            torchvision.transforms.Resize(image_size),
+208            torchvision.transforms.ToTensor(),
+209        ])
+210
+211        super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)
+
+
+
+
+ + +
+
+
213    def __getitem__(self, item):
+214        return super().__getitem__(item)[0]
+
+
+
+
+ +

Create MNIST dataset

+
+
+
217@option(Configs.dataset, 'MNIST')
+218def mnist_dataset(c: Configs):
+
+
+
+
+ + +
+
+
222    return MNISTDataset(c.image_size)
+
+
+
+
+ + +
+
+
225def main():
+
+
+
+
+ +

Create experiment

+
+
+
227    experiment.create(name='diffuse')
+
+
+
+
+ +

Create configurations

+
+
+
230    configs = Configs()
+
+
+
+
+ +

Set configurations. You can override the defaults by passing the values in the dictionary.

+
+
+
233    experiment.configs(configs, {
+234    })
+
+
+
+
+ +

Initialize

+
+
+
237    configs.init()
+
+
+
+
+ +

Set models for saving and loading

+
+
+
240    experiment.add_pytorch_models({'eps_model': configs.eps_model})
+
+
+
+
+ +

Start and run the training loop

+
+
+
243    with experiment.start():
+244        configs.run()
+
+
+
+
+ + +
+
+
248if __name__ == '__main__':
+249    main()
+
+
+ +
+ + + + + + \ No newline at end of file diff --git a/docs/diffusion/ddpm/index.html b/docs/diffusion/ddpm/index.html new file mode 100644 index 00000000..579e0077 --- /dev/null +++ b/docs/diffusion/ddpm/index.html @@ -0,0 +1,666 @@ + + + + + + + + + + + + + + + + + + + + + + + Denoising Diffusion Probabilistic Models (DDPM) + + + + + + + + +
+
+
+
+

+ home + diffusion + ddpm +

+

+ + + Github + + Twitter +

+
+
+
+
+ +

Denoising Diffusion Probabilistic Models (DDPM)

+

This is a PyTorch implementation/tutorial of the paper +Denoising Diffusion Probabilistic Models.

+

In simple terms, we get an image from data and add noise step by step. +Then We train a model to predict that noise at each step and use the model to +generate images.

+

The following definitions and derivations show how this works. +For details please refer to the paper.

+

Forward Process

+

The forward process adds noise to the data $x_0 \sim q(x_0)$, for $T$ timesteps.

+

+ +

+

where $\beta_1, \dots, \beta_T$ is the variance schedule.

+

We can sample $x_t$ at any timestep $t$ with,

+

+ +

+

where $\alpha_t = 1 - \beta_t$ and $\bar\alpha_t = \prod_{s=1}^t \alpha_s$

+

Reverse Process

+

The reverse process removes noise starting at $p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$ +for $T$ time steps.

+

+ +

+

$\color{cyan}\theta$ are the parameters we train.

+

Loss

+

We optimize the ELBO (from Jenson’s inequality) on the negative log likelihood.

+

+ +

+

The loss can be rewritten as follows.

+

+ +

+

$D_{KL}(q(x_T|x_0) \Vert p(x_T))$ is constant since we keep $\beta_1, \dots, \beta_T$ constant.

+

Computing $L_{t-1} = D_{KL}(q(x_{t-1}|x_t,x_0) \Vert \color{cyan}{p_\theta}(x_{t-1}|x_t))$

+

The forward process posterior conditioned by $x_0$ is,

+

+ +

+

The paper sets $\color{cyan}{\Sigma_\theta}(x_t, t) = \sigma_t^2 \mathbf{I}$ where $\sigma_t^2$ is set to constants +$\beta_t$ or $\tilde\beta_t$.

+

Then, + +

+

For given noise $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$ using $q(x_t|x_0)$

+

+ +

+

This gives,

+

+ +

+

Re-parameterizing with a model to predict noise

+

+ +

+

where $\epsilon_theta$ is a learned function that predicts $\epsilon$ given $(x_t, t)$.

+

This gives,

+

+ +

+

That is, we are training to predict the noise.

+

Simplified loss

+

+ +

+

This minimizes $-\log \color{cyan}{p_\theta}(x_0|x_1)$ when $t=1$ and $L_{t-1}$ for $t\gt1$ discarding the +weighting in $L_{t-1}$. Discarding the weights $\frac{\beta_t^2}{2\sigma_t^2 \alpha_t (1 - \bar\alpha_t)}$ +increase the weight given to higher $t$ (which have higher noise levels), therefore increasing the sample quality.

+

This file implements the loss calculation and a basic sampling method that we use to generate images during +training.

+

Here is the UNet model that gives $\color{cyan}{\epsilon_\theta}(x_t, t)$ and +training code. +This file can generate samples and interpolations from a trained model.

+

View Run

+
+
+
162from typing import Tuple, Optional
+163
+164import torch
+165import torch.nn.functional as F
+166import torch.utils.data
+167from torch import nn
+168
+169from labml_nn.diffusion.ddpm.utils import gather
+
+
+
+
+ +

Denoise Diffusion

+
+
+
172class DenoiseDiffusion:
+
+
+
+
+ +
    +
  • eps_model is $\color{cyan}{\epsilon_\theta}(x_t, t)$ model
  • +
  • n_steps is $t$
  • +
  • device is the device to place constants on
  • +
+
+
+
177    def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
+
+
+
+
+ + +
+
+
183        super().__init__()
+184        self.eps_model = eps_model
+
+
+
+
+ +

Create $\beta_1, \dots, \beta_T$ linearly increasing variance schedule

+
+
+
187        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
+
+
+
+
+ +

$\alpha_t = 1 - \beta_t$

+
+
+
190        self.alpha = 1. - self.beta
+
+
+
+
+ +

$\bar\alpha_t = \prod_{s=1}^t \alpha_s$

+
+
+
192        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
+
+
+
+
+ +

$T$

+
+
+
194        self.n_steps = n_steps
+
+
+
+
+ +

$\sigma^2 = \beta$

+
+
+
196        self.sigma2 = self.beta
+
+
+
+
+ +

Get $q(x_t|x_0)$ distribution

+

+ +

+
+
+
198    def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+
+
+
+
+ +

gather $\alpha_t$ and compute $\sqrt{\bar\alpha_t} x_0$

+
+
+
208        mean = gather(self.alpha_bar, t) ** 0.5 * x0
+
+
+
+
+ +

$(1-\bar\alpha_t) \mathbf{I}$

+
+
+
210        var = 1 - gather(self.alpha_bar, t)
+
+
+
+
+ + +
+
+
212        return mean, var
+
+
+
+
+ +

Sample from $q(x_t|x_0)$

+

+ +

+
+
+
214    def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):
+
+
+
+
+ +

$\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$

+
+
+
224        if eps is None:
+225            eps = torch.randn_like(x0)
+
+
+
+
+ +

get $q(x_t|x_0)$

+
+
+
228        mean, var = self.q_xt_x0(x0, t)
+
+
+
+
+ +

Sample from $q(x_t|x_0)$

+
+
+
230        return mean + (var ** 0.5) * eps
+
+
+
+
+ +

Sample from $\color{cyan}{p_\theta}(x_{t-1}|x_t)$

+

+ +

+
+
+
232    def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
+
+
+
+
+ +

$\color{cyan}{\epsilon_\theta}(x_t, t)$

+
+
+
246        eps_theta = self.eps_model(xt, t)
+
+
+
+
+ +

gather $\bar\alpha_t$

+
+
+
248        alpha_bar = gather(self.alpha_bar, t)
+
+
+
+
+ +

$\alpha_t$

+
+
+
250        alpha = gather(self.alpha, t)
+
+
+
+
+ +

$\frac{\beta}{\sqrt{1-\bar\alpha_t}}$

+
+
+
252        eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
+
+
+
+
+ +

+ +

+
+
+
255        mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
+
+
+
+
+ +

$\sigma^2$

+
+
+
257        var = gather(self.sigma2, t)
+
+
+
+
+ +

$\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$

+
+
+
260        eps = torch.randn(xt.shape, device=xt.device)
+
+
+
+
+ +

Sample

+
+
+
262        return mean + (var ** .5) * eps
+
+
+
+
+ +

Simplified Loss

+

+ +

+
+
+
264    def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):
+
+
+
+
+ +

Get batch size

+
+
+
273        batch_size = x0.shape[0]
+
+
+
+
+ +

Get random $t$ for each sample in the batch

+
+
+
275        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
+
+
+
+
+ +

$\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$

+
+
+
278        if noise is None:
+279            noise = torch.randn_like(x0)
+
+
+
+
+ +

Sample $x_t$ for $q(x_t|x_0)$

+
+
+
282        xt = self.q_sample(x0, t, eps=noise)
+
+
+
+
+ +

Get $\color{cyan}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)$

+
+
+
284        eps_theta = self.eps_model(xt, t)
+
+
+
+
+ +

MSE loss

+
+
+
287        return F.mse_loss(noise, eps_theta)
+
+
+ +
+ + + + + + \ No newline at end of file diff --git a/docs/diffusion/ddpm/readme.html b/docs/diffusion/ddpm/readme.html new file mode 100644 index 00000000..d5138696 --- /dev/null +++ b/docs/diffusion/ddpm/readme.html @@ -0,0 +1,150 @@ + + + + + + + + + + + + + + + + + + + + + + + Denoising Diffusion Probabilistic Models (DDPM) + + + + + + + + +
+
+
+
+

+ home + diffusion + ddpm +

+

+ + + Github + + Twitter +

+
+
+
+
+ +

Denoising Diffusion Probabilistic Models (DDPM)

+

This is a PyTorch implementation/tutorial of the paper +Denoising Diffusion Probabilistic Models.

+

In simple terms, we get an image from data and add noise step by step. +Then We train a model to predict that noise at each step and use the model to +generate images.

+

Here is the UNet model that predicts the noise and +training code. +This file can generate samples and interpolations +from a trained model.

+

View Run

+
+
+ +
+
+ +
+ + + + + + \ No newline at end of file diff --git a/docs/diffusion/ddpm/unet.html b/docs/diffusion/ddpm/unet.html new file mode 100644 index 00000000..039e0982 --- /dev/null +++ b/docs/diffusion/ddpm/unet.html @@ -0,0 +1,1336 @@ + + + + + + + + + + + + + + + + + + + + + + + U-Net model for Denoising Diffusion Probabilistic Models (DDPM) + + + + + + + + +
+
+
+
+

+ home + diffusion + ddpm +

+

+ + + Github + + Twitter +

+
+
+
+
+ +

U-Net model for Denoising Diffusion Probabilistic Models (DDPM)

+

This is a U-Net based model to predict noise +$\color{cyan}{\epsilon_\theta}(x_t, t)$.

+

U-Net is a gets it’s name from the U shape in the model diagram. +It processes a given image by progressively lowering (halving) the feature map resolution and then +increasing the resolution. +There are pass-through connection at each resolution.

+

U-Net diagram from paper

+

This implementation contains a bunch of modifications to original U-Net (residual blocks, multi-head attention) + and also adds time-step embeddings $t$.

+
+
+
24import math
+25from typing import Optional, Tuple, Union, List
+26
+27import torch
+28from torch import nn
+29
+30from labml_helpers.module import Module
+
+
+
+
+ +

Swish actiavation function

+

+ +

+
+
+
33class Swish(Module):
+
+
+
+
+ + +
+
+
40    def forward(self, x):
+41        return x * torch.sigmoid(x)
+
+
+
+
+ +

Embeddings for $t$

+
+
+
44class TimeEmbedding(nn.Module):
+
+
+
+
+ +
    +
  • n_channels is the number of dimensions in the embedding
  • +
+
+
+
49    def __init__(self, n_channels: int):
+
+
+
+
+ + +
+
+
53        super().__init__()
+54        self.n_channels = n_channels
+
+
+
+
+ +

First linear layer

+
+
+
56        self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)
+
+
+
+
+ +

Activation

+
+
+
58        self.act = Swish()
+
+
+
+
+ +

Second linear layer

+
+
+
60        self.lin2 = nn.Linear(self.n_channels, self.n_channels)
+
+
+
+
+ + +
+
+
62    def forward(self, t: torch.Tensor):
+
+
+
+
+ +

Create sinusoidal position embeddings +same as those from the transformer + +where $d$ is half_dim

+
+
+
70        half_dim = self.n_channels // 8
+71        emb = math.log(10_000) / (half_dim - 1)
+72        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
+73        emb = t[:, None] * emb[None, :]
+74        emb = torch.cat((emb.sin(), emb.cos()), dim=1)
+
+
+
+
+ +

Transform with the MLP

+
+
+
77        emb = self.act(self.lin1(emb))
+78        emb = self.lin2(emb)
+
+
+
+
+ + +
+
+
81        return emb
+
+
+
+
+ +

Residual block

+

A residual block has two convolution layers with group normalization. +Each resolution is processed with two residual blocks.

+
+
+
84class ResidualBlock(Module):
+
+
+
+
+ +
    +
  • in_channels is the number of input channels
  • +
  • out_channels is the number of input channels
  • +
  • time_channels is the number channels in the time step ($t$) embeddings
  • +
  • n_groups is the number of groups for group normalization
  • +
+
+
+
92    def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_groups: int = 32):
+
+
+
+
+ + +
+
+
99        super().__init__()
+
+
+
+
+ +

Group normalization and the first convolution layer

+
+
+
101        self.norm1 = nn.GroupNorm(n_groups, in_channels)
+102        self.act1 = Swish()
+103        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
+
+
+
+
+ +

Group normalization and the second convolution layer

+
+
+
106        self.norm2 = nn.GroupNorm(n_groups, out_channels)
+107        self.act2 = Swish()
+108        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
+
+
+
+
+ +

If the number of input channels is not equal to the number of output channels we have to +project the shortcut connection

+
+
+
112        if in_channels != out_channels:
+113            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
+114        else:
+115            self.shortcut = nn.Identity()
+
+
+
+
+ +

Linear layer for time embeddings

+
+
+
118        self.time_emb = nn.Linear(time_channels, out_channels)
+
+
+
+
+ +
    +
  • x has shape [batch_size, in_channels, height, width]
  • +
  • t has shape [batch_size, time_channels]
  • +
+
+
+
120    def forward(self, x: torch.Tensor, t: torch.Tensor):
+
+
+
+
+ +

First convolution layer

+
+
+
126        h = self.conv1(self.act1(self.norm1(x)))
+
+
+
+
+ +

Add time embeddings

+
+
+
128        h += self.time_emb(t)[:, :, None, None]
+
+
+
+
+ +

Second convolution layer

+
+
+
130        h = self.conv2(self.act2(self.norm2(h)))
+
+
+
+
+ +

Add the shortcut connection and return

+
+
+
133        return h + self.shortcut(x)
+
+
+
+
+ +

Attention block

+

This is similar to transformer multi-head attention.

+
+
+
136class AttentionBlock(Module):
+
+
+
+
+ +
    +
  • n_channels is the number of channels in the input
  • +
  • n_heads is the number of heads in multi-head attention
  • +
  • d_k is the number of dimensions in each head
  • +
  • n_groups is the number of groups for group normalization
  • +
+
+
+
143    def __init__(self, n_channels: int, n_heads: int = 1, d_k: int = None, n_groups: int = 32):
+
+
+
+
+ + +
+
+
150        super().__init__()
+
+
+
+
+ +

Default d_k

+
+
+
153        if d_k is None:
+154            d_k = n_channels
+
+
+
+
+ +

Normalization layer

+
+
+
156        self.norm = nn.GroupNorm(n_groups, n_channels)
+
+
+
+
+ +

Projections for query, key and values

+
+
+
158        self.projection = nn.Linear(n_channels, n_heads * d_k * 3)
+
+
+
+
+ +

Linear layer for final transformation

+
+
+
160        self.output = nn.Linear(n_heads * d_k, n_channels)
+
+
+
+
+ +

Scale for dot-product attention

+
+
+
162        self.scale = d_k ** -0.5
+
+
+
+
+ + +
+
+
164        self.n_heads = n_heads
+165        self.d_k = d_k
+
+
+
+
+ +
    +
  • x has shape [batch_size, in_channels, height, width]
  • +
  • t has shape [batch_size, time_channels]
  • +
+
+
+
167    def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None):
+
+
+
+
+ +

t is not used, but it’s kept in the arguments because for the attention layer function signature +to match with ResidualBlock.

+
+
+
174        _ = t
+
+
+
+
+ +

Get shape

+
+
+
176        batch_size, n_channels, height, width = x.shape
+
+
+
+
+ +

Change x to shape [batch_size, seq, n_channels]

+
+
+
178        x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)
+
+
+
+
+ +

Get query, key, and values (concatenated) and shape it to [batch_size, seq, n_heads, 3 * d_k]

+
+
+
180        qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)
+
+
+
+
+ +

Split query, key, and values. Each of them will have shape [batch_size, seq, n_heads, d_k]

+
+
+
182        q, k, v = torch.chunk(qkv, 3, dim=-1)
+
+
+
+
+ +

Calculate scaled dot-product $\frac{Q K^\top}{\sqrt{d_k}}$

+
+
+
184        attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scale
+
+
+
+
+ +

Softmax along the sequence dimension $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$

+
+
+
186        attn = attn.softmax(dim=1)
+
+
+
+
+ +

Multiply by values

+
+
+
188        res = torch.einsum('bijh,bjhd->bihd', attn, v)
+
+
+
+
+ +

Reshape to [batch_size, seq, n_heads * d_k]

+
+
+
190        res = res.view(batch_size, -1, self.n_heads * self.d_k)
+
+
+
+
+ +

Transform to [batch_size, seq, n_channels]

+
+
+
192        res = self.output(res)
+
+
+
+
+ +

Add skip connection

+
+
+
195        res += x
+
+
+
+
+ +

Change to shape [batch_size, in_channels, height, width]

+
+
+
198        res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)
+
+
+
+
+ + +
+
+
201        return res
+
+
+
+
+ +

Down block

+

This combines ResidualBlock and AttentionBlock. These are used in the first half of U-Net at each resolution.

+
+
+
204class DownBlock(Module):
+
+
+
+
+ + +
+
+
211    def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
+212        super().__init__()
+213        self.res = ResidualBlock(in_channels, out_channels, time_channels)
+214        if has_attn:
+215            self.attn = AttentionBlock(out_channels)
+216        else:
+217            self.attn = nn.Identity()
+
+
+
+
+ + +
+
+
219    def forward(self, x: torch.Tensor, t: torch.Tensor):
+220        x = self.res(x, t)
+221        x = self.attn(x)
+222        return x
+
+
+
+
+ +

Up block

+

This combines ResidualBlock and AttentionBlock. These are used in the second half of U-Net at each resolution.

+
+
+
225class UpBlock(Module):
+
+
+
+
+ + +
+
+
232    def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
+233        super().__init__()
+
+
+
+
+ +

The input has in_channels + out_channels because we concatenate the output of the same resolution +from the first half of the U-Net

+
+
+
236        self.res = ResidualBlock(in_channels + out_channels, out_channels, time_channels)
+237        if has_attn:
+238            self.attn = AttentionBlock(out_channels)
+239        else:
+240            self.attn = nn.Identity()
+
+
+
+
+ + +
+
+
242    def forward(self, x: torch.Tensor, t: torch.Tensor):
+243        x = self.res(x, t)
+244        x = self.attn(x)
+245        return x
+
+
+
+
+ +

Middle block

+

It combines a ResidualBlock, AttentionBlock, followed by another ResidualBlock. +This block is applied at the lowest resolution of the U-Net.

+
+
+
248class MiddleBlock(Module):
+
+
+
+
+ + +
+
+
256    def __init__(self, n_channels: int, time_channels: int):
+257        super().__init__()
+258        self.res1 = ResidualBlock(n_channels, n_channels, time_channels)
+259        self.attn = AttentionBlock(n_channels)
+260        self.res2 = ResidualBlock(n_channels, n_channels, time_channels)
+
+
+
+
+ + +
+
+
262    def forward(self, x: torch.Tensor, t: torch.Tensor):
+263        x = self.res1(x, t)
+264        x = self.attn(x)
+265        x = self.res2(x, t)
+266        return x
+
+
+
+
+ +

Scale up the feature map by $2 \times$

+
+
+
269class Upsample(nn.Module):
+
+
+
+
+ + +
+
+
274    def __init__(self, n_channels):
+275        super().__init__()
+276        self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))
+
+
+
+
+ + +
+
+
278    def forward(self, x: torch.Tensor, t: torch.Tensor):
+
+
+
+
+ +

t is not used, but it’s kept in the arguments because for the attention layer function signature +to match with ResidualBlock.

+
+
+
281        _ = t
+282        return self.conv(x)
+
+
+
+
+ +

Scale down the feature map by $\frac{1}{2} \times$

+
+
+
285class Downsample(nn.Module):
+
+
+
+
+ + +
+
+
290    def __init__(self, n_channels):
+291        super().__init__()
+292        self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))
+
+
+
+
+ + +
+
+
294    def forward(self, x: torch.Tensor, t: torch.Tensor):
+
+
+
+
+ +

t is not used, but it’s kept in the arguments because for the attention layer function signature +to match with ResidualBlock.

+
+
+
297        _ = t
+298        return self.conv(x)
+
+
+
+
+ +

U-Net

+
+
+
301class UNet(Module):
+
+
+
+
+ +
    +
  • image_channels is the number of channels in the image. $3$ for RGB.
  • +
  • n_channels is number of channels in the initial feature map that we transform the image into
  • +
  • ch_mults is the list of channel numbers at each resolution. The number of channels is ch_mults[i] * n_channels
  • +
  • is_attn is a list of booleans that indicate whether to use attention at each resolution
  • +
  • n_blocks is the number of UpDownBlocks at each resolution
  • +
+
+
+
306    def __init__(self, image_channels: int = 3, n_channels: int = 64,
+307                 ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
+308                 is_attn: Union[Tuple[bool, ...], List[int]] = (False, False, True, True),
+309                 n_blocks: int = 2):
+
+
+
+
+ + +
+
+
317        super().__init__()
+
+
+
+
+ +

Number of resolutions

+
+
+
320        n_resolutions = len(ch_mults)
+
+
+
+
+ +

Project image into feature map

+
+
+
323        self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))
+
+
+
+
+ +

Time embedding layer. Time embedding has n_channels * 4 channels

+
+
+
326        self.time_emb = TimeEmbedding(n_channels * 4)
+
+
+
+
+ +

First half of U-Net - decreasing resolution

+
+
+
329        down = []
+
+
+
+
+ +

Number of channels

+
+
+
331        out_channels = in_channels = n_channels
+
+
+
+
+ +

For each resolution

+
+
+
333        for i in range(n_resolutions):
+
+
+
+
+ +

Number of output channels at this resolution

+
+
+
335            out_channels = in_channels * ch_mults[i]
+
+
+
+
+ +

Add n_blocks

+
+
+
337            for _ in range(n_blocks):
+338                down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
+339                in_channels = out_channels
+
+
+
+
+ +

Down sample at all resolutions except the last

+
+
+
341            if i < n_resolutions - 1:
+342                down.append(Downsample(in_channels))
+
+
+
+
+ +

Combine the set of modules

+
+
+
345        self.down = nn.ModuleList(down)
+
+
+
+
+ +

Middle block

+
+
+
348        self.middle = MiddleBlock(out_channels, n_channels * 4, )
+
+
+
+
+ +

Second half of U-Net - increasing resolution

+
+
+
351        up = []
+
+
+
+
+ +

Number of channels

+
+
+
353        in_channels = out_channels
+
+
+
+
+ +

For each resolution

+
+
+
355        for i in reversed(range(n_resolutions)):
+
+
+
+
+ +

n_blocks at the same resolution

+
+
+
357            out_channels = in_channels
+358            for _ in range(n_blocks):
+359                up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
+
+
+
+
+ +

Final block to reduce the number of channels

+
+
+
361            out_channels = in_channels // ch_mults[i]
+362            up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
+363            in_channels = out_channels
+
+
+
+
+ +

Up sample at all resolutions except last

+
+
+
365            if i > 0:
+366                up.append(Upsample(in_channels))
+
+
+
+
+ +

Combine the set of modules

+
+
+
369        self.up = nn.ModuleList(up)
+
+
+
+
+ +

Final normalization and convolution layer

+
+
+
372        self.norm = nn.GroupNorm(8, n_channels)
+373        self.act = Swish()
+374        self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1))
+
+
+
+
+ +
    +
  • x has shape [batch_size, in_channels, height, width]
  • +
  • t has shape [batch_size]
  • +
+
+
+
376    def forward(self, x: torch.Tensor, t: torch.Tensor):
+
+
+
+
+ +

Get time-step embeddings

+
+
+
383        t = self.time_emb(t)
+
+
+
+
+ +

Get image projection

+
+
+
386        x = self.image_proj(x)
+
+
+
+
+ +

h will store outputs at each resolution for skip connection

+
+
+
389        h = [x]
+
+
+
+
+ +

First half of U-Net

+
+
+
391        for m in self.down:
+392            x = m(x, t)
+393            h.append(x)
+
+
+
+
+ +

Middle (bottom)

+
+
+
396        x = self.middle(x, t)
+
+
+
+
+ +

Second half of U-Net

+
+
+
399        for m in self.up:
+400            if isinstance(m, Upsample):
+401                x = m(x, t)
+402            else:
+
+
+
+
+ +

Get the skip connection from first half of U-Net and concatenate

+
+
+
404                s = h.pop()
+405                x = torch.cat((x, s), dim=1)
+
+
+
+
+ + +
+
+
407                x = m(x, t)
+
+
+
+
+ +

Final normalization and convolution

+
+
+
410        return self.final(self.act(self.norm(x)))
+
+
+ +
+ + + + + + \ No newline at end of file diff --git a/docs/diffusion/ddpm/unet.png b/docs/diffusion/ddpm/unet.png new file mode 100644 index 00000000..303c7ffe Binary files /dev/null and b/docs/diffusion/ddpm/unet.png differ diff --git a/docs/diffusion/ddpm/utils.html b/docs/diffusion/ddpm/utils.html new file mode 100644 index 00000000..63f9d436 --- /dev/null +++ b/docs/diffusion/ddpm/utils.html @@ -0,0 +1,163 @@ + + + + + + + + + + + + + + + + + + + + + + + Utility functions for DDPM experiment + + + + + + + + +
+
+
+
+

+ home + diffusion + ddpm +

+

+ + + Github + + Twitter +

+
+
+
+
+ +

Utility functions for DDPM experiemnt

+
+
+
10import torch.utils.data
+
+
+
+
+ +

Gather consts for $t$ and reshape to feature map shape

+
+
+
13def gather(consts: torch.Tensor, t: torch.Tensor):
+
+
+
+
+ + +
+
+
15    c = consts.gather(-1, t)
+16    return c.reshape(-1, 1, 1, 1)
+
+
+ +
+ + + + + + \ No newline at end of file diff --git a/docs/diffusion/index.html b/docs/diffusion/index.html new file mode 100644 index 00000000..9f94a4c6 --- /dev/null +++ b/docs/diffusion/index.html @@ -0,0 +1,142 @@ + + + + + + + + + + + + + + + + + + + + + + + Diffusion models + + + + + + + + +
+
+
+
+

+ home + diffusion +

+

+ + + Github + + Twitter +

+
+
+
+
+ +

Diffusion models

+ +
+
+
+
+
+ +
+ + + + + + \ No newline at end of file diff --git a/docs/index.html b/docs/index.html index 08c64984..3359f6d0 100644 --- a/docs/index.html +++ b/docs/index.html @@ -112,6 +112,10 @@ implementations.

  • Wasserstein GAN with Gradient Penalty
  • StyleGAN 2
  • +

    Diffusion models

    +

    Sketch RNN

    ✨ Graph Neural Networks