stable diffusion

This commit is contained in:
Varuna Jayasiri
2022-09-15 11:28:53 +05:30
parent 7d1550dd67
commit 999adb0cfa
120 changed files with 13063 additions and 1242 deletions

View File

@ -8,4 +8,7 @@ summary: >
# Diffusion models
* [Denoising Diffusion Probabilistic Models (DDPM)](ddpm/index.html)
* [Stable Diffusion](stable_diffusion/index.html)
* [Latent Diffusion Model](stable_diffusion/latent_diffusion.html)
* [Denoising Diffusion Implicit Models (DDIM) Sampling](stable_diffusion/sampler/ddim.html)
"""

View File

@ -0,0 +1,48 @@
"""
---
title: Stable Diffusion
summary: >
Annotated PyTorch implementation/tutorial of stable diffusion.
---
# Stable Diffusion
This is based on official stable diffusion repository
[CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion).
We have kept the model structure same so that open sourced weights could be directly loaded.
Our implementation does not contain training code.
### [PromptArt](https://promptart.labml.ai)
We have deployed a stable diffusion based image generation service
at [promptart.labml.ai](https://promptart.labml.ai)
### [Latent Diffusion Model](latent_diffusion.html)
The core is the [Latent Diffusion Model](latent_diffusion.html).
It consists of:
* [AutoEncoder](model/autoencoder.html)
* [U-Net](model/unet.html) with [attention](model/unet_attention.html)
The diffusion is conditioned based on [CLIP embeddings](model/clip_embedder.html).
### [Sampling Algorithms](sampler/index.html)
We have implemented the following [sampling algorithms](sampler/index.html):
* [Denoising Diffusion Probabilistic Models (DDPM) Sampling](sampler/ddpm.html)
* [Denoising Diffusion Implicit Models (DDIM) Sampling](sampler/ddim.html)
### [Example Scripts](scripts/index.html)
Here are the image generation scripts:
* [Generate images from text prompts](scripts/text_to_image.html)
* [Generate images based on a given image, guided by a prompt](scripts/image_to_image.html)
* [Modify parts of a given image based on a text prompt](scripts/in_paint.html)
#### [Utilities](util.html)
[`util.py`](util.html) defines the utility functions.
"""

View File

@ -0,0 +1,146 @@
"""
---
title: Latent Diffusion Models
summary: >
Annotated PyTorch implementation/tutorial of latent diffusion models from paper
High-Resolution Image Synthesis with Latent Diffusion Models
---
# Latent Diffusion Models
Latent diffusion models use an auto-encoder to map between image space and
latent space. The diffusion model works on the diffusion space, which makes it
a lot easier to train.
It is based on paper
[High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752).
They use a pre-trained auto-encoder and train the diffusion U-Net on the latent
space of the pre-trained auto-encoder.
For a simpler diffusion implementation refer to our [DDPM implementation](../ddpm/index.html).
We use same notations for $\alpha_t$, $\beta_t$ schedules, etc.
"""
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional
from labml_nn.diffusion.stable_diffusion.model.autoencoder import Autoencoder
from labml_nn.diffusion.stable_diffusion.model.clip_embedder import CLIPTextEmbedder
from labml_nn.diffusion.stable_diffusion.model.unet import UNetModel
class DiffusionWrapper(nn.Module):
"""
*This is an empty wrapper class around the [U-Net](model/unet.html).
We keep this to have the same model structure as
[CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion)
so that we do not have to map the checkpoint weights explicitly*.
"""
def __init__(self, diffusion_model: UNetModel):
super().__init__()
self.diffusion_model = diffusion_model
def forward(self, x: torch.Tensor, time_steps: torch.Tensor, context: torch.Tensor):
return self.diffusion_model(x, time_steps, context)
class LatentDiffusion(nn.Module):
"""
## Latent diffusion model
This contains following components:
* [AutoEncoder](model/autoencoder.html)
* [U-Net](model/unet.html) with [attention](model/unet_attention.html)
* [CLIP embeddings generator](model/clip_embedder.html)
"""
model: DiffusionWrapper
first_stage_model: Autoencoder
cond_stage_model: CLIPTextEmbedder
def __init__(self,
unet_model: UNetModel,
autoencoder: Autoencoder,
clip_embedder: CLIPTextEmbedder,
latent_scaling_factor: float,
n_steps: int,
linear_start: float,
linear_end: float,
):
"""
:param unet_model: is the [U-Net](model/unet.html) that predicts noise
$\epsilon_\text{cond}(x_t, c)$, in latent space
:param autoencoder: is the [AutoEncoder](model/autoencoder.html)
:param clip_embedder: is the [CLIP embeddings generator](model/clip_embedder.html)
:param latent_scaling_factor: is the scaling factor for the latent space. The encodings of
the autoencoder are scaled by this before feeding into the U-Net.
:param n_steps: is the number of diffusion steps $T$.
:param linear_start: is the start of the $\beta$ schedule.
:param linear_end: is the end of the $\beta$ schedule.
"""
super().__init__()
# Wrap the [U-Net](model/unet.html) to keep the same model structure as
# [CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion).
self.model = DiffusionWrapper(unet_model)
# Auto-encoder and scaling factor
self.first_stage_model = autoencoder
self.latent_scaling_factor = latent_scaling_factor
# [CLIP embeddings generator](model/clip_embedder.html)
self.cond_stage_model = clip_embedder
# Number of steps $T$
self.n_steps = n_steps
# $\beta$ schedule
beta = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_steps, dtype=torch.float64) ** 2
self.beta = nn.Parameter(beta.to(torch.float32), requires_grad=False)
# $\alpha_t = 1 - \beta_t$
alpha = 1. - beta
# $\bar\alpha_t = \prod_{s=1}^t \alpha_s$
alpha_bar = torch.cumprod(alpha, dim=0)
self.alpha_bar = nn.Parameter(alpha_bar.to(torch.float32), requires_grad=False)
@property
def device(self):
"""
### Get model device
"""
return next(iter(self.model.parameters())).device
def get_text_conditioning(self, prompts: List[str]):
"""
### Get [CLIP embeddings](model/clip_embedder.html) for a list of text prompts
"""
return self.cond_stage_model(prompts)
def autoencoder_encode(self, image: torch.Tensor):
"""
### Get scaled latent space representation of the image
The encoder output is a distribution.
We sample from that and multiply by the scaling factor.
"""
return self.latent_scaling_factor * self.first_stage_model.encode(image).sample()
def autoencoder_decode(self, z: torch.Tensor):
"""
### Get image from the latent representation
We scale down by the scaling factor and then decode.
"""
return self.first_stage_model.decode(z / self.latent_scaling_factor)
def forward(self, x: torch.Tensor, t: torch.Tensor, context: torch.Tensor):
"""
### Predict noise
Predict noise given the latent representation $x_t$, time step $t$, and the
conditioning context $c$.
$$\epsilon_\text{cond}(x_t, c)$$
"""
return self.model(x, t, context)

View File

@ -0,0 +1,13 @@
"""
---
title: Modules used in stable diffusion
summary: >
Models and components for stable diffusion.
---
# [Stable Diffusion](../index.html) Models
* [AutoEncoder](autoencoder.html)
* [U-Net](unet.html) with [attention](unet_attention.html)
* [CLIP embedder](clip_embedder.html).
"""

View File

@ -0,0 +1,433 @@
"""
---
title: Autoencoder for Stable Diffusion
summary: >
Annotated PyTorch implementation/tutorial of the autoencoder
for stable diffusion.
---
# Autoencoder for [Stable Diffusion](../index.html)
This implements the auto-encoder model used to map between image space and latent space.
We have kept to the model definition and naming unchanged from
[CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion)
so that we can load the checkpoints directly.
"""
from typing import List
import torch
import torch.nn.functional as F
from torch import nn
class Autoencoder(nn.Module):
"""
## Autoencoder
This consists of the encoder and decoder modules.
"""
def __init__(self, encoder: 'Encoder', decoder: 'Decoder', emb_channels: int, z_channels: int):
"""
:param encoder: is the encoder
:param decoder: is the decoder
:param emb_channels: is the number of dimensions in the quantized embedding space
:param z_channels: is the number of channels in the embedding space
"""
super().__init__()
self.encoder = encoder
self.decoder = decoder
# Convolution to map from embedding space to
# quantized embedding space moments (mean and log variance)
self.quant_conv = nn.Conv2d(2 * z_channels, 2 * emb_channels, 1)
# Convolution to map from quantized embedding space back to
# embedding space
self.post_quant_conv = nn.Conv2d(emb_channels, z_channels, 1)
def encode(self, img: torch.Tensor) -> 'GaussianDistribution':
"""
### Encode images to latent representation
:param img: is the image tensor with shape `[batch_size, img_channels, img_height, img_width]`
"""
# Get embeddings with shape `[batch_size, z_channels * 2, z_height, z_height]`
z = self.encoder(img)
# Get the moments in the quantized embedding space
moments = self.quant_conv(z)
# Return the distribution
return GaussianDistribution(moments)
def decode(self, z: torch.Tensor):
"""
### Decode images from latent representation
:param z: is the latent representation with shape `[batch_size, emb_channels, z_height, z_height]`
"""
# Map to embedding space from the quantized representation
z = self.post_quant_conv(z)
# Decode the image of shape `[batch_size, channels, height, width]`
return self.decoder(z)
class Encoder(nn.Module):
"""
## Encoder module
"""
def __init__(self, *, channels: int, channel_multipliers: List[int], n_resnet_blocks: int,
in_channels: int, z_channels: int):
"""
:param channels: is the number of channels in the first convolution layer
:param channel_multipliers: are the multiplicative factors for the number of channels in the
subsequent blocks
:param n_resnet_blocks: is the number of resnet layers at each resolution
:param in_channels: is the number of channels in the image
:param z_channels: is the number of channels in the embedding space
"""
super().__init__()
# Number of blocks of different resolutions.
# The resolution is halved at the end each top level block
n_resolutions = len(channel_multipliers)
# Initial $3 \times 3$ convolution layer that maps the image to `channels`
self.conv_in = nn.Conv2d(in_channels, channels, 3, stride=1, padding=1)
# Number of channels in each top level block
channels_list = [m * channels for m in [1] + channel_multipliers]
# List of top-level blocks
self.down = nn.ModuleList()
# Create top-level blocks
for i in range(n_resolutions):
# Each top level block consists of multiple ResNet Blocks and down-sampling
resnet_blocks = nn.ModuleList()
# Add ResNet Blocks
for _ in range(n_resnet_blocks):
resnet_blocks.append(ResnetBlock(channels, channels_list[i + 1]))
channels = channels_list[i + 1]
# Top-level block
down = nn.Module()
down.block = resnet_blocks
# Down-sampling at the end of each top level block except the last
if i != n_resolutions - 1:
down.downsample = DownSample(channels)
else:
down.downsample = nn.Identity()
#
self.down.append(down)
# Final ResNet blocks with attention
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(channels, channels)
self.mid.attn_1 = AttnBlock(channels)
self.mid.block_2 = ResnetBlock(channels, channels)
# Map to embedding space with a $3 \times 3$ convolution
self.norm_out = normalization(channels)
self.conv_out = nn.Conv2d(channels, 2 * z_channels, 3, stride=1, padding=1)
def forward(self, img: torch.Tensor):
"""
:param img: is the image tensor with shape `[batch_size, img_channels, img_height, img_width]`
"""
# Map to `channels` with the initial convolution
x = self.conv_in(img)
# Top-level blocks
for down in self.down:
# ResNet Blocks
for block in down.block:
x = block(x)
# Down-sampling
x = down.downsample(x)
# Final ResNet blocks with attention
x = self.mid.block_1(x)
x = self.mid.attn_1(x)
x = self.mid.block_2(x)
# Normalize and map to embedding space
x = self.norm_out(x)
x = swish(x)
x = self.conv_out(x)
#
return x
class Decoder(nn.Module):
"""
## Decoder module
"""
def __init__(self, *, channels: int, channel_multipliers: List[int], n_resnet_blocks: int,
out_channels: int, z_channels: int):
"""
:param channels: is the number of channels in the final convolution layer
:param channel_multipliers: are the multiplicative factors for the number of channels in the
previous blocks, in reverse order
:param n_resnet_blocks: is the number of resnet layers at each resolution
:param out_channels: is the number of channels in the image
:param z_channels: is the number of channels in the embedding space
"""
super().__init__()
# Number of blocks of different resolutions.
# The resolution is halved at the end each top level block
num_resolutions = len(channel_multipliers)
# Number of channels in each top level block, in the reverse order
channels_list = [m * channels for m in channel_multipliers]
# Number of channels in the top-level block
channels = channels_list[-1]
# Initial $3 \times 3$ convolution layer that maps the embedding space to `channels`
self.conv_in = nn.Conv2d(z_channels, channels, 3, stride=1, padding=1)
# ResNet blocks with attention
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(channels, channels)
self.mid.attn_1 = AttnBlock(channels)
self.mid.block_2 = ResnetBlock(channels, channels)
# List of top-level blocks
self.up = nn.ModuleList()
# Create top-level blocks
for i in reversed(range(num_resolutions)):
# Each top level block consists of multiple ResNet Blocks and up-sampling
resnet_blocks = nn.ModuleList()
# Add ResNet Blocks
for _ in range(n_resnet_blocks + 1):
resnet_blocks.append(ResnetBlock(channels, channels_list[i]))
channels = channels_list[i]
# Top-level block
up = nn.Module()
up.block = resnet_blocks
# Up-sampling at the end of each top level block except the first
if i != 0:
up.upsample = UpSample(channels)
else:
up.upsample = nn.Identity()
# Prepend to be consistent with the checkpoint
self.up.insert(0, up)
# Map to image space with a $3 \times 3$ convolution
self.norm_out = normalization(channels)
self.conv_out = nn.Conv2d(channels, out_channels, 3, stride=1, padding=1)
def forward(self, z: torch.Tensor):
"""
:param z: is the embedding tensor with shape `[batch_size, z_channels, z_height, z_height]`
"""
# Map to `channels` with the initial convolution
h = self.conv_in(z)
# ResNet blocks with attention
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# Top-level blocks
for up in reversed(self.up):
# ResNet Blocks
for block in up.block:
h = block(h)
# Up-sampling
h = up.upsample(h)
# Normalize and map to image space
h = self.norm_out(h)
h = swish(h)
img = self.conv_out(h)
#
return img
class GaussianDistribution:
"""
## Gaussian Distribution
"""
def __init__(self, parameters: torch.Tensor):
"""
:param parameters: are the means and log of variances of the embedding of shape
`[batch_size, z_channels * 2, z_height, z_height]`
"""
# Split mean and log of variance
self.mean, log_var = torch.chunk(parameters, 2, dim=1)
# Clamp the log of variances
self.log_var = torch.clamp(log_var, -30.0, 20.0)
# Calculate standard deviation
self.std = torch.exp(0.5 * self.log_var)
def sample(self):
# Sample from the distribution
return self.mean + self.std * torch.randn_like(self.std)
class AttnBlock(nn.Module):
"""
## Attention block
"""
def __init__(self, channels: int):
"""
:param channels: is the number of channels
"""
super().__init__()
# Group normalization
self.norm = normalization(channels)
# Query, key and value mappings
self.q = nn.Conv2d(channels, channels, 1)
self.k = nn.Conv2d(channels, channels, 1)
self.v = nn.Conv2d(channels, channels, 1)
# Final $1 \times 1$ convolution layer
self.proj_out = nn.Conv2d(channels, channels, 1)
# Attention scaling factor
self.scale = channels ** -0.5
def forward(self, x: torch.Tensor):
"""
:param x: is the tensor of shape `[batch_size, channels, height, width]`
"""
# Normalize `x`
x_norm = self.norm(x)
# Get query, key and vector embeddings
q = self.q(x_norm)
k = self.k(x_norm)
v = self.v(x_norm)
# Reshape to query, key and vector embeedings from
# `[batch_size, channels, height, width]` to
# `[batch_size, channels, height * width]`
b, c, h, w = q.shape
q = q.view(b, c, h * w)
k = k.view(b, c, h * w)
v = v.view(b, c, h * w)
# Compute $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$
attn = torch.einsum('bci,bcj->bij', q, k) * self.scale
attn = F.softmax(attn, dim=2)
# Compute $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$
out = torch.einsum('bij,bcj->bci', attn, v)
# Reshape back to `[batch_size, channels, height, width]`
out = out.view(b, c, h, w)
# Final $1 \times 1$ convolution layer
out = self.proj_out(out)
# Add residual connection
return x + out
class UpSample(nn.Module):
"""
## Up-sampling layer
"""
def __init__(self, channels: int):
"""
:param channels: is the number of channels
"""
super().__init__()
# $3 \times 3$ convolution mapping
self.conv = nn.Conv2d(channels, channels, 3, padding=1)
def forward(self, x: torch.Tensor):
"""
:param x: is the input feature map with shape `[batch_size, channels, height, width]`
"""
# Up-sample by a factor of $2$
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
# Apply convolution
return self.conv(x)
class DownSample(nn.Module):
"""
## Down-sampling layer
"""
def __init__(self, channels: int):
"""
:param channels: is the number of channels
"""
super().__init__()
# $3 \times 3$ convolution with stride length of $2$ to down-sample by a factor of $2$
self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=0)
def forward(self, x: torch.Tensor):
"""
:param x: is the input feature map with shape `[batch_size, channels, height, width]`
"""
# Add padding
x = F.pad(x, (0, 1, 0, 1), mode="constant", value=0)
# Apply convolution
return self.conv(x)
class ResnetBlock(nn.Module):
"""
## ResNet Block
"""
def __init__(self, in_channels: int, out_channels: int):
"""
:param in_channels: is the number of channels in the input
:param out_channels: is the number of channels in the output
"""
super().__init__()
# First normalization and convolution layer
self.norm1 = normalization(in_channels)
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1)
# Second normalization and convolution layer
self.norm2 = normalization(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1)
# `in_channels` to `out_channels` mapping layer for residual connection
if in_channels != out_channels:
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0)
else:
self.nin_shortcut = nn.Identity()
def forward(self, x: torch.Tensor):
"""
:param x: is the input feature map with shape `[batch_size, channels, height, width]`
"""
h = x
# First normalization and convolution layer
h = self.norm1(h)
h = swish(h)
h = self.conv1(h)
# Second normalization and convolution layer
h = self.norm2(h)
h = swish(h)
h = self.conv2(h)
# Map and add residual
return self.nin_shortcut(x) + h
def swish(x: torch.Tensor):
"""
### Swish activation
$$x \cdot \sigma(x)$$
"""
return x * torch.sigmoid(x)
def normalization(channels: int):
"""
### Group normalization
This is a helper function, with fixed number of groups and `eps`.
"""
return nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)

View File

@ -0,0 +1,50 @@
"""
---
title: CLIP Text Embedder
summary: >
CLIP embedder to get prompt embeddings for stable diffusion
---
# CLIP Text Embedder
This is used to get prompt embeddings for [stable diffusion](../index.html).
It uses HuggingFace Transformers CLIP model.
"""
from typing import List
from torch import nn
from transformers import CLIPTokenizer, CLIPTextModel
class CLIPTextEmbedder(nn.Module):
"""
## CLIP Text Embedder
"""
def __init__(self, version: str = "openai/clip-vit-large-patch14", device="cuda:0", max_length: int = 77):
"""
:param version: is the model version
:param device: is the device
:param max_length: is the max length of the tokenized prompt
"""
super().__init__()
# Load the tokenizer
self.tokenizer = CLIPTokenizer.from_pretrained(version)
# Load the CLIP transformer
self.transformer = CLIPTextModel.from_pretrained(version).eval()
self.device = device
self.max_length = max_length
def forward(self, prompts: List[str]):
"""
:param prompts: are the list of prompts to embed
"""
# Tokenize the prompts
batch_encoding = self.tokenizer(prompts, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
# Get token ids
tokens = batch_encoding["input_ids"].to(self.device)
# Get CLIP embeddings
return self.transformer(input_ids=tokens).last_hidden_state

View File

@ -0,0 +1,343 @@
"""
---
title: U-Net for Stable Diffusion
summary: >
Annotated PyTorch implementation/tutorial of the U-Net in stable diffusion.
---
# U-Net for [Stable Diffusion](../index.html)
This implements the U-Net that
gives $\epsilon_\text{cond}(x_t, c)$
We have kept to the model definition and naming unchanged from
[CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion)
so that we can load the checkpoints directly.
"""
import math
from typing import List
import numpy as np
import torch
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from labml_nn.diffusion.stable_diffusion.model.unet_attention import SpatialTransformer
class UNetModel(nn.Module):
"""
## U-Net model
"""
def __init__(
self, *,
in_channels: int,
out_channels: int,
channels: int,
n_res_blocks: int,
attention_levels: List[int],
channel_multipliers: List[int],
n_heads: int,
tf_layers: int = 1,
d_cond: int = 768):
"""
:param in_channels: is the number of channels in the input feature map
:param out_channels: is the number of channels in the output feature map
:param channels: is the base channel count for the model
:param n_res_blocks: number of residual blocks at each level
:param attention_levels: are the levels at which attention should be performed
:param channel_multipliers: are the multiplicative factors for number of channels for each level
:param n_heads: the number of attention heads in the transformers
"""
super().__init__()
self.channels = channels
# Number of levels
levels = len(channel_multipliers)
# Size time embeddings
d_time_emb = channels * 4
self.time_embed = nn.Sequential(
nn.Linear(channels, d_time_emb),
nn.SiLU(),
nn.Linear(d_time_emb, d_time_emb),
)
# Input half of the U-Net
self.input_blocks = nn.ModuleList()
# Initial $3 \times 3$ convolution that maps the input to `channels`.
# The blocks are wrapped in `TimestepEmbedSequential` module because
# different modules have different forward function signatures;
# for example, convolution only accepts the feature map and
# residual blocks accept the feature map and time embedding.
# `TimestepEmbedSequential` calls them accordingly.
self.input_blocks.append(TimestepEmbedSequential(
nn.Conv2d(in_channels, channels, 3, padding=1)))
# Number of channels at each block in the input half of U-Net
input_block_channels = [channels]
# Number of channels at each level
channels_list = [channels * m for m in channel_multipliers]
# Prepare levels
for i in range(levels):
# Add the residual blocks and attentions
for _ in range(n_res_blocks):
# Residual block maps from previous number of channels to the number of
# channels in the current level
layers = [ResBlock(channels, d_time_emb, out_channels=channels_list[i])]
channels = channels_list[i]
# Add transformer
if i in attention_levels:
layers.append(SpatialTransformer(channels, n_heads, tf_layers, d_cond))
# Add them to the input half of the U-Net and keep track of the number of channels of
# its output
self.input_blocks.append(TimestepEmbedSequential(*layers))
input_block_channels.append(channels)
# Down sample at all levels except last
if i != levels - 1:
self.input_blocks.append(TimestepEmbedSequential(DownSample(channels)))
input_block_channels.append(channels)
# The middle of the U-Net
self.middle_block = TimestepEmbedSequential(
ResBlock(channels, d_time_emb),
SpatialTransformer(channels, n_heads, tf_layers, d_cond),
ResBlock(channels, d_time_emb),
)
# Second half of the U-Net
self.output_blocks = nn.ModuleList([])
# Prepare levels in reverse order
for i in reversed(range(levels)):
# Add the residual blocks and attentions
for j in range(n_res_blocks + 1):
# Residual block maps from previous number of channels plus the
# skip connections from the input half of U-Net to the number of
# channels in the current level.
layers = [ResBlock(channels + input_block_channels.pop(), d_time_emb, out_channels=channels_list[i])]
channels = channels_list[i]
# Add transformer
if i in attention_levels:
layers.append(SpatialTransformer(channels, n_heads, tf_layers, d_cond))
# Up-sample at every level after last residual block
# except the last one.
# Note that we are iterating in reverse; i.e. `i == 0` is the last.
if i != 0 and j == n_res_blocks:
layers.append(UpSample(channels))
# Add to the output half of the U-Net
self.output_blocks.append(TimestepEmbedSequential(*layers))
# Final normalization and $3 \times 3$ convolution
self.out = nn.Sequential(
normalization(channels),
nn.SiLU(),
nn.Conv2d(channels, out_channels, 3, padding=1),
)
def time_step_embedding(self, time_steps: torch.Tensor, max_period: int = 10000):
"""
## Create sinusoidal time step embeddings
:param time_steps: are the time steps of shape `[batch_size]`
:param max_period: controls the minimum frequency of the embeddings.
"""
# $\frac{c}{2}$; half the channels are sin and the other half is cos,
half = self.channels // 2
# $\frac{1}{10000^{\frac{2i}{c}}}$
frequencies = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=time_steps.device)
# $\frac{t}{10000^{\frac{2i}{c}}}$
args = time_steps[:, None].float() * frequencies[None]
# $\cos\Bigg(\frac{t}{10000^{\frac{2i}{c}}}\Bigg)$ and $\sin\Bigg(\frac{t}{10000^{\frac{2i}{c}}}\Bigg)$
return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
def forward(self, x: torch.Tensor, time_steps: torch.Tensor, cond: torch.Tensor):
"""
:param x: is the input feature map of shape `[batch_size, channels, width, height]`
:param time_steps: are the time steps of shape `[batch_size]`
:param cond: conditioning of shape `[batch_size, n_cond, d_cond]`
"""
# To store the input half outputs for skip connections
x_input_block = []
# Get time step embeddings
t_emb = self.time_step_embedding(time_steps)
t_emb = self.time_embed(t_emb)
# Input half of the U-Net
for module in self.input_blocks:
x = module(x, t_emb, cond)
x_input_block.append(x)
# Middle of the U-Net
x = self.middle_block(x, t_emb, cond)
# Output half of the U-Net
for module in self.output_blocks:
x = th.cat([x, x_input_block.pop()], dim=1)
x = module(x, t_emb, cond)
# Final normalization and $3 \times 3$ convolution
return self.out(x)
class TimestepEmbedSequential(nn.Sequential):
"""
### Sequential block for modules with different inputs
This sequential module can compose of different modules suck as `ResBlock`,
`nn.Conv` and `SpatialTransformer` and calls them with the matching signatures
"""
def forward(self, x, t_emb, cond=None):
for layer in self:
if isinstance(layer, ResBlock):
x = layer(x, t_emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, cond)
else:
x = layer(x)
return x
class UpSample(nn.Module):
"""
### Up-sampling layer
"""
def __init__(self, channels: int):
"""
:param channels: is the number of channels
"""
super().__init__()
# $3 \times 3$ convolution mapping
self.conv = nn.Conv2d(channels, channels, 3, padding=1)
def forward(self, x: torch.Tensor):
"""
:param x: is the input feature map with shape `[batch_size, channels, height, width]`
"""
# Up-sample by a factor of $2$
x = F.interpolate(x, scale_factor=2, mode="nearest")
# Apply convolution
return self.conv(x)
class DownSample(nn.Module):
"""
## Down-sampling layer
"""
def __init__(self, channels: int):
"""
:param channels: is the number of channels
"""
super().__init__()
# $3 \times 3$ convolution with stride length of $2$ to down-sample by a factor of $2$
self.op = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
def forward(self, x: torch.Tensor):
"""
:param x: is the input feature map with shape `[batch_size, channels, height, width]`
"""
# Apply convolution
return self.op(x)
class ResBlock(nn.Module):
"""
## ResNet Block
"""
def __init__(self, channels: int, d_t_emb: int, *, out_channels=None):
"""
:param channels: the number of input channels
:param d_t_emb: the size of timestep embeddings
:param out_channels: is the number of out channels. defaults to `channels.
"""
super().__init__()
# `out_channels` not specified
if out_channels is None:
out_channels = channels
# First normalization and convolution
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
nn.Conv2d(channels, out_channels, 3, padding=1),
)
# Time step embeddings
self.emb_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(d_t_emb, out_channels),
)
# Final convolution layer
self.out_layers = nn.Sequential(
normalization(out_channels),
nn.SiLU(),
nn.Dropout(0.),
nn.Conv2d(out_channels, out_channels, 3, padding=1)
)
# `channels` to `out_channels` mapping layer for residual connection
if out_channels == channels:
self.skip_connection = nn.Identity()
else:
self.skip_connection = nn.Conv2d(channels, out_channels, 1)
def forward(self, x: torch.Tensor, t_emb: torch.Tensor):
"""
:param x: is the input feature map with shape `[batch_size, channels, height, width]`
:param t_emb: is the time step embeddings of shape `[batch_size, d_t_emb]`
"""
# Initial convolution
h = self.in_layers(x)
# Time step embeddings
t_emb = self.emb_layers(t_emb).type(h.dtype)
# Add time step embeddings
h = h + t_emb[:, :, None, None]
# Final convolution
h = self.out_layers(h)
# Add skip connection
return self.skip_connection(x) + h
class GroupNorm32(nn.GroupNorm):
"""
### Group normalization with float32 casting
"""
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def normalization(channels):
"""
### Group normalization
This is a helper function, with fixed number of groups..
"""
return GroupNorm32(32, channels)
def _test_time_embeddings():
"""
Test sinusoidal time step embeddings
"""
import matplotlib.pyplot as plt
plt.figure(figsize=(15, 5))
m = UNetModel(in_channels=1, out_channels=1, channels=320, n_res_blocks=1, attention_levels=[],
channel_multipliers=[],
n_heads=1, tf_layers=1, d_cond=1)
te = m.time_step_embedding(torch.arange(0, 1000))
plt.plot(np.arange(1000), te[:, [50, 100, 190, 260]].numpy())
plt.legend(["dim %d" % p for p in [50, 100, 190, 260]])
plt.title("Time embeddings")
plt.show()
#
if __name__ == '__main__':
_test_time_embeddings()

View File

@ -0,0 +1,224 @@
"""
---
title: Transformer for Stable Diffusion U-Net
summary: >
Annotated PyTorch implementation/tutorial of the transformer
for U-Net in stable diffusion.
---
# Transformer for Stable Diffusion [U-Net](unet.html)
This implements the transformer module used in [U-Net](unet.html) that
gives $\epsilon_\text{cond}(x_t, c)$
We have kept to the model definition and naming unchanged from
[CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion)
so that we can load the checkpoints directly.
"""
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
class SpatialTransformer(nn.Module):
"""
## Spatial Transformer
"""
def __init__(self, channels: int, n_heads: int, n_layers: int, d_cond: int):
"""
:param channels: is the number of channels in the feature map
:param n_heads: is the number of attention heads
:param n_layers: is the number of transformer layers
:param d_cond: is the size of the conditional embedding
"""
super().__init__()
# Initial group normalization
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True)
# Initial $1 \times 1$ convolution
self.proj_in = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0)
# Transformer layers
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(channels, n_heads, channels // n_heads, d_cond=d_cond) for _ in range(n_layers)]
)
# Final $1 \times 1$ convolution
self.proj_out = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0)
def forward(self, x: torch.Tensor, cond: torch.Tensor):
"""
:param x: is the feature map of shape `[batch_size, channels, height, width]`
:param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
"""
# Get shape `[batch_size, channels, height, width]`
b, c, h, w = x.shape
# For residual connection
x_in = x
# Normalize
x = self.norm(x)
# Initial $1 \times 1$ convolution
x = self.proj_in(x)
# Transpose and reshape from `[batch_size, channels, height, width]`
# to `[batch_size, height * width, channels]`
x = x.permute(0, 2, 3, 1).view(b, h * w, c)
# Apply the transformer layers
for block in self.transformer_blocks:
x = block(x, cond)
# Reshape and transpose from `[batch_size, height * width, channels]`
# to `[batch_size, channels, height, width]`
x = x.view(b, h, w, c).permute(0, 3, 1, 2)
# Final $1 \times 1$ convolution
x = self.proj_out(x)
# Add residual
return x + x_in
class BasicTransformerBlock(nn.Module):
"""
### Transformer Layer
"""
def __init__(self, d_model: int, n_heads: int, d_head: int, d_cond: int):
"""
:param d_model: is the input embedding size
:param n_heads: is the number of attention heads
:param d_head: is the size of a attention head
:param d_cond: is the size of the conditional embeddings
"""
super().__init__()
# Self-attention layer and pre-norm layer
self.attn1 = CrossAttention(d_model, d_model, n_heads, d_head)
self.norm1 = nn.LayerNorm(d_model)
# Cross attention layer and pre-norm layer
self.attn2 = CrossAttention(d_model, d_cond, n_heads, d_head)
self.norm2 = nn.LayerNorm(d_model)
# Feed-forward network and pre-norm layer
self.ff = FeedForward(d_model)
self.norm3 = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor, cond: torch.Tensor):
"""
:param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
:param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
"""
# Self attention
x = self.attn1(self.norm1(x)) + x
# Cross-attention with conditioning
x = self.attn2(self.norm2(x), cond=cond) + x
# Feed-forward network
x = self.ff(self.norm3(x)) + x
#
return x
class CrossAttention(nn.Module):
"""
### Cross Attention Layer
This falls-back to self-attention when conditional embeddings are not specified.
"""
def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True):
"""
:param d_model: is the input embedding size
:param n_heads: is the number of attention heads
:param d_head: is the size of a attention head
:param d_cond: is the size of the conditional embeddings
:param is_inplace: specifies whether to perform the attention softmax computation inplace to
save memory
"""
super().__init__()
self.is_inplace = is_inplace
self.n_heads = n_heads
# Attention scaling factor
self.scale = d_head ** -0.5
# Query, key and value mappings
d_attn = d_head * n_heads
self.to_q = nn.Linear(d_model, d_attn, bias=False)
self.to_k = nn.Linear(d_cond, d_attn, bias=False)
self.to_v = nn.Linear(d_cond, d_attn, bias=False)
# Final linear layer
self.to_out = nn.Sequential(nn.Linear(d_attn, d_model))
def forward(self, x: torch.Tensor, cond: Optional[torch.Tensor] = None):
"""
:param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
:param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
"""
# If `cond` is `None` we perform self attention
if cond is None:
cond = x
# Get query, key and value vectors
q = self.to_q(x)
k = self.to_k(cond)
v = self.to_v(cond)
# Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
q = q.view(*q.shape[:2], self.n_heads, -1)
k = k.view(*k.shape[:2], self.n_heads, -1)
v = v.view(*v.shape[:2], self.n_heads, -1)
# Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
attn = torch.einsum('bihd,bjhd->bhij', q, k) * self.scale
# Compute softmax
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
if self.is_inplace:
half = attn.shape[0] // 2
attn[half:] = attn[half:].softmax(dim=-1)
attn[:half] = attn[:half].softmax(dim=-1)
else:
attn = attn.softmax(dim=-1)
# Compute attention output
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
out = torch.einsum('bhij,bjhd->bihd', attn, v)
# Reshape to `[batch_size, height * width, n_heads * d_head]`
out = out.reshape(*out.shape[:2], -1)
# Map to `[batch_size, height * width, d_model]` with a linear layer
return self.to_out(out)
class FeedForward(nn.Module):
"""
### Feed-Forward Network
"""
def __init__(self, d_model: int, d_mult: int = 4):
"""
:param d_model: is the input embedding size
:param d_mult: is multiplicative factor for the hidden layer size
"""
super().__init__()
self.net = nn.Sequential(
GeGLU(d_model, d_model * d_mult),
nn.Dropout(0.),
nn.Linear(d_model * d_mult, d_model)
)
def forward(self, x: torch.Tensor):
return self.net(x)
class GeGLU(nn.Module):
"""
### GeGLU Activation
$$\text{GeGLU}(x) = (xW + b) * \text{GELU}(xV + c)$$
"""
def __init__(self, d_in: int, d_out: int):
super().__init__()
# Combined linear projections $xW + b$ and $xV + c$
self.proj = nn.Linear(d_in, d_out * 2)
def forward(self, x: torch.Tensor):
# Get $xW + b$ and $xV + c$
x, gate = self.proj(x).chunk(2, dim=-1)
# $\text{GeGLU}(x) = (xW + b) * \text{GELU}(xV + c)$
return x * F.gelu(gate)

View File

@ -0,0 +1,126 @@
"""
---
title: Sampling algorithms for stable diffusion
summary: >
Annotated PyTorch implementation/tutorial of
sampling algorithms
for stable diffusion model.
---
# Sampling algorithms for [stable diffusion](../index.html)
We have implemented the following [sampling algorithms](sampler/index.html):
* [Denoising Diffusion Probabilistic Models (DDPM) Sampling](ddpm.html)
* [Denoising Diffusion Implicit Models (DDIM) Sampling](ddim.html)
"""
from typing import Optional, List
import torch
from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
class DiffusionSampler:
"""
## Base class for sampling algorithms
"""
model: LatentDiffusion
def __init__(self, model: LatentDiffusion):
"""
:param model: is the model to predict noise $\epsilon_\text{cond}(x_t, c)$
"""
super().__init__()
# Set the model $\epsilon_\text{cond}(x_t, c)$
self.model = model
# Get number of steps the model was trained with $T$
self.n_steps = model.n_steps
def get_eps(self, x: torch.Tensor, t: torch.Tensor, c: torch.Tensor, *,
uncond_scale: float, uncond_cond: Optional[torch.Tensor]):
"""
## Get $\epsilon(x_t, c)$
:param x: is $x_t$ of shape `[batch_size, channels, height, width]`
:param t: is $t$ of shape `[batch_size]`
:param c: is the conditional embeddings $c$ of shape `[batch_size, emb_size]`
:param uncond_scale: is the unconditional guidance scale $s$. This is used for
$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
:param uncond_cond: is the conditional embedding for empty prompt $c_u$
"""
# When the scale $s = 1$
# $$\epsilon_\theta(x_t, c) = \epsilon_\text{cond}(x_t, c)$$
if uncond_cond is None or uncond_scale == 1.:
return self.model(x, t, c)
# Duplicate $x_t$ and $t$
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
# Concatenated $c$ and $c_u$
c_in = torch.cat([uncond_cond, c])
# Get $\epsilon_\text{cond}(x_t, c)$ and $\epsilon_\text{cond}(x_t, c_u)$
e_t_uncond, e_t_cond = self.model(x_in, t_in, c_in).chunk(2)
# Calculate
# $$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$$
e_t = e_t_uncond + uncond_scale * (e_t_cond - e_t_uncond)
#
return e_t
def sample(self,
shape: List[int],
cond: torch.Tensor,
repeat_noise: bool = False,
temperature: float = 1.,
x_last: Optional[torch.Tensor] = None,
uncond_scale: float = 1.,
uncond_cond: Optional[torch.Tensor] = None,
skip_steps: int = 0,
):
"""
### Sampling Loop
:param shape: is the shape of the generated images in the
form `[batch_size, channels, height, width]`
:param cond: is the conditional embeddings $c$
:param temperature: is the noise temperature (random noise gets multiplied by this)
:param x_last: is $x_T$. If not provided random noise will be used.
:param uncond_scale: is the unconditional guidance scale $s$. This is used for
$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
:param uncond_cond: is the conditional embedding for empty prompt $c_u$
:param skip_steps: is the number of time steps to skip.
"""
raise NotImplementedError()
def paint(self, x: torch.Tensor, cond: torch.Tensor, t_start: int, *,
orig: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None, orig_noise: Optional[torch.Tensor] = None,
uncond_scale: float = 1.,
uncond_cond: Optional[torch.Tensor] = None,
):
"""
### Painting Loop
:param x: is $x_{T'}$ of shape `[batch_size, channels, height, width]`
:param cond: is the conditional embeddings $c$
:param t_start: is the sampling step to start from, $T'$
:param orig: is the original image in latent page which we are in paining.
:param mask: is the mask to keep the original image.
:param orig_noise: is fixed noise to be added to the original image.
:param uncond_scale: is the unconditional guidance scale $s$. This is used for
$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
:param uncond_cond: is the conditional embedding for empty prompt $c_u$
"""
raise NotImplementedError()
def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):
"""
### Sample from $q(x_t|x_0)$
:param x0: is $x_0$ of shape `[batch_size, channels, height, width]`
:param index: is the time step $t$ index
:param noise: is the noise, $\epsilon$
"""
raise NotImplementedError()

View File

@ -0,0 +1,300 @@
"""
---
title: Denoising Diffusion Implicit Models (DDIM) Sampling
summary: >
Annotated PyTorch implementation/tutorial of
Denoising Diffusion Implicit Models (DDIM) Sampling
for stable diffusion model.
---
# Denoising Diffusion Implicit Models (DDIM) Sampling
This implements DDIM sampling from the paper
[Denoising Diffusion Implicit Models](https://papers.labml.ai/paper/2010.02502)
"""
from typing import Optional, List
import numpy as np
import torch
from labml import monit
from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
from labml_nn.diffusion.stable_diffusion.sampler import DiffusionSampler
class DDIMSampler(DiffusionSampler):
"""
## DDIM Sampler
This extends the [`DiffusionSampler` base class](index.html).
DDPM samples images by repeatedly removing noise by sampling step by step using,
\begin{align}
x_{\tau_{i-1}} &= \sqrt{\alpha_{\tau_{i-1}}}\Bigg(
\frac{x_{\tau_i} - \sqrt{1 - \alpha_{\tau_i}}\epsilon_\theta(x_{\tau_i})}{\sqrt{\alpha_{\tau_i}}}
\Bigg) \\
&+ \sqrt{1 - \alpha_{\tau_{i- 1}} - \sigma_{\tau_i}^2} \cdot \epsilon_\theta(x_{\tau_i}) \\
&+ \sigma_{\tau_i} \epsilon_{\tau_i}
\end{align}
where $\epsilon_{\tau_i}$ is random noise,
$\tau$ is a subsequence of $[1,2,\dots,T]$ of length $S$,
and
$$\sigma_{\tau_i} =
\eta \sqrt{\frac{1 - \alpha_{\tau_{i-1}}}{1 - \alpha_{\tau_i}}}
\sqrt{1 - \frac{\alpha_{\tau_i}}{\alpha_{\tau_{i-1}}}}$$
Note that, $\alpha_t$ in DDIM paper refers to ${\color{lightgreen}\bar\alpha_t}$ from [DDPM](ddpm.html).
"""
model: LatentDiffusion
def __init__(self, model: LatentDiffusion, n_steps: int, ddim_discretize: str = "uniform", ddim_eta: float = 0.):
"""
:param model: is the model to predict noise $\epsilon_\text{cond}(x_t, c)$
:param n_steps: is the number of DDIM sampling steps, $S$
:param ddim_discretize: specifies how to extract $\tau$ from $[1,2,\dots,T]$.
It can be either `uniform` or `quad`.
:param ddim_eta: is $\eta$ used to calculate $\sigma_{\tau_i}$. $\eta = 0$ makes the
sampling process deterministic.
"""
super().__init__(model)
# Number of steps, $T$
self.n_steps = model.n_steps
# Calculate $\tau$ to be uniformly distributed across $[1,2,\dots,T]$
if ddim_discretize == 'uniform':
c = self.n_steps // n_steps
self.time_steps = np.asarray(list(range(0, self.n_steps, c))) + 1
# Calculate $\tau$ to be quadratically distributed across $[1,2,\dots,T]$
elif ddim_discretize == 'quad':
self.time_steps = ((np.linspace(0, np.sqrt(self.n_steps * .8), n_steps)) ** 2).astype(int) + 1
else:
raise NotImplementedError(ddim_discretize)
with torch.no_grad():
# Get ${\color{lightgreen}\bar\alpha_t}$
alpha_bar = self.model.alpha_bar
# $\alpha_{\tau_i}$
self.ddim_alpha = alpha_bar[self.time_steps].clone().to(torch.float32)
# $\sqrt{\alpha_{\tau_i}}$
self.ddim_alpha_sqrt = torch.sqrt(self.ddim_alpha)
# $\alpha_{\tau_{i-1}}$
self.ddim_alpha_prev = torch.cat([alpha_bar[0:1], alpha_bar[self.time_steps[:-1]]])
# $$\sigma_{\tau_i} =
# \eta \sqrt{\frac{1 - \alpha_{\tau_{i-1}}}{1 - \alpha_{\tau_i}}}
# \sqrt{1 - \frac{\alpha_{\tau_i}}{\alpha_{\tau_{i-1}}}}$$
self.ddim_sigma = (ddim_eta *
((1 - self.ddim_alpha_prev) / (1 - self.ddim_alpha) *
(1 - self.ddim_alpha / self.ddim_alpha_prev)) ** .5)
# $\sqrt{1 - \alpha_{\tau_i}}$
self.ddim_sqrt_one_minus_alpha = (1. - self.ddim_alpha) ** .5
@torch.no_grad()
def sample(self,
shape: List[int],
cond: torch.Tensor,
repeat_noise: bool = False,
temperature: float = 1.,
x_last: Optional[torch.Tensor] = None,
uncond_scale: float = 1.,
uncond_cond: Optional[torch.Tensor] = None,
skip_steps: int = 0,
):
"""
### Sampling Loop
:param shape: is the shape of the generated images in the
form `[batch_size, channels, height, width]`
:param cond: is the conditional embeddings $c$
:param temperature: is the noise temperature (random noise gets multiplied by this)
:param x_last: is $x_{\tau_S}$. If not provided random noise will be used.
:param uncond_scale: is the unconditional guidance scale $s$. This is used for
$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
:param uncond_cond: is the conditional embedding for empty prompt $c_u$
:param skip_steps: is the number of time steps to skip $i'$. We start sampling from $S - i'$.
And `x_last` is then $x_{\tau_{S - i'}}$.
"""
# Get device and batch size
device = self.model.device
bs = shape[0]
# Get $x_{\tau_S}$
x = x_last if x_last is not None else torch.randn(shape, device=device)
# Time steps to sample at $\tau_{S - i'}, \tau_{S - i' - 1}, \dots, \tau_1$
time_steps = np.flip(self.time_steps)[skip_steps:]
for i, step in monit.enum('Sample', time_steps):
# Index $i$ in the list $[\tau_1, \tau_2, \dots, \tau_S]$
index = len(time_steps) - i - 1
# Time step $\tau_i$
ts = x.new_full((bs,), step, dtype=torch.long)
# Sample $x_{\tau_{i-1}}$
x, pred_x0, e_t = self.p_sample(x, cond, ts, step, index=index,
repeat_noise=repeat_noise,
temperature=temperature,
uncond_scale=uncond_scale,
uncond_cond=uncond_cond)
# Return $x_0$
return x
@torch.no_grad()
def p_sample(self, x: torch.Tensor, c: torch.Tensor, t: torch.Tensor, step: int, index: int, *,
repeat_noise: bool = False,
temperature: float = 1.,
uncond_scale: float = 1.,
uncond_cond: Optional[torch.Tensor] = None):
"""
### Sample $x_{\tau_{i-1}}$
:param x: is $x_{\tau_i}$ of shape `[batch_size, channels, height, width]`
:param c: is the conditional embeddings $c$ of shape `[batch_size, emb_size]`
:param t: is $\tau_i$ of shape `[batch_size]`
:param step: is the step $\tau_i$ as an integer
:param index: is index $i$ in the list $[\tau_1, \tau_2, \dots, \tau_S]$
:param repeat_noise: specified whether the noise should be same for all samples in the batch
:param temperature: is the noise temperature (random noise gets multiplied by this)
:param uncond_scale: is the unconditional guidance scale $s$. This is used for
$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
:param uncond_cond: is the conditional embedding for empty prompt $c_u$
"""
# Get $\epsilon_\theta(x_{\tau_i}}$
e_t = self.get_eps(x, t, c,
uncond_scale=uncond_scale,
uncond_cond=uncond_cond)
# Calculate $x_{\tau_{i - 1}}$ and predicted $x_0$
x_prev, pred_x0 = self.get_x_prev_and_pred_x0(e_t, index, x,
temperature=temperature,
repeat_noise=repeat_noise)
#
return x_prev, pred_x0, e_t
def get_x_prev_and_pred_x0(self, e_t: torch.Tensor, index: int, x: torch.Tensor, *,
temperature: float,
repeat_noise: bool):
"""
### Sample $x_{\tau_{i-1}}$ given $\epsilon_\theta(x_{\tau_i}}$
"""
# $\alpha_{\tau_i}$
alpha = self.ddim_alpha[index]
# $\alpha_{\tau_{i-1}}$
alpha_prev = self.ddim_alpha_prev[index]
# $\sigma_{\tau_i}$
sigma = self.ddim_sigma[index]
# $\sqrt{1 - \alpha_{\tau_i}}$
sqrt_one_minus_alpha = self.ddim_sqrt_one_minus_alpha[index]
# Current prediction for $x_0$,
# $$\frac{x_{\tau_i} - \sqrt{1 - \alpha_{\tau_i}}\epsilon_\theta(x_{\tau_i})}{\sqrt{\alpha_{\tau_i}}}$$
pred_x0 = (x - sqrt_one_minus_alpha * e_t) / (alpha ** 0.5)
# Direction pointing to $x_t$
# $$\sqrt{1 - \alpha_{\tau_{i- 1}} - \sigma_{\tau_i}^2} \cdot \epsilon_\theta(x_{\tau_i})$$
dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * e_t
# No noise is added, when $\eta = 0$
if sigma == 0.:
noise = 0.
# If same noise is used for all samples in the batch
elif repeat_noise:
noise = torch.randn((1, *x.shape[1:]), device=x.device)
# Different noise for each sample
else:
noise = torch.randn(x.shape, device=x.device)
# Multiply noise by the temperature
noise = noise * temperature
# \begin{align}
# x_{\tau_{i-1}} &= \sqrt{\alpha_{\tau_{i-1}}}\Bigg(
# \frac{x_{\tau_i} - \sqrt{1 - \alpha_{\tau_i}}\epsilon_\theta(x_{\tau_i})}{\sqrt{\alpha_{\tau_i}}}
# \Bigg) \\
# &+ \sqrt{1 - \alpha_{\tau_{i- 1}} - \sigma_{\tau_i}^2} \cdot \epsilon_\theta(x_{\tau_i}) \\
# &+ \sigma_{\tau_i} \epsilon_{\tau_i}
# \end{align}
x_prev = (alpha_prev ** 0.5) * pred_x0 + dir_xt + sigma * noise
#
return x_prev, pred_x0
@torch.no_grad()
def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):
"""
### Sample from $q_{\sigma,\tau}(x_{\tau_i}|x_0)$
$$q_{\sigma,\tau}(x_t|x_0) =
\mathcal{N} \Big(x_t; \sqrt{\alpha_{\tau_i}} x_0, (1-\alpha_{\tau_i}) \mathbf{I} \Big)$$
:param x0: is $x_0$ of shape `[batch_size, channels, height, width]`
:param index: is the time step $\tau_i$ index $i$
:param noise: is the noise, $\epsilon$
"""
# Random noise, if noise is not specified
if noise is None:
noise = torch.randn_like(x0)
# Sample from
# $$q_{\sigma,\tau}(x_t|x_0) =
# \mathcal{N} \Big(x_t; \sqrt{\alpha_{\tau_i}} x_0, (1-\alpha_{\tau_i}) \mathbf{I} \Big)$$
return self.ddim_alpha_sqrt[index] * x0 + self.ddim_sqrt_one_minus_alpha[index] * noise
@torch.no_grad()
def paint(self, x: torch.Tensor, cond: torch.Tensor, t_start: int, *,
orig: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None, orig_noise: Optional[torch.Tensor] = None,
uncond_scale: float = 1.,
uncond_cond: Optional[torch.Tensor] = None,
):
"""
### Painting Loop
:param x: is $x_{S'}$ of shape `[batch_size, channels, height, width]`
:param cond: is the conditional embeddings $c$
:param t_start: is the sampling step to start from, $S'$
:param orig: is the original image in latent page which we are in paining.
If this is not provided, it'll be an image to image transformation.
:param mask: is the mask to keep the original image.
:param orig_noise: is fixed noise to be added to the original image.
:param uncond_scale: is the unconditional guidance scale $s$. This is used for
$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
:param uncond_cond: is the conditional embedding for empty prompt $c_u$
"""
# Get batch size
bs = x.shape[0]
# Time steps to sample at $\tau_{S`}, \tau_{S' - 1}, \dots, \tau_1$
time_steps = np.flip(self.time_steps[:t_start])
for i, step in monit.enum('Paint', time_steps):
# Index $i$ in the list $[\tau_1, \tau_2, \dots, \tau_S]$
index = len(time_steps) - i - 1
# Time step $\tau_i$
ts = x.new_full((bs,), step, dtype=torch.long)
# Sample $x_{\tau_{i-1}}$
x, _, _ = self.p_sample(x, cond, ts, step, index=index,
uncond_scale=uncond_scale,
uncond_cond=uncond_cond)
# Replace the masked area with original image
if orig is not None:
# Get the $q_{\sigma,\tau}(x_{\tau_i}|x_0)$ for original image in latent space
orig_t = self.q_sample(orig, index, noise=orig_noise)
# Replace the masked area
x = orig_t * mask + x * (1 - mask)
#
return x

View File

@ -0,0 +1,226 @@
"""
---
title: Denoising Diffusion Probabilistic Models (DDPM) Sampling
summary: >
Annotated PyTorch implementation/tutorial of
Denoising Diffusion Probabilistic Models (DDPM) Sampling
for stable diffusion model.
---
# Denoising Diffusion Probabilistic Models (DDPM) Sampling
For a simpler DDPM implementation refer to our [DDPM implementation](../../ddpm/index.html).
We use same notations for $\alpha_t$, $\beta_t$ schedules, etc.
"""
from typing import Optional, List
import numpy as np
import torch
from labml import monit
from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
from labml_nn.diffusion.stable_diffusion.sampler import DiffusionSampler
class DDPMSampler(DiffusionSampler):
"""
## DDPM Sampler
This extends the [`DiffusionSampler` base class](index.html).
DDPM samples images by repeatedly removing noise by sampling step by step from
$p_\theta(x_{t-1} | x_t)$,
\begin{align}
p_\theta(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1}; \mu_\theta(x_t, t), \tilde\beta_t \mathbf{I} \big) \\
\mu_t(x_t, t) &= \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}x_0
+ \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}x_t \\
\tilde\beta_t &= \frac{1 - \bar\alpha_{t-1}}{1 - \bar\alpha_t} \beta_t \\
x_0 &= \frac{1}{\sqrt{\bar\alpha_t}} x_t - \Big(\sqrt{\frac{1}{\bar\alpha_t} - 1}\Big)\epsilon_\theta \\
\end{align}
"""
model: LatentDiffusion
def __init__(self, model: LatentDiffusion):
"""
:param model: is the model to predict noise $\epsilon_\text{cond}(x_t, c)$
"""
super().__init__(model)
# Sampling steps $1, 2, \dots, T$
self.time_steps = np.asarray(list(range(self.n_steps)))
with torch.no_grad():
# $\bar\alpha_t$
alpha_bar = self.model.alpha_bar
# $\beta_t$ schedule
beta = self.model.beta
# $\bar\alpha_{t-1}$
alpha_bar_prev = torch.cat([alpha_bar.new_tensor([1.]), alpha_bar[:-1]])
# $\sqrt{\bar\alpha}$
self.sqrt_alpha_bar = alpha_bar ** .5
# $\sqrt{1 - \bar{alpha}}$
self.sqrt_1m_alpha_bar = alpha_bar ** .5
# $\frac{1}{\sqrt{\bar\alpha_t}}$
self.sqrt_recip_alpha_bar = alpha_bar ** -.5
# $\sqrt{\frac{1}{\bar\alpha_t} - 1}$
self.sqrt_recip_m1_alpha_bar = (1 / alpha_bar - 1) ** .5
# $\frac{1 - \bar\alpha_{t-1}}{1 - \bar\alpha_t} \beta_t$
variance = beta * (1. - alpha_bar_prev) / (1. - alpha_bar)
# Clamped log of $\tilde\beta_t$
self.log_var = torch.log(torch.clamp(variance, min=1e-20))
# $\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}$
self.mean_x0_coef = beta * (alpha_bar_prev ** .5) / (1. - alpha_bar)
# $\frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}$
self.mean_xt_coef = (1. - alpha_bar_prev) * ((1 - beta) ** 0.5) / (1. - alpha_bar)
@torch.no_grad()
def sample(self,
shape: List[int],
cond: torch.Tensor,
repeat_noise: bool = False,
temperature: float = 1.,
x_last: Optional[torch.Tensor] = None,
uncond_scale: float = 1.,
uncond_cond: Optional[torch.Tensor] = None,
skip_steps: int = 0,
):
"""
### Sampling Loop
:param shape: is the shape of the generated images in the
form `[batch_size, channels, height, width]`
:param cond: is the conditional embeddings $c$
:param temperature: is the noise temperature (random noise gets multiplied by this)
:param x_last: is $x_T$. If not provided random noise will be used.
:param uncond_scale: is the unconditional guidance scale $s$. This is used for
$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
:param uncond_cond: is the conditional embedding for empty prompt $c_u$
:param skip_steps: is the number of time steps to skip $t'$. We start sampling from $T - t'$.
And `x_last` is then $x_{T - t'}$.
"""
# Get device and batch size
device = self.model.device
bs = shape[0]
# Get $x_T$
x = x_last if x_last is not None else torch.randn(shape, device=device)
# Time steps to sample at $T - t', T - t' - 1, \dots, 1$
time_steps = np.flip(self.time_steps)[skip_steps:]
# Sampling loop
for step in monit.iterate('Sample', time_steps):
# Time step $t$
ts = x.new_full((bs,), step, dtype=torch.long)
# Sample $x_{t-1}$
x, pred_x0, e_t = self.p_sample(x, cond, ts, step,
repeat_noise=repeat_noise,
temperature=temperature,
uncond_scale=uncond_scale,
uncond_cond=uncond_cond)
# Return $x_0$
return x
@torch.no_grad()
def p_sample(self, x: torch.Tensor, c: torch.Tensor, t: torch.Tensor, step: int,
repeat_noise: bool = False,
temperature: float = 1.,
uncond_scale: float = 1., uncond_cond: Optional[torch.Tensor] = None):
"""
### Sample $x_{t-1}$ from $p_\theta(x_{t-1} | x_t)$
:param x: is $x_t$ of shape `[batch_size, channels, height, width]`
:param c: is the conditional embeddings $c$ of shape `[batch_size, emb_size]`
:param t: is $t$ of shape `[batch_size]`
:param step: is the step $t$ as an integer
:repeat_noise: specified whether the noise should be same for all samples in the batch
:param temperature: is the noise temperature (random noise gets multiplied by this)
:param uncond_scale: is the unconditional guidance scale $s$. This is used for
$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
:param uncond_cond: is the conditional embedding for empty prompt $c_u$
"""
# Get $\epsilon_\theta$
e_t = self.get_eps(x, t, c,
uncond_scale=uncond_scale,
uncond_cond=uncond_cond)
# Get batch size
bs = x.shape[0]
# $\frac{1}{\sqrt{\bar\alpha_t}}$
sqrt_recip_alpha_bar = x.new_full((bs, 1, 1, 1), self.sqrt_recip_alpha_bar[step])
# $\sqrt{\frac{1}{\bar\alpha_t} - 1}$
sqrt_recip_m1_alpha_bar = x.new_full((bs, 1, 1, 1), self.sqrt_recip_m1_alpha_bar[step])
# Calculate $x_0$ with current $\epsilon_\theta$
#
# $$x_0 = \frac{1}{\sqrt{\bar\alpha_t}} x_t - \Big(\sqrt{\frac{1}{\bar\alpha_t} - 1}\Big)\epsilon_\theta$$
x0 = sqrt_recip_alpha_bar * x - sqrt_recip_m1_alpha_bar * e_t
# $\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}$
mean_x0_coef = x.new_full((bs, 1, 1, 1), self.mean_x0_coef[step])
# $\frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}$
mean_xt_coef = x.new_full((bs, 1, 1, 1), self.mean_xt_coef[step])
# Calculate $\mu_t(x_t, t)$
#
# $$\mu_t(x_t, t) = \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}x_0
# + \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}x_t$$
mean = mean_x0_coef * x0 + mean_xt_coef * x
# $\log \tilde\beta_t$
log_var = x.new_full((bs, 1, 1, 1), self.log_var[step])
# Do not add noise when $t = 1$ (final step sampling process).
# Note that `step` is `0` when $t = 1$)
if step == 0:
noise = 0
# If same noise is used for all samples in the batch
elif repeat_noise:
noise = torch.randn((1, *x.shape[1:]))
# Different noise for each sample
else:
noise = torch.randn(x.shape)
# Multiply noise by the temperature
noise = noise * temperature
# Sample from,
#
# $$p_\theta(x_{t-1} | x_t) = \mathcal{N}\big(x_{t-1}; \mu_\theta(x_t, t), \tilde\beta_t \mathbf{I} \big)$$
x_prev = mean + (0.5 * log_var).exp() * noise
#
return x_prev, x0, e_t
@torch.no_grad()
def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):
"""
### Sample from $q(x_t|x_0)$
$$q(x_t|x_0) = \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)$$
:param x0: is $x_0$ of shape `[batch_size, channels, height, width]`
:param index: is the time step $t$ index
:param noise: is the noise, $\epsilon$
"""
# Random noise, if noise is not specified
if noise is None:
noise = torch.randn_like(x0)
# Sample from $\mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)$
return self.sqrt_alpha_bar[index] * x0 + self.sqrt_1m_alpha_bar[index] * noise

