This is the training code for StyleGAN 2 model.

These are images generated after training for about 80K steps.
Our implementation is a minimalistic StyleGAN 2 model training code. Only single GPU training is supported to keep the implementation simple. We managed to shrink it to keep it at less than 500 lines of code, including the training loop.
Without DDP (distributed data parallel) and multi-gpu training it will not be possible to train the model for large resolutions (128+). If you want training code with fp16 and DDP take a look at lucidrains/stylegan2-pytorch.
We trained this on CelebA-HQ dataset. You can find the download instruction in this discussion on fast.ai. Save the images inside data/stylegan
 folder.
31import math
32from pathlib import Path
33from typing import Iterator, Tuple
34
35import torch
36import torch.utils.data
37import torchvision
38from PIL import Image
39
40from labml import tracker, lab, monit, experiment
41from labml.configs import BaseConfigs
42from labml_helpers.device import DeviceConfigs
43from labml_helpers.train_valid import ModeState, hook_model_outputs
44from labml_nn.gan.stylegan import Discriminator, Generator, MappingNetwork, GradientPenalty, PathLengthPenalty
45from labml_nn.gan.wasserstein import DiscriminatorLoss, GeneratorLoss
46from labml_nn.utils import cycle_dataloader49class Dataset(torch.utils.data.Dataset):path
 path to the folder containing the images image_size
 size of the image56    def __init__(self, path: str, image_size: int):61        super().__init__()Get the paths of all jpg
 files 
64        self.paths = [p for p in Path(path).glob(f'**/*.jpg')]Transformation
67        self.transform = torchvision.transforms.Compose([Resize the image
69            torchvision.transforms.Resize(image_size),Convert to PyTorch tensor
71            torchvision.transforms.ToTensor(),
72        ])Number of images
74    def __len__(self):76        return len(self.paths)Get the the index
-th image 
78    def __getitem__(self, index):80        path = self.paths[index]
81        img = Image.open(path)
82        return self.transform(img)85class Configs(BaseConfigs):Device to train the model on. DeviceConfigs
  picks up an available CUDA device or defaults to CPU. 
93    device: torch.device = DeviceConfigs()96    discriminator: Discriminator98    generator: Generator100    mapping_network: MappingNetworkDiscriminator and generator loss functions. We use Wasserstein loss
104    discriminator_loss: DiscriminatorLoss
105    generator_loss: GeneratorLossOptimizers
108    generator_optimizer: torch.optim.Adam
109    discriminator_optimizer: torch.optim.Adam
110    mapping_network_optimizer: torch.optim.Adam113    gradient_penalty = GradientPenalty()Gradient penalty coefficient
115    gradient_penalty_coefficient: float = 10.118    path_length_penalty: PathLengthPenaltyData loader
121    loader: IteratorBatch size
124    batch_size: int = 32Dimensionality of and
126    d_latent: int = 512Height/width of the image
128    image_size: int = 32Number of layers in the mapping network
130    mapping_network_layers: int = 8Generator & Discriminator learning rate
132    learning_rate: float = 1e-3Mapping network learning rate ( lower than the others)
134    mapping_network_learning_rate: float = 1e-5Number of steps to accumulate gradients on. Use this to increase the effective batch size.
136    gradient_accumulate_steps: int = 1and for Adam optimizer
138    adam_betas: Tuple[float, float] = (0.0, 0.99)Probability of mixing styles
140    style_mixing_prob: float = 0.9Total number of training steps
143    training_steps: int = 150_000Number of blocks in the generator (calculated based on image resolution)
146    n_gen_blocks: intInstead of calculating the regularization losses, the paper proposes lazy regularization where the regularization terms are calculated once in a while. This improves the training efficiency a lot.
The interval at which to compute gradient penalty
154    lazy_gradient_penalty_interval: int = 4Path length penalty calculation interval
156    lazy_path_penalty_interval: int = 32Skip calculating path length penalty during the initial phase of training
158    lazy_path_penalty_after: int = 5_000How often to log generated images
161    log_generated_interval: int = 500How often to save model checkpoints
163    save_checkpoint_interval: int = 2_000Training mode state for logging activations
166    mode: ModeStateWhether to log model layer outputs
168    log_layer_outputs: bool = False We trained this on CelebA-HQ dataset. You can find the download instruction in this discussion on fast.ai. Save the images inside data/stylegan
 folder. 
