This trains a U-Net model on Carvana dataset. You can find the download instructions on Kaggle.
Save the training images inside carvana/train
 folder and the masks in carvana/train_masks
 folder.
For simplicity, we do not do a training and validation split.
19import numpy as np
20import torchvision.transforms.functional
21
22import torch
23import torch.utils.data
24from labml import lab, tracker, experiment, monit
25from labml.configs import BaseConfigs
26from labml_nn.helpers.device import DeviceConfigs
27from labml_nn.unet import UNet
28from labml_nn.unet.carvana import CarvanaDataset
29from torch import nn32class Configs(BaseConfigs):Device to train the model on. DeviceConfigs
  picks up an available CUDA device or defaults to CPU. 
39    device: torch.device = DeviceConfigs()Number of channels in the image. for RGB.
45    image_channels: int = 3Number of channels in the output mask. for binary mask.
47    mask_channels: int = 1Batch size
50    batch_size: int = 1Learning rate
52    learning_rate: float = 2.5e-4Number of training epochs
55    epochs: int = 4Dataset
58    dataset: CarvanaDatasetDataloader
60    data_loader: torch.utils.data.DataLoaderLoss function
63    loss_func = nn.BCELoss()Sigmoid function for binary classification
65    sigmoid = nn.Sigmoid()Adam optimizer
68    optimizer: torch.optim.Adam70    def init(self):Initialize the Carvana dataset
72        self.dataset = CarvanaDataset(lab.get_data_path() / 'carvana' / 'train',
73                                      lab.get_data_path() / 'carvana' / 'train_masks')Initialize the model
75        self.model = UNet(self.image_channels, self.mask_channels).to(self.device)Create dataloader
78        self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size,
79                                                       shuffle=True, pin_memory=True)Create optimizer
81        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)Image logging
84        tracker.set_image("sample", True)86    @torch.no_grad()
87    def sample(self, idx=-1):Get a random sample
93        x, _ = self.dataset[np.random.randint(len(self.dataset))]Move data to device
95        x = x.to(self.device)Get predicted mask
98        mask = self.sigmoid(self.model(x[None, :]))Crop the image to the size of the mask
100        x = torchvision.transforms.functional.center_crop(x, [mask.shape[2], mask.shape[3]])Log samples
102        tracker.save('sample', x * mask)104    def train(self):112        for _, (image, mask) in monit.mix(('Train', self.data_loader), (self.sample, list(range(50)))):Increment global step
114            tracker.add_global_step()Move data to device
116            image, mask = image.to(self.device), mask.to(self.device)Make the gradients zero
119            self.optimizer.zero_grad()Get predicted mask logits
121            logits = self.model(image)Crop the target mask to the size of the logits. Size of the logits will be smaller if we don't use padding in convolutional layers in the U-Net.
124            mask = torchvision.transforms.functional.center_crop(mask, [logits.shape[2], logits.shape[3]])Calculate loss
126            loss = self.loss_func(self.sigmoid(logits), mask)Compute gradients
128            loss.backward()Take an optimization step
130            self.optimizer.step()Track the loss
132            tracker.save('loss', loss)134    def run(self):138        for _ in monit.loop(self.epochs):Train the model
140            self.train()New line in the console
142            tracker.new_line()Save the model
146def main():Create experiment
148    experiment.create(name='unet')Create configurations
151    configs = Configs()Set configurations. You can override the defaults by passing the values in the dictionary.
154    experiment.configs(configs, {})Initialize
157    configs.init()Set models for saving and loading
160    experiment.add_pytorch_models({'model': configs.model})Start and run the training loop
163    with experiment.start():
164        configs.run()168if __name__ == '__main__':
169    main()