This is a PyTorch implementation of the paper Analyzing and Improving the Image Quality of StyleGAN which introduces StyleGAN 2. StyleGAN 2 is an improvement over StyleGAN from the paper A Style-Based Generator Architecture for Generative Adversarial Networks. And StyleGAN is based on Progressive GAN from the paper Progressive Growing of GANs for Improved Quality, Stability, and Variation. All three papers are from the same authors from NVIDIA AI.
Our implementation is a minimalistic StyleGAN 2 model training code. Only single GPU training is supported to keep the implementation simple. We managed to shrink it to keep it at less than 500 lines of code, including the training loop.
🏃 Here's the training code: experiment.py
.

These are images generated after training for about 80K steps.
We'll first introduce the three papers at a high level.
Generative adversarial networks have two components; the generator and the discriminator. The generator network takes a random latent vector () and tries to generate a realistic image. The discriminator network tries to differentiate the real images from generated images. When we train the two networks together the generator starts generating images indistinguishable from real images.
Progressive GAN generates high-resolution images () of size. It does so by progressively increasing the image size. First, it trains a network that produces a image, then , then an image, and so on up to the desired image resolution.
At each resolution, the generator network produces an image in latent space which is converted into RGB, with a convolution. When we progress from a lower resolution to a higher resolution (say from to ) we scale the latent image by and add a new block (two convolution layers) and a new layer to get RGB. The transition is done smoothly by adding a residual connection to the scaled RGB image. The weight of this residual connection is slowly reduced, to let the new block take over.
The discriminator is a mirror image of the generator network. The progressive growth of the discriminator is done similarly.
and denote feature map resolution scaling and scaling. , , ... denote feature map resolution at the generator or discriminator block. Each discriminator and generator block consists of 2 convolution layers with leaky ReLU activations.
They use minibatch standard deviation to increase variation and equalized learning rate which we discussed below in the implementation. They also use pixel-wise normalization where at each pixel the feature vector is normalized. They apply this to all the convolution layer outputs (except RGB).
StyleGAN improves the generator of Progressive GAN keeping the discriminator architecture the same.
It maps the random latent vector () into a different latent space (), with an 8-layer neural network. This gives an intermediate latent space where the factors of variations are more linear (disentangled).
Then is transformed into two vectors (styles) per layer, , and used for scaling and shifting (biasing) in each layer with operator (normalize and scale):
To prevent the generator from assuming adjacent styles are correlated, they randomly use different styles for different blocks. That is, they sample two latent vectors and corresponding and use based styles for some blocks and based styles for some blacks randomly.
Noise is made available to each block which helps the generator create more realistic images. Noise is scaled per channel by a learned weight.
All the up and down-sampling operations are accompanied by bilinear smoothing.
denotes a linear layer. denotes a broadcast and scaling operation (noise is a single channel). StyleGAN also uses progressive growing like Progressive GAN.
StyleGAN 2 changes both the generator and the discriminator of StyleGAN.
They remove the operator and replace it with the weight modulation and demodulation step. This is supposed to improve what they call droplet artifacts that are present in generated images, which are caused by the normalization in operator. Style vector per layer is calculated from as .
Then the convolution weights are modulated as follows. ( here on refers to weights not intermediate latent space, we are sticking to the same notation as the paper.)
Then it's demodulated by normalizing, where is the input channel, is the output channel, and is the kernel index.
Path length regularization encourages a fixed-size step in to result in a non-zero, fixed-magnitude change in the generated image.
StyleGAN2 uses residual connections (with down-sampling) in the discriminator and skip connections in the generator with up-sampling (the RGB outputs from each layer are added - no residual connections in feature maps). They show that with experiments that the contribution of low-resolution layers is higher at beginning of the training and then high-resolution layers take over.
148import math
149from typing import Tuple, Optional, List
150
151import numpy as np
152import torch
153import torch.nn.functional as F
154import torch.utils.data
155from torch import nnThis is an MLP with 8 linear layers. The mapping network maps the latent vector to an intermediate latent space . space will be disentangled from the image space where the factors of variation become more linear.
158class MappingNetwork(nn.Module):features
 is the number of features in  and  n_layers
 is the number of layers in the mapping network.173    def __init__(self, features: int, n_layers: int):178        super().__init__()Create the MLP
