mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-26 08:41:23 +08:00
U-Net (#137)
This commit is contained in:
170
labml_nn/unet/experiment.py
Normal file
170
labml_nn/unet/experiment.py
Normal file
@ -0,0 +1,170 @@
|
||||
"""
|
||||
---
|
||||
title: Training a U-Net on Carvana dataset
|
||||
summary: >
|
||||
Code for training a U-Net model on Carvana dataset.
|
||||
---
|
||||
|
||||
# Training [U-Net](index.html)
|
||||
|
||||
This trains a [U-Net](index.html) model on [Carvana dataset](carvana.html).
|
||||
You can find the download instructions
|
||||
[on Kaggle](https://www.kaggle.com/competitions/carvana-image-masking-challenge/data).
|
||||
|
||||
Save the training images inside `carvana/train` folder and the masks in `carvana/train_masks` folder.
|
||||
|
||||
For simplicity, we do not do a training and validation split.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import torchvision.transforms.functional
|
||||
from torch import nn
|
||||
|
||||
from labml import lab, tracker, experiment, monit
|
||||
from labml.configs import BaseConfigs
|
||||
from labml_helpers.device import DeviceConfigs
|
||||
from labml_nn.unet.carvana import CarvanaDataset
|
||||
from labml_nn.unet import UNet
|
||||
|
||||
|
||||
class Configs(BaseConfigs):
|
||||
"""
|
||||
## Configurations
|
||||
"""
|
||||
# Device to train the model on.
|
||||
# [`DeviceConfigs`](https://docs.labml.ai/api/helpers.html#labml_helpers.device.DeviceConfigs)
|
||||
# picks up an available CUDA device or defaults to CPU.
|
||||
device: torch.device = DeviceConfigs()
|
||||
|
||||
# [U-Net](index.html) model
|
||||
model: UNet
|
||||
|
||||
# Number of channels in the image. $3$ for RGB.
|
||||
image_channels: int = 3
|
||||
# Number of channels in the output mask. $1$ for binary mask.
|
||||
mask_channels: int = 1
|
||||
|
||||
# Batch size
|
||||
batch_size: int = 1
|
||||
# Learning rate
|
||||
learning_rate: float = 2.5e-4
|
||||
|
||||
# Number of training epochs
|
||||
epochs: int = 4
|
||||
|
||||
# Dataset
|
||||
dataset: CarvanaDataset
|
||||
# Dataloader
|
||||
data_loader: torch.utils.data.DataLoader
|
||||
|
||||
# Loss function
|
||||
loss_func = nn.BCELoss()
|
||||
# Sigmoid function for binary classification
|
||||
sigmoid = nn.Sigmoid()
|
||||
|
||||
# Adam optimizer
|
||||
optimizer: torch.optim.Adam
|
||||
|
||||
def init(self):
|
||||
# Initialize the [Carvana dataset](carvana.html)
|
||||
self.dataset = CarvanaDataset(lab.get_data_path() / 'carvana' / 'train',
|
||||
lab.get_data_path() / 'carvana' / 'train_masks')
|
||||
# Initialize the model
|
||||
self.model = UNet(self.image_channels, self.mask_channels).to(self.device)
|
||||
|
||||
# Create dataloader
|
||||
self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size,
|
||||
shuffle=True, pin_memory=True)
|
||||
# Create optimizer
|
||||
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
|
||||
|
||||
# Image logging
|
||||
tracker.set_image("sample", True)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, idx=-1):
|
||||
"""
|
||||
### Sample images
|
||||
"""
|
||||
|
||||
# Get a random sample
|
||||
x, _ = self.dataset[np.random.randint(len(self.dataset))]
|
||||
# Move data to device
|
||||
x = x.to(self.device)
|
||||
|
||||
# Get predicted mask
|
||||
mask = self.sigmoid(self.model(x[None, :]))
|
||||
# Crop the image to the size of the mask
|
||||
x = torchvision.transforms.functional.center_crop(x, [mask.shape[2], mask.shape[3]])
|
||||
# Log samples
|
||||
tracker.save('sample', x * mask)
|
||||
|
||||
def train(self):
|
||||
"""
|
||||
### Train for an epoch
|
||||
"""
|
||||
|
||||
# Iterate through the dataset.
|
||||
# Use [`mix`](https://docs.labml.ai/api/monit.html#labml.monit.mix)
|
||||
# to sample $50$ times per epoch.
|
||||
for _, (image, mask) in monit.mix(('Train', self.data_loader), (self.sample, list(range(50)))):
|
||||
# Increment global step
|
||||
tracker.add_global_step()
|
||||
# Move data to device
|
||||
image, mask = image.to(self.device), mask.to(self.device)
|
||||
|
||||
# Make the gradients zero
|
||||
self.optimizer.zero_grad()
|
||||
# Get predicted mask logits
|
||||
logits = self.model(image)
|
||||
# Crop the target mask to the size of the logits. Size of the logits will be smaller if we
|
||||
# don't use padding in convolutional layers in the U-Net.
|
||||
mask = torchvision.transforms.functional.center_crop(mask, [logits.shape[2], logits.shape[3]])
|
||||
# Calculate loss
|
||||
loss = self.loss_func(self.sigmoid(logits), mask)
|
||||
# Compute gradients
|
||||
loss.backward()
|
||||
# Take an optimization step
|
||||
self.optimizer.step()
|
||||
# Track the loss
|
||||
tracker.save('loss', loss)
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
### Training loop
|
||||
"""
|
||||
for _ in monit.loop(self.epochs):
|
||||
# Train the model
|
||||
self.train()
|
||||
# New line in the console
|
||||
tracker.new_line()
|
||||
# Save the model
|
||||
experiment.save_checkpoint()
|
||||
|
||||
|
||||
def main():
|
||||
# Create experiment
|
||||
experiment.create(name='unet')
|
||||
|
||||
# Create configurations
|
||||
configs = Configs()
|
||||
|
||||
# Set configurations. You can override the defaults by passing the values in the dictionary.
|
||||
experiment.configs(configs, {})
|
||||
|
||||
# Initialize
|
||||
configs.init()
|
||||
|
||||
# Set models for saving and loading
|
||||
experiment.add_pytorch_models({'model': configs.model})
|
||||
|
||||
# Start and run the training loop
|
||||
with experiment.start():
|
||||
configs.run()
|
||||
|
||||
|
||||
#
|
||||
if __name__ == '__main__':
|
||||
main()
|
Reference in New Issue
Block a user