This commit is contained in:
Varuna Jayasiri
2022-07-25 16:55:06 +05:30
committed by GitHub
parent 5c13fcbbb7
commit 948e436854
13 changed files with 2225 additions and 6 deletions

170
labml_nn/unet/experiment.py Normal file
View File

@ -0,0 +1,170 @@
"""
---
title: Training a U-Net on Carvana dataset
summary: >
Code for training a U-Net model on Carvana dataset.
---
# Training [U-Net](index.html)
This trains a [U-Net](index.html) model on [Carvana dataset](carvana.html).
You can find the download instructions
[on Kaggle](https://www.kaggle.com/competitions/carvana-image-masking-challenge/data).
Save the training images inside `carvana/train` folder and the masks in `carvana/train_masks` folder.
For simplicity, we do not do a training and validation split.
"""
import numpy as np
import torch
import torch.utils.data
import torchvision.transforms.functional
from torch import nn
from labml import lab, tracker, experiment, monit
from labml.configs import BaseConfigs
from labml_helpers.device import DeviceConfigs
from labml_nn.unet.carvana import CarvanaDataset
from labml_nn.unet import UNet
class Configs(BaseConfigs):
"""
## Configurations
"""
# Device to train the model on.
# [`DeviceConfigs`](https://docs.labml.ai/api/helpers.html#labml_helpers.device.DeviceConfigs)
# picks up an available CUDA device or defaults to CPU.
device: torch.device = DeviceConfigs()
# [U-Net](index.html) model
model: UNet
# Number of channels in the image. $3$ for RGB.
image_channels: int = 3
# Number of channels in the output mask. $1$ for binary mask.
mask_channels: int = 1
# Batch size
batch_size: int = 1
# Learning rate
learning_rate: float = 2.5e-4
# Number of training epochs
epochs: int = 4
# Dataset
dataset: CarvanaDataset
# Dataloader
data_loader: torch.utils.data.DataLoader
# Loss function
loss_func = nn.BCELoss()
# Sigmoid function for binary classification
sigmoid = nn.Sigmoid()
# Adam optimizer
optimizer: torch.optim.Adam
def init(self):
# Initialize the [Carvana dataset](carvana.html)
self.dataset = CarvanaDataset(lab.get_data_path() / 'carvana' / 'train',
lab.get_data_path() / 'carvana' / 'train_masks')
# Initialize the model
self.model = UNet(self.image_channels, self.mask_channels).to(self.device)
# Create dataloader
self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size,
shuffle=True, pin_memory=True)
# Create optimizer
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
# Image logging
tracker.set_image("sample", True)
@torch.no_grad()
def sample(self, idx=-1):
"""
### Sample images
"""
# Get a random sample
x, _ = self.dataset[np.random.randint(len(self.dataset))]
# Move data to device
x = x.to(self.device)
# Get predicted mask
mask = self.sigmoid(self.model(x[None, :]))
# Crop the image to the size of the mask
x = torchvision.transforms.functional.center_crop(x, [mask.shape[2], mask.shape[3]])
# Log samples
tracker.save('sample', x * mask)
def train(self):
"""
### Train for an epoch
"""
# Iterate through the dataset.
# Use [`mix`](https://docs.labml.ai/api/monit.html#labml.monit.mix)
# to sample $50$ times per epoch.
for _, (image, mask) in monit.mix(('Train', self.data_loader), (self.sample, list(range(50)))):
# Increment global step
tracker.add_global_step()
# Move data to device
image, mask = image.to(self.device), mask.to(self.device)
# Make the gradients zero
self.optimizer.zero_grad()
# Get predicted mask logits
logits = self.model(image)
# Crop the target mask to the size of the logits. Size of the logits will be smaller if we
# don't use padding in convolutional layers in the U-Net.
mask = torchvision.transforms.functional.center_crop(mask, [logits.shape[2], logits.shape[3]])
# Calculate loss
loss = self.loss_func(self.sigmoid(logits), mask)
# Compute gradients
loss.backward()
# Take an optimization step
self.optimizer.step()
# Track the loss
tracker.save('loss', loss)
def run(self):
"""
### Training loop
"""
for _ in monit.loop(self.epochs):
# Train the model
self.train()
# New line in the console
tracker.new_line()
# Save the model
experiment.save_checkpoint()
def main():
# Create experiment
experiment.create(name='unet')
# Create configurations
configs = Configs()
# Set configurations. You can override the defaults by passing the values in the dictionary.
experiment.configs(configs, {})
# Initialize
configs.init()
# Set models for saving and loading
experiment.add_pytorch_models({'model': configs.model})
# Start and run the training loop
with experiment.start():
configs.run()
#
if __name__ == '__main__':
main()