181        layers = []
182        for i in range(n_layers):184            layers.append(EqualizedLinear(features, features))Leaky Relu
186            layers.append(nn.LeakyReLU(negative_slope=0.2, inplace=True))
187
188        self.net = nn.Sequential(*layers)190    def forward(self, z: torch.Tensor):Normalize
192        z = F.normalize(z, dim=1)Map to
194        return self.net(z) denotes a linear layer.  denotes a broadcast and scaling operation (noise is a single channel). toRGB
 also has a style modulation which is not shown in the diagram to keep it simple.
The generator starts with a learned constant. Then it has a series of blocks. The feature map resolution is doubled at each block Each block outputs an RGB image and they are scaled up and summed to get the final RGB image.
197class Generator(nn.Module):log_resolution
 is the  of image resolution d_latent
 is the dimensionality of  n_features
 number of features in the convolution layer at the highest resolution (final block) max_features
 maximum number of features in any generator block214    def __init__(self, log_resolution: int, d_latent: int, n_features: int = 32, max_features: int = 512):221        super().__init__()226        features = [min(max_features, n_features * (2 ** i)) for i in range(log_resolution - 2, -1, -1)]Number of generator blocks
228        self.n_blocks = len(features)Trainable constant
231        self.initial_constant = nn.Parameter(torch.randn((1, features[0], 4, 4)))First style block for resolution and layer to get RGB
234        self.style_block = StyleBlock(d_latent, features[0], features[0])
235        self.to_rgb = ToRGB(d_latent, features[0])Generator blocks
238        blocks = [GeneratorBlock(d_latent, features[i - 1], features[i]) for i in range(1, self.n_blocks)]
239        self.blocks = nn.ModuleList(blocks)up sampling layer. The feature space is up sampled at each block
243        self.up_sample = UpSample()w
 is . In order to mix-styles (use different  for different layers), we provide a separate  for each generator block. It has shape [n_blocks, batch_size, d_latent]
. input_noise
 is the noise for each block. It's a list of pairs of noise sensors because each block (except the initial) has two noise inputs after each convolution layer (see the diagram).245    def forward(self, w: torch.Tensor, input_noise: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]]):Get batch size
255        batch_size = w.shape[1]Expand the learned constant to match batch size
258        x = self.initial_constant.expand(batch_size, -1, -1, -1)The first style block
261        x = self.style_block(x, w[0], input_noise[0][1])Get first rgb image
263        rgb = self.to_rgb(x, w[0])Evaluate rest of the blocks
266        for i in range(1, self.n_blocks):Up sample the feature map
268            x = self.up_sample(x)Run it through the generator block
270            x, rgb_new = self.blocks[i - 1](x, w[i], input_noise[i])Up sample the RGB image and add to the rgb from the block
272            rgb = self.up_sample(rgb) + rgb_newReturn the final RGB image
275        return rgb denotes a linear layer.  denotes a broadcast and scaling operation (noise is a single channel). toRGB
 also has a style modulation which is not shown in the diagram to keep it simple.
The generator block consists of two style blocks ( convolutions with style modulation) and an RGB output.
278class GeneratorBlock(nn.Module):d_latent
 is the dimensionality of  in_features
 is the number of features in the input feature map out_features
 is the number of features in the output feature map294    def __init__(self, d_latent: int, in_features: int, out_features: int):300        super().__init__()First style block changes the feature map size to out_features
 
303        self.style_block1 = StyleBlock(d_latent, in_features, out_features)Second style block
305        self.style_block2 = StyleBlock(d_latent, out_features, out_features)toRGB layer
308        self.to_rgb = ToRGB(d_latent, out_features)x
 is the input feature map of shape [batch_size, in_features, height, width]
 w
 is  with shape [batch_size, d_latent]
 noise
 is a tuple of two noise tensors of shape [batch_size, 1, height, width]
