mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-30 02:08:50 +08:00
460 lines
16 KiB
Python
460 lines
16 KiB
Python
"""
|
|
---
|
|
title: StyleGAN 2 Model Training
|
|
summary: >
|
|
An annotated PyTorch implementation of StyleGAN2 model training code.
|
|
---
|
|
|
|
# [StyleGAN 2](index.html) Model Training
|
|
|
|
This is the training code for [StyleGAN 2](index.html) model.
|
|
|
|

|
|
|
|
---*These are $64 \times 64$ images generated after training for about 80K steps.*---
|
|
|
|
*Our implementation is a minimalistic StyleGAN 2 model training code.
|
|
Only single GPU training is supported to keep the implementation simple.
|
|
We managed to shrink it to keep it at less than 500 lines of code, including the training loop.*
|
|
|
|
*Without DDP (distributed data parallel) and multi-gpu training it will not be possible to train the model
|
|
for large resolutions (128+).
|
|
If you want training code with fp16 and DDP take a look at
|
|
[lucidrains/stylegan2-pytorch](https://github.com/lucidrains/stylegan2-pytorch).*
|
|
|
|
We trained this on [CelebA-HQ dataset](https://github.com/tkarras/progressive_growing_of_gans).
|
|
You can find the download instruction in this
|
|
[discussion on fast.ai](https://forums.fast.ai/t/download-celeba-hq-dataset/45873/3).
|
|
Save the images inside [`data/stylegan` folder](#dataset_path).
|
|
"""
|
|
|
|
import math
|
|
from pathlib import Path
|
|
from typing import Iterator, Tuple
|
|
|
|
import torchvision
|
|
from PIL import Image
|
|
|
|
import torch
|
|
import torch.utils.data
|
|
from labml import tracker, lab, monit, experiment
|
|
from labml.configs import BaseConfigs
|
|
from labml_nn.gan.stylegan import Discriminator, Generator, MappingNetwork, GradientPenalty, PathLengthPenalty
|
|
from labml_nn.gan.wasserstein import DiscriminatorLoss, GeneratorLoss
|
|
from labml_nn.helpers.device import DeviceConfigs
|
|
from labml_nn.helpers.trainer import ModeState
|
|
from labml_nn.utils import cycle_dataloader
|
|
|
|
|
|
class Dataset(torch.utils.data.Dataset):
|
|
"""
|
|
## Dataset
|
|
|
|
This loads the training dataset and resize it to the give image size.
|
|
"""
|
|
|
|
def __init__(self, path: str, image_size: int):
|
|
"""
|
|
* `path` path to the folder containing the images
|
|
* `image_size` size of the image
|
|
"""
|
|
super().__init__()
|
|
|
|
# Get the paths of all `jpg` files
|
|
self.paths = [p for p in Path(path).glob(f'**/*.jpg')]
|
|
|
|
# Transformation
|
|
self.transform = torchvision.transforms.Compose([
|
|
# Resize the image
|
|
torchvision.transforms.Resize(image_size),
|
|
# Convert to PyTorch tensor
|
|
torchvision.transforms.ToTensor(),
|
|
])
|
|
|
|
def __len__(self):
|
|
"""Number of images"""
|
|
return len(self.paths)
|
|
|
|
def __getitem__(self, index):
|
|
"""Get the the `index`-th image"""
|
|
path = self.paths[index]
|
|
img = Image.open(path)
|
|
return self.transform(img)
|
|
|
|
|
|
class Configs(BaseConfigs):
|
|
"""
|
|
## Configurations
|
|
"""
|
|
|
|
# Device to train the model on.
|
|
# [`DeviceConfigs`](../../helpers/device.html)
|
|
# picks up an available CUDA device or defaults to CPU.
|
|
device: torch.device = DeviceConfigs()
|
|
|
|
# [StyleGAN2 Discriminator](index.html#discriminator)
|
|
discriminator: Discriminator
|
|
# [StyleGAN2 Generator](index.html#generator)
|
|
generator: Generator
|
|
# [Mapping network](index.html#mapping_network)
|
|
mapping_network: MappingNetwork
|
|
|
|
# Discriminator and generator loss functions.
|
|
# We use [Wasserstein loss](../wasserstein/index.html)
|
|
discriminator_loss: DiscriminatorLoss
|
|
generator_loss: GeneratorLoss
|
|
|
|
# Optimizers
|
|
generator_optimizer: torch.optim.Adam
|
|
discriminator_optimizer: torch.optim.Adam
|
|
mapping_network_optimizer: torch.optim.Adam
|
|
|
|
# [Gradient Penalty Regularization Loss](index.html#gradient_penalty)
|
|
gradient_penalty = GradientPenalty()
|
|
# Gradient penalty coefficient $\gamma$
|
|
gradient_penalty_coefficient: float = 10.
|
|
|
|
# [Path length penalty](index.html#path_length_penalty)
|
|
path_length_penalty: PathLengthPenalty
|
|
|
|
# Data loader
|
|
loader: Iterator
|
|
|
|
# Batch size
|
|
batch_size: int = 32
|
|
# Dimensionality of $z$ and $w$
|
|
d_latent: int = 512
|
|
# Height/width of the image
|
|
image_size: int = 32
|
|
# Number of layers in the mapping network
|
|
mapping_network_layers: int = 8
|
|
# Generator & Discriminator learning rate
|
|
learning_rate: float = 1e-3
|
|
# Mapping network learning rate ($100 \times$ lower than the others)
|
|
mapping_network_learning_rate: float = 1e-5
|
|
# Number of steps to accumulate gradients on. Use this to increase the effective batch size.
|
|
gradient_accumulate_steps: int = 1
|
|
# $\beta_1$ and $\beta_2$ for Adam optimizer
|
|
adam_betas: Tuple[float, float] = (0.0, 0.99)
|
|
# Probability of mixing styles
|
|
style_mixing_prob: float = 0.9
|
|
|
|
# Total number of training steps
|
|
training_steps: int = 150_000
|
|
|
|
# Number of blocks in the generator (calculated based on image resolution)
|
|
n_gen_blocks: int
|
|
|
|
# ### Lazy regularization
|
|
# Instead of calculating the regularization losses, the paper proposes lazy regularization
|
|
# where the regularization terms are calculated once in a while.
|
|
# This improves the training efficiency a lot.
|
|
|
|
# The interval at which to compute gradient penalty
|
|
lazy_gradient_penalty_interval: int = 4
|
|
# Path length penalty calculation interval
|
|
lazy_path_penalty_interval: int = 32
|
|
# Skip calculating path length penalty during the initial phase of training
|
|
lazy_path_penalty_after: int = 5_000
|
|
|
|
# How often to log generated images
|
|
log_generated_interval: int = 500
|
|
# How often to save model checkpoints
|
|
save_checkpoint_interval: int = 2_000
|
|
|
|
# Training mode state for logging activations
|
|
mode: ModeState
|
|
|
|
# <a id="dataset_path"></a>
|
|
# We trained this on [CelebA-HQ dataset](https://github.com/tkarras/progressive_growing_of_gans).
|
|
# You can find the download instruction in this
|
|
# [discussion on fast.ai](https://forums.fast.ai/t/download-celeba-hq-dataset/45873/3).
|
|
# Save the images inside `data/stylegan` folder.
|
|
dataset_path: str = str(lab.get_data_path() / 'stylegan2')
|
|
|
|
def init(self):
|
|
"""
|
|
### Initialize
|
|
"""
|
|
# Create dataset
|
|
dataset = Dataset(self.dataset_path, self.image_size)
|
|
# Create data loader
|
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=8,
|
|
shuffle=True, drop_last=True, pin_memory=True)
|
|
# Continuous [cyclic loader](../../utils.html#cycle_dataloader)
|
|
self.loader = cycle_dataloader(dataloader)
|
|
|
|
# $\log_2$ of image resolution
|
|
log_resolution = int(math.log2(self.image_size))
|
|
|
|
# Create discriminator and generator
|
|
self.discriminator = Discriminator(log_resolution).to(self.device)
|
|
self.generator = Generator(log_resolution, self.d_latent).to(self.device)
|
|
# Get number of generator blocks for creating style and noise inputs
|
|
self.n_gen_blocks = self.generator.n_blocks
|
|
# Create mapping network
|
|
self.mapping_network = MappingNetwork(self.d_latent, self.mapping_network_layers).to(self.device)
|
|
# Create path length penalty loss
|
|
self.path_length_penalty = PathLengthPenalty(0.99).to(self.device)
|
|
|
|
# Discriminator and generator losses
|
|
self.discriminator_loss = DiscriminatorLoss().to(self.device)
|
|
self.generator_loss = GeneratorLoss().to(self.device)
|
|
|
|
# Create optimizers
|
|
self.discriminator_optimizer = torch.optim.Adam(
|
|
self.discriminator.parameters(),
|
|
lr=self.learning_rate, betas=self.adam_betas
|
|
)
|
|
self.generator_optimizer = torch.optim.Adam(
|
|
self.generator.parameters(),
|
|
lr=self.learning_rate, betas=self.adam_betas
|
|
)
|
|
self.mapping_network_optimizer = torch.optim.Adam(
|
|
self.mapping_network.parameters(),
|
|
lr=self.mapping_network_learning_rate, betas=self.adam_betas
|
|
)
|
|
|
|
# Set tracker configurations
|
|
tracker.set_image("generated", True)
|
|
|
|
def get_w(self, batch_size: int):
|
|
"""
|
|
### Sample $w$
|
|
|
|
This samples $z$ randomly and get $w$ from the mapping network.
|
|
|
|
We also apply style mixing sometimes where we generate two latent variables
|
|
$z_1$ and $z_2$ and get corresponding $w_1$ and $w_2$.
|
|
Then we randomly sample a cross-over point and apply $w_1$ to
|
|
the generator blocks before the cross-over point and
|
|
$w_2$ to the blocks after.
|
|
"""
|
|
|
|
# Mix styles
|
|
if torch.rand(()).item() < self.style_mixing_prob:
|
|
# Random cross-over point
|
|
cross_over_point = int(torch.rand(()).item() * self.n_gen_blocks)
|
|
# Sample $z_1$ and $z_2$
|
|
z2 = torch.randn(batch_size, self.d_latent).to(self.device)
|
|
z1 = torch.randn(batch_size, self.d_latent).to(self.device)
|
|
# Get $w_1$ and $w_2$
|
|
w1 = self.mapping_network(z1)
|
|
w2 = self.mapping_network(z2)
|
|
# Expand $w_1$ and $w_2$ for the generator blocks and concatenate
|
|
w1 = w1[None, :, :].expand(cross_over_point, -1, -1)
|
|
w2 = w2[None, :, :].expand(self.n_gen_blocks - cross_over_point, -1, -1)
|
|
return torch.cat((w1, w2), dim=0)
|
|
# Without mixing
|
|
else:
|
|
# Sample $z$ and $z$
|
|
z = torch.randn(batch_size, self.d_latent).to(self.device)
|
|
# Get $w$ and $w$
|
|
w = self.mapping_network(z)
|
|
# Expand $w$ for the generator blocks
|
|
return w[None, :, :].expand(self.n_gen_blocks, -1, -1)
|
|
|
|
def get_noise(self, batch_size: int):
|
|
"""
|
|
### Generate noise
|
|
|
|
This generates noise for each [generator block](index.html#generator_block)
|
|
"""
|
|
# List to store noise
|
|
noise = []
|
|
# Noise resolution starts from $4$
|
|
resolution = 4
|
|
|
|
# Generate noise for each generator block
|
|
for i in range(self.n_gen_blocks):
|
|
# The first block has only one $3 \times 3$ convolution
|
|
if i == 0:
|
|
n1 = None
|
|
# Generate noise to add after the first convolution layer
|
|
else:
|
|
n1 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)
|
|
# Generate noise to add after the second convolution layer
|
|
n2 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)
|
|
|
|
# Add noise tensors to the list
|
|
noise.append((n1, n2))
|
|
|
|
# Next block has $2 \times$ resolution
|
|
resolution *= 2
|
|
|
|
# Return noise tensors
|
|
return noise
|
|
|
|
def generate_images(self, batch_size: int):
|
|
"""
|
|
### Generate images
|
|
|
|
This generate images using the generator
|
|
"""
|
|
|
|
# Get $w$
|
|
w = self.get_w(batch_size)
|
|
# Get noise
|
|
noise = self.get_noise(batch_size)
|
|
|
|
# Generate images
|
|
images = self.generator(w, noise)
|
|
|
|
# Return images and $w$
|
|
return images, w
|
|
|
|
def step(self, idx: int):
|
|
"""
|
|
### Training Step
|
|
"""
|
|
|
|
# Train the discriminator
|
|
with monit.section('Discriminator'):
|
|
# Reset gradients
|
|
self.discriminator_optimizer.zero_grad()
|
|
|
|
# Accumulate gradients for `gradient_accumulate_steps`
|
|
for i in range(self.gradient_accumulate_steps):
|
|
# Sample images from generator
|
|
generated_images, _ = self.generate_images(self.batch_size)
|
|
# Discriminator classification for generated images
|
|
fake_output = self.discriminator(generated_images.detach())
|
|
|
|
# Get real images from the data loader
|
|
real_images = next(self.loader).to(self.device)
|
|
# We need to calculate gradients w.r.t. real images for gradient penalty
|
|
if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
|
|
real_images.requires_grad_()
|
|
# Discriminator classification for real images
|
|
real_output = self.discriminator(real_images)
|
|
|
|
# Get discriminator loss
|
|
real_loss, fake_loss = self.discriminator_loss(real_output, fake_output)
|
|
disc_loss = real_loss + fake_loss
|
|
|
|
# Add gradient penalty
|
|
if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
|
|
# Calculate and log gradient penalty
|
|
gp = self.gradient_penalty(real_images, real_output)
|
|
tracker.add('loss.gp', gp)
|
|
# Multiply by coefficient and add gradient penalty
|
|
disc_loss = disc_loss + 0.5 * self.gradient_penalty_coefficient * gp * self.lazy_gradient_penalty_interval
|
|
|
|
# Compute gradients
|
|
disc_loss.backward()
|
|
|
|
# Log discriminator loss
|
|
tracker.add('loss.discriminator', disc_loss)
|
|
|
|
if (idx + 1) % self.log_generated_interval == 0:
|
|
# Log discriminator model parameters occasionally
|
|
tracker.add('discriminator', self.discriminator)
|
|
|
|
# Clip gradients for stabilization
|
|
torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=1.0)
|
|
# Take optimizer step
|
|
self.discriminator_optimizer.step()
|
|
|
|
# Train the generator
|
|
with monit.section('Generator'):
|
|
# Reset gradients
|
|
self.generator_optimizer.zero_grad()
|
|
self.mapping_network_optimizer.zero_grad()
|
|
|
|
# Accumulate gradients for `gradient_accumulate_steps`
|
|
for i in range(self.gradient_accumulate_steps):
|
|
# Sample images from generator
|
|
generated_images, w = self.generate_images(self.batch_size)
|
|
# Discriminator classification for generated images
|
|
fake_output = self.discriminator(generated_images)
|
|
|
|
# Get generator loss
|
|
gen_loss = self.generator_loss(fake_output)
|
|
|
|
# Add path length penalty
|
|
if idx > self.lazy_path_penalty_after and (idx + 1) % self.lazy_path_penalty_interval == 0:
|
|
# Calculate path length penalty
|
|
plp = self.path_length_penalty(w, generated_images)
|
|
# Ignore if `nan`
|
|
if not torch.isnan(plp):
|
|
tracker.add('loss.plp', plp)
|
|
gen_loss = gen_loss + plp
|
|
|
|
# Calculate gradients
|
|
gen_loss.backward()
|
|
|
|
# Log generator loss
|
|
tracker.add('loss.generator', gen_loss)
|
|
|
|
if (idx + 1) % self.log_generated_interval == 0:
|
|
# Log discriminator model parameters occasionally
|
|
tracker.add('generator', self.generator)
|
|
tracker.add('mapping_network', self.mapping_network)
|
|
|
|
# Clip gradients for stabilization
|
|
torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=1.0)
|
|
torch.nn.utils.clip_grad_norm_(self.mapping_network.parameters(), max_norm=1.0)
|
|
|
|
# Take optimizer step
|
|
self.generator_optimizer.step()
|
|
self.mapping_network_optimizer.step()
|
|
|
|
# Log generated images
|
|
if (idx + 1) % self.log_generated_interval == 0:
|
|
tracker.add('generated', torch.cat([generated_images[:6], real_images[:3]], dim=0))
|
|
# Save model checkpoints
|
|
if (idx + 1) % self.save_checkpoint_interval == 0:
|
|
# Save checkpoint
|
|
pass
|
|
|
|
# Flush tracker
|
|
tracker.save()
|
|
|
|
def train(self):
|
|
"""
|
|
## Train model
|
|
"""
|
|
|
|
# Loop for `training_steps`
|
|
for i in monit.loop(self.training_steps):
|
|
# Take a training step
|
|
self.step(i)
|
|
#
|
|
if (i + 1) % self.log_generated_interval == 0:
|
|
tracker.new_line()
|
|
|
|
|
|
def main():
|
|
"""
|
|
### Train StyleGAN2
|
|
"""
|
|
|
|
# Create an experiment
|
|
experiment.create(name='stylegan2')
|
|
# Create configurations object
|
|
configs = Configs()
|
|
|
|
# Set configurations and override some
|
|
experiment.configs(configs, {
|
|
'device.cuda_device': 0,
|
|
'image_size': 64,
|
|
'log_generated_interval': 200
|
|
})
|
|
|
|
# Initialize
|
|
configs.init()
|
|
# Set models for saving and loading
|
|
experiment.add_pytorch_models(mapping_network=configs.mapping_network,
|
|
generator=configs.generator,
|
|
discriminator=configs.discriminator)
|
|
|
|
# Start the experiment
|
|
with experiment.start():
|
|
# Run the training loop
|
|
configs.train()
|
|
|
|
|
|
#
|
|
if __name__ == '__main__':
|
|
main()
|