View File

@ -0,0 +1,13 @@
"""
---
title: Scripts to show example usages stable diffusion
summary: >
Annotated PyTorch implementation/tutorial of example usages of stable diffusion
---
# Scripts to show example usages [stable diffusion](../index.html)
* [Prompt to image diffusion](text_to_image.html)
* [Image to image diffusion](image_to_image.html)
* [In-painting](in_paint.html)
"""

View File

@ -0,0 +1,149 @@
"""
---
title: Generate images using stable diffusion with a prompt from a given image
summary: >
Generate images using stable diffusion with a prompt from a given image
---
# Generate images using [stable diffusion](../index.html) with a prompt from a given image
"""
import argparse
from pathlib import Path
import torch
from labml import lab, monit
from labml_nn.diffusion.stable_diffusion.sampler.ddim import DDIMSampler
from labml_nn.diffusion.stable_diffusion.util import load_model, load_img, save_images, set_seed
class Img2Img:
"""
### Image to image class
"""
def __init__(self, *, checkpoint_path: Path,
ddim_steps: int = 50,
ddim_eta: float = 0.0):
"""
:param checkpoint_path: is the path of the checkpoint
:param ddim_steps: is the number of sampling steps
:param ddim_eta: is the [DDIM sampling](../sampler/ddim.html) $\eta$ constant
"""
self.ddim_steps = ddim_steps
# Load [latent diffusion model](../latent_diffusion.html)
self.model = load_model(checkpoint_path)
# Get device
self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
# Move the model to device
self.model.to(self.device)
# Initialize [DDIM sampler](../sampler/ddim.html)
self.sampler = DDIMSampler(self.model,
n_steps=ddim_steps,
ddim_eta=ddim_eta)
@torch.no_grad()
def __call__(self, *,
dest_path: str,
orig_img: str,
strength: float,
batch_size: int = 3,
prompt: str,
uncond_scale: float = 5.0,
):
"""
:param dest_path: is the path to store the generated images
:param orig_img: is the image to transform
:param strength: specifies how much of the original image should not be preserved
:param batch_size: is the number of images to generate in a batch
:param prompt: is the prompt to generate images with
:param uncond_scale: is the unconditional guidance scale $s$. This is used for
$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
"""
# Make a batch of prompts
prompts = batch_size * [prompt]
# Load image
orig_image = load_img(orig_img).to(self.device)
# Encode the image in the latent space and make `batch_size` copies of it
orig = self.model.autoencoder_encode(orig_image).repeat(batch_size, 1, 1, 1)
# Get the number of steps to diffuse the original
assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
t_index = int(strength * self.ddim_steps)
# AMP auto casting
with torch.cuda.amp.autocast():
# In unconditional scaling is not $1$ get the embeddings for empty prompts (no conditioning).
if uncond_scale != 1.0:
un_cond = self.model.get_text_conditioning(batch_size * [""])
else:
un_cond = None
# Get the prompt embeddings
cond = self.model.get_text_conditioning(prompts)
# Add noise to the original image
x = self.sampler.q_sample(orig, t_index)
# Reconstruct from the noisy image
x = self.sampler.paint(x, cond, t_index,
uncond_scale=uncond_scale,
uncond_cond=un_cond)
# Decode the image from the [autoencoder](../model/autoencoder.html)
images = self.model.autoencoder_decode(x)
# Save images
save_images(images, dest_path, 'img_')
def main():
"""
### CLI
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--prompt",
type=str,
nargs="?",
default="a painting of a cute monkey playing guitar",
help="the prompt to render"
)
parser.add_argument(
"--orig-img",
type=str,
nargs="?",
help="path to the input image"
)
parser.add_argument("--batch_size", type=int, default=4, help="batch size", )
parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps")
parser.add_argument("--scale", type=float, default=5.0,
help="unconditional guidance scale: "
"eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))")
parser.add_argument("--strength", type=float, default=0.75,
help="strength for noise: "
" 1.0 corresponds to full destruction of information in init image")
opt = parser.parse_args()
set_seed(42)
img2img = Img2Img(checkpoint_path=lab.get_data_path() / 'stable-diffusion' / 'sd-v1-4.ckpt',
ddim_steps=opt.steps)
with monit.section('Generate'):
img2img(
dest_path='outputs',
orig_img=opt.orig_img,
strength=opt.strength,
batch_size=opt.batch_size,
prompt=opt.prompt,
uncond_scale=opt.scale)
#
if __name__ == "__main__":
main()

View File

@ -0,0 +1,166 @@
"""
---
title: In-paint images using stable diffusion with a prompt
summary: >
In-paint images using stable diffusion with a prompt
---
# In-paint images using [stable diffusion](../index.html) with a prompt
"""
import argparse
from pathlib import Path
from typing import Optional
import torch
from labml import lab, monit
from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
from labml_nn.diffusion.stable_diffusion.sampler import DiffusionSampler
from labml_nn.diffusion.stable_diffusion.sampler.ddim import DDIMSampler
from labml_nn.diffusion.stable_diffusion.util import load_model, save_images, load_img, set_seed
class InPaint:
"""
### Image in-painting class
"""
model: LatentDiffusion
sampler: DiffusionSampler
def __init__(self, *, checkpoint_path: Path,
ddim_steps: int = 50,
ddim_eta: float = 0.0):
"""
:param checkpoint_path: is the path of the checkpoint
:param ddim_steps: is the number of sampling steps
:param ddim_eta: is the [DDIM sampling](../sampler/ddim.html) $\eta$ constant
"""
self.ddim_steps = ddim_steps
# Load [latent diffusion model](../latent_diffusion.html)
self.model = load_model(checkpoint_path)
# Get device
self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
# Move the model to device
self.model.to(self.device)
# Initialize [DDIM sampler](../sampler/ddim.html)
self.sampler = DDIMSampler(self.model,
n_steps=ddim_steps,
ddim_eta=ddim_eta)
@torch.no_grad()
def __call__(self, *,
dest_path: str,
orig_img: str,
strength: float,
batch_size: int = 3,
prompt: str,
uncond_scale: float = 5.0,
mask: Optional[torch.Tensor] = None,
):
"""
:param dest_path: is the path to store the generated images
:param orig_img: is the image to transform
:param strength: specifies how much of the original image should not be preserved
:param batch_size: is the number of images to generate in a batch
:param prompt: is the prompt to generate images with
:param uncond_scale: is the unconditional guidance scale $s$. This is used for
$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
"""
# Make a batch of prompts
prompts = batch_size * [prompt]
# Load image
orig_image = load_img(orig_img).to(self.device)
# Encode the image in the latent space and make `batch_size` copies of it
orig = self.model.autoencoder_encode(orig_image).repeat(batch_size, 1, 1, 1)
# If `mask` is not provided,
# we set a sample mask to preserve the bottom half of the image
if mask is None:
mask = torch.zeros_like(orig, device=self.device)
mask[:, :, mask.shape[2] // 2:, :] = 1.
else:
mask = mask.to(self.device)
# Noise diffuse the original image
orig_noise = torch.randn(orig.shape, device=self.device)
# Get the number of steps to diffuse the original
assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
t_index = int(strength * self.ddim_steps)
# AMP auto casting
with torch.cuda.amp.autocast():
# In unconditional scaling is not $1$ get the embeddings for empty prompts (no conditioning).
if uncond_scale != 1.0:
un_cond = self.model.get_text_conditioning(batch_size * [""])
else:
un_cond = None
# Get the prompt embeddings
cond = self.model.get_text_conditioning(prompts)
# Add noise to the original image
x = self.sampler.q_sample(orig, t_index, noise=orig_noise)
# Reconstruct from the noisy image, while preserving the masked area
x = self.sampler.paint(x, cond, t_index,
orig=orig,
mask=mask,
orig_noise=orig_noise,
uncond_scale=uncond_scale,
uncond_cond=un_cond)
# Decode the image from the [autoencoder](../model/autoencoder.html)
images = self.model.autoencoder_decode(x)
# Save images
save_images(images, dest_path, 'paint_')
def main():
"""
### CLI
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--prompt",
type=str,
nargs="?",
default="a painting of a cute monkey playing guitar",
help="the prompt to render"
)
parser.add_argument(
"--orig-img",
type=str,
nargs="?",
help="path to the input image"
)
parser.add_argument("--batch_size", type=int, default=4, help="batch size", )
parser.add_argument("--steps", type=int, default=50, help="number of sampling steps")
parser.add_argument("--scale", type=float, default=5.0,
help="unconditional guidance scale: "
"eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))")
parser.add_argument("--strength", type=float, default=0.75,
help="strength for noise: "
" 1.0 corresponds to full destruction of information in init image")
opt = parser.parse_args()
set_seed(42)
in_paint = InPaint(checkpoint_path=lab.get_data_path() / 'stable-diffusion' / 'sd-v1-4.ckpt',
ddim_steps=opt.steps)
with monit.section('Generate'):
in_paint(dest_path='outputs',
orig_img=opt.orig_img,
strength=opt.strength,
batch_size=opt.batch_size,
prompt=opt.prompt,
uncond_scale=opt.scale)
#
if __name__ == "__main__":
main()