175    dataset_path: str = str(lab.get_data_path() / 'stylegan2')177    def init(self):Create dataset
182        dataset = Dataset(self.dataset_path, self.image_size)Create data loader
184        dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=8,
185                                                 shuffle=True, drop_last=True, pin_memory=True)Continuous cyclic loader
187        self.loader = cycle_dataloader(dataloader)of image resolution
190        log_resolution = int(math.log2(self.image_size))Create discriminator and generator
193        self.discriminator = Discriminator(log_resolution).to(self.device)
194        self.generator = Generator(log_resolution, self.d_latent).to(self.device)Get number of generator blocks for creating style and noise inputs
196        self.n_gen_blocks = self.generator.n_blocksCreate mapping network
198        self.mapping_network = MappingNetwork(self.d_latent, self.mapping_network_layers).to(self.device)Create path length penalty loss
200        self.path_length_penalty = PathLengthPenalty(0.99).to(self.device)Add model hooks to monitor layer outputs
203        if self.log_layer_outputs:
204            hook_model_outputs(self.mode, self.discriminator, 'discriminator')
205            hook_model_outputs(self.mode, self.generator, 'generator')
206            hook_model_outputs(self.mode, self.mapping_network, 'mapping_network')Discriminator and generator losses
209        self.discriminator_loss = DiscriminatorLoss().to(self.device)
210        self.generator_loss = GeneratorLoss().to(self.device)Create optimizers
213        self.discriminator_optimizer = torch.optim.Adam(
214            self.discriminator.parameters(),
215            lr=self.learning_rate, betas=self.adam_betas
216        )
217        self.generator_optimizer = torch.optim.Adam(
218            self.generator.parameters(),
219            lr=self.learning_rate, betas=self.adam_betas
220        )
221        self.mapping_network_optimizer = torch.optim.Adam(
222            self.mapping_network.parameters(),
223            lr=self.mapping_network_learning_rate, betas=self.adam_betas
224        )Set tracker configurations
227        tracker.set_image("generated", True)This samples randomly and get from the mapping network.
We also apply style mixing sometimes where we generate two latent variables and and get corresponding and . Then we randomly sample a cross-over point and apply to the generator blocks before the cross-over point and to the blocks after.
229    def get_w(self, batch_size: int):Mix styles
243        if torch.rand(()).item() < self.style_mixing_prob:Random cross-over point
245            cross_over_point = int(torch.rand(()).item() * self.n_gen_blocks)Sample and
247            z2 = torch.randn(batch_size, self.d_latent).to(self.device)
248            z1 = torch.randn(batch_size, self.d_latent).to(self.device)Get and
250            w1 = self.mapping_network(z1)
251            w2 = self.mapping_network(z2)Expand and for the generator blocks and concatenate
253            w1 = w1[None, :, :].expand(cross_over_point, -1, -1)
254            w2 = w2[None, :, :].expand(self.n_gen_blocks - cross_over_point, -1, -1)
255            return torch.cat((w1, w2), dim=0)Without mixing
257        else:Sample and
259            z = torch.randn(batch_size, self.d_latent).to(self.device)Get and
261            w = self.mapping_network(z)Expand for the generator blocks
263            return w[None, :, :].expand(self.n_gen_blocks, -1, -1)265    def get_noise(self, batch_size: int):List to store noise
272        noise = []Noise resolution starts from
274        resolution = 4Generate noise for each generator block
277        for i in range(self.n_gen_blocks):The first block has only one convolution
279            if i == 0:
280                n1 = NoneGenerate noise to add after the first convolution layer
282            else:
283                n1 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)Generate noise to add after the second convolution layer
285            n2 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)Add noise tensors to the list
288            noise.append((n1, n2))Next block has resolution
291            resolution *= 2Return noise tensors
294        return noise296    def generate_images(self, batch_size: int):Get
304        w = self.get_w(batch_size)Get noise
306        noise = self.get_noise(batch_size)Generate images
309        images = self.generator(w, noise)Return images and
312        return images, w314    def step(self, idx: int):Train the discriminator
320        with monit.section('Discriminator'):Reset gradients
322            self.discriminator_optimizer.zero_grad()Accumulate gradients for gradient_accumulate_steps
 