310    def forward(self, x: torch.Tensor, w: torch.Tensor, noise: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]):First style block with first noise tensor. The output is of shape [batch_size, out_features, height, width]
 
318        x = self.style_block1(x, w, noise[0])Second style block with second noise tensor. The output is of shape [batch_size, out_features, height, width]
 
321        x = self.style_block2(x, w, noise[1])Get RGB image
324        rgb = self.to_rgb(x, w)Return feature map and rgb image
327        return x, rgbdenotes a linear layer. denotes a broadcast and scaling operation (noise is single channel).
Style block has a weight modulation convolution layer.
330class StyleBlock(nn.Module):d_latent
 is the dimensionality of  in_features
 is the number of features in the input feature map out_features
 is the number of features in the output feature map344    def __init__(self, d_latent: int, in_features: int, out_features: int):350        super().__init__()Get style vector from (denoted by in the diagram) with an equalized learning-rate linear layer
353        self.to_style = EqualizedLinear(d_latent, in_features, bias=1.0)Weight modulated convolution layer
355        self.conv = Conv2dWeightModulate(in_features, out_features, kernel_size=3)Noise scale
357        self.scale_noise = nn.Parameter(torch.zeros(1))Bias
359        self.bias = nn.Parameter(torch.zeros(out_features))Activation function
362        self.activation = nn.LeakyReLU(0.2, True)x
 is the input feature map of shape [batch_size, in_features, height, width]
 w
 is  with shape [batch_size, d_latent]
 noise
 is a tensor of shape [batch_size, 1, height, width]
364    def forward(self, x: torch.Tensor, w: torch.Tensor, noise: Optional[torch.Tensor]):Get style vector
371        s = self.to_style(w)Weight modulated convolution
373        x = self.conv(x, s)Scale and add noise
375        if noise is not None:
376            x = x + self.scale_noise[None, :, None, None] * noiseAdd bias and evaluate activation function
378        return self.activation(x + self.bias[None, :, None, None])381class ToRGB(nn.Module):d_latent
 is the dimensionality of  features
 is the number of features in the feature map394    def __init__(self, d_latent: int, features: int):399        super().__init__()Get style vector from (denoted by in the diagram) with an equalized learning-rate linear layer
402        self.to_style = EqualizedLinear(d_latent, features, bias=1.0)Weight modulated convolution layer without demodulation
405        self.conv = Conv2dWeightModulate(features, 3, kernel_size=1, demodulate=False)Bias
407        self.bias = nn.Parameter(torch.zeros(3))Activation function
409        self.activation = nn.LeakyReLU(0.2, True)x
 is the input feature map of shape [batch_size, in_features, height, width]
 w
 is  with shape [batch_size, d_latent]
411    def forward(self, x: torch.Tensor, w: torch.Tensor):Get style vector
417        style = self.to_style(w)Weight modulated convolution
419        x = self.conv(x, style)Add bias and evaluate activation function
421        return self.activation(x + self.bias[None, :, None, None])This layer scales the convolution weights by the style vector and demodulates by normalizing it.
424class Conv2dWeightModulate(nn.Module):in_features
 is the number of features in the input feature map out_features
 is the number of features in the output feature map kernel_size
 is the size of the convolution kernel demodulate
 is flag whether to normalize weights by its standard deviation eps
 is the  for normalizing431    def __init__(self, in_features: int, out_features: int, kernel_size: int,
432                 demodulate: float = True, eps: float = 1e-8):440        super().__init__()Number of output features
442        self.out_features = out_featuresWhether to normalize weights
444        self.demodulate = demodulatePadding size
446        self.padding = (kernel_size - 1) // 2449        self.weight = EqualizedWeight([out_features, in_features, kernel_size, kernel_size])451        self.eps = epsx
 is the input feature map of shape [batch_size, in_features, height, width]
 s
 is style based scaling tensor of shape [batch_size, in_features]