View File

@ -0,0 +1,151 @@
"""
---
title: Generate images using stable diffusion with a prompt
summary: >
Generate images using stable diffusion with a prompt
---
# Generate images using [stable diffusion](../index.html) with a prompt
"""
import argparse
import os
from pathlib import Path
import torch
from labml import lab, monit
from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
from labml_nn.diffusion.stable_diffusion.sampler.ddim import DDIMSampler
from labml_nn.diffusion.stable_diffusion.sampler.ddpm import DDPMSampler
from labml_nn.diffusion.stable_diffusion.util import load_model, save_images, set_seed
class Txt2Img:
"""
### Text to image class
"""
model: LatentDiffusion
def __init__(self, *,
checkpoint_path: Path,
sampler_name: str,
n_steps: int = 50,
ddim_eta: float = 0.0,
):
"""
:param checkpoint_path: is the path of the checkpoint
:param sampler_name: is the name of the [sampler](../sampler/index.html)
:param n_steps: is the number of sampling steps
:param ddim_eta: is the [DDIM sampling](../sampler/ddim.html) $\eta$ constant
"""
# Load [latent diffusion model](../latent_diffusion.html)
self.model = load_model(checkpoint_path)
# Get device
self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
# Move the model to device
self.model.to(self.device)
# Initialize [sampler](../sampler/index.html)
if sampler_name == 'ddim':
self.sampler = DDIMSampler(self.model,
n_steps=n_steps,
ddim_eta=ddim_eta)
elif sampler_name == 'ddpm':
self.sampler = DDPMSampler(self.model)
@torch.no_grad()
def __call__(self, *,
dest_path: str,
batch_size: int = 3,
prompt: str,
h: int = 512, w: int = 512,
uncond_scale: float = 7.5,
):
"""
:param dest_path: is the path to store the generated images
:param batch_size: is the number of images to generate in a batch
:param prompt: is the prompt to generate images with
:param h: is the height of the image
:param w: is the width of the image
:param uncond_scale: is the unconditional guidance scale $s$. This is used for
$\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
"""
# Number of channels in the image
c = 4
# Image to latent space resolution reduction
f = 8
# Make a batch of prompts
prompts = batch_size * [prompt]
# AMP auto casting
with torch.cuda.amp.autocast():
# In unconditional scaling is not $1$ get the embeddings for empty prompts (no conditioning).
if uncond_scale != 1.0:
un_cond = self.model.get_text_conditioning(batch_size * [""])
else:
un_cond = None
# Get the prompt embeddings
cond = self.model.get_text_conditioning(prompts)
# [Sample in the latent space](../sampler/index.html).
# `x` will be of shape `[batch_size, c, h / f, w / f]`
x = self.sampler.sample(cond=cond,
shape=[batch_size, c, h // f, w // f],
uncond_scale=uncond_scale,
uncond_cond=un_cond)
# Decode the image from the [autoencoder](../model/autoencoder.html)
images = self.model.autoencoder_decode(x)
# Save images
save_images(images, dest_path, 'txt_')
def main():
"""
### CLI
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--prompt",
type=str,
nargs="?",
default="a painting of a virus monster playing guitar",
help="the prompt to render"
)
parser.add_argument("--batch_size", type=int, default=4, help="batch size", )
parser.add_argument(
'--sampler',
dest='sampler_name',
choices=['ddim', 'ddpm'],
default='plms',
help=f'Set the sampler.',
)
parser.add_argument("--steps", type=int, default=50, help="number of sampling steps", )
parser.add_argument("--scale", type=float, default=7.5,
help="unconditional guidance scale: "
"eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))")
opt = parser.parse_args()
set_seed(42)
txt2img = Txt2Img(checkpoint_path=lab.get_data_path() / 'stable-diffusion' / 'sd-v1-4.ckpt',
sampler_name=opt.sampler_name,
n_steps=opt.steps)
with monit.section('Generate'):
txt2img(dest_path='outputs',
batch_size=opt.batch_size,
prompt=opt.prompt,
uncond_scale=opt.scale)
#
if __name__ == "__main__":
main()

View File

@ -0,0 +1,151 @@
"""
---
title: Utility functions for stable diffusion
summary: >
Utility functions for stable diffusion
---
# Utility functions for [stable diffusion](index.html)
"""
import os
import random
from pathlib import Path
import PIL
import numpy as np
import torch
from PIL import Image
from labml import monit
from labml.logger import inspect
from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
from labml_nn.diffusion.stable_diffusion.model.autoencoder import Encoder, Decoder, Autoencoder
from labml_nn.diffusion.stable_diffusion.model.clip_embedder import CLIPTextEmbedder
from labml_nn.diffusion.stable_diffusion.model.unet import UNetModel
def set_seed(seed: int):
"""
### Set random seeds
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def load_model(path: Path = None) -> LatentDiffusion:
"""
### Load [`LatentDiffusion` model](latent_diffusion.html)
"""
# Initialize the autoencoder
with monit.section('Initialize autoencoder'):
encoder = Encoder(z_channels=4,
in_channels=3,
channels=128,
channel_multipliers=[1, 2, 4, 4],
n_resnet_blocks=2)
decoder = Decoder(out_channels=3,
z_channels=4,
channels=128,
channel_multipliers=[1, 2, 4, 4],
n_resnet_blocks=2)
autoencoder = Autoencoder(emb_channels=4,
encoder=encoder,
decoder=decoder,
z_channels=4)
# Initialize the CLIP text embedder
with monit.section('Initialize CLIP Embedder'):
clip_text_embedder = CLIPTextEmbedder()
# Initialize the U-Net
with monit.section('Initialize U-Net'):
unet_model = UNetModel(in_channels=4,
out_channels=4,
channels=320,
attention_levels=[0, 1, 2],
n_res_blocks=2,
channel_multipliers=[1, 2, 4, 4],
n_heads=8,
tf_layers=1,
d_cond=768)
# Initialize the Latent Diffusion model
with monit.section('Initialize Latent Diffusion model'):
model = LatentDiffusion(linear_start=0.00085,
linear_end=0.0120,
n_steps=1000,
latent_scaling_factor=0.18215,
autoencoder=autoencoder,
clip_embedder=clip_text_embedder,
unet_model=unet_model)
# Load the checkpoint
with monit.section(f"Loading model from {path}"):
checkpoint = torch.load(path, map_location="cpu")
# Set model state
with monit.section('Load state'):
missing_keys, extra_keys = model.load_state_dict(checkpoint["state_dict"], strict=False)
# Debugging output
inspect(global_step=checkpoint.get('global_step', -1), missing_keys=missing_keys, extra_keys=extra_keys,
_expand=True)
#
model.eval()
return model
def load_img(path: str):
"""
### Load an image
This loads an image from a file and returns a PyTorch tensor.
:param path: is the path of the image
"""
# Open Image
image = Image.open(path).convert("RGB")
# Get image size
w, h = image.size
# Resize to a multiple of 32
w = w - w % 32
h = h - h % 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
# Convert to numpy and map to `[-1, 1]` for `[0, 255]`
image = np.array(image).astype(np.float32) * (2. / 255.0) - 1
# Transpose to shape `[batch_size, channels, height, width]`
image = image[None].transpose(0, 3, 1, 2)
# Convert to torch
return torch.from_numpy(image)
def save_images(images: torch.Tensor, dest_path: str, prefix: str = '', img_format: str = 'jpeg'):
"""
### Save a images
:param images: is the tensor with images of shape `[batch_size, channels, height, width]`
:param dest_path: is the folder to save images in
:param prefix: is the prefix to add to file names
:param img_format: is the image format
"""
# Create the destination folder
os.makedirs(dest_path, exist_ok=True)
# Map images to `[0, 1]` space and clip
images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
# Transpose to `[batch_size, height, width, channels]` and convert to numpy
images = images.cpu().permute(0, 2, 3, 1).numpy()
# Save images
for i, img in enumerate(images):
img = Image.fromarray((255. * img).astype(np.uint8))
img.save(os.path.join(dest_path, f"{prefix}{i:05}.{img_format}"), format=img_format)

View File

@ -8,7 +8,8 @@ summary: >
# GPT-NeoX Checkpoints
"""
from typing import Dict, Union, Tuple
from pathlib import Path
from typing import Dict, Union, Tuple, Optional
import torch
from torch import nn
@ -19,12 +20,23 @@ from labml.utils.download import download_file
# Parent url
CHECKPOINTS_URL = 'https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/'
# Download path
CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox_fast' / 'slim_weights'
if not CHECKPOINTS_DOWNLOAD_PATH.exists():
CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox' / 'slim_weights'
inspect(neox_checkpoint_path=CHECKPOINTS_DOWNLOAD_PATH)
_CHECKPOINTS_DOWNLOAD_PATH: Optional[Path] = None
# Download path
def get_checkpoints_download_path():
global _CHECKPOINTS_DOWNLOAD_PATH
if _CHECKPOINTS_DOWNLOAD_PATH is not None:
return _CHECKPOINTS_DOWNLOAD_PATH
_CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox_fast' / 'slim_weights'
if not _CHECKPOINTS_DOWNLOAD_PATH.exists():
_CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox' / 'slim_weights'
inspect(neox_checkpoint_path=_CHECKPOINTS_DOWNLOAD_PATH)
return _CHECKPOINTS_DOWNLOAD_PATH
def get_files_to_download(n_layers: int = 44):
@ -65,7 +77,7 @@ def download(n_layers: int = 44):
# Log
logger.log(['Downloading ', (f'{i + 1 :3d}/{len(files)}', Text.meta), ': ', (f, Text.value)])
# Download
download_file(CHECKPOINTS_URL + f, CHECKPOINTS_DOWNLOAD_PATH / f)
download_file(CHECKPOINTS_URL + f, get_checkpoints_download_path() / f)
def load_checkpoint_files(files: Tuple[str, str]):
@ -75,7 +87,7 @@ def load_checkpoint_files(files: Tuple[str, str]):
:param files: pair of files to load
:return: the loaded parameter tensors
"""
checkpoint_path = CHECKPOINTS_DOWNLOAD_PATH / 'global_step150000'
checkpoint_path = get_checkpoints_download_path() / 'global_step150000'
with monit.section('Load checkpoint'):
data = [torch.load(checkpoint_path / f) for f in files]