mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-29 09:38:56 +08:00
stable diffusion
This commit is contained in:
@ -8,4 +8,7 @@ summary: >
|
||||
# Diffusion models
|
||||
|
||||
* [Denoising Diffusion Probabilistic Models (DDPM)](ddpm/index.html)
|
||||
* [Stable Diffusion](stable_diffusion/index.html)
|
||||
* [Latent Diffusion Model](stable_diffusion/latent_diffusion.html)
|
||||
* [Denoising Diffusion Implicit Models (DDIM) Sampling](stable_diffusion/sampler/ddim.html)
|
||||
"""
|
||||
|
||||
48
labml_nn/diffusion/stable_diffusion/__init__.py
Normal file
48
labml_nn/diffusion/stable_diffusion/__init__.py
Normal file
@ -0,0 +1,48 @@
|
||||
"""
|
||||
---
|
||||
title: Stable Diffusion
|
||||
summary: >
|
||||
Annotated PyTorch implementation/tutorial of stable diffusion.
|
||||
---
|
||||
|
||||
# Stable Diffusion
|
||||
|
||||
This is based on official stable diffusion repository
|
||||
[CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion).
|
||||
We have kept the model structure same so that open sourced weights could be directly loaded.
|
||||
Our implementation does not contain training code.
|
||||
|
||||
### [PromptArt](https://promptart.labml.ai)
|
||||
|
||||
We have deployed a stable diffusion based image generation service
|
||||
at [promptart.labml.ai](https://promptart.labml.ai)
|
||||
|
||||
### [Latent Diffusion Model](latent_diffusion.html)
|
||||
|
||||
The core is the [Latent Diffusion Model](latent_diffusion.html).
|
||||
It consists of:
|
||||
|
||||
* [AutoEncoder](model/autoencoder.html)
|
||||
* [U-Net](model/unet.html) with [attention](model/unet_attention.html)
|
||||
|
||||
The diffusion is conditioned based on [CLIP embeddings](model/clip_embedder.html).
|
||||
|
||||
### [Sampling Algorithms](sampler/index.html)
|
||||
|
||||
We have implemented the following [sampling algorithms](sampler/index.html):
|
||||
|
||||
* [Denoising Diffusion Probabilistic Models (DDPM) Sampling](sampler/ddpm.html)
|
||||
* [Denoising Diffusion Implicit Models (DDIM) Sampling](sampler/ddim.html)
|
||||
|
||||
### [Example Scripts](scripts/index.html)
|
||||
|
||||
Here are the image generation scripts:
|
||||
|
||||
* [Generate images from text prompts](scripts/text_to_image.html)
|
||||
* [Generate images based on a given image, guided by a prompt](scripts/image_to_image.html)
|
||||
* [Modify parts of a given image based on a text prompt](scripts/in_paint.html)
|
||||
|
||||
#### [Utilities](util.html)
|
||||
|
||||
[`util.py`](util.html) defines the utility functions.
|
||||
"""
|
||||
146
labml_nn/diffusion/stable_diffusion/latent_diffusion.py
Normal file
146
labml_nn/diffusion/stable_diffusion/latent_diffusion.py
Normal file
@ -0,0 +1,146 @@
|
||||
"""
|
||||
---
|
||||
title: Latent Diffusion Models
|
||||
summary: >
|
||||
Annotated PyTorch implementation/tutorial of latent diffusion models from paper
|
||||
High-Resolution Image Synthesis with Latent Diffusion Models
|
||||
---
|
||||
|
||||
# Latent Diffusion Models
|
||||
|
||||
Latent diffusion models use an auto-encoder to map between image space and
|
||||
latent space. The diffusion model works on the diffusion space, which makes it
|
||||
a lot easier to train.
|
||||
It is based on paper
|
||||
[High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752).
|
||||
|
||||
They use a pre-trained auto-encoder and train the diffusion U-Net on the latent
|
||||
space of the pre-trained auto-encoder.
|
||||
|
||||
For a simpler diffusion implementation refer to our [DDPM implementation](../ddpm/index.html).
|
||||
We use same notations for $\alpha_t$, $\beta_t$ schedules, etc.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional
|
||||
|
||||
from labml_nn.diffusion.stable_diffusion.model.autoencoder import Autoencoder
|
||||
from labml_nn.diffusion.stable_diffusion.model.clip_embedder import CLIPTextEmbedder
|
||||
from labml_nn.diffusion.stable_diffusion.model.unet import UNetModel
|
||||
|
||||
|
||||
class DiffusionWrapper(nn.Module):
|
||||
"""
|
||||
*This is an empty wrapper class around the [U-Net](model/unet.html).
|
||||
We keep this to have the same model structure as
|
||||
[CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion)
|
||||
so that we do not have to map the checkpoint weights explicitly*.
|
||||
"""
|
||||
|
||||
def __init__(self, diffusion_model: UNetModel):
|
||||
super().__init__()
|
||||
self.diffusion_model = diffusion_model
|
||||
|
||||
def forward(self, x: torch.Tensor, time_steps: torch.Tensor, context: torch.Tensor):
|
||||
return self.diffusion_model(x, time_steps, context)
|
||||
|
||||
|
||||
class LatentDiffusion(nn.Module):
|
||||
"""
|
||||
## Latent diffusion model
|
||||
|
||||
This contains following components:
|
||||
|
||||
* [AutoEncoder](model/autoencoder.html)
|
||||
* [U-Net](model/unet.html) with [attention](model/unet_attention.html)
|
||||
* [CLIP embeddings generator](model/clip_embedder.html)
|
||||
"""
|
||||
model: DiffusionWrapper
|
||||
first_stage_model: Autoencoder
|
||||
cond_stage_model: CLIPTextEmbedder
|
||||
|
||||
def __init__(self,
|
||||
unet_model: UNetModel,
|
||||
autoencoder: Autoencoder,
|
||||
clip_embedder: CLIPTextEmbedder,
|
||||
latent_scaling_factor: float,
|
||||
n_steps: int,
|
||||
linear_start: float,
|
||||
linear_end: float,
|
||||
):
|
||||
"""
|
||||
:param unet_model: is the [U-Net](model/unet.html) that predicts noise
|
||||
$\epsilon_\text{cond}(x_t, c)$, in latent space
|
||||
:param autoencoder: is the [AutoEncoder](model/autoencoder.html)
|
||||
:param clip_embedder: is the [CLIP embeddings generator](model/clip_embedder.html)
|
||||
:param latent_scaling_factor: is the scaling factor for the latent space. The encodings of
|
||||
the autoencoder are scaled by this before feeding into the U-Net.
|
||||
:param n_steps: is the number of diffusion steps $T$.
|
||||
:param linear_start: is the start of the $\beta$ schedule.
|
||||
:param linear_end: is the end of the $\beta$ schedule.
|
||||
"""
|
||||
super().__init__()
|
||||
# Wrap the [U-Net](model/unet.html) to keep the same model structure as
|
||||
# [CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion).
|
||||
self.model = DiffusionWrapper(unet_model)
|
||||
# Auto-encoder and scaling factor
|
||||
self.first_stage_model = autoencoder
|
||||
self.latent_scaling_factor = latent_scaling_factor
|
||||
# [CLIP embeddings generator](model/clip_embedder.html)
|
||||
self.cond_stage_model = clip_embedder
|
||||
|
||||
# Number of steps $T$
|
||||
self.n_steps = n_steps
|
||||
|
||||
# $\beta$ schedule
|
||||
beta = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_steps, dtype=torch.float64) ** 2
|
||||
self.beta = nn.Parameter(beta.to(torch.float32), requires_grad=False)
|
||||
# $\alpha_t = 1 - \beta_t$
|
||||
alpha = 1. - beta
|
||||
# $\bar\alpha_t = \prod_{s=1}^t \alpha_s$
|
||||
alpha_bar = torch.cumprod(alpha, dim=0)
|
||||
self.alpha_bar = nn.Parameter(alpha_bar.to(torch.float32), requires_grad=False)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
"""
|
||||
### Get model device
|
||||
"""
|
||||
return next(iter(self.model.parameters())).device
|
||||
|
||||
def get_text_conditioning(self, prompts: List[str]):
|
||||
"""
|
||||
### Get [CLIP embeddings](model/clip_embedder.html) for a list of text prompts
|
||||
"""
|
||||
return self.cond_stage_model(prompts)
|
||||
|
||||
def autoencoder_encode(self, image: torch.Tensor):
|
||||
"""
|
||||
### Get scaled latent space representation of the image
|
||||
|
||||
The encoder output is a distribution.
|
||||
We sample from that and multiply by the scaling factor.
|
||||
"""
|
||||
return self.latent_scaling_factor * self.first_stage_model.encode(image).sample()
|
||||
|
||||
def autoencoder_decode(self, z: torch.Tensor):
|
||||
"""
|
||||
### Get image from the latent representation
|
||||
|
||||
We scale down by the scaling factor and then decode.
|
||||
"""
|
||||
return self.first_stage_model.decode(z / self.latent_scaling_factor)
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor, context: torch.Tensor):
|
||||
"""
|
||||
### Predict noise
|
||||
|
||||
Predict noise given the latent representation $x_t$, time step $t$, and the
|
||||
conditioning context $c$.
|
||||
|
||||
$$\epsilon_\text{cond}(x_t, c)$$
|
||||
"""
|
||||
return self.model(x, t, context)
|
||||
13
labml_nn/diffusion/stable_diffusion/model/__init__.py
Normal file
13
labml_nn/diffusion/stable_diffusion/model/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
"""
|
||||
---
|
||||
title: Modules used in stable diffusion
|
||||
summary: >
|
||||
Models and components for stable diffusion.
|
||||
---
|
||||
|
||||
# [Stable Diffusion](../index.html) Models
|
||||
|
||||
* [AutoEncoder](autoencoder.html)
|
||||
* [U-Net](unet.html) with [attention](unet_attention.html)
|
||||
* [CLIP embedder](clip_embedder.html).
|
||||
"""
|
||||
433
labml_nn/diffusion/stable_diffusion/model/autoencoder.py
Normal file
433
labml_nn/diffusion/stable_diffusion/model/autoencoder.py
Normal file
@ -0,0 +1,433 @@
|
||||
"""
|
||||
---
|
||||
title: Autoencoder for Stable Diffusion
|
||||
summary: >
|
||||
Annotated PyTorch implementation/tutorial of the autoencoder
|
||||
for stable diffusion.
|
||||
---
|
||||
|
||||
# Autoencoder for [Stable Diffusion](../index.html)
|
||||
|
||||
This implements the auto-encoder model used to map between image space and latent space.
|
||||
|
||||
We have kept to the model definition and naming unchanged from
|
||||
[CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion)
|
||||
so that we can load the checkpoints directly.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
class Autoencoder(nn.Module):
|
||||
"""
|
||||
## Autoencoder
|
||||
|
||||
This consists of the encoder and decoder modules.
|
||||
"""
|
||||
|
||||
def __init__(self, encoder: 'Encoder', decoder: 'Decoder', emb_channels: int, z_channels: int):
|
||||
"""
|
||||
:param encoder: is the encoder
|
||||
:param decoder: is the decoder
|
||||
:param emb_channels: is the number of dimensions in the quantized embedding space
|
||||
:param z_channels: is the number of channels in the embedding space
|
||||
"""
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
# Convolution to map from embedding space to
|
||||
# quantized embedding space moments (mean and log variance)
|
||||
self.quant_conv = nn.Conv2d(2 * z_channels, 2 * emb_channels, 1)
|
||||
# Convolution to map from quantized embedding space back to
|
||||
# embedding space
|
||||
self.post_quant_conv = nn.Conv2d(emb_channels, z_channels, 1)
|
||||
|
||||
def encode(self, img: torch.Tensor) -> 'GaussianDistribution':
|
||||
"""
|
||||
### Encode images to latent representation
|
||||
|
||||
:param img: is the image tensor with shape `[batch_size, img_channels, img_height, img_width]`
|
||||
"""
|
||||
# Get embeddings with shape `[batch_size, z_channels * 2, z_height, z_height]`
|
||||
z = self.encoder(img)
|
||||
# Get the moments in the quantized embedding space
|
||||
moments = self.quant_conv(z)
|
||||
# Return the distribution
|
||||
return GaussianDistribution(moments)
|
||||
|
||||
def decode(self, z: torch.Tensor):
|
||||
"""
|
||||
### Decode images from latent representation
|
||||
|
||||
:param z: is the latent representation with shape `[batch_size, emb_channels, z_height, z_height]`
|
||||
"""
|
||||
# Map to embedding space from the quantized representation
|
||||
z = self.post_quant_conv(z)
|
||||
# Decode the image of shape `[batch_size, channels, height, width]`
|
||||
return self.decoder(z)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""
|
||||
## Encoder module
|
||||
"""
|
||||
|
||||
def __init__(self, *, channels: int, channel_multipliers: List[int], n_resnet_blocks: int,
|
||||
in_channels: int, z_channels: int):
|
||||
"""
|
||||
:param channels: is the number of channels in the first convolution layer
|
||||
:param channel_multipliers: are the multiplicative factors for the number of channels in the
|
||||
subsequent blocks
|
||||
:param n_resnet_blocks: is the number of resnet layers at each resolution
|
||||
:param in_channels: is the number of channels in the image
|
||||
:param z_channels: is the number of channels in the embedding space
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Number of blocks of different resolutions.
|
||||
# The resolution is halved at the end each top level block
|
||||
n_resolutions = len(channel_multipliers)
|
||||
|
||||
# Initial $3 \times 3$ convolution layer that maps the image to `channels`
|
||||
self.conv_in = nn.Conv2d(in_channels, channels, 3, stride=1, padding=1)
|
||||
|
||||
# Number of channels in each top level block
|
||||
channels_list = [m * channels for m in [1] + channel_multipliers]
|
||||
|
||||
# List of top-level blocks
|
||||
self.down = nn.ModuleList()
|
||||
# Create top-level blocks
|
||||
for i in range(n_resolutions):
|
||||
# Each top level block consists of multiple ResNet Blocks and down-sampling
|
||||
resnet_blocks = nn.ModuleList()
|
||||
# Add ResNet Blocks
|
||||
for _ in range(n_resnet_blocks):
|
||||
resnet_blocks.append(ResnetBlock(channels, channels_list[i + 1]))
|
||||
channels = channels_list[i + 1]
|
||||
# Top-level block
|
||||
down = nn.Module()
|
||||
down.block = resnet_blocks
|
||||
# Down-sampling at the end of each top level block except the last
|
||||
if i != n_resolutions - 1:
|
||||
down.downsample = DownSample(channels)
|
||||
else:
|
||||
down.downsample = nn.Identity()
|
||||
#
|
||||
self.down.append(down)
|
||||
|
||||
# Final ResNet blocks with attention
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(channels, channels)
|
||||
self.mid.attn_1 = AttnBlock(channels)
|
||||
self.mid.block_2 = ResnetBlock(channels, channels)
|
||||
|
||||
# Map to embedding space with a $3 \times 3$ convolution
|
||||
self.norm_out = normalization(channels)
|
||||
self.conv_out = nn.Conv2d(channels, 2 * z_channels, 3, stride=1, padding=1)
|
||||
|
||||
def forward(self, img: torch.Tensor):
|
||||
"""
|
||||
:param img: is the image tensor with shape `[batch_size, img_channels, img_height, img_width]`
|
||||
"""
|
||||
|
||||
# Map to `channels` with the initial convolution
|
||||
x = self.conv_in(img)
|
||||
|
||||
# Top-level blocks
|
||||
for down in self.down:
|
||||
# ResNet Blocks
|
||||
for block in down.block:
|
||||
x = block(x)
|
||||
# Down-sampling
|
||||
x = down.downsample(x)
|
||||
|
||||
# Final ResNet blocks with attention
|
||||
x = self.mid.block_1(x)
|
||||
x = self.mid.attn_1(x)
|
||||
x = self.mid.block_2(x)
|
||||
|
||||
# Normalize and map to embedding space
|
||||
x = self.norm_out(x)
|
||||
x = swish(x)
|
||||
x = self.conv_out(x)
|
||||
|
||||
#
|
||||
return x
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""
|
||||
## Decoder module
|
||||
"""
|
||||
|
||||
def __init__(self, *, channels: int, channel_multipliers: List[int], n_resnet_blocks: int,
|
||||
out_channels: int, z_channels: int):
|
||||
"""
|
||||
:param channels: is the number of channels in the final convolution layer
|
||||
:param channel_multipliers: are the multiplicative factors for the number of channels in the
|
||||
previous blocks, in reverse order
|
||||
:param n_resnet_blocks: is the number of resnet layers at each resolution
|
||||
:param out_channels: is the number of channels in the image
|
||||
:param z_channels: is the number of channels in the embedding space
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Number of blocks of different resolutions.
|
||||
# The resolution is halved at the end each top level block
|
||||
num_resolutions = len(channel_multipliers)
|
||||
|
||||
# Number of channels in each top level block, in the reverse order
|
||||
channels_list = [m * channels for m in channel_multipliers]
|
||||
|
||||
# Number of channels in the top-level block
|
||||
channels = channels_list[-1]
|
||||
|
||||
# Initial $3 \times 3$ convolution layer that maps the embedding space to `channels`
|
||||
self.conv_in = nn.Conv2d(z_channels, channels, 3, stride=1, padding=1)
|
||||
|
||||
# ResNet blocks with attention
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(channels, channels)
|
||||
self.mid.attn_1 = AttnBlock(channels)
|
||||
self.mid.block_2 = ResnetBlock(channels, channels)
|
||||
|
||||
# List of top-level blocks
|
||||
self.up = nn.ModuleList()
|
||||
# Create top-level blocks
|
||||
for i in reversed(range(num_resolutions)):
|
||||
# Each top level block consists of multiple ResNet Blocks and up-sampling
|
||||
resnet_blocks = nn.ModuleList()
|
||||
# Add ResNet Blocks
|
||||
for _ in range(n_resnet_blocks + 1):
|
||||
resnet_blocks.append(ResnetBlock(channels, channels_list[i]))
|
||||
channels = channels_list[i]
|
||||
# Top-level block
|
||||
up = nn.Module()
|
||||
up.block = resnet_blocks
|
||||
# Up-sampling at the end of each top level block except the first
|
||||
if i != 0:
|
||||
up.upsample = UpSample(channels)
|
||||
else:
|
||||
up.upsample = nn.Identity()
|
||||
# Prepend to be consistent with the checkpoint
|
||||
self.up.insert(0, up)
|
||||
|
||||
# Map to image space with a $3 \times 3$ convolution
|
||||
self.norm_out = normalization(channels)
|
||||
self.conv_out = nn.Conv2d(channels, out_channels, 3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z: torch.Tensor):
|
||||
"""
|
||||
:param z: is the embedding tensor with shape `[batch_size, z_channels, z_height, z_height]`
|
||||
"""
|
||||
|
||||
# Map to `channels` with the initial convolution
|
||||
h = self.conv_in(z)
|
||||
|
||||
# ResNet blocks with attention
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
|
||||
# Top-level blocks
|
||||
for up in reversed(self.up):
|
||||
# ResNet Blocks
|
||||
for block in up.block:
|
||||
h = block(h)
|
||||
# Up-sampling
|
||||
h = up.upsample(h)
|
||||
|
||||
# Normalize and map to image space
|
||||
h = self.norm_out(h)
|
||||
h = swish(h)
|
||||
img = self.conv_out(h)
|
||||
|
||||
#
|
||||
return img
|
||||
|
||||
|
||||
class GaussianDistribution:
|
||||
"""
|
||||
## Gaussian Distribution
|
||||
"""
|
||||
|
||||
def __init__(self, parameters: torch.Tensor):
|
||||
"""
|
||||
:param parameters: are the means and log of variances of the embedding of shape
|
||||
`[batch_size, z_channels * 2, z_height, z_height]`
|
||||
"""
|
||||
# Split mean and log of variance
|
||||
self.mean, log_var = torch.chunk(parameters, 2, dim=1)
|
||||
# Clamp the log of variances
|
||||
self.log_var = torch.clamp(log_var, -30.0, 20.0)
|
||||
# Calculate standard deviation
|
||||
self.std = torch.exp(0.5 * self.log_var)
|
||||
|
||||
def sample(self):
|
||||
# Sample from the distribution
|
||||
return self.mean + self.std * torch.randn_like(self.std)
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
"""
|
||||
## Attention block
|
||||
"""
|
||||
|
||||
def __init__(self, channels: int):
|
||||
"""
|
||||
:param channels: is the number of channels
|
||||
"""
|
||||
super().__init__()
|
||||
# Group normalization
|
||||
self.norm = normalization(channels)
|
||||
# Query, key and value mappings
|
||||
self.q = nn.Conv2d(channels, channels, 1)
|
||||
self.k = nn.Conv2d(channels, channels, 1)
|
||||
self.v = nn.Conv2d(channels, channels, 1)
|
||||
# Final $1 \times 1$ convolution layer
|
||||
self.proj_out = nn.Conv2d(channels, channels, 1)
|
||||
# Attention scaling factor
|
||||
self.scale = channels ** -0.5
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
:param x: is the tensor of shape `[batch_size, channels, height, width]`
|
||||
"""
|
||||
# Normalize `x`
|
||||
x_norm = self.norm(x)
|
||||
# Get query, key and vector embeddings
|
||||
q = self.q(x_norm)
|
||||
k = self.k(x_norm)
|
||||
v = self.v(x_norm)
|
||||
|
||||
# Reshape to query, key and vector embeedings from
|
||||
# `[batch_size, channels, height, width]` to
|
||||
# `[batch_size, channels, height * width]`
|
||||
b, c, h, w = q.shape
|
||||
q = q.view(b, c, h * w)
|
||||
k = k.view(b, c, h * w)
|
||||
v = v.view(b, c, h * w)
|
||||
|
||||
# Compute $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$
|
||||
attn = torch.einsum('bci,bcj->bij', q, k) * self.scale
|
||||
attn = F.softmax(attn, dim=2)
|
||||
|
||||
# Compute $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$
|
||||
out = torch.einsum('bij,bcj->bci', attn, v)
|
||||
|
||||
# Reshape back to `[batch_size, channels, height, width]`
|
||||
out = out.view(b, c, h, w)
|
||||
# Final $1 \times 1$ convolution layer
|
||||
out = self.proj_out(out)
|
||||
|
||||
# Add residual connection
|
||||
return x + out
|
||||
|
||||
|
||||
class UpSample(nn.Module):
|
||||
"""
|
||||
## Up-sampling layer
|
||||
"""
|
||||
def __init__(self, channels: int):
|
||||
"""
|
||||
:param channels: is the number of channels
|
||||
"""
|
||||
super().__init__()
|
||||
# $3 \times 3$ convolution mapping
|
||||
self.conv = nn.Conv2d(channels, channels, 3, padding=1)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
:param x: is the input feature map with shape `[batch_size, channels, height, width]`
|
||||
"""
|
||||
# Up-sample by a factor of $2$
|
||||
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
# Apply convolution
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class DownSample(nn.Module):
|
||||
"""
|
||||
## Down-sampling layer
|
||||
"""
|
||||
def __init__(self, channels: int):
|
||||
"""
|
||||
:param channels: is the number of channels
|
||||
"""
|
||||
super().__init__()
|
||||
# $3 \times 3$ convolution with stride length of $2$ to down-sample by a factor of $2$
|
||||
self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
:param x: is the input feature map with shape `[batch_size, channels, height, width]`
|
||||
"""
|
||||
# Add padding
|
||||
x = F.pad(x, (0, 1, 0, 1), mode="constant", value=0)
|
||||
# Apply convolution
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
"""
|
||||
## ResNet Block
|
||||
"""
|
||||
def __init__(self, in_channels: int, out_channels: int):
|
||||
"""
|
||||
:param in_channels: is the number of channels in the input
|
||||
:param out_channels: is the number of channels in the output
|
||||
"""
|
||||
super().__init__()
|
||||
# First normalization and convolution layer
|
||||
self.norm1 = normalization(in_channels)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1)
|
||||
# Second normalization and convolution layer
|
||||
self.norm2 = normalization(out_channels)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1)
|
||||
# `in_channels` to `out_channels` mapping layer for residual connection
|
||||
if in_channels != out_channels:
|
||||
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0)
|
||||
else:
|
||||
self.nin_shortcut = nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
:param x: is the input feature map with shape `[batch_size, channels, height, width]`
|
||||
"""
|
||||
|
||||
h = x
|
||||
|
||||
# First normalization and convolution layer
|
||||
h = self.norm1(h)
|
||||
h = swish(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
# Second normalization and convolution layer
|
||||
h = self.norm2(h)
|
||||
h = swish(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
# Map and add residual
|
||||
return self.nin_shortcut(x) + h
|
||||
|
||||
|
||||
def swish(x: torch.Tensor):
|
||||
"""
|
||||
### Swish activation
|
||||
|
||||
$$x \cdot \sigma(x)$$
|
||||
"""
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def normalization(channels: int):
|
||||
"""
|
||||
### Group normalization
|
||||
|
||||
This is a helper function, with fixed number of groups and `eps`.
|
||||
"""
|
||||
return nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)
|
||||
50
labml_nn/diffusion/stable_diffusion/model/clip_embedder.py
Normal file
50
labml_nn/diffusion/stable_diffusion/model/clip_embedder.py
Normal file
@ -0,0 +1,50 @@
|
||||
"""
|
||||
---
|
||||
title: CLIP Text Embedder
|
||||
summary: >
|
||||
CLIP embedder to get prompt embeddings for stable diffusion
|
||||
---
|
||||
|
||||
# CLIP Text Embedder
|
||||
|
||||
This is used to get prompt embeddings for [stable diffusion](../index.html).
|
||||
It uses HuggingFace Transformers CLIP model.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
from torch import nn
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
|
||||
|
||||
class CLIPTextEmbedder(nn.Module):
|
||||
"""
|
||||
## CLIP Text Embedder
|
||||
"""
|
||||
|
||||
def __init__(self, version: str = "openai/clip-vit-large-patch14", device="cuda:0", max_length: int = 77):
|
||||
"""
|
||||
:param version: is the model version
|
||||
:param device: is the device
|
||||
:param max_length: is the max length of the tokenized prompt
|
||||
"""
|
||||
super().__init__()
|
||||
# Load the tokenizer
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
# Load the CLIP transformer
|
||||
self.transformer = CLIPTextModel.from_pretrained(version).eval()
|
||||
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
|
||||
def forward(self, prompts: List[str]):
|
||||
"""
|
||||
:param prompts: are the list of prompts to embed
|
||||
"""
|
||||
# Tokenize the prompts
|
||||
batch_encoding = self.tokenizer(prompts, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
# Get token ids
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
# Get CLIP embeddings
|
||||
return self.transformer(input_ids=tokens).last_hidden_state
|
||||
343
labml_nn/diffusion/stable_diffusion/model/unet.py
Normal file
343
labml_nn/diffusion/stable_diffusion/model/unet.py
Normal file
@ -0,0 +1,343 @@
|
||||
"""
|
||||
---
|
||||
title: U-Net for Stable Diffusion
|
||||
summary: >
|
||||
Annotated PyTorch implementation/tutorial of the U-Net in stable diffusion.
|
||||
---
|
||||
|
||||
# U-Net for [Stable Diffusion](../index.html)
|
||||
|
||||
This implements the U-Net that
|
||||
gives $\epsilon_\text{cond}(x_t, c)$
|
||||
|
||||
We have kept to the model definition and naming unchanged from
|
||||
[CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion)
|
||||
so that we can load the checkpoints directly.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from labml_nn.diffusion.stable_diffusion.model.unet_attention import SpatialTransformer
|
||||
|
||||
|
||||
class UNetModel(nn.Module):
|
||||
"""
|
||||
## U-Net model
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, *,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
channels: int,
|
||||
n_res_blocks: int,
|
||||
attention_levels: List[int],
|
||||
channel_multipliers: List[int],
|
||||
n_heads: int,
|
||||
tf_layers: int = 1,
|
||||
d_cond: int = 768):
|
||||
"""
|
||||
:param in_channels: is the number of channels in the input feature map
|
||||
:param out_channels: is the number of channels in the output feature map
|
||||
:param channels: is the base channel count for the model
|
||||
:param n_res_blocks: number of residual blocks at each level
|
||||
:param attention_levels: are the levels at which attention should be performed
|
||||
:param channel_multipliers: are the multiplicative factors for number of channels for each level
|
||||
:param n_heads: the number of attention heads in the transformers
|
||||
"""
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
|
||||
# Number of levels
|
||||
levels = len(channel_multipliers)
|
||||
# Size time embeddings
|
||||
d_time_emb = channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
nn.Linear(channels, d_time_emb),
|
||||
nn.SiLU(),
|
||||
nn.Linear(d_time_emb, d_time_emb),
|
||||
)
|
||||
|
||||
# Input half of the U-Net
|
||||
self.input_blocks = nn.ModuleList()
|
||||
# Initial $3 \times 3$ convolution that maps the input to `channels`.
|
||||
# The blocks are wrapped in `TimestepEmbedSequential` module because
|
||||
# different modules have different forward function signatures;
|
||||
# for example, convolution only accepts the feature map and
|
||||
# residual blocks accept the feature map and time embedding.
|
||||
# `TimestepEmbedSequential` calls them accordingly.
|
||||
self.input_blocks.append(TimestepEmbedSequential(
|
||||
nn.Conv2d(in_channels, channels, 3, padding=1)))
|
||||
# Number of channels at each block in the input half of U-Net
|
||||
input_block_channels = [channels]
|
||||
# Number of channels at each level
|
||||
channels_list = [channels * m for m in channel_multipliers]
|
||||
# Prepare levels
|
||||
for i in range(levels):
|
||||
# Add the residual blocks and attentions
|
||||
for _ in range(n_res_blocks):
|
||||
# Residual block maps from previous number of channels to the number of
|
||||
# channels in the current level
|
||||
layers = [ResBlock(channels, d_time_emb, out_channels=channels_list[i])]
|
||||
channels = channels_list[i]
|
||||
# Add transformer
|
||||
if i in attention_levels:
|
||||
layers.append(SpatialTransformer(channels, n_heads, tf_layers, d_cond))
|
||||
# Add them to the input half of the U-Net and keep track of the number of channels of
|
||||
# its output
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
input_block_channels.append(channels)
|
||||
# Down sample at all levels except last
|
||||
if i != levels - 1:
|
||||
self.input_blocks.append(TimestepEmbedSequential(DownSample(channels)))
|
||||
input_block_channels.append(channels)
|
||||
|
||||
# The middle of the U-Net
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(channels, d_time_emb),
|
||||
SpatialTransformer(channels, n_heads, tf_layers, d_cond),
|
||||
ResBlock(channels, d_time_emb),
|
||||
)
|
||||
|
||||
# Second half of the U-Net
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
# Prepare levels in reverse order
|
||||
for i in reversed(range(levels)):
|
||||
# Add the residual blocks and attentions
|
||||
for j in range(n_res_blocks + 1):
|
||||
# Residual block maps from previous number of channels plus the
|
||||
# skip connections from the input half of U-Net to the number of
|
||||
# channels in the current level.
|
||||
layers = [ResBlock(channels + input_block_channels.pop(), d_time_emb, out_channels=channels_list[i])]
|
||||
channels = channels_list[i]
|
||||
# Add transformer
|
||||
if i in attention_levels:
|
||||
layers.append(SpatialTransformer(channels, n_heads, tf_layers, d_cond))
|
||||
# Up-sample at every level after last residual block
|
||||
# except the last one.
|
||||
# Note that we are iterating in reverse; i.e. `i == 0` is the last.
|
||||
if i != 0 and j == n_res_blocks:
|
||||
layers.append(UpSample(channels))
|
||||
# Add to the output half of the U-Net
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
|
||||
# Final normalization and $3 \times 3$ convolution
|
||||
self.out = nn.Sequential(
|
||||
normalization(channels),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(channels, out_channels, 3, padding=1),
|
||||
)
|
||||
|
||||
def time_step_embedding(self, time_steps: torch.Tensor, max_period: int = 10000):
|
||||
"""
|
||||
## Create sinusoidal time step embeddings
|
||||
|
||||
:param time_steps: are the time steps of shape `[batch_size]`
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
"""
|
||||
# $\frac{c}{2}$; half the channels are sin and the other half is cos,
|
||||
half = self.channels // 2
|
||||
# $\frac{1}{10000^{\frac{2i}{c}}}$
|
||||
frequencies = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||
).to(device=time_steps.device)
|
||||
# $\frac{t}{10000^{\frac{2i}{c}}}$
|
||||
args = time_steps[:, None].float() * frequencies[None]
|
||||
# $\cos\Bigg(\frac{t}{10000^{\frac{2i}{c}}}\Bigg)$ and $\sin\Bigg(\frac{t}{10000^{\frac{2i}{c}}}\Bigg)$
|
||||
return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
|
||||
def forward(self, x: torch.Tensor, time_steps: torch.Tensor, cond: torch.Tensor):
|
||||
"""
|
||||
:param x: is the input feature map of shape `[batch_size, channels, width, height]`
|
||||
:param time_steps: are the time steps of shape `[batch_size]`
|
||||
:param cond: conditioning of shape `[batch_size, n_cond, d_cond]`
|
||||
"""
|
||||
# To store the input half outputs for skip connections
|
||||
x_input_block = []
|
||||
|
||||
# Get time step embeddings
|
||||
t_emb = self.time_step_embedding(time_steps)
|
||||
t_emb = self.time_embed(t_emb)
|
||||
|
||||
# Input half of the U-Net
|
||||
for module in self.input_blocks:
|
||||
x = module(x, t_emb, cond)
|
||||
x_input_block.append(x)
|
||||
# Middle of the U-Net
|
||||
x = self.middle_block(x, t_emb, cond)
|
||||
# Output half of the U-Net
|
||||
for module in self.output_blocks:
|
||||
x = th.cat([x, x_input_block.pop()], dim=1)
|
||||
x = module(x, t_emb, cond)
|
||||
|
||||
# Final normalization and $3 \times 3$ convolution
|
||||
return self.out(x)
|
||||
|
||||
|
||||
class TimestepEmbedSequential(nn.Sequential):
|
||||
"""
|
||||
### Sequential block for modules with different inputs
|
||||
|
||||
This sequential module can compose of different modules suck as `ResBlock`,
|
||||
`nn.Conv` and `SpatialTransformer` and calls them with the matching signatures
|
||||
"""
|
||||
|
||||
def forward(self, x, t_emb, cond=None):
|
||||
for layer in self:
|
||||
if isinstance(layer, ResBlock):
|
||||
x = layer(x, t_emb)
|
||||
elif isinstance(layer, SpatialTransformer):
|
||||
x = layer(x, cond)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class UpSample(nn.Module):
|
||||
"""
|
||||
### Up-sampling layer
|
||||
"""
|
||||
|
||||
def __init__(self, channels: int):
|
||||
"""
|
||||
:param channels: is the number of channels
|
||||
"""
|
||||
super().__init__()
|
||||
# $3 \times 3$ convolution mapping
|
||||
self.conv = nn.Conv2d(channels, channels, 3, padding=1)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
:param x: is the input feature map with shape `[batch_size, channels, height, width]`
|
||||
"""
|
||||
# Up-sample by a factor of $2$
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
# Apply convolution
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class DownSample(nn.Module):
|
||||
"""
|
||||
## Down-sampling layer
|
||||
"""
|
||||
|
||||
def __init__(self, channels: int):
|
||||
"""
|
||||
:param channels: is the number of channels
|
||||
"""
|
||||
super().__init__()
|
||||
# $3 \times 3$ convolution with stride length of $2$ to down-sample by a factor of $2$
|
||||
self.op = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
:param x: is the input feature map with shape `[batch_size, channels, height, width]`
|
||||
"""
|
||||
# Apply convolution
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""
|
||||
## ResNet Block
|
||||
"""
|
||||
|
||||
def __init__(self, channels: int, d_t_emb: int, *, out_channels=None):
|
||||
"""
|
||||
:param channels: the number of input channels
|
||||
:param d_t_emb: the size of timestep embeddings
|
||||
:param out_channels: is the number of out channels. defaults to `channels.
|
||||
"""
|
||||
super().__init__()
|
||||
# `out_channels` not specified
|
||||
if out_channels is None:
|
||||
out_channels = channels
|
||||
|
||||
# First normalization and convolution
|
||||
self.in_layers = nn.Sequential(
|
||||
normalization(channels),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(channels, out_channels, 3, padding=1),
|
||||
)
|
||||
|
||||
# Time step embeddings
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(d_t_emb, out_channels),
|
||||
)
|
||||
# Final convolution layer
|
||||
self.out_layers = nn.Sequential(
|
||||
normalization(out_channels),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(0.),
|
||||
nn.Conv2d(out_channels, out_channels, 3, padding=1)
|
||||
)
|
||||
|
||||
# `channels` to `out_channels` mapping layer for residual connection
|
||||
if out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
else:
|
||||
self.skip_connection = nn.Conv2d(channels, out_channels, 1)
|
||||
|
||||
def forward(self, x: torch.Tensor, t_emb: torch.Tensor):
|
||||
"""
|
||||
:param x: is the input feature map with shape `[batch_size, channels, height, width]`
|
||||
:param t_emb: is the time step embeddings of shape `[batch_size, d_t_emb]`
|
||||
"""
|
||||
# Initial convolution
|
||||
h = self.in_layers(x)
|
||||
# Time step embeddings
|
||||
t_emb = self.emb_layers(t_emb).type(h.dtype)
|
||||
# Add time step embeddings
|
||||
h = h + t_emb[:, :, None, None]
|
||||
# Final convolution
|
||||
h = self.out_layers(h)
|
||||
# Add skip connection
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
"""
|
||||
### Group normalization with float32 casting
|
||||
"""
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
def normalization(channels):
|
||||
"""
|
||||
### Group normalization
|
||||
|
||||
This is a helper function, with fixed number of groups..
|
||||
"""
|
||||
return GroupNorm32(32, channels)
|
||||
|
||||
|
||||
def _test_time_embeddings():
|
||||
"""
|
||||
Test sinusoidal time step embeddings
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
plt.figure(figsize=(15, 5))
|
||||
m = UNetModel(in_channels=1, out_channels=1, channels=320, n_res_blocks=1, attention_levels=[],
|
||||
channel_multipliers=[],
|
||||
n_heads=1, tf_layers=1, d_cond=1)
|
||||
te = m.time_step_embedding(torch.arange(0, 1000))
|
||||
plt.plot(np.arange(1000), te[:, [50, 100, 190, 260]].numpy())
|
||||
plt.legend(["dim %d" % p for p in [50, 100, 190, 260]])
|
||||
plt.title("Time embeddings")
|
||||
plt.show()
|
||||
|
||||
|
||||
#
|
||||
if __name__ == '__main__':
|
||||
_test_time_embeddings()
|
||||
224
labml_nn/diffusion/stable_diffusion/model/unet_attention.py
Normal file
224
labml_nn/diffusion/stable_diffusion/model/unet_attention.py
Normal file
@ -0,0 +1,224 @@
|
||||
"""
|
||||
---
|
||||
title: Transformer for Stable Diffusion U-Net
|
||||
summary: >
|
||||
Annotated PyTorch implementation/tutorial of the transformer
|
||||
for U-Net in stable diffusion.
|
||||
---
|
||||
|
||||
# Transformer for Stable Diffusion [U-Net](unet.html)
|
||||
|
||||
This implements the transformer module used in [U-Net](unet.html) that
|
||||
gives $\epsilon_\text{cond}(x_t, c)$
|
||||
|
||||
We have kept to the model definition and naming unchanged from
|
||||
[CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion)
|
||||
so that we can load the checkpoints directly.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
## Spatial Transformer
|
||||
"""
|
||||
def __init__(self, channels: int, n_heads: int, n_layers: int, d_cond: int):
|
||||
"""
|
||||
:param channels: is the number of channels in the feature map
|
||||
:param n_heads: is the number of attention heads
|
||||
:param n_layers: is the number of transformer layers
|
||||
:param d_cond: is the size of the conditional embedding
|
||||
"""
|
||||
super().__init__()
|
||||
# Initial group normalization
|
||||
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True)
|
||||
# Initial $1 \times 1$ convolution
|
||||
self.proj_in = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
# Transformer layers
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[BasicTransformerBlock(channels, n_heads, channels // n_heads, d_cond=d_cond) for _ in range(n_layers)]
|
||||
)
|
||||
|
||||
# Final $1 \times 1$ convolution
|
||||
self.proj_out = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x: torch.Tensor, cond: torch.Tensor):
|
||||
"""
|
||||
:param x: is the feature map of shape `[batch_size, channels, height, width]`
|
||||
:param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
|
||||
"""
|
||||
# Get shape `[batch_size, channels, height, width]`
|
||||
b, c, h, w = x.shape
|
||||
# For residual connection
|
||||
x_in = x
|
||||
# Normalize
|
||||
x = self.norm(x)
|
||||
# Initial $1 \times 1$ convolution
|
||||
x = self.proj_in(x)
|
||||
# Transpose and reshape from `[batch_size, channels, height, width]`
|
||||
# to `[batch_size, height * width, channels]`
|
||||
x = x.permute(0, 2, 3, 1).view(b, h * w, c)
|
||||
# Apply the transformer layers
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, cond)
|
||||
# Reshape and transpose from `[batch_size, height * width, channels]`
|
||||
# to `[batch_size, channels, height, width]`
|
||||
x = x.view(b, h, w, c).permute(0, 3, 1, 2)
|
||||
# Final $1 \times 1$ convolution
|
||||
x = self.proj_out(x)
|
||||
# Add residual
|
||||
return x + x_in
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
"""
|
||||
### Transformer Layer
|
||||
"""
|
||||
def __init__(self, d_model: int, n_heads: int, d_head: int, d_cond: int):
|
||||
"""
|
||||
:param d_model: is the input embedding size
|
||||
:param n_heads: is the number of attention heads
|
||||
:param d_head: is the size of a attention head
|
||||
:param d_cond: is the size of the conditional embeddings
|
||||
"""
|
||||
super().__init__()
|
||||
# Self-attention layer and pre-norm layer
|
||||
self.attn1 = CrossAttention(d_model, d_model, n_heads, d_head)
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
# Cross attention layer and pre-norm layer
|
||||
self.attn2 = CrossAttention(d_model, d_cond, n_heads, d_head)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
# Feed-forward network and pre-norm layer
|
||||
self.ff = FeedForward(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, x: torch.Tensor, cond: torch.Tensor):
|
||||
"""
|
||||
:param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
|
||||
:param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
|
||||
"""
|
||||
# Self attention
|
||||
x = self.attn1(self.norm1(x)) + x
|
||||
# Cross-attention with conditioning
|
||||
x = self.attn2(self.norm2(x), cond=cond) + x
|
||||
# Feed-forward network
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
#
|
||||
return x
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
"""
|
||||
### Cross Attention Layer
|
||||
|
||||
This falls-back to self-attention when conditional embeddings are not specified.
|
||||
"""
|
||||
def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True):
|
||||
"""
|
||||
:param d_model: is the input embedding size
|
||||
:param n_heads: is the number of attention heads
|
||||
:param d_head: is the size of a attention head
|
||||
:param d_cond: is the size of the conditional embeddings
|
||||
:param is_inplace: specifies whether to perform the attention softmax computation inplace to
|
||||
save memory
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.is_inplace = is_inplace
|
||||
self.n_heads = n_heads
|
||||
|
||||
# Attention scaling factor
|
||||
self.scale = d_head ** -0.5
|
||||
|
||||
# Query, key and value mappings
|
||||
d_attn = d_head * n_heads
|
||||
self.to_q = nn.Linear(d_model, d_attn, bias=False)
|
||||
self.to_k = nn.Linear(d_cond, d_attn, bias=False)
|
||||
self.to_v = nn.Linear(d_cond, d_attn, bias=False)
|
||||
|
||||
# Final linear layer
|
||||
self.to_out = nn.Sequential(nn.Linear(d_attn, d_model))
|
||||
|
||||
def forward(self, x: torch.Tensor, cond: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
:param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
|
||||
:param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
|
||||
"""
|
||||
|
||||
# If `cond` is `None` we perform self attention
|
||||
if cond is None:
|
||||
cond = x
|
||||
|
||||
# Get query, key and value vectors
|
||||
q = self.to_q(x)
|
||||
k = self.to_k(cond)
|
||||
v = self.to_v(cond)
|
||||
|
||||
# Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
|
||||
q = q.view(*q.shape[:2], self.n_heads, -1)
|
||||
k = k.view(*k.shape[:2], self.n_heads, -1)
|
||||
v = v.view(*v.shape[:2], self.n_heads, -1)
|
||||
|
||||
# Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
|
||||
attn = torch.einsum('bihd,bjhd->bhij', q, k) * self.scale
|
||||
|
||||
# Compute softmax
|
||||
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
|
||||
if self.is_inplace:
|
||||
half = attn.shape[0] // 2
|
||||
attn[half:] = attn[half:].softmax(dim=-1)
|
||||
attn[:half] = attn[:half].softmax(dim=-1)
|
||||
else:
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
# Compute attention output
|
||||
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
|
||||
out = torch.einsum('bhij,bjhd->bihd', attn, v)
|
||||
# Reshape to `[batch_size, height * width, n_heads * d_head]`
|
||||
out = out.reshape(*out.shape[:2], -1)
|
||||
# Map to `[batch_size, height * width, d_model]` with a linear layer
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
"""
|
||||
### Feed-Forward Network
|
||||
"""
|
||||
def __init__(self, d_model: int, d_mult: int = 4):
|
||||
"""
|
||||
:param d_model: is the input embedding size
|
||||
:param d_mult: is multiplicative factor for the hidden layer size
|
||||
"""
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
GeGLU(d_model, d_model * d_mult),
|
||||
nn.Dropout(0.),
|
||||
nn.Linear(d_model * d_mult, d_model)
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class GeGLU(nn.Module):
|
||||
"""
|
||||
### GeGLU Activation
|
||||
|
||||
$$\text{GeGLU}(x) = (xW + b) * \text{GELU}(xV + c)$$
|
||||
"""
|
||||
def __init__(self, d_in: int, d_out: int):
|
||||
super().__init__()
|
||||
# Combined linear projections $xW + b$ and $xV + c$
|
||||
self.proj = nn.Linear(d_in, d_out * 2)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# Get $xW + b$ and $xV + c$
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
# $\text{GeGLU}(x) = (xW + b) * \text{GELU}(xV + c)$
|
||||
return x * F.gelu(gate)
|
||||
126
labml_nn/diffusion/stable_diffusion/sampler/__init__.py
Normal file
126
labml_nn/diffusion/stable_diffusion/sampler/__init__.py
Normal file
@ -0,0 +1,126 @@
|
||||
"""
|
||||
---
|
||||
title: Sampling algorithms for stable diffusion
|
||||
summary: >
|
||||
Annotated PyTorch implementation/tutorial of
|
||||
sampling algorithms
|
||||
for stable diffusion model.
|
||||
---
|
||||
|
||||
# Sampling algorithms for [stable diffusion](../index.html)
|
||||
|
||||
We have implemented the following [sampling algorithms](sampler/index.html):
|
||||
|
||||
* [Denoising Diffusion Probabilistic Models (DDPM) Sampling](ddpm.html)
|
||||
* [Denoising Diffusion Implicit Models (DDIM) Sampling](ddim.html)
|
||||
"""
|
||||
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
|
||||
from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
|
||||
|
||||
|
||||
class DiffusionSampler:
|
||||
"""
|
||||
## Base class for sampling algorithms
|
||||
"""
|
||||
model: LatentDiffusion
|
||||
|
||||
def __init__(self, model: LatentDiffusion):
|
||||
"""
|
||||
:param model: is the model to predict noise $\epsilon_\text{cond}(x_t, c)$
|
||||
"""
|
||||
super().__init__()
|
||||
# Set the model $\epsilon_\text{cond}(x_t, c)$
|
||||
self.model = model
|
||||
# Get number of steps the model was trained with $T$
|
||||
self.n_steps = model.n_steps
|
||||
|
||||
def get_eps(self, x: torch.Tensor, t: torch.Tensor, c: torch.Tensor, *,
|
||||
uncond_scale: float, uncond_cond: Optional[torch.Tensor]):
|
||||
"""
|
||||
## Get $\epsilon(x_t, c)$
|
||||
|
||||
:param x: is $x_t$ of shape `[batch_size, channels, height, width]`
|
||||
:param t: is $t$ of shape `[batch_size]`
|
||||
:param c: is the conditional embeddings $c$ of shape `[batch_size, emb_size]`
|
||||
:param uncond_scale: is the unconditional guidance scale $s$. This is used for
|
||||
$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
|
||||
:param uncond_cond: is the conditional embedding for empty prompt $c_u$
|
||||
"""
|
||||
# When the scale $s = 1$
|
||||
# $$\epsilon_\theta(x_t, c) = \epsilon_\text{cond}(x_t, c)$$
|
||||
if uncond_cond is None or uncond_scale == 1.:
|
||||
return self.model(x, t, c)
|
||||
|
||||
# Duplicate $x_t$ and $t$
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
# Concatenated $c$ and $c_u$
|
||||
c_in = torch.cat([uncond_cond, c])
|
||||
# Get $\epsilon_\text{cond}(x_t, c)$ and $\epsilon_\text{cond}(x_t, c_u)$
|
||||
e_t_uncond, e_t_cond = self.model(x_in, t_in, c_in).chunk(2)
|
||||
# Calculate
|
||||
# $$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$$
|
||||
e_t = e_t_uncond + uncond_scale * (e_t_cond - e_t_uncond)
|
||||
|
||||
#
|
||||
return e_t
|
||||
|
||||
def sample(self,
|
||||
shape: List[int],
|
||||
cond: torch.Tensor,
|
||||
repeat_noise: bool = False,
|
||||
temperature: float = 1.,
|
||||
x_last: Optional[torch.Tensor] = None,
|
||||
uncond_scale: float = 1.,
|
||||
uncond_cond: Optional[torch.Tensor] = None,
|
||||
skip_steps: int = 0,
|
||||
):
|
||||
"""
|
||||
### Sampling Loop
|
||||
|
||||
:param shape: is the shape of the generated images in the
|
||||
form `[batch_size, channels, height, width]`
|
||||
:param cond: is the conditional embeddings $c$
|
||||
:param temperature: is the noise temperature (random noise gets multiplied by this)
|
||||
:param x_last: is $x_T$. If not provided random noise will be used.
|
||||
:param uncond_scale: is the unconditional guidance scale $s$. This is used for
|
||||
$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
|
||||
:param uncond_cond: is the conditional embedding for empty prompt $c_u$
|
||||
:param skip_steps: is the number of time steps to skip.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def paint(self, x: torch.Tensor, cond: torch.Tensor, t_start: int, *,
|
||||
orig: Optional[torch.Tensor] = None,
|
||||
mask: Optional[torch.Tensor] = None, orig_noise: Optional[torch.Tensor] = None,
|
||||
uncond_scale: float = 1.,
|
||||
uncond_cond: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
### Painting Loop
|
||||
|
||||
:param x: is $x_{T'}$ of shape `[batch_size, channels, height, width]`
|
||||
:param cond: is the conditional embeddings $c$
|
||||
:param t_start: is the sampling step to start from, $T'$
|
||||
:param orig: is the original image in latent page which we are in paining.
|
||||
:param mask: is the mask to keep the original image.
|
||||
:param orig_noise: is fixed noise to be added to the original image.
|
||||
:param uncond_scale: is the unconditional guidance scale $s$. This is used for
|
||||
$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
|
||||
:param uncond_cond: is the conditional embedding for empty prompt $c_u$
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
### Sample from $q(x_t|x_0)$
|
||||
|
||||
:param x0: is $x_0$ of shape `[batch_size, channels, height, width]`
|
||||
:param index: is the time step $t$ index
|
||||
:param noise: is the noise, $\epsilon$
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
300
labml_nn/diffusion/stable_diffusion/sampler/ddim.py
Normal file
300
labml_nn/diffusion/stable_diffusion/sampler/ddim.py
Normal file
@ -0,0 +1,300 @@
|
||||
"""
|
||||
---
|
||||
title: Denoising Diffusion Implicit Models (DDIM) Sampling
|
||||
summary: >
|
||||
Annotated PyTorch implementation/tutorial of
|
||||
Denoising Diffusion Implicit Models (DDIM) Sampling
|
||||
for stable diffusion model.
|
||||
---
|
||||
|
||||
# Denoising Diffusion Implicit Models (DDIM) Sampling
|
||||
|
||||
This implements DDIM sampling from the paper
|
||||
[Denoising Diffusion Implicit Models](https://papers.labml.ai/paper/2010.02502)
|
||||
"""
|
||||
|
||||
from typing import Optional, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from labml import monit
|
||||
from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
|
||||
from labml_nn.diffusion.stable_diffusion.sampler import DiffusionSampler
|
||||
|
||||
|
||||
class DDIMSampler(DiffusionSampler):
|
||||
"""
|
||||
## DDIM Sampler
|
||||
|
||||
This extends the [`DiffusionSampler` base class](index.html).
|
||||
|
||||
DDPM samples images by repeatedly removing noise by sampling step by step using,
|
||||
|
||||
\begin{align}
|
||||
x_{\tau_{i-1}} &= \sqrt{\alpha_{\tau_{i-1}}}\Bigg(
|
||||
\frac{x_{\tau_i} - \sqrt{1 - \alpha_{\tau_i}}\epsilon_\theta(x_{\tau_i})}{\sqrt{\alpha_{\tau_i}}}
|
||||
\Bigg) \\
|
||||
&+ \sqrt{1 - \alpha_{\tau_{i- 1}} - \sigma_{\tau_i}^2} \cdot \epsilon_\theta(x_{\tau_i}) \\
|
||||
&+ \sigma_{\tau_i} \epsilon_{\tau_i}
|
||||
\end{align}
|
||||
|
||||
where $\epsilon_{\tau_i}$ is random noise,
|
||||
$\tau$ is a subsequence of $[1,2,\dots,T]$ of length $S$,
|
||||
and
|
||||
$$\sigma_{\tau_i} =
|
||||
\eta \sqrt{\frac{1 - \alpha_{\tau_{i-1}}}{1 - \alpha_{\tau_i}}}
|
||||
\sqrt{1 - \frac{\alpha_{\tau_i}}{\alpha_{\tau_{i-1}}}}$$
|
||||
|
||||
Note that, $\alpha_t$ in DDIM paper refers to ${\color{lightgreen}\bar\alpha_t}$ from [DDPM](ddpm.html).
|
||||
"""
|
||||
|
||||
model: LatentDiffusion
|
||||
|
||||
def __init__(self, model: LatentDiffusion, n_steps: int, ddim_discretize: str = "uniform", ddim_eta: float = 0.):
|
||||
"""
|
||||
:param model: is the model to predict noise $\epsilon_\text{cond}(x_t, c)$
|
||||
:param n_steps: is the number of DDIM sampling steps, $S$
|
||||
:param ddim_discretize: specifies how to extract $\tau$ from $[1,2,\dots,T]$.
|
||||
It can be either `uniform` or `quad`.
|
||||
:param ddim_eta: is $\eta$ used to calculate $\sigma_{\tau_i}$. $\eta = 0$ makes the
|
||||
sampling process deterministic.
|
||||
"""
|
||||
super().__init__(model)
|
||||
# Number of steps, $T$
|
||||
self.n_steps = model.n_steps
|
||||
|
||||
# Calculate $\tau$ to be uniformly distributed across $[1,2,\dots,T]$
|
||||
if ddim_discretize == 'uniform':
|
||||
c = self.n_steps // n_steps
|
||||
self.time_steps = np.asarray(list(range(0, self.n_steps, c))) + 1
|
||||
# Calculate $\tau$ to be quadratically distributed across $[1,2,\dots,T]$
|
||||
elif ddim_discretize == 'quad':
|
||||
self.time_steps = ((np.linspace(0, np.sqrt(self.n_steps * .8), n_steps)) ** 2).astype(int) + 1
|
||||
else:
|
||||
raise NotImplementedError(ddim_discretize)
|
||||
|
||||
with torch.no_grad():
|
||||
# Get ${\color{lightgreen}\bar\alpha_t}$
|
||||
alpha_bar = self.model.alpha_bar
|
||||
|
||||
# $\alpha_{\tau_i}$
|
||||
self.ddim_alpha = alpha_bar[self.time_steps].clone().to(torch.float32)
|
||||
# $\sqrt{\alpha_{\tau_i}}$
|
||||
self.ddim_alpha_sqrt = torch.sqrt(self.ddim_alpha)
|
||||
# $\alpha_{\tau_{i-1}}$
|
||||
self.ddim_alpha_prev = torch.cat([alpha_bar[0:1], alpha_bar[self.time_steps[:-1]]])
|
||||
|
||||
# $$\sigma_{\tau_i} =
|
||||
# \eta \sqrt{\frac{1 - \alpha_{\tau_{i-1}}}{1 - \alpha_{\tau_i}}}
|
||||
# \sqrt{1 - \frac{\alpha_{\tau_i}}{\alpha_{\tau_{i-1}}}}$$
|
||||
self.ddim_sigma = (ddim_eta *
|
||||
((1 - self.ddim_alpha_prev) / (1 - self.ddim_alpha) *
|
||||
(1 - self.ddim_alpha / self.ddim_alpha_prev)) ** .5)
|
||||
|
||||
# $\sqrt{1 - \alpha_{\tau_i}}$
|
||||
self.ddim_sqrt_one_minus_alpha = (1. - self.ddim_alpha) ** .5
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
shape: List[int],
|
||||
cond: torch.Tensor,
|
||||
repeat_noise: bool = False,
|
||||
temperature: float = 1.,
|
||||
x_last: Optional[torch.Tensor] = None,
|
||||
uncond_scale: float = 1.,
|
||||
uncond_cond: Optional[torch.Tensor] = None,
|
||||
skip_steps: int = 0,
|
||||
):
|
||||
"""
|
||||
### Sampling Loop
|
||||
|
||||
:param shape: is the shape of the generated images in the
|
||||
form `[batch_size, channels, height, width]`
|
||||
:param cond: is the conditional embeddings $c$
|
||||
:param temperature: is the noise temperature (random noise gets multiplied by this)
|
||||
:param x_last: is $x_{\tau_S}$. If not provided random noise will be used.
|
||||
:param uncond_scale: is the unconditional guidance scale $s$. This is used for
|
||||
$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
|
||||
:param uncond_cond: is the conditional embedding for empty prompt $c_u$
|
||||
:param skip_steps: is the number of time steps to skip $i'$. We start sampling from $S - i'$.
|
||||
And `x_last` is then $x_{\tau_{S - i'}}$.
|
||||
"""
|
||||
|
||||
# Get device and batch size
|
||||
device = self.model.device
|
||||
bs = shape[0]
|
||||
|
||||
# Get $x_{\tau_S}$
|
||||
x = x_last if x_last is not None else torch.randn(shape, device=device)
|
||||
|
||||
# Time steps to sample at $\tau_{S - i'}, \tau_{S - i' - 1}, \dots, \tau_1$
|
||||
time_steps = np.flip(self.time_steps)[skip_steps:]
|
||||
|
||||
for i, step in monit.enum('Sample', time_steps):
|
||||
# Index $i$ in the list $[\tau_1, \tau_2, \dots, \tau_S]$
|
||||
index = len(time_steps) - i - 1
|
||||
# Time step $\tau_i$
|
||||
ts = x.new_full((bs,), step, dtype=torch.long)
|
||||
|
||||
# Sample $x_{\tau_{i-1}}$
|
||||
x, pred_x0, e_t = self.p_sample(x, cond, ts, step, index=index,
|
||||
repeat_noise=repeat_noise,
|
||||
temperature=temperature,
|
||||
uncond_scale=uncond_scale,
|
||||
uncond_cond=uncond_cond)
|
||||
|
||||
# Return $x_0$
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, x: torch.Tensor, c: torch.Tensor, t: torch.Tensor, step: int, index: int, *,
|
||||
repeat_noise: bool = False,
|
||||
temperature: float = 1.,
|
||||
uncond_scale: float = 1.,
|
||||
uncond_cond: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
### Sample $x_{\tau_{i-1}}$
|
||||
|
||||
:param x: is $x_{\tau_i}$ of shape `[batch_size, channels, height, width]`
|
||||
:param c: is the conditional embeddings $c$ of shape `[batch_size, emb_size]`
|
||||
:param t: is $\tau_i$ of shape `[batch_size]`
|
||||
:param step: is the step $\tau_i$ as an integer
|
||||
:param index: is index $i$ in the list $[\tau_1, \tau_2, \dots, \tau_S]$
|
||||
:param repeat_noise: specified whether the noise should be same for all samples in the batch
|
||||
:param temperature: is the noise temperature (random noise gets multiplied by this)
|
||||
:param uncond_scale: is the unconditional guidance scale $s$. This is used for
|
||||
$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
|
||||
:param uncond_cond: is the conditional embedding for empty prompt $c_u$
|
||||
"""
|
||||
|
||||
# Get $\epsilon_\theta(x_{\tau_i}}$
|
||||
e_t = self.get_eps(x, t, c,
|
||||
uncond_scale=uncond_scale,
|
||||
uncond_cond=uncond_cond)
|
||||
|
||||
# Calculate $x_{\tau_{i - 1}}$ and predicted $x_0$
|
||||
x_prev, pred_x0 = self.get_x_prev_and_pred_x0(e_t, index, x,
|
||||
temperature=temperature,
|
||||
repeat_noise=repeat_noise)
|
||||
|
||||
#
|
||||
return x_prev, pred_x0, e_t
|
||||
|
||||
def get_x_prev_and_pred_x0(self, e_t: torch.Tensor, index: int, x: torch.Tensor, *,
|
||||
temperature: float,
|
||||
repeat_noise: bool):
|
||||
"""
|
||||
### Sample $x_{\tau_{i-1}}$ given $\epsilon_\theta(x_{\tau_i}}$
|
||||
"""
|
||||
|
||||
# $\alpha_{\tau_i}$
|
||||
alpha = self.ddim_alpha[index]
|
||||
# $\alpha_{\tau_{i-1}}$
|
||||
alpha_prev = self.ddim_alpha_prev[index]
|
||||
# $\sigma_{\tau_i}$
|
||||
sigma = self.ddim_sigma[index]
|
||||
# $\sqrt{1 - \alpha_{\tau_i}}$
|
||||
sqrt_one_minus_alpha = self.ddim_sqrt_one_minus_alpha[index]
|
||||
|
||||
# Current prediction for $x_0$,
|
||||
# $$\frac{x_{\tau_i} - \sqrt{1 - \alpha_{\tau_i}}\epsilon_\theta(x_{\tau_i})}{\sqrt{\alpha_{\tau_i}}}$$
|
||||
pred_x0 = (x - sqrt_one_minus_alpha * e_t) / (alpha ** 0.5)
|
||||
# Direction pointing to $x_t$
|
||||
# $$\sqrt{1 - \alpha_{\tau_{i- 1}} - \sigma_{\tau_i}^2} \cdot \epsilon_\theta(x_{\tau_i})$$
|
||||
dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * e_t
|
||||
|
||||
# No noise is added, when $\eta = 0$
|
||||
if sigma == 0.:
|
||||
noise = 0.
|
||||
# If same noise is used for all samples in the batch
|
||||
elif repeat_noise:
|
||||
noise = torch.randn((1, *x.shape[1:]), device=x.device)
|
||||
# Different noise for each sample
|
||||
else:
|
||||
noise = torch.randn(x.shape, device=x.device)
|
||||
|
||||
# Multiply noise by the temperature
|
||||
noise = noise * temperature
|
||||
|
||||
# \begin{align}
|
||||
# x_{\tau_{i-1}} &= \sqrt{\alpha_{\tau_{i-1}}}\Bigg(
|
||||
# \frac{x_{\tau_i} - \sqrt{1 - \alpha_{\tau_i}}\epsilon_\theta(x_{\tau_i})}{\sqrt{\alpha_{\tau_i}}}
|
||||
# \Bigg) \\
|
||||
# &+ \sqrt{1 - \alpha_{\tau_{i- 1}} - \sigma_{\tau_i}^2} \cdot \epsilon_\theta(x_{\tau_i}) \\
|
||||
# &+ \sigma_{\tau_i} \epsilon_{\tau_i}
|
||||
# \end{align}
|
||||
x_prev = (alpha_prev ** 0.5) * pred_x0 + dir_xt + sigma * noise
|
||||
|
||||
#
|
||||
return x_prev, pred_x0
|
||||
|
||||
@torch.no_grad()
|
||||
def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
### Sample from $q_{\sigma,\tau}(x_{\tau_i}|x_0)$
|
||||
|
||||
$$q_{\sigma,\tau}(x_t|x_0) =
|
||||
\mathcal{N} \Big(x_t; \sqrt{\alpha_{\tau_i}} x_0, (1-\alpha_{\tau_i}) \mathbf{I} \Big)$$
|
||||
|
||||
:param x0: is $x_0$ of shape `[batch_size, channels, height, width]`
|
||||
:param index: is the time step $\tau_i$ index $i$
|
||||
:param noise: is the noise, $\epsilon$
|
||||
"""
|
||||
|
||||
# Random noise, if noise is not specified
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
|
||||
# Sample from
|
||||
# $$q_{\sigma,\tau}(x_t|x_0) =
|
||||
# \mathcal{N} \Big(x_t; \sqrt{\alpha_{\tau_i}} x_0, (1-\alpha_{\tau_i}) \mathbf{I} \Big)$$
|
||||
return self.ddim_alpha_sqrt[index] * x0 + self.ddim_sqrt_one_minus_alpha[index] * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def paint(self, x: torch.Tensor, cond: torch.Tensor, t_start: int, *,
|
||||
orig: Optional[torch.Tensor] = None,
|
||||
mask: Optional[torch.Tensor] = None, orig_noise: Optional[torch.Tensor] = None,
|
||||
uncond_scale: float = 1.,
|
||||
uncond_cond: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
### Painting Loop
|
||||
|
||||
:param x: is $x_{S'}$ of shape `[batch_size, channels, height, width]`
|
||||
:param cond: is the conditional embeddings $c$
|
||||
:param t_start: is the sampling step to start from, $S'$
|
||||
:param orig: is the original image in latent page which we are in paining.
|
||||
If this is not provided, it'll be an image to image transformation.
|
||||
:param mask: is the mask to keep the original image.
|
||||
:param orig_noise: is fixed noise to be added to the original image.
|
||||
:param uncond_scale: is the unconditional guidance scale $s$. This is used for
|
||||
$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
|
||||
:param uncond_cond: is the conditional embedding for empty prompt $c_u$
|
||||
"""
|
||||
# Get batch size
|
||||
bs = x.shape[0]
|
||||
|
||||
# Time steps to sample at $\tau_{S`}, \tau_{S' - 1}, \dots, \tau_1$
|
||||
time_steps = np.flip(self.time_steps[:t_start])
|
||||
|
||||
for i, step in monit.enum('Paint', time_steps):
|
||||
# Index $i$ in the list $[\tau_1, \tau_2, \dots, \tau_S]$
|
||||
index = len(time_steps) - i - 1
|
||||
# Time step $\tau_i$
|
||||
ts = x.new_full((bs,), step, dtype=torch.long)
|
||||
|
||||
# Sample $x_{\tau_{i-1}}$
|
||||
x, _, _ = self.p_sample(x, cond, ts, step, index=index,
|
||||
uncond_scale=uncond_scale,
|
||||
uncond_cond=uncond_cond)
|
||||
|
||||
# Replace the masked area with original image
|
||||
if orig is not None:
|
||||
# Get the $q_{\sigma,\tau}(x_{\tau_i}|x_0)$ for original image in latent space
|
||||
orig_t = self.q_sample(orig, index, noise=orig_noise)
|
||||
# Replace the masked area
|
||||
x = orig_t * mask + x * (1 - mask)
|
||||
|
||||
#
|
||||
return x
|
||||
226
labml_nn/diffusion/stable_diffusion/sampler/ddpm.py
Normal file
226
labml_nn/diffusion/stable_diffusion/sampler/ddpm.py
Normal file
@ -0,0 +1,226 @@
|
||||
"""
|
||||
---
|
||||
title: Denoising Diffusion Probabilistic Models (DDPM) Sampling
|
||||
summary: >
|
||||
Annotated PyTorch implementation/tutorial of
|
||||
Denoising Diffusion Probabilistic Models (DDPM) Sampling
|
||||
for stable diffusion model.
|
||||
---
|
||||
|
||||
# Denoising Diffusion Probabilistic Models (DDPM) Sampling
|
||||
|
||||
For a simpler DDPM implementation refer to our [DDPM implementation](../../ddpm/index.html).
|
||||
We use same notations for $\alpha_t$, $\beta_t$ schedules, etc.
|
||||
"""
|
||||
|
||||
from typing import Optional, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from labml import monit
|
||||
from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
|
||||
from labml_nn.diffusion.stable_diffusion.sampler import DiffusionSampler
|
||||
|
||||
|
||||
class DDPMSampler(DiffusionSampler):
|
||||
"""
|
||||
## DDPM Sampler
|
||||
|
||||
This extends the [`DiffusionSampler` base class](index.html).
|
||||
|
||||
DDPM samples images by repeatedly removing noise by sampling step by step from
|
||||
$p_\theta(x_{t-1} | x_t)$,
|
||||
|
||||
\begin{align}
|
||||
|
||||
p_\theta(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1}; \mu_\theta(x_t, t), \tilde\beta_t \mathbf{I} \big) \\
|
||||
|
||||
\mu_t(x_t, t) &= \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}x_0
|
||||
+ \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}x_t \\
|
||||
|
||||
\tilde\beta_t &= \frac{1 - \bar\alpha_{t-1}}{1 - \bar\alpha_t} \beta_t \\
|
||||
|
||||
x_0 &= \frac{1}{\sqrt{\bar\alpha_t}} x_t - \Big(\sqrt{\frac{1}{\bar\alpha_t} - 1}\Big)\epsilon_\theta \\
|
||||
|
||||
\end{align}
|
||||
"""
|
||||
|
||||
model: LatentDiffusion
|
||||
|
||||
def __init__(self, model: LatentDiffusion):
|
||||
"""
|
||||
:param model: is the model to predict noise $\epsilon_\text{cond}(x_t, c)$
|
||||
"""
|
||||
super().__init__(model)
|
||||
|
||||
# Sampling steps $1, 2, \dots, T$
|
||||
self.time_steps = np.asarray(list(range(self.n_steps)))
|
||||
|
||||
with torch.no_grad():
|
||||
# $\bar\alpha_t$
|
||||
alpha_bar = self.model.alpha_bar
|
||||
# $\beta_t$ schedule
|
||||
beta = self.model.beta
|
||||
# $\bar\alpha_{t-1}$
|
||||
alpha_bar_prev = torch.cat([alpha_bar.new_tensor([1.]), alpha_bar[:-1]])
|
||||
|
||||
# $\sqrt{\bar\alpha}$
|
||||
self.sqrt_alpha_bar = alpha_bar ** .5
|
||||
# $\sqrt{1 - \bar{alpha}}$
|
||||
self.sqrt_1m_alpha_bar = alpha_bar ** .5
|
||||
# $\frac{1}{\sqrt{\bar\alpha_t}}$
|
||||
self.sqrt_recip_alpha_bar = alpha_bar ** -.5
|
||||
# $\sqrt{\frac{1}{\bar\alpha_t} - 1}$
|
||||
self.sqrt_recip_m1_alpha_bar = (1 / alpha_bar - 1) ** .5
|
||||
|
||||
# $\frac{1 - \bar\alpha_{t-1}}{1 - \bar\alpha_t} \beta_t$
|
||||
variance = beta * (1. - alpha_bar_prev) / (1. - alpha_bar)
|
||||
# Clamped log of $\tilde\beta_t$
|
||||
self.log_var = torch.log(torch.clamp(variance, min=1e-20))
|
||||
# $\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}$
|
||||
self.mean_x0_coef = beta * (alpha_bar_prev ** .5) / (1. - alpha_bar)
|
||||
# $\frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}$
|
||||
self.mean_xt_coef = (1. - alpha_bar_prev) * ((1 - beta) ** 0.5) / (1. - alpha_bar)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
shape: List[int],
|
||||
cond: torch.Tensor,
|
||||
repeat_noise: bool = False,
|
||||
temperature: float = 1.,
|
||||
x_last: Optional[torch.Tensor] = None,
|
||||
uncond_scale: float = 1.,
|
||||
uncond_cond: Optional[torch.Tensor] = None,
|
||||
skip_steps: int = 0,
|
||||
):
|
||||
"""
|
||||
### Sampling Loop
|
||||
|
||||
:param shape: is the shape of the generated images in the
|
||||
form `[batch_size, channels, height, width]`
|
||||
:param cond: is the conditional embeddings $c$
|
||||
:param temperature: is the noise temperature (random noise gets multiplied by this)
|
||||
:param x_last: is $x_T$. If not provided random noise will be used.
|
||||
:param uncond_scale: is the unconditional guidance scale $s$. This is used for
|
||||
$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
|
||||
:param uncond_cond: is the conditional embedding for empty prompt $c_u$
|
||||
:param skip_steps: is the number of time steps to skip $t'$. We start sampling from $T - t'$.
|
||||
And `x_last` is then $x_{T - t'}$.
|
||||
"""
|
||||
|
||||
# Get device and batch size
|
||||
device = self.model.device
|
||||
bs = shape[0]
|
||||
|
||||
# Get $x_T$
|
||||
x = x_last if x_last is not None else torch.randn(shape, device=device)
|
||||
|
||||
# Time steps to sample at $T - t', T - t' - 1, \dots, 1$
|
||||
time_steps = np.flip(self.time_steps)[skip_steps:]
|
||||
|
||||
# Sampling loop
|
||||
for step in monit.iterate('Sample', time_steps):
|
||||
# Time step $t$
|
||||
ts = x.new_full((bs,), step, dtype=torch.long)
|
||||
|
||||
# Sample $x_{t-1}$
|
||||
x, pred_x0, e_t = self.p_sample(x, cond, ts, step,
|
||||
repeat_noise=repeat_noise,
|
||||
temperature=temperature,
|
||||
uncond_scale=uncond_scale,
|
||||
uncond_cond=uncond_cond)
|
||||
|
||||
# Return $x_0$
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, x: torch.Tensor, c: torch.Tensor, t: torch.Tensor, step: int,
|
||||
repeat_noise: bool = False,
|
||||
temperature: float = 1.,
|
||||
uncond_scale: float = 1., uncond_cond: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
### Sample $x_{t-1}$ from $p_\theta(x_{t-1} | x_t)$
|
||||
|
||||
:param x: is $x_t$ of shape `[batch_size, channels, height, width]`
|
||||
:param c: is the conditional embeddings $c$ of shape `[batch_size, emb_size]`
|
||||
:param t: is $t$ of shape `[batch_size]`
|
||||
:param step: is the step $t$ as an integer
|
||||
:repeat_noise: specified whether the noise should be same for all samples in the batch
|
||||
:param temperature: is the noise temperature (random noise gets multiplied by this)
|
||||
:param uncond_scale: is the unconditional guidance scale $s$. This is used for
|
||||
$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
|
||||
:param uncond_cond: is the conditional embedding for empty prompt $c_u$
|
||||
"""
|
||||
|
||||
# Get $\epsilon_\theta$
|
||||
e_t = self.get_eps(x, t, c,
|
||||
uncond_scale=uncond_scale,
|
||||
uncond_cond=uncond_cond)
|
||||
|
||||
# Get batch size
|
||||
bs = x.shape[0]
|
||||
|
||||
# $\frac{1}{\sqrt{\bar\alpha_t}}$
|
||||
sqrt_recip_alpha_bar = x.new_full((bs, 1, 1, 1), self.sqrt_recip_alpha_bar[step])
|
||||
# $\sqrt{\frac{1}{\bar\alpha_t} - 1}$
|
||||
sqrt_recip_m1_alpha_bar = x.new_full((bs, 1, 1, 1), self.sqrt_recip_m1_alpha_bar[step])
|
||||
|
||||
# Calculate $x_0$ with current $\epsilon_\theta$
|
||||
#
|
||||
# $$x_0 = \frac{1}{\sqrt{\bar\alpha_t}} x_t - \Big(\sqrt{\frac{1}{\bar\alpha_t} - 1}\Big)\epsilon_\theta$$
|
||||
x0 = sqrt_recip_alpha_bar * x - sqrt_recip_m1_alpha_bar * e_t
|
||||
|
||||
# $\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}$
|
||||
mean_x0_coef = x.new_full((bs, 1, 1, 1), self.mean_x0_coef[step])
|
||||
# $\frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}$
|
||||
mean_xt_coef = x.new_full((bs, 1, 1, 1), self.mean_xt_coef[step])
|
||||
|
||||
# Calculate $\mu_t(x_t, t)$
|
||||
#
|
||||
# $$\mu_t(x_t, t) = \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}x_0
|
||||
# + \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}x_t$$
|
||||
mean = mean_x0_coef * x0 + mean_xt_coef * x
|
||||
# $\log \tilde\beta_t$
|
||||
log_var = x.new_full((bs, 1, 1, 1), self.log_var[step])
|
||||
|
||||
# Do not add noise when $t = 1$ (final step sampling process).
|
||||
# Note that `step` is `0` when $t = 1$)
|
||||
if step == 0:
|
||||
noise = 0
|
||||
# If same noise is used for all samples in the batch
|
||||
elif repeat_noise:
|
||||
noise = torch.randn((1, *x.shape[1:]))
|
||||
# Different noise for each sample
|
||||
else:
|
||||
noise = torch.randn(x.shape)
|
||||
|
||||
# Multiply noise by the temperature
|
||||
noise = noise * temperature
|
||||
|
||||
# Sample from,
|
||||
#
|
||||
# $$p_\theta(x_{t-1} | x_t) = \mathcal{N}\big(x_{t-1}; \mu_\theta(x_t, t), \tilde\beta_t \mathbf{I} \big)$$
|
||||
x_prev = mean + (0.5 * log_var).exp() * noise
|
||||
|
||||
#
|
||||
return x_prev, x0, e_t
|
||||
|
||||
@torch.no_grad()
|
||||
def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
### Sample from $q(x_t|x_0)$
|
||||
|
||||
$$q(x_t|x_0) = \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)$$
|
||||
|
||||
:param x0: is $x_0$ of shape `[batch_size, channels, height, width]`
|
||||
:param index: is the time step $t$ index
|
||||
:param noise: is the noise, $\epsilon$
|
||||
"""
|
||||
|
||||
# Random noise, if noise is not specified
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
|
||||
# Sample from $\mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)$
|
||||
return self.sqrt_alpha_bar[index] * x0 + self.sqrt_1m_alpha_bar[index] * noise
|
||||
13
labml_nn/diffusion/stable_diffusion/scripts/__init__.py
Normal file
13
labml_nn/diffusion/stable_diffusion/scripts/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
"""
|
||||
---
|
||||
title: Scripts to show example usages stable diffusion
|
||||
summary: >
|
||||
Annotated PyTorch implementation/tutorial of example usages of stable diffusion
|
||||
---
|
||||
|
||||
# Scripts to show example usages [stable diffusion](../index.html)
|
||||
|
||||
* [Prompt to image diffusion](text_to_image.html)
|
||||
* [Image to image diffusion](image_to_image.html)
|
||||
* [In-painting](in_paint.html)
|
||||
"""
|
||||
149
labml_nn/diffusion/stable_diffusion/scripts/image_to_image.py
Normal file
149
labml_nn/diffusion/stable_diffusion/scripts/image_to_image.py
Normal file
@ -0,0 +1,149 @@
|
||||
"""
|
||||
---
|
||||
title: Generate images using stable diffusion with a prompt from a given image
|
||||
summary: >
|
||||
Generate images using stable diffusion with a prompt from a given image
|
||||
---
|
||||
|
||||
# Generate images using [stable diffusion](../index.html) with a prompt from a given image
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from labml import lab, monit
|
||||
from labml_nn.diffusion.stable_diffusion.sampler.ddim import DDIMSampler
|
||||
from labml_nn.diffusion.stable_diffusion.util import load_model, load_img, save_images, set_seed
|
||||
|
||||
|
||||
class Img2Img:
|
||||
"""
|
||||
### Image to image class
|
||||
"""
|
||||
|
||||
def __init__(self, *, checkpoint_path: Path,
|
||||
ddim_steps: int = 50,
|
||||
ddim_eta: float = 0.0):
|
||||
"""
|
||||
:param checkpoint_path: is the path of the checkpoint
|
||||
:param ddim_steps: is the number of sampling steps
|
||||
:param ddim_eta: is the [DDIM sampling](../sampler/ddim.html) $\eta$ constant
|
||||
"""
|
||||
self.ddim_steps = ddim_steps
|
||||
|
||||
# Load [latent diffusion model](../latent_diffusion.html)
|
||||
self.model = load_model(checkpoint_path)
|
||||
# Get device
|
||||
self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
||||
# Move the model to device
|
||||
self.model.to(self.device)
|
||||
|
||||
# Initialize [DDIM sampler](../sampler/ddim.html)
|
||||
self.sampler = DDIMSampler(self.model,
|
||||
n_steps=ddim_steps,
|
||||
ddim_eta=ddim_eta)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, *,
|
||||
dest_path: str,
|
||||
orig_img: str,
|
||||
strength: float,
|
||||
batch_size: int = 3,
|
||||
prompt: str,
|
||||
uncond_scale: float = 5.0,
|
||||
):
|
||||
"""
|
||||
:param dest_path: is the path to store the generated images
|
||||
:param orig_img: is the image to transform
|
||||
:param strength: specifies how much of the original image should not be preserved
|
||||
:param batch_size: is the number of images to generate in a batch
|
||||
:param prompt: is the prompt to generate images with
|
||||
:param uncond_scale: is the unconditional guidance scale $s$. This is used for
|
||||
$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
|
||||
"""
|
||||
# Make a batch of prompts
|
||||
prompts = batch_size * [prompt]
|
||||
# Load image
|
||||
orig_image = load_img(orig_img).to(self.device)
|
||||
# Encode the image in the latent space and make `batch_size` copies of it
|
||||
orig = self.model.autoencoder_encode(orig_image).repeat(batch_size, 1, 1, 1)
|
||||
|
||||
# Get the number of steps to diffuse the original
|
||||
assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||
t_index = int(strength * self.ddim_steps)
|
||||
|
||||
# AMP auto casting
|
||||
with torch.cuda.amp.autocast():
|
||||
# In unconditional scaling is not $1$ get the embeddings for empty prompts (no conditioning).
|
||||
if uncond_scale != 1.0:
|
||||
un_cond = self.model.get_text_conditioning(batch_size * [""])
|
||||
else:
|
||||
un_cond = None
|
||||
# Get the prompt embeddings
|
||||
cond = self.model.get_text_conditioning(prompts)
|
||||
# Add noise to the original image
|
||||
x = self.sampler.q_sample(orig, t_index)
|
||||
# Reconstruct from the noisy image
|
||||
x = self.sampler.paint(x, cond, t_index,
|
||||
uncond_scale=uncond_scale,
|
||||
uncond_cond=un_cond)
|
||||
# Decode the image from the [autoencoder](../model/autoencoder.html)
|
||||
images = self.model.autoencoder_decode(x)
|
||||
|
||||
# Save images
|
||||
save_images(images, dest_path, 'img_')
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
### CLI
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
nargs="?",
|
||||
default="a painting of a cute monkey playing guitar",
|
||||
help="the prompt to render"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--orig-img",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="path to the input image"
|
||||
)
|
||||
|
||||
parser.add_argument("--batch_size", type=int, default=4, help="batch size", )
|
||||
parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps")
|
||||
|
||||
parser.add_argument("--scale", type=float, default=5.0,
|
||||
help="unconditional guidance scale: "
|
||||
"eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))")
|
||||
|
||||
parser.add_argument("--strength", type=float, default=0.75,
|
||||
help="strength for noise: "
|
||||
" 1.0 corresponds to full destruction of information in init image")
|
||||
|
||||
opt = parser.parse_args()
|
||||
set_seed(42)
|
||||
|
||||
img2img = Img2Img(checkpoint_path=lab.get_data_path() / 'stable-diffusion' / 'sd-v1-4.ckpt',
|
||||
ddim_steps=opt.steps)
|
||||
|
||||
with monit.section('Generate'):
|
||||
img2img(
|
||||
dest_path='outputs',
|
||||
orig_img=opt.orig_img,
|
||||
strength=opt.strength,
|
||||
batch_size=opt.batch_size,
|
||||
prompt=opt.prompt,
|
||||
uncond_scale=opt.scale)
|
||||
|
||||
|
||||
#
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
166
labml_nn/diffusion/stable_diffusion/scripts/in_paint.py
Normal file
166
labml_nn/diffusion/stable_diffusion/scripts/in_paint.py
Normal file
@ -0,0 +1,166 @@
|
||||
"""
|
||||
---
|
||||
title: In-paint images using stable diffusion with a prompt
|
||||
summary: >
|
||||
In-paint images using stable diffusion with a prompt
|
||||
---
|
||||
|
||||
# In-paint images using [stable diffusion](../index.html) with a prompt
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from labml import lab, monit
|
||||
from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
|
||||
from labml_nn.diffusion.stable_diffusion.sampler import DiffusionSampler
|
||||
from labml_nn.diffusion.stable_diffusion.sampler.ddim import DDIMSampler
|
||||
from labml_nn.diffusion.stable_diffusion.util import load_model, save_images, load_img, set_seed
|
||||
|
||||
|
||||
class InPaint:
|
||||
"""
|
||||
### Image in-painting class
|
||||
"""
|
||||
model: LatentDiffusion
|
||||
sampler: DiffusionSampler
|
||||
|
||||
def __init__(self, *, checkpoint_path: Path,
|
||||
ddim_steps: int = 50,
|
||||
ddim_eta: float = 0.0):
|
||||
"""
|
||||
:param checkpoint_path: is the path of the checkpoint
|
||||
:param ddim_steps: is the number of sampling steps
|
||||
:param ddim_eta: is the [DDIM sampling](../sampler/ddim.html) $\eta$ constant
|
||||
"""
|
||||
self.ddim_steps = ddim_steps
|
||||
|
||||
# Load [latent diffusion model](../latent_diffusion.html)
|
||||
self.model = load_model(checkpoint_path)
|
||||
# Get device
|
||||
self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
||||
# Move the model to device
|
||||
self.model.to(self.device)
|
||||
|
||||
# Initialize [DDIM sampler](../sampler/ddim.html)
|
||||
self.sampler = DDIMSampler(self.model,
|
||||
n_steps=ddim_steps,
|
||||
ddim_eta=ddim_eta)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, *,
|
||||
dest_path: str,
|
||||
orig_img: str,
|
||||
strength: float,
|
||||
batch_size: int = 3,
|
||||
prompt: str,
|
||||
uncond_scale: float = 5.0,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
:param dest_path: is the path to store the generated images
|
||||
:param orig_img: is the image to transform
|
||||
:param strength: specifies how much of the original image should not be preserved
|
||||
:param batch_size: is the number of images to generate in a batch
|
||||
:param prompt: is the prompt to generate images with
|
||||
:param uncond_scale: is the unconditional guidance scale $s$. This is used for
|
||||
$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
|
||||
"""
|
||||
# Make a batch of prompts
|
||||
prompts = batch_size * [prompt]
|
||||
# Load image
|
||||
orig_image = load_img(orig_img).to(self.device)
|
||||
# Encode the image in the latent space and make `batch_size` copies of it
|
||||
orig = self.model.autoencoder_encode(orig_image).repeat(batch_size, 1, 1, 1)
|
||||
# If `mask` is not provided,
|
||||
# we set a sample mask to preserve the bottom half of the image
|
||||
if mask is None:
|
||||
mask = torch.zeros_like(orig, device=self.device)
|
||||
mask[:, :, mask.shape[2] // 2:, :] = 1.
|
||||
else:
|
||||
mask = mask.to(self.device)
|
||||
# Noise diffuse the original image
|
||||
orig_noise = torch.randn(orig.shape, device=self.device)
|
||||
|
||||
# Get the number of steps to diffuse the original
|
||||
assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||
t_index = int(strength * self.ddim_steps)
|
||||
|
||||
# AMP auto casting
|
||||
with torch.cuda.amp.autocast():
|
||||
# In unconditional scaling is not $1$ get the embeddings for empty prompts (no conditioning).
|
||||
if uncond_scale != 1.0:
|
||||
un_cond = self.model.get_text_conditioning(batch_size * [""])
|
||||
else:
|
||||
un_cond = None
|
||||
# Get the prompt embeddings
|
||||
cond = self.model.get_text_conditioning(prompts)
|
||||
# Add noise to the original image
|
||||
x = self.sampler.q_sample(orig, t_index, noise=orig_noise)
|
||||
# Reconstruct from the noisy image, while preserving the masked area
|
||||
x = self.sampler.paint(x, cond, t_index,
|
||||
orig=orig,
|
||||
mask=mask,
|
||||
orig_noise=orig_noise,
|
||||
uncond_scale=uncond_scale,
|
||||
uncond_cond=un_cond)
|
||||
# Decode the image from the [autoencoder](../model/autoencoder.html)
|
||||
images = self.model.autoencoder_decode(x)
|
||||
|
||||
# Save images
|
||||
save_images(images, dest_path, 'paint_')
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
### CLI
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
nargs="?",
|
||||
default="a painting of a cute monkey playing guitar",
|
||||
help="the prompt to render"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--orig-img",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="path to the input image"
|
||||
)
|
||||
|
||||
parser.add_argument("--batch_size", type=int, default=4, help="batch size", )
|
||||
parser.add_argument("--steps", type=int, default=50, help="number of sampling steps")
|
||||
|
||||
parser.add_argument("--scale", type=float, default=5.0,
|
||||
help="unconditional guidance scale: "
|
||||
"eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))")
|
||||
|
||||
parser.add_argument("--strength", type=float, default=0.75,
|
||||
help="strength for noise: "
|
||||
" 1.0 corresponds to full destruction of information in init image")
|
||||
|
||||
opt = parser.parse_args()
|
||||
set_seed(42)
|
||||
|
||||
in_paint = InPaint(checkpoint_path=lab.get_data_path() / 'stable-diffusion' / 'sd-v1-4.ckpt',
|
||||
ddim_steps=opt.steps)
|
||||
|
||||
with monit.section('Generate'):
|
||||
in_paint(dest_path='outputs',
|
||||
orig_img=opt.orig_img,
|
||||
strength=opt.strength,
|
||||
batch_size=opt.batch_size,
|
||||
prompt=opt.prompt,
|
||||
uncond_scale=opt.scale)
|
||||
|
||||
|
||||
#
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
151
labml_nn/diffusion/stable_diffusion/scripts/text_to_image.py
Normal file
151
labml_nn/diffusion/stable_diffusion/scripts/text_to_image.py
Normal file
@ -0,0 +1,151 @@
|
||||
"""
|
||||
---
|
||||
title: Generate images using stable diffusion with a prompt
|
||||
summary: >
|
||||
Generate images using stable diffusion with a prompt
|
||||
---
|
||||
|
||||
# Generate images using [stable diffusion](../index.html) with a prompt
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from labml import lab, monit
|
||||
from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
|
||||
from labml_nn.diffusion.stable_diffusion.sampler.ddim import DDIMSampler
|
||||
from labml_nn.diffusion.stable_diffusion.sampler.ddpm import DDPMSampler
|
||||
from labml_nn.diffusion.stable_diffusion.util import load_model, save_images, set_seed
|
||||
|
||||
|
||||
class Txt2Img:
|
||||
"""
|
||||
### Text to image class
|
||||
"""
|
||||
model: LatentDiffusion
|
||||
|
||||
def __init__(self, *,
|
||||
checkpoint_path: Path,
|
||||
sampler_name: str,
|
||||
n_steps: int = 50,
|
||||
ddim_eta: float = 0.0,
|
||||
):
|
||||
"""
|
||||
:param checkpoint_path: is the path of the checkpoint
|
||||
:param sampler_name: is the name of the [sampler](../sampler/index.html)
|
||||
:param n_steps: is the number of sampling steps
|
||||
:param ddim_eta: is the [DDIM sampling](../sampler/ddim.html) $\eta$ constant
|
||||
"""
|
||||
# Load [latent diffusion model](../latent_diffusion.html)
|
||||
self.model = load_model(checkpoint_path)
|
||||
# Get device
|
||||
self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
||||
# Move the model to device
|
||||
self.model.to(self.device)
|
||||
|
||||
# Initialize [sampler](../sampler/index.html)
|
||||
if sampler_name == 'ddim':
|
||||
self.sampler = DDIMSampler(self.model,
|
||||
n_steps=n_steps,
|
||||
ddim_eta=ddim_eta)
|
||||
elif sampler_name == 'ddpm':
|
||||
self.sampler = DDPMSampler(self.model)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, *,
|
||||
dest_path: str,
|
||||
batch_size: int = 3,
|
||||
prompt: str,
|
||||
h: int = 512, w: int = 512,
|
||||
uncond_scale: float = 7.5,
|
||||
):
|
||||
"""
|
||||
:param dest_path: is the path to store the generated images
|
||||
:param batch_size: is the number of images to generate in a batch
|
||||
:param prompt: is the prompt to generate images with
|
||||
:param h: is the height of the image
|
||||
:param w: is the width of the image
|
||||
:param uncond_scale: is the unconditional guidance scale $s$. This is used for
|
||||
$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
|
||||
"""
|
||||
# Number of channels in the image
|
||||
c = 4
|
||||
# Image to latent space resolution reduction
|
||||
f = 8
|
||||
|
||||
# Make a batch of prompts
|
||||
prompts = batch_size * [prompt]
|
||||
|
||||
# AMP auto casting
|
||||
with torch.cuda.amp.autocast():
|
||||
# In unconditional scaling is not $1$ get the embeddings for empty prompts (no conditioning).
|
||||
if uncond_scale != 1.0:
|
||||
un_cond = self.model.get_text_conditioning(batch_size * [""])
|
||||
else:
|
||||
un_cond = None
|
||||
# Get the prompt embeddings
|
||||
cond = self.model.get_text_conditioning(prompts)
|
||||
# [Sample in the latent space](../sampler/index.html).
|
||||
# `x` will be of shape `[batch_size, c, h / f, w / f]`
|
||||
x = self.sampler.sample(cond=cond,
|
||||
shape=[batch_size, c, h // f, w // f],
|
||||
uncond_scale=uncond_scale,
|
||||
uncond_cond=un_cond)
|
||||
# Decode the image from the [autoencoder](../model/autoencoder.html)
|
||||
images = self.model.autoencoder_decode(x)
|
||||
|
||||
# Save images
|
||||
save_images(images, dest_path, 'txt_')
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
### CLI
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
nargs="?",
|
||||
default="a painting of a virus monster playing guitar",
|
||||
help="the prompt to render"
|
||||
)
|
||||
|
||||
parser.add_argument("--batch_size", type=int, default=4, help="batch size", )
|
||||
|
||||
parser.add_argument(
|
||||
'--sampler',
|
||||
dest='sampler_name',
|
||||
choices=['ddim', 'ddpm'],
|
||||
default='plms',
|
||||
help=f'Set the sampler.',
|
||||
)
|
||||
|
||||
parser.add_argument("--steps", type=int, default=50, help="number of sampling steps", )
|
||||
|
||||
parser.add_argument("--scale", type=float, default=7.5,
|
||||
help="unconditional guidance scale: "
|
||||
"eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))")
|
||||
|
||||
opt = parser.parse_args()
|
||||
|
||||
set_seed(42)
|
||||
|
||||
txt2img = Txt2Img(checkpoint_path=lab.get_data_path() / 'stable-diffusion' / 'sd-v1-4.ckpt',
|
||||
sampler_name=opt.sampler_name,
|
||||
n_steps=opt.steps)
|
||||
|
||||
with monit.section('Generate'):
|
||||
txt2img(dest_path='outputs',
|
||||
batch_size=opt.batch_size,
|
||||
prompt=opt.prompt,
|
||||
uncond_scale=opt.scale)
|
||||
|
||||
|
||||
#
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
151
labml_nn/diffusion/stable_diffusion/util.py
Normal file
151
labml_nn/diffusion/stable_diffusion/util.py
Normal file
@ -0,0 +1,151 @@
|
||||
"""
|
||||
---
|
||||
title: Utility functions for stable diffusion
|
||||
summary: >
|
||||
Utility functions for stable diffusion
|
||||
---
|
||||
|
||||
# Utility functions for [stable diffusion](index.html)
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import PIL
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from labml import monit
|
||||
from labml.logger import inspect
|
||||
from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
|
||||
from labml_nn.diffusion.stable_diffusion.model.autoencoder import Encoder, Decoder, Autoencoder
|
||||
from labml_nn.diffusion.stable_diffusion.model.clip_embedder import CLIPTextEmbedder
|
||||
from labml_nn.diffusion.stable_diffusion.model.unet import UNetModel
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
"""
|
||||
### Set random seeds
|
||||
"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def load_model(path: Path = None) -> LatentDiffusion:
|
||||
"""
|
||||
### Load [`LatentDiffusion` model](latent_diffusion.html)
|
||||
"""
|
||||
|
||||
# Initialize the autoencoder
|
||||
with monit.section('Initialize autoencoder'):
|
||||
encoder = Encoder(z_channels=4,
|
||||
in_channels=3,
|
||||
channels=128,
|
||||
channel_multipliers=[1, 2, 4, 4],
|
||||
n_resnet_blocks=2)
|
||||
|
||||
decoder = Decoder(out_channels=3,
|
||||
z_channels=4,
|
||||
channels=128,
|
||||
channel_multipliers=[1, 2, 4, 4],
|
||||
n_resnet_blocks=2)
|
||||
|
||||
autoencoder = Autoencoder(emb_channels=4,
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
z_channels=4)
|
||||
|
||||
# Initialize the CLIP text embedder
|
||||
with monit.section('Initialize CLIP Embedder'):
|
||||
clip_text_embedder = CLIPTextEmbedder()
|
||||
|
||||
# Initialize the U-Net
|
||||
with monit.section('Initialize U-Net'):
|
||||
unet_model = UNetModel(in_channels=4,
|
||||
out_channels=4,
|
||||
channels=320,
|
||||
attention_levels=[0, 1, 2],
|
||||
n_res_blocks=2,
|
||||
channel_multipliers=[1, 2, 4, 4],
|
||||
n_heads=8,
|
||||
tf_layers=1,
|
||||
d_cond=768)
|
||||
|
||||
# Initialize the Latent Diffusion model
|
||||
with monit.section('Initialize Latent Diffusion model'):
|
||||
model = LatentDiffusion(linear_start=0.00085,
|
||||
linear_end=0.0120,
|
||||
n_steps=1000,
|
||||
latent_scaling_factor=0.18215,
|
||||
|
||||
autoencoder=autoencoder,
|
||||
clip_embedder=clip_text_embedder,
|
||||
unet_model=unet_model)
|
||||
|
||||
# Load the checkpoint
|
||||
with monit.section(f"Loading model from {path}"):
|
||||
checkpoint = torch.load(path, map_location="cpu")
|
||||
|
||||
# Set model state
|
||||
with monit.section('Load state'):
|
||||
missing_keys, extra_keys = model.load_state_dict(checkpoint["state_dict"], strict=False)
|
||||
|
||||
# Debugging output
|
||||
inspect(global_step=checkpoint.get('global_step', -1), missing_keys=missing_keys, extra_keys=extra_keys,
|
||||
_expand=True)
|
||||
|
||||
#
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def load_img(path: str):
|
||||
"""
|
||||
### Load an image
|
||||
|
||||
This loads an image from a file and returns a PyTorch tensor.
|
||||
|
||||
:param path: is the path of the image
|
||||
"""
|
||||
# Open Image
|
||||
image = Image.open(path).convert("RGB")
|
||||
# Get image size
|
||||
w, h = image.size
|
||||
# Resize to a multiple of 32
|
||||
w = w - w % 32
|
||||
h = h - h % 32
|
||||
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
||||
# Convert to numpy and map to `[-1, 1]` for `[0, 255]`
|
||||
image = np.array(image).astype(np.float32) * (2. / 255.0) - 1
|
||||
# Transpose to shape `[batch_size, channels, height, width]`
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
# Convert to torch
|
||||
return torch.from_numpy(image)
|
||||
|
||||
|
||||
def save_images(images: torch.Tensor, dest_path: str, prefix: str = '', img_format: str = 'jpeg'):
|
||||
"""
|
||||
### Save a images
|
||||
|
||||
:param images: is the tensor with images of shape `[batch_size, channels, height, width]`
|
||||
:param dest_path: is the folder to save images in
|
||||
:param prefix: is the prefix to add to file names
|
||||
:param img_format: is the image format
|
||||
"""
|
||||
|
||||
# Create the destination folder
|
||||
os.makedirs(dest_path, exist_ok=True)
|
||||
|
||||
# Map images to `[0, 1]` space and clip
|
||||
images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
# Transpose to `[batch_size, height, width, channels]` and convert to numpy
|
||||
images = images.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
# Save images
|
||||
for i, img in enumerate(images):
|
||||
img = Image.fromarray((255. * img).astype(np.uint8))
|
||||
img.save(os.path.join(dest_path, f"{prefix}{i:05}.{img_format}"), format=img_format)
|
||||
@ -8,7 +8,8 @@ summary: >
|
||||
# GPT-NeoX Checkpoints
|
||||
|
||||
"""
|
||||
from typing import Dict, Union, Tuple
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union, Tuple, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -19,12 +20,23 @@ from labml.utils.download import download_file
|
||||
|
||||
# Parent url
|
||||
CHECKPOINTS_URL = 'https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/'
|
||||
# Download path
|
||||
|
||||
CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox_fast' / 'slim_weights'
|
||||
if not CHECKPOINTS_DOWNLOAD_PATH.exists():
|
||||
CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox' / 'slim_weights'
|
||||
inspect(neox_checkpoint_path=CHECKPOINTS_DOWNLOAD_PATH)
|
||||
_CHECKPOINTS_DOWNLOAD_PATH: Optional[Path] = None
|
||||
|
||||
|
||||
# Download path
|
||||
def get_checkpoints_download_path():
|
||||
global _CHECKPOINTS_DOWNLOAD_PATH
|
||||
|
||||
if _CHECKPOINTS_DOWNLOAD_PATH is not None:
|
||||
return _CHECKPOINTS_DOWNLOAD_PATH
|
||||
|
||||
_CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox_fast' / 'slim_weights'
|
||||
if not _CHECKPOINTS_DOWNLOAD_PATH.exists():
|
||||
_CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox' / 'slim_weights'
|
||||
inspect(neox_checkpoint_path=_CHECKPOINTS_DOWNLOAD_PATH)
|
||||
|
||||
return _CHECKPOINTS_DOWNLOAD_PATH
|
||||
|
||||
|
||||
def get_files_to_download(n_layers: int = 44):
|
||||
@ -65,7 +77,7 @@ def download(n_layers: int = 44):
|
||||
# Log
|
||||
logger.log(['Downloading ', (f'{i + 1 :3d}/{len(files)}', Text.meta), ': ', (f, Text.value)])
|
||||
# Download
|
||||
download_file(CHECKPOINTS_URL + f, CHECKPOINTS_DOWNLOAD_PATH / f)
|
||||
download_file(CHECKPOINTS_URL + f, get_checkpoints_download_path() / f)
|
||||
|
||||
|
||||
def load_checkpoint_files(files: Tuple[str, str]):
|
||||
@ -75,7 +87,7 @@ def load_checkpoint_files(files: Tuple[str, str]):
|
||||
:param files: pair of files to load
|
||||
:return: the loaded parameter tensors
|
||||
"""
|
||||
checkpoint_path = CHECKPOINTS_DOWNLOAD_PATH / 'global_step150000'
|
||||
checkpoint_path = get_checkpoints_download_path() / 'global_step150000'
|
||||
with monit.section('Load checkpoint'):
|
||||
data = [torch.load(checkpoint_path / f) for f in files]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user