453    def forward(self, x: torch.Tensor, s: torch.Tensor):Get batch size, height and width
460        b, _, h, w = x.shapeReshape the scales
463        s = s[:, None, :, None, None]465        weights = self.weight()[None, :, :, :, :]where is the input channel, is the output channel, and is the kernel index.
The result has shape [batch_size, out_features, in_features, kernel_size, kernel_size]
 
470        weights = weights * sDemodulate
473        if self.demodulate:475            sigma_inv = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + self.eps)477            weights = weights * sigma_invReshape x
 
480        x = x.reshape(1, -1, h, w)Reshape weights
483        _, _, *ws = weights.shape
484        weights = weights.reshape(b * self.out_features, *ws)Use grouped convolution to efficiently calculate the convolution with sample wise kernel. i.e. we have a different kernel (weights) for each sample in the batch
488        x = F.conv2d(x, weights, padding=self.padding, groups=b)Reshape x
 to [batch_size, out_features, height, width]
 and return 
491        return x.reshape(-1, self.out_features, h, w)Discriminator first transforms the image to a feature map of the same resolution and then runs it through a series of blocks with residual connections. The resolution is down-sampled by at each block while doubling the number of features.
494class Discriminator(nn.Module):log_resolution
 is the  of image resolution n_features
 number of features in the convolution layer at the highest resolution (first block) max_features
 maximum number of features in any generator block508    def __init__(self, log_resolution: int, n_features: int = 64, max_features: int = 512):514        super().__init__()Layer to convert RGB image to a feature map with n_features
 number of features. 
517        self.from_rgb = nn.Sequential(
518            EqualizedConv2d(3, n_features, 1),
519            nn.LeakyReLU(0.2, True),
520        )525        features = [min(max_features, n_features * (2 ** i)) for i in range(log_resolution - 1)]Number of discirminator blocks
527        n_blocks = len(features) - 1Discriminator blocks
529        blocks = [DiscriminatorBlock(features[i], features[i + 1]) for i in range(n_blocks)]
530        self.blocks = nn.Sequential(*blocks)533        self.std_dev = MiniBatchStdDev()Number of features after adding the standard deviations map
535        final_features = features[-1] + 1Final convolution layer
537        self.conv = EqualizedConv2d(final_features, final_features, 3)Final linear layer to get the classification
539        self.final = EqualizedLinear(2 * 2 * final_features, 1)x
 is the input image of shape [batch_size, 3, height, width]
541    def forward(self, x: torch.Tensor):Try to normalize the image (this is totally optional, but sped up the early training a little)
547        x = x - 0.5Convert from RGB
549        x = self.from_rgb(x)Run through the discriminator blocks
551        x = self.blocks(x)Calculate and append mini-batch standard deviation
554        x = self.std_dev(x)convolution
556        x = self.conv(x)Flatten
558        x = x.reshape(x.shape[0], -1)Return the classification score
560        return self.final(x)563class DiscriminatorBlock(nn.Module):in_features
 is the number of features in the input feature map out_features
 is the number of features in the output feature map574    def __init__(self, in_features, out_features):579        super().__init__()Down-sampling and convolution layer for the residual connection
581        self.residual = nn.Sequential(DownSample(),
582                                      EqualizedConv2d(in_features, out_features, kernel_size=1))Two convolutions
585        self.block = nn.Sequential(
586            EqualizedConv2d(in_features, in_features, kernel_size=3, padding=1),
587            nn.LeakyReLU(0.2, True),
588            EqualizedConv2d(in_features, out_features, kernel_size=3, padding=1),
589            nn.LeakyReLU(0.2, True),
590        )Down-sampling layer
593        self.down_sample = DownSample()Scaling factor after adding the residual
596        self.scale = 1 / math.sqrt(2)598    def forward(self, x):Get the residual connection
600        residual = self.residual(x)Convolutions
603        x = self.block(x)Down-sample
605        x = self.down_sample(x)Add the residual and scale
608        return (x + residual) * self.scaleMini-batch standard deviation calculates the standard deviation across a mini-batch (or a subgroups within the mini-batch) for each feature in the feature map. Then it takes the mean of all the standard deviations and appends it to the feature map as one extra feature.
611class MiniBatchStdDev(nn.Module):group_size
 is the number of samples to calculate standard deviation across.623    def __init__(self, group_size: int = 4):627        super().__init__()
