""" --- title: Training a U-Net on Carvana dataset summary: > Code for training a U-Net model on Carvana dataset. --- # Training [U-Net](index.html) This trains a [U-Net](index.html) model on [Carvana dataset](carvana.html). You can find the download instructions [on Kaggle](https://www.kaggle.com/competitions/carvana-image-masking-challenge/data). Save the training images inside `carvana/train` folder and the masks in `carvana/train_masks` folder. For simplicity, we do not do a training and validation split. """ import numpy as np import torchvision.transforms.functional import torch import torch.utils.data from labml import lab, tracker, experiment, monit from labml.configs import BaseConfigs from labml_nn.helpers.device import DeviceConfigs from labml_nn.unet import UNet from labml_nn.unet.carvana import CarvanaDataset from torch import nn class Configs(BaseConfigs): """ ## Configurations """ # Device to train the model on. # [`DeviceConfigs`](../helpers/device.html) # picks up an available CUDA device or defaults to CPU. device: torch.device = DeviceConfigs() # [U-Net](index.html) model model: UNet # Number of channels in the image. $3$ for RGB. image_channels: int = 3 # Number of channels in the output mask. $1$ for binary mask. mask_channels: int = 1 # Batch size batch_size: int = 1 # Learning rate learning_rate: float = 2.5e-4 # Number of training epochs epochs: int = 4 # Dataset dataset: CarvanaDataset # Dataloader data_loader: torch.utils.data.DataLoader # Loss function loss_func = nn.BCELoss() # Sigmoid function for binary classification sigmoid = nn.Sigmoid() # Adam optimizer optimizer: torch.optim.Adam def init(self): # Initialize the [Carvana dataset](carvana.html) self.dataset = CarvanaDataset(lab.get_data_path() / 'carvana' / 'train', lab.get_data_path() / 'carvana' / 'train_masks') # Initialize the model self.model = UNet(self.image_channels, self.mask_channels).to(self.device) # Create dataloader self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True) # Create optimizer self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) # Image logging tracker.set_image("sample", True) @torch.no_grad() def sample(self, idx=-1): """ ### Sample images """ # Get a random sample x, _ = self.dataset[np.random.randint(len(self.dataset))] # Move data to device x = x.to(self.device) # Get predicted mask mask = self.sigmoid(self.model(x[None, :])) # Crop the image to the size of the mask x = torchvision.transforms.functional.center_crop(x, [mask.shape[2], mask.shape[3]]) # Log samples tracker.save('sample', x * mask) def train(self): """ ### Train for an epoch """ # Iterate through the dataset. # Use [`mix`](https://docs.labml.ai/api/monit.html#labml.monit.mix) # to sample $50$ times per epoch. for _, (image, mask) in monit.mix(('Train', self.data_loader), (self.sample, list(range(50)))): # Increment global step tracker.add_global_step() # Move data to device image, mask = image.to(self.device), mask.to(self.device) # Make the gradients zero self.optimizer.zero_grad() # Get predicted mask logits logits = self.model(image) # Crop the target mask to the size of the logits. Size of the logits will be smaller if we # don't use padding in convolutional layers in the U-Net. mask = torchvision.transforms.functional.center_crop(mask, [logits.shape[2], logits.shape[3]]) # Calculate loss loss = self.loss_func(self.sigmoid(logits), mask) # Compute gradients loss.backward() # Take an optimization step self.optimizer.step() # Track the loss tracker.save('loss', loss) def run(self): """ ### Training loop """ for _ in monit.loop(self.epochs): # Train the model self.train() # New line in the console tracker.new_line() # Save the model def main(): # Create experiment experiment.create(name='unet') # Create configurations configs = Configs() # Set configurations. You can override the defaults by passing the values in the dictionary. experiment.configs(configs, {}) # Initialize configs.init() # Set models for saving and loading experiment.add_pytorch_models({'model': configs.model}) # Start and run the training loop with experiment.start(): configs.run() # if __name__ == '__main__': main()