This trains a U-Net model on Carvana dataset. You can find the download instructions on Kaggle.
Save the training images inside carvana/train
folder and the masks in carvana/train_masks
folder.
For simplicity, we do not do a training and validation split.
19import numpy as np
20import torchvision.transforms.functional
21
22import torch
23import torch.utils.data
24from labml import lab, tracker, experiment, monit
25from labml.configs import BaseConfigs
26from labml_nn.helpers.device import DeviceConfigs
27from labml_nn.unet import UNet
28from labml_nn.unet.carvana import CarvanaDataset
29from torch import nn
32class Configs(BaseConfigs):
Device to train the model on. DeviceConfigs
picks up an available CUDA device or defaults to CPU.
39 device: torch.device = DeviceConfigs()
Number of channels in the image. for RGB.
45 image_channels: int = 3
Number of channels in the output mask. for binary mask.
47 mask_channels: int = 1
Batch size
50 batch_size: int = 1
Learning rate
52 learning_rate: float = 2.5e-4
Number of training epochs
55 epochs: int = 4
Dataset
58 dataset: CarvanaDataset
Dataloader
60 data_loader: torch.utils.data.DataLoader
Loss function
63 loss_func = nn.BCELoss()
Sigmoid function for binary classification
65 sigmoid = nn.Sigmoid()
Adam optimizer
68 optimizer: torch.optim.Adam
70 def init(self):
Initialize the Carvana dataset
72 self.dataset = CarvanaDataset(lab.get_data_path() / 'carvana' / 'train',
73 lab.get_data_path() / 'carvana' / 'train_masks')
Initialize the model
75 self.model = UNet(self.image_channels, self.mask_channels).to(self.device)
Create dataloader
78 self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size,
79 shuffle=True, pin_memory=True)
Create optimizer
81 self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
Image logging
84 tracker.set_image("sample", True)
86 @torch.no_grad()
87 def sample(self, idx=-1):
Get a random sample
93 x, _ = self.dataset[np.random.randint(len(self.dataset))]
Move data to device
95 x = x.to(self.device)
Get predicted mask
98 mask = self.sigmoid(self.model(x[None, :]))
Crop the image to the size of the mask
100 x = torchvision.transforms.functional.center_crop(x, [mask.shape[2], mask.shape[3]])
Log samples
102 tracker.save('sample', x * mask)
104 def train(self):
112 for _, (image, mask) in monit.mix(('Train', self.data_loader), (self.sample, list(range(50)))):
Increment global step
114 tracker.add_global_step()
Move data to device
116 image, mask = image.to(self.device), mask.to(self.device)
Make the gradients zero
119 self.optimizer.zero_grad()
Get predicted mask logits
121 logits = self.model(image)
Crop the target mask to the size of the logits. Size of the logits will be smaller if we don't use padding in convolutional layers in the U-Net.
124 mask = torchvision.transforms.functional.center_crop(mask, [logits.shape[2], logits.shape[3]])
Calculate loss
126 loss = self.loss_func(self.sigmoid(logits), mask)
Compute gradients
128 loss.backward()
Take an optimization step
130 self.optimizer.step()
Track the loss
132 tracker.save('loss', loss)
134 def run(self):
138 for _ in monit.loop(self.epochs):
Train the model
140 self.train()
New line in the console
142 tracker.new_line()
Save the model
146def main():
Create experiment
148 experiment.create(name='unet')
Create configurations
151 configs = Configs()
Set configurations. You can override the defaults by passing the values in the dictionary.
154 experiment.configs(configs, {})
Initialize
157 configs.init()
Set models for saving and loading
160 experiment.add_pytorch_models({'model': configs.model})
Start and run the training loop
163 with experiment.start():
164 configs.run()
168if __name__ == '__main__':
169 main()