mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-11-01 03:43:09 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			171 lines
		
	
	
		
			5.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			171 lines
		
	
	
		
			5.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| ---
 | |
| title: Training a U-Net on Carvana dataset
 | |
| summary: >
 | |
|   Code for training a U-Net model on Carvana dataset.
 | |
| ---
 | |
| 
 | |
| # Training [U-Net](index.html)
 | |
| 
 | |
| This trains a [U-Net](index.html) model on [Carvana dataset](carvana.html).
 | |
| You can find the download instructions
 | |
| [on Kaggle](https://www.kaggle.com/competitions/carvana-image-masking-challenge/data).
 | |
| 
 | |
| Save the training images inside `carvana/train` folder and the masks in `carvana/train_masks` folder.
 | |
| 
 | |
| For simplicity, we do not do a training and validation split.
 | |
| """
 | |
| 
 | |
| import numpy as np
 | |
| import torch
 | |
| import torch.utils.data
 | |
| import torchvision.transforms.functional
 | |
| from torch import nn
 | |
| 
 | |
| from labml import lab, tracker, experiment, monit
 | |
| from labml.configs import BaseConfigs
 | |
| from labml_helpers.device import DeviceConfigs
 | |
| from labml_nn.unet.carvana import CarvanaDataset
 | |
| from labml_nn.unet import UNet
 | |
| 
 | |
| 
 | |
| class Configs(BaseConfigs):
 | |
|     """
 | |
|     ## Configurations
 | |
|     """
 | |
|     # Device to train the model on.
 | |
|     # [`DeviceConfigs`](https://docs.labml.ai/api/helpers.html#labml_helpers.device.DeviceConfigs)
 | |
|     #  picks up an available CUDA device or defaults to CPU.
 | |
|     device: torch.device = DeviceConfigs()
 | |
| 
 | |
|     # [U-Net](index.html) model
 | |
|     model: UNet
 | |
| 
 | |
|     # Number of channels in the image. $3$ for RGB.
 | |
|     image_channels: int = 3
 | |
|     # Number of channels in the output mask. $1$ for binary mask.
 | |
|     mask_channels: int = 1
 | |
| 
 | |
|     # Batch size
 | |
|     batch_size: int = 1
 | |
|     # Learning rate
 | |
|     learning_rate: float = 2.5e-4
 | |
| 
 | |
|     # Number of training epochs
 | |
|     epochs: int = 4
 | |
| 
 | |
|     # Dataset
 | |
|     dataset: CarvanaDataset
 | |
|     # Dataloader
 | |
|     data_loader: torch.utils.data.DataLoader
 | |
| 
 | |
|     # Loss function
 | |
|     loss_func = nn.BCELoss()
 | |
|     # Sigmoid function for binary classification
 | |
|     sigmoid = nn.Sigmoid()
 | |
| 
 | |
|     # Adam optimizer
 | |
|     optimizer: torch.optim.Adam
 | |
| 
 | |
|     def init(self):
 | |
|         # Initialize the [Carvana dataset](carvana.html)
 | |
|         self.dataset = CarvanaDataset(lab.get_data_path() / 'carvana' / 'train',
 | |
|                                       lab.get_data_path() / 'carvana' / 'train_masks')
 | |
|         # Initialize the model
 | |
|         self.model = UNet(self.image_channels, self.mask_channels).to(self.device)
 | |
| 
 | |
|         # Create dataloader
 | |
|         self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size,
 | |
|                                                        shuffle=True, pin_memory=True)
 | |
|         # Create optimizer
 | |
|         self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
 | |
| 
 | |
|         # Image logging
 | |
|         tracker.set_image("sample", True)
 | |
| 
 | |
|     @torch.no_grad()
 | |
|     def sample(self, idx=-1):
 | |
|         """
 | |
|         ### Sample images
 | |
|         """
 | |
| 
 | |
|         # Get a random sample
 | |
|         x, _ = self.dataset[np.random.randint(len(self.dataset))]
 | |
|         # Move data to device
 | |
|         x = x.to(self.device)
 | |
| 
 | |
|         # Get predicted mask
 | |
|         mask = self.sigmoid(self.model(x[None, :]))
 | |
|         # Crop the image to the size of the mask
 | |
|         x = torchvision.transforms.functional.center_crop(x, [mask.shape[2], mask.shape[3]])
 | |
|         # Log samples
 | |
|         tracker.save('sample', x * mask)
 | |
| 
 | |
|     def train(self):
 | |
|         """
 | |
|         ### Train for an epoch
 | |
|         """
 | |
| 
 | |
|         # Iterate through the dataset.
 | |
|         # Use [`mix`](https://docs.labml.ai/api/monit.html#labml.monit.mix)
 | |
|         # to sample $50$ times per epoch.
 | |
|         for _, (image, mask) in monit.mix(('Train', self.data_loader), (self.sample, list(range(50)))):
 | |
|             # Increment global step
 | |
|             tracker.add_global_step()
 | |
|             # Move data to device
 | |
|             image, mask = image.to(self.device), mask.to(self.device)
 | |
| 
 | |
|             # Make the gradients zero
 | |
|             self.optimizer.zero_grad()
 | |
|             # Get predicted mask logits
 | |
|             logits = self.model(image)
 | |
|             # Crop the target mask to the size of the logits. Size of the logits will be smaller if we
 | |
|             # don't use padding in convolutional layers in the U-Net.
 | |
|             mask = torchvision.transforms.functional.center_crop(mask, [logits.shape[2], logits.shape[3]])
 | |
|             # Calculate loss
 | |
|             loss = self.loss_func(self.sigmoid(logits), mask)
 | |
|             # Compute gradients
 | |
|             loss.backward()
 | |
|             # Take an optimization step
 | |
|             self.optimizer.step()
 | |
|             # Track the loss
 | |
|             tracker.save('loss', loss)
 | |
| 
 | |
|     def run(self):
 | |
|         """
 | |
|         ### Training loop
 | |
|         """
 | |
|         for _ in monit.loop(self.epochs):
 | |
|             # Train the model
 | |
|             self.train()
 | |
|             # New line in the console
 | |
|             tracker.new_line()
 | |
|             # Save the model
 | |
|             experiment.save_checkpoint()
 | |
| 
 | |
| 
 | |
| def main():
 | |
|     # Create experiment
 | |
|     experiment.create(name='unet')
 | |
| 
 | |
|     # Create configurations
 | |
|     configs = Configs()
 | |
| 
 | |
|     # Set configurations. You can override the defaults by passing the values in the dictionary.
 | |
|     experiment.configs(configs, {})
 | |
| 
 | |
|     # Initialize
 | |
|     configs.init()
 | |
| 
 | |
|     # Set models for saving and loading
 | |
|     experiment.add_pytorch_models({'model': configs.model})
 | |
| 
 | |
|     # Start and run the training loop
 | |
|     with experiment.start():
 | |
|         configs.run()
 | |
| 
 | |
| 
 | |
| #
 | |
| if __name__ == '__main__':
 | |
|     main()
 | 
