11import os
12import random
13from pathlib import Path
14
15import PIL
16import numpy as np
17import torch
18from PIL import Image
19
20from labml import monit
21from labml.logger import inspect
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
23from labml_nn.diffusion.stable_diffusion.model.autoencoder import Encoder, Decoder, Autoencoder
24from labml_nn.diffusion.stable_diffusion.model.clip_embedder import CLIPTextEmbedder
25from labml_nn.diffusion.stable_diffusion.model.unet import UNetModel28def set_seed(seed: int):32    random.seed(seed)
33    np.random.seed(seed)
34    torch.manual_seed(seed)
35    torch.cuda.manual_seed_all(seed)38def load_model(path: Path = None) -> LatentDiffusion:Initialize the autoencoder
44    with monit.section('Initialize autoencoder'):
45        encoder = Encoder(z_channels=4,
46                          in_channels=3,
47                          channels=128,
48                          channel_multipliers=[1, 2, 4, 4],
49                          n_resnet_blocks=2)
50
51        decoder = Decoder(out_channels=3,
52                          z_channels=4,
53                          channels=128,
54                          channel_multipliers=[1, 2, 4, 4],
55                          n_resnet_blocks=2)
56
57        autoencoder = Autoencoder(emb_channels=4,
58                                  encoder=encoder,
59                                  decoder=decoder,
60                                  z_channels=4)Initialize the CLIP text embedder
63    with monit.section('Initialize CLIP Embedder'):
64        clip_text_embedder = CLIPTextEmbedder()Initialize the U-Net
67    with monit.section('Initialize U-Net'):
68        unet_model = UNetModel(in_channels=4,
69                               out_channels=4,
70                               channels=320,
71                               attention_levels=[0, 1, 2],
72                               n_res_blocks=2,
73                               channel_multipliers=[1, 2, 4, 4],
74                               n_heads=8,
75                               tf_layers=1,
76                               d_cond=768)Initialize the Latent Diffusion model
79    with monit.section('Initialize Latent Diffusion model'):
80        model = LatentDiffusion(linear_start=0.00085,
81                                linear_end=0.0120,
82                                n_steps=1000,
83                                latent_scaling_factor=0.18215,
84
85                                autoencoder=autoencoder,
86                                clip_embedder=clip_text_embedder,
87                                unet_model=unet_model)Load the checkpoint
90    with monit.section(f"Loading model from {path}"):
91        checkpoint = torch.load(path, map_location="cpu")Set model state
94    with monit.section('Load state'):
95        missing_keys, extra_keys = model.load_state_dict(checkpoint["state_dict"], strict=False)Debugging output
98    inspect(global_step=checkpoint.get('global_step', -1), missing_keys=missing_keys, extra_keys=extra_keys,
99            _expand=True)102    model.eval()
103    return modelThis loads an image from a file and returns a PyTorch tensor.
path
  is the path of the image106def load_img(path: str):Open Image
115    image = Image.open(path).convert("RGB")Get image size
117    w, h = image.sizeResize to a multiple of 32
119    w = w - w % 32
120    h = h - h % 32
121    image = image.resize((w, h), resample=PIL.Image.LANCZOS)Convert to numpy and map to [-1, 1]
 for [0, 255]
 
123    image = np.array(image).astype(np.float32) * (2. / 255.0) - 1Transpose to shape [batch_size, channels, height, width]
 
125    image = image[None].transpose(0, 3, 1, 2)Convert to torch
127    return torch.from_numpy(image)images
  is the tensor with images of shape [batch_size, channels, height, width]
 dest_path
  is the folder to save images in prefix
  is the prefix to add to file names img_format
  is the image format130def save_images(images: torch.Tensor, dest_path: str, prefix: str = '', img_format: str = 'jpeg'):Create the destination folder
141    os.makedirs(dest_path, exist_ok=True)Map images to [0, 1]
 space and clip 
144    images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)Transpose to [batch_size, height, width, channels]
 and convert to numpy 
146    images = images.cpu().permute(0, 2, 3, 1).numpy()Save images
149    for i, img in enumerate(images):
150        img = Image.fromarray((255. * img).astype(np.uint8))
151        img.save(os.path.join(dest_path, f"{prefix}{i:05}.{img_format}"), format=img_format)