This is the training code for StyleGAN 2 model.

These are images generated after training for about 80K steps.
Our implementation is a minimalistic StyleGAN 2 model training code. Only single GPU training is supported to keep the implementation simple. We managed to shrink it to keep it at less than 500 lines of code, including the training loop.
Without DDP (distributed data parallel) and multi-gpu training it will not be possible to train the model for large resolutions (128+). If you want training code with fp16 and DDP take a look at lucidrains/stylegan2-pytorch.
We trained this on CelebA-HQ dataset. You can find the download instruction in this discussion on fast.ai. Save the images inside data/stylegan
 folder.
31import math
32from pathlib import Path
33from typing import Iterator, Tuple
34
35import torchvision
36from PIL import Image
37
38import torch
39import torch.utils.data
40from labml import tracker, lab, monit, experiment
41from labml.configs import BaseConfigs
42from labml_nn.gan.stylegan import Discriminator, Generator, MappingNetwork, GradientPenalty, PathLengthPenalty
43from labml_nn.gan.wasserstein import DiscriminatorLoss, GeneratorLoss
44from labml_nn.helpers.device import DeviceConfigs
45from labml_nn.helpers.trainer import ModeState
46from labml_nn.utils import cycle_dataloader49class Dataset(torch.utils.data.Dataset):path
 path to the folder containing the images image_size
 size of the image56    def __init__(self, path: str, image_size: int):61        super().__init__()Get the paths of all jpg
 files 
64        self.paths = [p for p in Path(path).glob(f'**/*.jpg')]Transformation
67        self.transform = torchvision.transforms.Compose([Resize the image
69            torchvision.transforms.Resize(image_size),Convert to PyTorch tensor
71            torchvision.transforms.ToTensor(),
72        ])Number of images
74    def __len__(self):76        return len(self.paths)Get the the index
-th image 
78    def __getitem__(self, index):80        path = self.paths[index]
81        img = Image.open(path)
82        return self.transform(img)85class Configs(BaseConfigs):Device to train the model on. DeviceConfigs
  picks up an available CUDA device or defaults to CPU. 
93    device: torch.device = DeviceConfigs()96    discriminator: Discriminator98    generator: Generator100    mapping_network: MappingNetworkDiscriminator and generator loss functions. We use Wasserstein loss
104    discriminator_loss: DiscriminatorLoss
105    generator_loss: GeneratorLossOptimizers
108    generator_optimizer: torch.optim.Adam
109    discriminator_optimizer: torch.optim.Adam
110    mapping_network_optimizer: torch.optim.Adam113    gradient_penalty = GradientPenalty()Gradient penalty coefficient
115    gradient_penalty_coefficient: float = 10.118    path_length_penalty: PathLengthPenaltyData loader
121    loader: IteratorBatch size
124    batch_size: int = 32Dimensionality of and
126    d_latent: int = 512Height/width of the image
128    image_size: int = 32Number of layers in the mapping network
130    mapping_network_layers: int = 8Generator & Discriminator learning rate
132    learning_rate: float = 1e-3Mapping network learning rate ( lower than the others)
134    mapping_network_learning_rate: float = 1e-5Number of steps to accumulate gradients on. Use this to increase the effective batch size.
136    gradient_accumulate_steps: int = 1and for Adam optimizer
138    adam_betas: Tuple[float, float] = (0.0, 0.99)Probability of mixing styles
140    style_mixing_prob: float = 0.9Total number of training steps
143    training_steps: int = 150_000Number of blocks in the generator (calculated based on image resolution)
146    n_gen_blocks: intInstead of calculating the regularization losses, the paper proposes lazy regularization where the regularization terms are calculated once in a while. This improves the training efficiency a lot.
The interval at which to compute gradient penalty
154    lazy_gradient_penalty_interval: int = 4Path length penalty calculation interval
156    lazy_path_penalty_interval: int = 32Skip calculating path length penalty during the initial phase of training
158    lazy_path_penalty_after: int = 5_000How often to log generated images
161    log_generated_interval: int = 500How often to save model checkpoints
163    save_checkpoint_interval: int = 2_000Training mode state for logging activations
166    mode: ModeState We trained this on CelebA-HQ dataset. You can find the download instruction in this discussion on fast.ai. Save the images inside data/stylegan
 folder. 