325            for i in range(self.gradient_accumulate_steps):Update mode
. Set whether to log activation 
327                with self.mode.update(is_log_activations=(idx + 1) % self.log_generated_interval == 0):Sample images from generator
329                    generated_images, _ = self.generate_images(self.batch_size)Discriminator classification for generated images
331                    fake_output = self.discriminator(generated_images.detach())Get real images from the data loader
334                    real_images = next(self.loader).to(self.device)We need to calculate gradients w.r.t. real images for gradient penalty
336                    if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
337                        real_images.requires_grad_()Discriminator classification for real images
339                    real_output = self.discriminator(real_images)Get discriminator loss
342                    real_loss, fake_loss = self.discriminator_loss(real_output, fake_output)
343                    disc_loss = real_loss + fake_lossAdd gradient penalty
346                    if (idx + 1) % self.lazy_gradient_penalty_interval == 0:Calculate and log gradient penalty
348                        gp = self.gradient_penalty(real_images, real_output)
349                        tracker.add('loss.gp', gp)Multiply by coefficient and add gradient penalty
351                        disc_loss = disc_loss + 0.5 * self.gradient_penalty_coefficient * gp * self.lazy_gradient_penalty_intervalCompute gradients
354                    disc_loss.backward()Log discriminator loss
357                    tracker.add('loss.discriminator', disc_loss)
358
359            if (idx + 1) % self.log_generated_interval == 0:Log discriminator model parameters occasionally
361                tracker.add('discriminator', self.discriminator)Clip gradients for stabilization
364            torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=1.0)Take optimizer step
366            self.discriminator_optimizer.step()Train the generator
369        with monit.section('Generator'):Reset gradients
371            self.generator_optimizer.zero_grad()
372            self.mapping_network_optimizer.zero_grad()Accumulate gradients for gradient_accumulate_steps
 
375            for i in range(self.gradient_accumulate_steps):Sample images from generator
377                generated_images, w = self.generate_images(self.batch_size)Discriminator classification for generated images
379                fake_output = self.discriminator(generated_images)Get generator loss
382                gen_loss = self.generator_loss(fake_output)Add path length penalty
385                if idx > self.lazy_path_penalty_after and (idx + 1) % self.lazy_path_penalty_interval == 0:Calculate path length penalty
387                    plp = self.path_length_penalty(w, generated_images)Ignore if nan
 
389                    if not torch.isnan(plp):
390                        tracker.add('loss.plp', plp)
391                        gen_loss = gen_loss + plpCalculate gradients
394                gen_loss.backward()Log generator loss
397                tracker.add('loss.generator', gen_loss)
398
399            if (idx + 1) % self.log_generated_interval == 0:Log discriminator model parameters occasionally
401                tracker.add('generator', self.generator)
402                tracker.add('mapping_network', self.mapping_network)Clip gradients for stabilization
405            torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=1.0)
406            torch.nn.utils.clip_grad_norm_(self.mapping_network.parameters(), max_norm=1.0)Take optimizer step
409            self.generator_optimizer.step()
410            self.mapping_network_optimizer.step()Log generated images
413        if (idx + 1) % self.log_generated_interval == 0:
414            tracker.add('generated', torch.cat([generated_images[:6], real_images[:3]], dim=0))Save model checkpoints
416        if (idx + 1) % self.save_checkpoint_interval == 0:
417            experiment.save_checkpoint()Flush tracker
420        tracker.save()422    def train(self):Loop for training_steps
 
428        for i in monit.loop(self.training_steps):Take a training step
430            self.step(i)432            if (i + 1) % self.log_generated_interval == 0:
433                tracker.new_line()436def main():Create an experiment
442    experiment.create(name='stylegan2')Create configurations object
444    configs = Configs()Set configurations and override some
447    experiment.configs(configs, {
448        'device.cuda_device': 0,
449        'image_size': 64,
450        'log_generated_interval': 200
451    })Initialize
454    configs.init()Set models for saving and loading
456    experiment.add_pytorch_models(mapping_network=configs.mapping_network,
457                                  generator=configs.generator,
458                                  discriminator=configs.discriminator)Start the experiment
461    with experiment.start():Run the training loop
463        configs.train()467if __name__ == '__main__':
468    main()