mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 10:48:49 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			250 lines
		
	
	
		
			7.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			250 lines
		
	
	
		
			7.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| ---
 | |
| title: Denoising Diffusion Probabilistic Models (DDPM) training
 | |
| summary: >
 | |
|   Training code for
 | |
|   Denoising Diffusion Probabilistic Model.
 | |
| ---
 | |
| 
 | |
| # [Denoising Diffusion Probabilistic Models (DDPM)](index.html) training
 | |
| 
 | |
| This trains a DDPM based model on CelebA HQ dataset. You can find the download instruction in this
 | |
| [discussion on fast.ai](https://forums.fast.ai/t/download-celeba-hq-dataset/45873/3).
 | |
| Save the images inside [`data/celebA` folder](#dataset_path).
 | |
| 
 | |
| The paper had used a exponential moving average of the model with a decay of $0.9999$. We have skipped this for
 | |
| simplicity.
 | |
| """
 | |
| from typing import List
 | |
| 
 | |
| import torch
 | |
| import torch.utils.data
 | |
| import torchvision
 | |
| from PIL import Image
 | |
| 
 | |
| from labml import lab, tracker, experiment, monit
 | |
| from labml.configs import BaseConfigs, option
 | |
| from labml_helpers.device import DeviceConfigs
 | |
| from labml_nn.diffusion.ddpm import DenoiseDiffusion
 | |
| from labml_nn.diffusion.ddpm.unet import UNet
 | |
| 
 | |
| 
 | |
| class Configs(BaseConfigs):
 | |
|     """
 | |
|     ## Configurations
 | |
|     """
 | |
|     # Device to train the model on.
 | |
|     # [`DeviceConfigs`](https://docs.labml.ai/api/helpers.html#labml_helpers.device.DeviceConfigs)
 | |
|     #  picks up an available CUDA device or defaults to CPU.
 | |
|     device: torch.device = DeviceConfigs()
 | |
| 
 | |
|     # U-Net model for $\textcolor{cyan}{\epsilon_\theta}(x_t, t)$
 | |
|     eps_model: UNet
 | |
|     # [DDPM algorithm](index.html)
 | |
|     diffusion: DenoiseDiffusion
 | |
| 
 | |
|     # Number of channels in the image. $3$ for RGB.
 | |
|     image_channels: int = 3
 | |
|     # Image size
 | |
|     image_size: int = 32
 | |
|     # Number of channels in the initial feature map
 | |
|     n_channels: int = 64
 | |
|     # The list of channel numbers at each resolution.
 | |
|     # The number of channels is `channel_multipliers[i] * n_channels`
 | |
|     channel_multipliers: List[int] = [1, 2, 2, 4]
 | |
|     # The list of booleans that indicate whether to use attention at each resolution
 | |
|     is_attention: List[int] = [False, False, False, True]
 | |
| 
 | |
|     # Number of time steps $T$
 | |
|     n_steps: int = 1_000
 | |
|     # Batch size
 | |
|     batch_size: int = 64
 | |
|     # Number of samples to generate
 | |
|     n_samples: int = 16
 | |
|     # Learning rate
 | |
|     learning_rate: float = 2e-5
 | |
| 
 | |
|     # Number of training epochs
 | |
|     epochs: int = 1_000
 | |
| 
 | |
|     # Dataset
 | |
|     dataset: torch.utils.data.Dataset
 | |
|     # Dataloader
 | |
|     data_loader: torch.utils.data.DataLoader
 | |
| 
 | |
|     # Adam optimizer
 | |
|     optimizer: torch.optim.Adam
 | |
| 
 | |
|     def init(self):
 | |
|         # Create $\textcolor{cyan}{\epsilon_\theta}(x_t, t)$ model
 | |
|         self.eps_model = UNet(
 | |
|             image_channels=self.image_channels,
 | |
|             n_channels=self.n_channels,
 | |
|             ch_mults=self.channel_multipliers,
 | |
|             is_attn=self.is_attention,
 | |
|         ).to(self.device)
 | |
| 
 | |
|         # Create [DDPM class](index.html)
 | |
|         self.diffusion = DenoiseDiffusion(
 | |
|             eps_model=self.eps_model,
 | |
|             n_steps=self.n_steps,
 | |
|             device=self.device,
 | |
|         )
 | |
| 
 | |
|         # Create dataloader
 | |
|         self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)
 | |
|         # Create optimizer
 | |
|         self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)
 | |
| 
 | |
|         # Image logging
 | |
|         tracker.set_image("sample", True)
 | |
| 
 | |
|     def sample(self):
 | |
|         """
 | |
|         ### Sample images
 | |
|         """
 | |
|         with torch.no_grad():
 | |
|             # $x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$
 | |
|             x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
 | |
|                             device=self.device)
 | |
| 
 | |
|             # Remove noise for $T$ steps
 | |
|             for t_ in monit.iterate('Sample', self.n_steps):
 | |
|                 # $t$
 | |
|                 t = self.n_steps - t_ - 1
 | |
|                 # Sample from $\textcolor{cyan}{p_\theta}(x_{t-1}|x_t)$
 | |
|                 x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))
 | |
| 
 | |
|             # Log samples
 | |
|             tracker.save('sample', x)
 | |
| 
 | |
|     def train(self):
 | |
|         """
 | |
|         ### Train
 | |
|         """
 | |
| 
 | |
|         # Iterate through the dataset
 | |
|         for data in monit.iterate('Train', self.data_loader):
 | |
|             # Increment global step
 | |
|             tracker.add_global_step()
 | |
|             # Move data to device
 | |
|             data = data.to(self.device)
 | |
| 
 | |
|             # Make the gradients zero
 | |
|             self.optimizer.zero_grad()
 | |
|             # Calculate loss
 | |
|             loss = self.diffusion.loss(data)
 | |
|             # Compute gradients
 | |
|             loss.backward()
 | |
|             # Take an optimization step
 | |
|             self.optimizer.step()
 | |
|             # Track the loss
 | |
|             tracker.save('loss', loss)
 | |
| 
 | |
|     def run(self):
 | |
|         """
 | |
|         ### Training loop
 | |
|         """
 | |
|         for _ in monit.loop(self.epochs):
 | |
|             # Train the model
 | |
|             self.train()
 | |
|             # Sample some images
 | |
|             self.sample()
 | |
|             # New line in the console
 | |
|             tracker.new_line()
 | |
|             # Save the model
 | |
|             experiment.save_checkpoint()
 | |
| 
 | |
| 
 | |
| class CelebADataset(torch.utils.data.Dataset):
 | |
|     """
 | |
|     ### CelebA HQ dataset
 | |
|     """
 | |
| 
 | |
|     def __init__(self, image_size: int):
 | |
|         super().__init__()
 | |
| 
 | |
|         # CelebA images folder
 | |
|         folder = lab.get_data_path() / 'celebA'
 | |
|         # List of files
 | |
|         self._files = [p for p in folder.glob(f'**/*.jpg')]
 | |
| 
 | |
|         # Transformations to resize the image and convert to tensor
 | |
|         self._transform = torchvision.transforms.Compose([
 | |
|             torchvision.transforms.Resize(image_size),
 | |
|             torchvision.transforms.ToTensor(),
 | |
|         ])
 | |
| 
 | |
|     def __len__(self):
 | |
|         """
 | |
|         Size of the dataset
 | |
|         """
 | |
|         return len(self._files)
 | |
| 
 | |
|     def __getitem__(self, index: int):
 | |
|         """
 | |
|         Get an image
 | |
|         """
 | |
|         img = Image.open(self._files[index])
 | |
|         return self._transform(img)
 | |
| 
 | |
| 
 | |
| @option(Configs.dataset, 'CelebA')
 | |
| def celeb_dataset(c: Configs):
 | |
|     """
 | |
|     Create CelebA dataset
 | |
|     """
 | |
|     return CelebADataset(c.image_size)
 | |
| 
 | |
| 
 | |
| class MNISTDataset(torchvision.datasets.MNIST):
 | |
|     """
 | |
|     ### MNIST dataset
 | |
|     """
 | |
| 
 | |
|     def __init__(self, image_size):
 | |
|         transform = torchvision.transforms.Compose([
 | |
|             torchvision.transforms.Resize(image_size),
 | |
|             torchvision.transforms.ToTensor(),
 | |
|         ])
 | |
| 
 | |
|         super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)
 | |
| 
 | |
|     def __getitem__(self, item):
 | |
|         return super().__getitem__(item)[0]
 | |
| 
 | |
| 
 | |
| @option(Configs.dataset, 'MNIST')
 | |
| def mnist_dataset(c: Configs):
 | |
|     """
 | |
|     Create MNIST dataset
 | |
|     """
 | |
|     return MNISTDataset(c.image_size)
 | |
| 
 | |
| 
 | |
| def main():
 | |
|     # Create experiment
 | |
|     experiment.create(name='diffuse')
 | |
| 
 | |
|     # Create configurations
 | |
|     configs = Configs()
 | |
| 
 | |
|     # Set configurations. You can override the defaults by passing the values in the dictionary.
 | |
|     experiment.configs(configs, {
 | |
|     })
 | |
| 
 | |
|     # Initialize
 | |
|     configs.init()
 | |
| 
 | |
|     # Set models for saving and loading
 | |
|     experiment.add_pytorch_models({'eps_model': configs.eps_model})
 | |
| 
 | |
|     # Start and run the training loop
 | |
|     with experiment.start():
 | |
|         configs.run()
 | |
| 
 | |
| 
 | |
| #
 | |
| if __name__ == '__main__':
 | |
|     main()
 | 