173    dataset_path: str = str(lab.get_data_path() / 'stylegan2')175    def init(self):Create dataset
180        dataset = Dataset(self.dataset_path, self.image_size)Create data loader
182        dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=8,
183                                                 shuffle=True, drop_last=True, pin_memory=True)Continuous cyclic loader
185        self.loader = cycle_dataloader(dataloader)of image resolution
188        log_resolution = int(math.log2(self.image_size))Create discriminator and generator
191        self.discriminator = Discriminator(log_resolution).to(self.device)
192        self.generator = Generator(log_resolution, self.d_latent).to(self.device)Get number of generator blocks for creating style and noise inputs
194        self.n_gen_blocks = self.generator.n_blocksCreate mapping network
196        self.mapping_network = MappingNetwork(self.d_latent, self.mapping_network_layers).to(self.device)Create path length penalty loss
198        self.path_length_penalty = PathLengthPenalty(0.99).to(self.device)Discriminator and generator losses
201        self.discriminator_loss = DiscriminatorLoss().to(self.device)
202        self.generator_loss = GeneratorLoss().to(self.device)Create optimizers
205        self.discriminator_optimizer = torch.optim.Adam(
206            self.discriminator.parameters(),
207            lr=self.learning_rate, betas=self.adam_betas
208        )
209        self.generator_optimizer = torch.optim.Adam(
210            self.generator.parameters(),
211            lr=self.learning_rate, betas=self.adam_betas
212        )
213        self.mapping_network_optimizer = torch.optim.Adam(
214            self.mapping_network.parameters(),
215            lr=self.mapping_network_learning_rate, betas=self.adam_betas
216        )Set tracker configurations
219        tracker.set_image("generated", True)This samples randomly and get from the mapping network.
We also apply style mixing sometimes where we generate two latent variables and and get corresponding and . Then we randomly sample a cross-over point and apply to the generator blocks before the cross-over point and to the blocks after.
221    def get_w(self, batch_size: int):Mix styles
235        if torch.rand(()).item() < self.style_mixing_prob:Random cross-over point
237            cross_over_point = int(torch.rand(()).item() * self.n_gen_blocks)Sample and
239            z2 = torch.randn(batch_size, self.d_latent).to(self.device)
240            z1 = torch.randn(batch_size, self.d_latent).to(self.device)Get and
242            w1 = self.mapping_network(z1)
243            w2 = self.mapping_network(z2)Expand and for the generator blocks and concatenate
245            w1 = w1[None, :, :].expand(cross_over_point, -1, -1)
246            w2 = w2[None, :, :].expand(self.n_gen_blocks - cross_over_point, -1, -1)
247            return torch.cat((w1, w2), dim=0)Without mixing
249        else:Sample and
251            z = torch.randn(batch_size, self.d_latent).to(self.device)Get and
253            w = self.mapping_network(z)Expand for the generator blocks
255            return w[None, :, :].expand(self.n_gen_blocks, -1, -1)257    def get_noise(self, batch_size: int):List to store noise
264        noise = []Noise resolution starts from
266        resolution = 4Generate noise for each generator block
269        for i in range(self.n_gen_blocks):The first block has only one convolution
271            if i == 0:
272                n1 = NoneGenerate noise to add after the first convolution layer
274            else:
275                n1 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)Generate noise to add after the second convolution layer
277            n2 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)Add noise tensors to the list
280            noise.append((n1, n2))Next block has resolution
283            resolution *= 2Return noise tensors
286        return noise288    def generate_images(self, batch_size: int):Get
296        w = self.get_w(batch_size)Get noise
298        noise = self.get_noise(batch_size)Generate images
301        images = self.generator(w, noise)Return images and
304        return images, w306    def step(self, idx: int):Train the discriminator
312        with monit.section('Discriminator'):Reset gradients
314            self.discriminator_optimizer.zero_grad()Accumulate gradients for gradient_accumulate_steps
 
