This trains a U-Net model on Carvana dataset. You can find the download instructions on Kaggle.
Save the training images inside carvana/train
 folder and the masks in carvana/train_masks
 folder.
For simplicity, we do not do a training and validation split.
19import numpy as np
20import torch
21import torch.utils.data
22import torchvision.transforms.functional
23from torch import nn
24
25from labml import lab, tracker, experiment, monit
26from labml.configs import BaseConfigs
27from labml_helpers.device import DeviceConfigs
28from labml_nn.unet.carvana import CarvanaDataset
29from labml_nn.unet import UNet32class Configs(BaseConfigs):Device to train the model on. DeviceConfigs
  picks up an available CUDA device or defaults to CPU. 
39    device: torch.device = DeviceConfigs()Number of channels in the image. for RGB.
45    image_channels: int = 3Number of channels in the output mask. for binary mask.
47    mask_channels: int = 1Batch size
50    batch_size: int = 1Learning rate
52    learning_rate: float = 2.5e-4Number of training epochs
55    epochs: int = 4Dataset
58    dataset: CarvanaDatasetDataloader
60    data_loader: torch.utils.data.DataLoaderLoss function
63    loss_func = nn.BCELoss()Sigmoid function for binary classification
65    sigmoid = nn.Sigmoid()Adam optimizer
68    optimizer: torch.optim.Adam70    def init(self):Initialize the Carvana dataset
72        self.dataset = CarvanaDataset(lab.get_data_path() / 'carvana' / 'train',
73                                      lab.get_data_path() / 'carvana' / 'train_masks')Initialize the model
75        self.model = UNet(self.image_channels, self.mask_channels).to(self.device)Create dataloader
78        self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size,
79                                                       shuffle=True, pin_memory=True)Create optimizer
81        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)Image logging
84        tracker.set_image("sample", True)86    @torch.no_grad()
87    def sample(self, idx=-1):Get a random sample
93        x, _ = self.dataset[np.random.randint(len(self.dataset))]Move data to device
95        x = x.to(self.device)Get predicted mask
98        mask = self.sigmoid(self.model(x[None, :]))Crop the image to the size of the mask
100        x = torchvision.transforms.functional.center_crop(x, [mask.shape[2], mask.shape[3]])Log samples
102        tracker.save('sample', x * mask)104    def train(self):112        for _, (image, mask) in monit.mix(('Train', self.data_loader), (self.sample, list(range(50)))):Increment global step
114            tracker.add_global_step()Move data to device
116            image, mask = image.to(self.device), mask.to(self.device)Make the gradients zero
119            self.optimizer.zero_grad()Get predicted mask logits
121            logits = self.model(image)Crop the target mask to the size of the logits. Size of the logits will be smaller if we don't use padding in convolutional layers in the U-Net.
124            mask = torchvision.transforms.functional.center_crop(mask, [logits.shape[2], logits.shape[3]])Calculate loss
126            loss = self.loss_func(self.sigmoid(logits), mask)Compute gradients
128            loss.backward()Take an optimization step
130            self.optimizer.step()Track the loss
132            tracker.save('loss', loss)134    def run(self):138        for _ in monit.loop(self.epochs):Train the model
140            self.train()New line in the console
142            tracker.new_line()Save the model
144            experiment.save_checkpoint()147def main():Create experiment
149    experiment.create(name='unet')Create configurations
152    configs = Configs()Set configurations. You can override the defaults by passing the values in the dictionary.
155    experiment.configs(configs, {})Initialize
158    configs.init()Set models for saving and loading
161    experiment.add_pytorch_models({'model': configs.model})Start and run the training loop
164    with experiment.start():
165        configs.run()169if __name__ == '__main__':
170    main()