628        self.group_size = group_sizex
 is the feature map630    def forward(self, x: torch.Tensor):Check if the batch size is divisible by the group size
635        assert x.shape[0] % self.group_size == 0Split the samples into groups of group_size
, we flatten the feature map to a single dimension since we want to calculate the standard deviation for each feature. 
638        grouped = x.view(self.group_size, -1)645        std = torch.sqrt(grouped.var(dim=0) + 1e-8)Get the mean standard deviation
647        std = std.mean().view(1, 1, 1, 1)Expand the standard deviation to append to the feature map
649        b, _, h, w = x.shape
650        std = std.expand(b, -1, h, w)Append (concatenate) the standard deviations to the feature map
652        return torch.cat([x, std], dim=1)The down-sample operation smoothens each feature channel and scale using bilinear interpolation. This is based on the paper Making Convolutional Networks Shift-Invariant Again.
655class DownSample(nn.Module):667    def __init__(self):
668        super().__init__()Smoothing layer
670        self.smooth = Smooth()672    def forward(self, x: torch.Tensor):Smoothing or blurring
674        x = self.smooth(x)Scaled down
676        return F.interpolate(x, (x.shape[2] // 2, x.shape[3] // 2), mode='bilinear', align_corners=False)The up-sample operation scales the image up by and smoothens each feature channel. This is based on the paper Making Convolutional Networks Shift-Invariant Again.
679class UpSample(nn.Module):690    def __init__(self):
691        super().__init__()Up-sampling layer
693        self.up_sample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)Smoothing layer
695        self.smooth = Smooth()697    def forward(self, x: torch.Tensor):Up-sample and smoothen
699        return self.smooth(self.up_sample(x))702class Smooth(nn.Module):711    def __init__(self):
712        super().__init__()Blurring kernel
714        kernel = [[1, 2, 1],
715                  [2, 4, 2],
716                  [1, 2, 1]]Convert the kernel to a PyTorch tensor
718        kernel = torch.tensor([[kernel]], dtype=torch.float)Normalize the kernel
720        kernel /= kernel.sum()Save kernel as a fixed parameter (no gradient updates)
722        self.kernel = nn.Parameter(kernel, requires_grad=False)Padding layer
724        self.pad = nn.ReplicationPad2d(1)726    def forward(self, x: torch.Tensor):Get shape of the input feature map
728        b, c, h, w = x.shapeReshape for smoothening
730        x = x.view(-1, 1, h, w)Add padding
733        x = self.pad(x)Smoothen (blur) with the kernel
736        x = F.conv2d(x, self.kernel)Reshape and return
739        return x.view(b, c, h, w)This uses learning-rate equalized weights for a linear layer.
742class EqualizedLinear(nn.Module):in_features
 is the number of features in the input feature map out_features
 is the number of features in the output feature map bias
 is the bias initialization constant751    def __init__(self, in_features: int, out_features: int, bias: float = 0.):758        super().__init__()760        self.weight = EqualizedWeight([out_features, in_features])Bias
762        self.bias = nn.Parameter(torch.ones(out_features) * bias)764    def forward(self, x: torch.Tensor):Linear transformation
766        return F.linear(x, self.weight(), bias=self.bias)This uses learning-rate equalized weights for a convolution layer.
769class EqualizedConv2d(nn.Module):in_features
 is the number of features in the input feature map out_features
 is the number of features in the output feature map kernel_size
 is the size of the convolution kernel padding
 is the padding to be added on both sides of each size dimension778    def __init__(self, in_features: int, out_features: int,
779                 kernel_size: int, padding: int = 0):786        super().__init__()Padding size
788        self.padding = padding790        self.weight = EqualizedWeight([out_features, in_features, kernel_size, kernel_size])Bias
792        self.bias = nn.Parameter(torch.ones(out_features))794    def forward(self, x: torch.Tensor):Convolution
796        return F.conv2d(x, self.weight(), bias=self.bias, padding=self.padding)This is based on equalized learning rate introduced in the Progressive GAN paper. Instead of initializing weights at they initialize weights to and then multiply them by when using it.
The gradients on stored parameters get multiplied by but this doesn't have an affect since optimizers such as Adam normalize them by a running mean of the squared gradients.
The optimizer updates on are proportionate to the learning rate . But the effective weights get updated proportionately to . Without equalized learning rate, the effective weights will get updated proportionately to just .
So we are effectively scaling the learning rate by for these weight parameters.
799class EqualizedWeight(nn.Module):shape
 is the shape of the weight parameter820    def __init__(self, shape: List[int]):824        super().__init__()He initialization constant
827        self.c = 1 / math.sqrt(np.prod(shape[1:]))Initialize the weights with
829        self.weight = nn.Parameter(torch.randn(shape))Weight multiplication coefficient
832    def forward(self):Multiply the weights by and return
834        return self.weight * self.cThis is the regularization penality from the paper Which Training Methods for GANs do actually Converge?.
That is we try to reduce the L2 norm of gradients of the discriminator with respect to images, for real images ().
837class GradientPenalty(nn.Module):x
 is  d
 is 853    def forward(self, x: torch.Tensor, d: torch.Tensor):Get batch size
860        batch_size = x.shape[0]Calculate gradients of  with respect to . grad_outputs
 is set to  since we want the gradients of , and we need to create and retain graph since we have to compute gradients with respect to weight on this loss. 
866        gradients, *_ = torch.autograd.grad(outputs=d,
867                                            inputs=x,
868                                            grad_outputs=d.new_ones(d.shape),
869                                            create_graph=True)Reshape gradients to calculate the norm
872        gradients = gradients.reshape(batch_size, -1)Calculate the norm
874        norm = gradients.norm(2, dim=-1)Return the loss
876        return torch.mean(norm ** 2)This regularization encourages a fixed-size step in to result in a fixed-magnitude change in the image.
where is the Jacobian , are sampled from from the mapping network, and are images with noise .
is the exponential moving average of as the training progresses.
is calculated without explicitly calculating the Jacobian using
879class PathLengthPenalty(nn.Module):beta
 is the constant  used to calculate the exponential moving average 903    def __init__(self, beta: float):907        super().__init__()910        self.beta = betaNumber of steps calculated
912        self.steps = nn.Parameter(torch.tensor(0.), requires_grad=False)Exponential sum of where is the value of it at -th step of training
916        self.exp_sum_a = nn.Parameter(torch.tensor(0.), requires_grad=False)w
 is the batch of  of shape [batch_size, d_latent]
 x
 are the generated images of shape [batch_size, 3, height, width]
918    def forward(self, w: torch.Tensor, x: torch.Tensor):Get the device
925        device = x.deviceGet number of pixels
927        image_size = x.shape[2] * x.shape[3]Calculate
929        y = torch.randn(x.shape, device=device)Calculate and normalize by the square root of image size. This is scaling is not mentioned in the paper but was present in their implementation.
933        output = (x * y).sum() / math.sqrt(image_size)Calculate gradients to get
936        gradients, *_ = torch.autograd.grad(outputs=output,
937                                            inputs=w,
938                                            grad_outputs=torch.ones(output.shape, device=device),
939                                            create_graph=True)Calculate L2-norm of
942        norm = (gradients ** 2).sum(dim=2).mean(dim=1).sqrt()Regularize after first step
945        if self.steps > 0:Calculate
948            a = self.exp_sum_a / (1 - self.beta ** self.steps)Calculate the penalty
952            loss = torch.mean((norm - a) ** 2)
953        else:Return a dummy loss if we can't calculate
955            loss = norm.new_tensor(0)Calculate the mean of
958        mean = norm.mean().detach()Update exponential sum
960        self.exp_sum_a.mul_(self.beta).add_(mean, alpha=1 - self.beta)Increment
962        self.steps.add_(1.)Return the penalty
965        return loss