317            for i in range(self.gradient_accumulate_steps):Sample images from generator
319                generated_images, _ = self.generate_images(self.batch_size)Discriminator classification for generated images
321                fake_output = self.discriminator(generated_images.detach())Get real images from the data loader
324                real_images = next(self.loader).to(self.device)We need to calculate gradients w.r.t. real images for gradient penalty
326                if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
327                    real_images.requires_grad_()Discriminator classification for real images
329                real_output = self.discriminator(real_images)Get discriminator loss
332                real_loss, fake_loss = self.discriminator_loss(real_output, fake_output)
333                disc_loss = real_loss + fake_lossAdd gradient penalty
336                if (idx + 1) % self.lazy_gradient_penalty_interval == 0:Calculate and log gradient penalty
338                    gp = self.gradient_penalty(real_images, real_output)
339                    tracker.add('loss.gp', gp)Multiply by coefficient and add gradient penalty
341                    disc_loss = disc_loss + 0.5 * self.gradient_penalty_coefficient * gp * self.lazy_gradient_penalty_intervalCompute gradients
344                disc_loss.backward()Log discriminator loss
347                tracker.add('loss.discriminator', disc_loss)
348
349            if (idx + 1) % self.log_generated_interval == 0:Log discriminator model parameters occasionally
351                tracker.add('discriminator', self.discriminator)Clip gradients for stabilization
354            torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=1.0)Take optimizer step
356            self.discriminator_optimizer.step()Train the generator
359        with monit.section('Generator'):Reset gradients
361            self.generator_optimizer.zero_grad()
362            self.mapping_network_optimizer.zero_grad()Accumulate gradients for gradient_accumulate_steps
 
365            for i in range(self.gradient_accumulate_steps):Sample images from generator
367                generated_images, w = self.generate_images(self.batch_size)Discriminator classification for generated images
369                fake_output = self.discriminator(generated_images)Get generator loss
372                gen_loss = self.generator_loss(fake_output)Add path length penalty
375                if idx > self.lazy_path_penalty_after and (idx + 1) % self.lazy_path_penalty_interval == 0:Calculate path length penalty
377                    plp = self.path_length_penalty(w, generated_images)Ignore if nan
 
379                    if not torch.isnan(plp):
380                        tracker.add('loss.plp', plp)
381                        gen_loss = gen_loss + plpCalculate gradients
384                gen_loss.backward()Log generator loss
387                tracker.add('loss.generator', gen_loss)
388
389            if (idx + 1) % self.log_generated_interval == 0:Log discriminator model parameters occasionally
391                tracker.add('generator', self.generator)
392                tracker.add('mapping_network', self.mapping_network)Clip gradients for stabilization
395            torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=1.0)
396            torch.nn.utils.clip_grad_norm_(self.mapping_network.parameters(), max_norm=1.0)Take optimizer step
399            self.generator_optimizer.step()
400            self.mapping_network_optimizer.step()Log generated images
403        if (idx + 1) % self.log_generated_interval == 0:
404            tracker.add('generated', torch.cat([generated_images[:6], real_images[:3]], dim=0))Save model checkpoints
406        if (idx + 1) % self.save_checkpoint_interval == 0:Save checkpoint
408            passFlush tracker
411        tracker.save()413    def train(self):Loop for training_steps
 
419        for i in monit.loop(self.training_steps):Take a training step
421            self.step(i)423            if (i + 1) % self.log_generated_interval == 0:
424                tracker.new_line()427def main():Create an experiment
433    experiment.create(name='stylegan2')Create configurations object
435    configs = Configs()Set configurations and override some
438    experiment.configs(configs, {
439        'device.cuda_device': 0,
440        'image_size': 64,
441        'log_generated_interval': 200
442    })Initialize
445    configs.init()Set models for saving and loading
447    experiment.add_pytorch_models(mapping_network=configs.mapping_network,
448                                  generator=configs.generator,
449                                  discriminator=configs.discriminator)Start the experiment
452    with experiment.start():Run the training loop
454        configs.train()458if __name__ == '__main__':
459    main()