diff --git a/.gitignore b/.gitignore index 9619c610..c82e990f 100644 --- a/.gitignore +++ b/.gitignore @@ -12,4 +12,5 @@ labml_helpers labml_samples data logs -html/ \ No newline at end of file +html/ +diagrams/ \ No newline at end of file diff --git a/Makefile b/Makefile index ab61b2d8..e6c25a42 100644 --- a/Makefile +++ b/Makefile @@ -23,14 +23,11 @@ uninstall: ## Uninstall docs: ## Render annotated HTML find ./docs/ -name "*.html" -type f -delete + find ./docs/ -name "*.svg" -type f -delete python utils/sitemap.py + python utils/diagrams.py cd labml_nn; pylit --remove_empty_sections --title_md -t ../../../pylit/templates/nn -d ../docs -w * -pages-old: ## Copy to lab-ml site - cd labml_nn; pylit --remove_empty_sections --title_md -t ../../../pylit/templates/nn_old -d ../html/labml_nn * - @cd ../pages; git pull - cp -r html/* ../pages/ - help: ## Show this help. @fgrep -h "##" $(MAKEFILE_LIST) | fgrep -v fgrep | sed -e 's/\\$$//' | sed -e 's/##//' diff --git a/docs/activations/index.html b/docs/activations/index.html index 8e818fcf..560b440f 100644 --- a/docs/activations/index.html +++ b/docs/activations/index.html @@ -95,19 +95,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/activations/swish.html b/docs/activations/swish.html index eb7cbe84..123ea8ee 100644 --- a/docs/activations/swish.html +++ b/docs/activations/swish.html @@ -134,19 +134,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/capsule_networks/index.html b/docs/capsule_networks/index.html index 84cb2c05..9db61c64 100644 --- a/docs/capsule_networks/index.html +++ b/docs/capsule_networks/index.html @@ -465,19 +465,46 @@ of $\mathcal{L}_k$ for for all $k$.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/capsule_networks/mnist.html b/docs/capsule_networks/mnist.html index 42ff668a..f01d20d0 100644 --- a/docs/capsule_networks/mnist.html +++ b/docs/capsule_networks/mnist.html @@ -549,19 +549,46 @@ take it through decoder to get reconstruction

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/capsule_networks/readme.html b/docs/capsule_networks/readme.html index 7cb46a86..ceed2e82 100644 --- a/docs/capsule_networks/readme.html +++ b/docs/capsule_networks/readme.html @@ -108,19 +108,46 @@ confusions I had with the paper.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/cnn/cnn_visualization.html b/docs/cnn/cnn_visualization.html index ac0ead78..ea5574ac 100644 --- a/docs/cnn/cnn_visualization.html +++ b/docs/cnn/cnn_visualization.html @@ -454,19 +454,46 @@ plt.imshow(sobel_h(act[0][idx-1]), cmap=plt.cm.gray)

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/cnn/cross_validation.html b/docs/cnn/cross_validation.html index 9215c235..d238c069 100644 --- a/docs/cnn/cross_validation.html +++ b/docs/cnn/cross_validation.html @@ -246,19 +246,46 @@ from nutsml import *

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/cnn/models/cnn.html b/docs/cnn/models/cnn.html index 42c07de8..7c7dd2e1 100644 --- a/docs/cnn/models/cnn.html +++ b/docs/cnn/models/cnn.html @@ -479,19 +479,46 @@ Calculate the output shape after applying a convolution

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/cnn/ray_tune.html b/docs/cnn/ray_tune.html index 51dc8d5e..f842fe6c 100644 --- a/docs/cnn/ray_tune.html +++ b/docs/cnn/ray_tune.html @@ -273,19 +273,46 @@ ASHA (Asynchronous Successive Halving Algorithm) scheduler displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/cnn/utils/cv_train.html b/docs/cnn/utils/cv_train.html index 1bba00d4..7c854827 100644 --- a/docs/cnn/utils/cv_train.html +++ b/docs/cnn/utils/cv_train.html @@ -500,19 +500,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/cnn/utils/dataloader.html b/docs/cnn/utils/dataloader.html index d47437df..622535dd 100644 --- a/docs/cnn/utils/dataloader.html +++ b/docs/cnn/utils/dataloader.html @@ -313,19 +313,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/cnn/utils/train.html b/docs/cnn/utils/train.html index f967270e..93cf1412 100644 --- a/docs/cnn/utils/train.html +++ b/docs/cnn/utils/train.html @@ -552,19 +552,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/experiments/cifar10.html b/docs/experiments/cifar10.html index a9a45378..6e1a0e15 100644 --- a/docs/experiments/cifar10.html +++ b/docs/experiments/cifar10.html @@ -108,19 +108,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/experiments/index.html b/docs/experiments/index.html index b5799458..f1b1177c 100644 --- a/docs/experiments/index.html +++ b/docs/experiments/index.html @@ -84,19 +84,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/experiments/mnist.html b/docs/experiments/mnist.html index b393f096..96d47fbd 100644 --- a/docs/experiments/mnist.html +++ b/docs/experiments/mnist.html @@ -416,19 +416,46 @@ This will keep the accuracy metric stats separate for training and validation. + \ No newline at end of file diff --git a/docs/experiments/nlp_autoregression.html b/docs/experiments/nlp_autoregression.html index 50542f7e..43b06a1f 100644 --- a/docs/experiments/nlp_autoregression.html +++ b/docs/experiments/nlp_autoregression.html @@ -965,19 +965,46 @@ We need to transpose it to be sequence first.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/gan/cycle_gan/index.html b/docs/gan/cycle_gan/index.html index 3fb584f1..37cd1ef5 100644 --- a/docs/gan/cycle_gan/index.html +++ b/docs/gan/cycle_gan/index.html @@ -1940,19 +1940,46 @@ in the call to experiment.configs

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/gan/cycle_gan/readme.html b/docs/gan/cycle_gan/readme.html index b288e2a9..a6c1f8c9 100644 --- a/docs/gan/cycle_gan/readme.html +++ b/docs/gan/cycle_gan/readme.html @@ -98,19 +98,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/gan/dcgan/index.html b/docs/gan/dcgan/index.html index 12f0ef33..1f9c0f37 100644 --- a/docs/gan/dcgan/index.html +++ b/docs/gan/dcgan/index.html @@ -370,19 +370,46 @@ generator and discriminator networks

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/gan/dcgan/readme.html b/docs/gan/dcgan/readme.html index 07410ed2..b85d9be6 100644 --- a/docs/gan/dcgan/readme.html +++ b/docs/gan/dcgan/readme.html @@ -98,19 +98,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/gan/index.html b/docs/gan/index.html index 45cccd83..22423086 100644 --- a/docs/gan/index.html +++ b/docs/gan/index.html @@ -102,19 +102,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/gan/original/experiment.html b/docs/gan/original/experiment.html index 33469fd0..4eef9668 100644 --- a/docs/gan/original/experiment.html +++ b/docs/gan/original/experiment.html @@ -599,19 +599,46 @@ Default of 0.9 fails.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/gan/original/index.html b/docs/gan/original/index.html index 6944be4d..ba8bef6f 100644 --- a/docs/gan/original/index.html +++ b/docs/gan/original/index.html @@ -309,19 +309,46 @@ the above gradient.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/gan/original/readme.html b/docs/gan/original/readme.html index 08587a98..46bfeb4e 100644 --- a/docs/gan/original/readme.html +++ b/docs/gan/original/readme.html @@ -98,19 +98,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/gan/stylegan/discriminator_block.svg b/docs/gan/stylegan/discriminator_block.svg new file mode 100644 index 00000000..f9ed820b --- /dev/null +++ b/docs/gan/stylegan/discriminator_block.svg @@ -0,0 +1,85 @@ +
Downsample
Downsample
3x3 Conv
3x3 Conv
3x3 Conv
3x3 Conv
Downsample
Downsample
1x1 Conv
1x1 Conv
+
+
Viewer does not support full SVG 1.1
\ No newline at end of file diff --git a/docs/gan/stylegan/experiment.html b/docs/gan/stylegan/experiment.html new file mode 100644 index 00000000..2a7ad67b --- /dev/null +++ b/docs/gan/stylegan/experiment.html @@ -0,0 +1,1663 @@ + + + + + + + + + + + + + + + + + + + + + + + Style GAN 2 Model Training + + + + + + + + +
+
+
+
+

+ home + gan + stylegan +

+

+ + + Github + + Join Slact + + Twitter +

+
+
+
+
+ +

Style GAN 2 Model Training

+

This is the training code for Style GAN 2 model.

+

Generated Images

+

These are $64 \times 64$ images generated after training for about 80K steps.

+

Our implementation is a minimalistic Style GAN2 model training code. +Only single GPU training is supported to keep the implementation simple. +We managed to shrink it to keep it at less than 500 lines of code, including the training loop.

+

Without DDP (distributed data parallel) and multi-gpu training it will not be possible to train the model +for large resolutions (128+). +If you want training code with fp16 and DDP take a look at +lucidrains/stylegan2-pytorch.

+

We trained this on CelebA-HQ dataset. +You can find the download instruction in this +discussion on fast.ai. +Save the images inside data/stylegan folder.

+
+
+
31import math
+32from pathlib import Path
+33from typing import Iterator, Tuple
+34
+35import torch
+36import torch.utils.data
+37import torchvision
+38from PIL import Image
+39
+40from labml import tracker, lab, monit, experiment
+41from labml.configs import BaseConfigs
+42from labml_helpers.device import DeviceConfigs
+43from labml_helpers.train_valid import ModeState, hook_model_outputs
+44from labml_nn.gan.stylegan import Discriminator, Generator, MappingNetwork, GradientPenalty, PathLengthPenalty
+45from labml_nn.gan.wasserstein import DiscriminatorLoss, GeneratorLoss
+46from labml_nn.utils import cycle_dataloader
+
+
+
+
+ +

Dataset

+

This loads the training dataset and resize it to the give image size.

+
+
+
49class Dataset(torch.utils.data.Dataset):
+
+
+
+
+ +
    +
  • path path to the folder containing the images
  • +
  • image_size size of the image
  • +
+
+
+
56    def __init__(self, path: str, image_size: int):
+
+
+
+
+ + +
+
+
61        super().__init__()
+
+
+
+
+ +

Get the paths of all jpg files

+
+
+
64        self.paths = [p for p in Path(path).glob(f'**/*.jpg')]
+
+
+
+
+ +

Transformation

+
+
+
67        self.transform = torchvision.transforms.Compose([
+
+
+
+
+ +

Resize the image

+
+
+
69            torchvision.transforms.Resize(image_size),
+
+
+
+
+ +

Convert to PyTorch tensor

+
+
+
71            torchvision.transforms.ToTensor(),
+72        ])
+
+
+
+
+ +

Number of images

+
+
+
74    def __len__(self):
+
+
+
+
+ + +
+
+
76        return len(self.paths)
+
+
+
+
+ +

Get the the index-th image

+
+
+
78    def __getitem__(self, index):
+
+
+
+
+ + +
+
+
80        path = self.paths[index]
+81        img = Image.open(path)
+82        return self.transform(img)
+
+
+
+
+ +

Configurations

+
+
+
85class Configs(BaseConfigs):
+
+
+
+
+ +

Device to train the model on. +DeviceConfigs + picks up an available CUDA device or defaults to CPU.

+
+
+
93    device: torch.device = DeviceConfigs()
+
+
+
+ +
+
96    discriminator: Discriminator
+
+
+
+ +
+
98    generator: Generator
+
+
+
+
+ +

Mapping network

+
+
+
100    mapping_network: MappingNetwork
+
+
+
+
+ +

Discriminator and generator loss functions. +We use Wasserstein loss

+
+
+
104    discriminator_loss: DiscriminatorLoss
+105    generator_loss: GeneratorLoss
+
+
+
+
+ +

Optimizers

+
+
+
108    generator_optimizer: torch.optim.Adam
+109    discriminator_optimizer: torch.optim.Adam
+110    mapping_network_optimizer: torch.optim.Adam
+
+
+
+ +
+
113    gradient_penalty = GradientPenalty()
+
+
+
+
+ +

Gradient penalty coefficient $\gamma$

+
+
+
115    gradient_penalty_coefficient: float = 10.
+
+
+
+ +
+
118    path_length_penalty: PathLengthPenalty
+
+
+
+
+ +

Data loader

+
+
+
121    loader: Iterator
+
+
+
+
+ +

Batch size

+
+
+
124    batch_size: int = 32
+
+
+
+
+ +

Dimensionality of $z$ and $w$

+
+
+
126    d_latent: int = 512
+
+
+
+
+ +

Height/width of the image

+
+
+
128    image_size: int = 32
+
+
+
+
+ +

Number of layers in the mapping network

+
+
+
130    mapping_network_layers: int = 8
+
+
+
+
+ +

Generator & Discriminator learning rate

+
+
+
132    learning_rate: float = 1e-3
+
+
+
+
+ +

Mapping network learning rate ($100 \times$ lower than the others)

+
+
+
134    mapping_network_learning_rate: float = 1e-5
+
+
+
+
+ +

Number of steps to accumulate gradients on. Use this to increase the effective batch size.

+
+
+
136    gradient_accumulate_steps: int = 1
+
+
+
+
+ +

$\beta_1$ and $\beta_2$ for Adam optimizer

+
+
+
138    adam_betas: Tuple[float, float] = (0.0, 0.99)
+
+
+
+
+ +

Probability of mixing styles

+
+
+
140    style_mixing_prob: float = 0.9
+
+
+
+
+ +

Total number of training steps

+
+
+
143    training_steps: int = 150_000
+
+
+
+
+ +

Number of blocks in the generator (calculated based on image resolution)

+
+
+
146    n_gen_blocks: int
+
+
+
+
+ +

Lazy regularization

+

Instead of calculating the regularization losses, the paper proposes lazy regularization +where the regularization terms are calculated once in a while. +This improves the training efficiency a lot.

+
+
+
+
+
+
+
+ +

The interval at which to compute gradient penalty

+
+
+
154    lazy_gradient_penalty_interval: int = 4
+
+
+
+
+ +

Path length penalty calculation interval

+
+
+
156    lazy_path_penalty_interval: int = 32
+
+
+
+
+ +

Skip calculating path length penalty during the initial phase of training

+
+
+
158    lazy_path_penalty_after: int = 5_000
+
+
+
+
+ +

How often to log generated images

+
+
+
161    log_generated_interval: int = 500
+
+
+
+
+ +

How often to save model checkpoints

+
+
+
163    save_checkpoint_interval: int = 2_000
+
+
+
+
+ +

Training mode state for logging activations

+
+
+
166    mode: ModeState
+
+
+
+
+ +

Whether to log model layer outputs

+
+
+
168    log_layer_outputs: bool = False
+
+
+
+
+ +

+We trained this on CelebA-HQ dataset. +You can find the download instruction in this +discussion on fast.ai. +Save the images inside data/stylegan folder.

+
+
+
175    dataset_path: str = str(lab.get_data_path() / 'stylegan2')
+
+
+
+
+ +

Initialize

+
+
+
177    def init(self):
+
+
+
+
+ +

Create dataset

+
+
+
182        dataset = Dataset(self.dataset_path, self.image_size)
+
+
+
+
+ +

Create data loader

+
+
+
184        dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=32,
+185                                                 shuffle=True, drop_last=True, pin_memory=True)
+
+
+
+
+ +

Continuous cyclic loader

+
+
+
187        self.loader = cycle_dataloader(dataloader)
+
+
+
+
+ +

$\log_2$ of image resolution

+
+
+
190        log_resolution = int(math.log2(self.image_size))
+
+
+
+
+ +

Create discriminator and generator

+
+
+
193        self.discriminator = Discriminator(log_resolution).to(self.device)
+194        self.generator = Generator(log_resolution, self.d_latent).to(self.device)
+
+
+
+
+ +

Get number of generator blocks for creating style and noise inputs

+
+
+
196        self.n_gen_blocks = self.generator.n_blocks
+
+
+
+
+ +

Create mapping network

+
+
+
198        self.mapping_network = MappingNetwork(self.d_latent, self.mapping_network_layers).to(self.device)
+
+
+
+
+ +

Create path length penalty loss

+
+
+
200        self.path_length_penalty = PathLengthPenalty(0.99).to(self.device)
+
+
+
+
+ +

Add model hooks to monitor layer outputs

+
+
+
203        if self.log_layer_outputs:
+204            hook_model_outputs(self.mode, self.discriminator, 'discriminator')
+205            hook_model_outputs(self.mode, self.generator, 'generator')
+206            hook_model_outputs(self.mode, self.mapping_network, 'mapping_network')
+
+
+
+
+ +

Discriminator and generator losses

+
+
+
209        self.discriminator_loss = DiscriminatorLoss().to(self.device)
+210        self.generator_loss = GeneratorLoss().to(self.device)
+
+
+
+
+ +

Create optimizers

+
+
+
213        self.discriminator_optimizer = torch.optim.Adam(
+214            self.discriminator.parameters(),
+215            lr=self.learning_rate, betas=self.adam_betas
+216        )
+217        self.generator_optimizer = torch.optim.Adam(
+218            self.generator.parameters(),
+219            lr=self.learning_rate, betas=self.adam_betas
+220        )
+221        self.mapping_network_optimizer = torch.optim.Adam(
+222            self.mapping_network.parameters(),
+223            lr=self.mapping_network_learning_rate, betas=self.adam_betas
+224        )
+
+
+
+
+ +

Set tracker configurations

+
+
+
227        tracker.set_image("generated", True)
+
+
+
+
+ +

Sample $w$

+

This samples $z$ randomly and get $w$ from the mapping network.

+

We also apply style mixing sometimes where we generate two latent variables +$z_1$ and $z_2$ and get corresponding $w_1$ and $w_2$. +Then we randomly sample a cross-over point and apply $w_1$ to +the generator blocks before the cross-over point and +$w_2$ to the blocks after.

+
+
+
229    def get_w(self, batch_size: int):
+
+
+
+
+ +

Mix styles

+
+
+
243        if torch.rand(()).item() < self.style_mixing_prob:
+
+
+
+
+ +

Random cross-over point

+
+
+
245            cross_over_point = int(torch.rand(()).item() * self.n_gen_blocks)
+
+
+
+
+ +

Sample $z_1$ and $z_2$

+
+
+
247            z2 = torch.randn(batch_size, self.d_latent).to(self.device)
+248            z1 = torch.randn(batch_size, self.d_latent).to(self.device)
+
+
+
+
+ +

Get $w_1$ and $w_2$

+
+
+
250            w1 = self.mapping_network(z1)
+251            w2 = self.mapping_network(z2)
+
+
+
+
+ +

Expand $w_1$ and $w_2$ for the generator blocks and concatenate

+
+
+
253            w1 = w1[None, :, :].expand(cross_over_point, -1, -1)
+254            w2 = w2[None, :, :].expand(self.n_gen_blocks - cross_over_point, -1, -1)
+255            return torch.cat((w1, w2), dim=0)
+
+
+
+
+ +

Without mixing

+
+
+
257        else:
+
+
+
+
+ +

Sample $z$ and $z$

+
+
+
259            z = torch.randn(batch_size, self.d_latent).to(self.device)
+
+
+
+
+ +

Get $w$ and $w$

+
+
+
261            w = self.mapping_network(z)
+
+
+
+
+ +

Expand $w$ for the generator blocks

+
+
+
263            return w[None, :, :].expand(self.n_gen_blocks, -1, -1)
+
+
+
+
+ +

Generate noise

+

This generates noise for each generator block

+
+
+
265    def get_noise(self, batch_size: int):
+
+
+
+
+ +

List to store noise

+
+
+
272        noise = []
+
+
+
+
+ +

Noise resolution starts from $4$

+
+
+
274        resolution = 4
+
+
+
+
+ +

Generate noise for each generator block

+
+
+
277        for i in range(self.n_gen_blocks):
+
+
+
+
+ +

The first block has only one $3 \times 3$ convolution

+
+
+
279            if i == 0:
+280                n1 = None
+
+
+
+
+ +

Generate noise to add after the first convolution layer

+
+
+
282            else:
+283                n1 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)
+
+
+
+
+ +

Generate noise to add after the second convolution layer

+
+
+
285            n2 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)
+
+
+
+
+ +

Add noise tensors to the list

+
+
+
288            noise.append((n1, n2))
+
+
+
+
+ +

Next block has $2 \times$ resolution

+
+
+
291            resolution *= 2
+
+
+
+
+ +

Return noise tensors

+
+
+
294        return noise
+
+
+
+
+ +

Generate images

+

This generate images using the generator

+
+
+
296    def generate_images(self, batch_size: int):
+
+
+
+
+ +

Get $w$

+
+
+
304        w = self.get_w(batch_size)
+
+
+
+
+ +

Get noise

+
+
+
306        noise = self.get_noise(batch_size)
+
+
+
+
+ +

Generate images

+
+
+
309        images = self.generator(w, noise)
+
+
+
+
+ +

Return images and $w$

+
+
+
312        return images, w
+
+
+
+
+ +

Training Step

+
+
+
314    def step(self, idx: int):
+
+
+
+
+ +

Train the discriminator

+
+
+
320        with monit.section('Discriminator'):
+
+
+
+
+ +

Reset gradients

+
+
+
322            self.discriminator_optimizer.zero_grad()
+
+
+
+
+ +

Accumulate gradients for gradient_accumulate_steps

+
+
+
325            for i in range(self.gradient_accumulate_steps):
+
+
+
+
+ +

Update mode. Set whether to log activation

+
+
+
327                with self.mode.update(is_log_activations=(idx + 1) % self.log_generated_interval == 0):
+
+
+
+
+ +

Sample images from generator

+
+
+
329                    generated_images, _ = self.generate_images(self.batch_size)
+
+
+
+
+ +

Discriminator classification for generated images

+
+
+
331                    fake_output = self.discriminator(generated_images.detach())
+
+
+
+
+ +

Get real images from the data loader

+
+
+
334                    real_images = next(self.loader).to(self.device)
+
+
+
+
+ +

We need to calculate gradients w.r.t. real images for gradient penalty

+
+
+
336                    if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
+337                        real_images.requires_grad_()
+
+
+
+
+ +

Discriminator classification for real images

+
+
+
339                    real_output = self.discriminator(real_images)
+
+
+
+
+ +

Get discriminator loss

+
+
+
342                    real_loss, fake_loss = self.discriminator_loss(real_output, fake_output)
+343                    disc_loss = real_loss + fake_loss
+
+
+
+
+ +

Add gradient penalty

+
+
+
346                    if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
+
+
+
+
+ +

Calculate and log gradient penalty

+
+
+
348                        gp = self.gradient_penalty(real_images, real_output)
+349                        tracker.add('loss.gp', gp)
+
+
+
+
+ +

Multiply by coefficient and add gradient penalty

+
+
+
351                        disc_loss = disc_loss + 0.5 * self.gradient_penalty_coefficient * gp * self.lazy_gradient_penalty_interval
+
+
+
+
+ +

Compute gradients

+
+
+
354                    disc_loss.backward()
+
+
+
+
+ +

Log discriminator loss

+
+
+
357                    tracker.add('loss.discriminator', disc_loss)
+358
+359            if (idx + 1) % self.log_generated_interval == 0:
+
+
+
+
+ +

Log discriminator model parameters occasionally

+
+
+
361                tracker.add('discriminator', self.discriminator)
+
+
+
+
+ +

Clip gradients for stabilization

+
+
+
364            torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=1.0)
+
+
+
+
+ +

Take optimizer step

+
+
+
366            self.discriminator_optimizer.step()
+
+
+
+
+ +

Train the generator

+
+
+
369        with monit.section('Generator'):
+
+
+
+
+ +

Reset gradients

+
+
+
371            self.generator_optimizer.zero_grad()
+372            self.mapping_network_optimizer.zero_grad()
+
+
+
+
+ +

Accumulate gradients for gradient_accumulate_steps

+
+
+
375            for i in range(self.gradient_accumulate_steps):
+
+
+
+
+ +

Sample images from generator

+
+
+
377                generated_images, w = self.generate_images(self.batch_size)
+
+
+
+
+ +

Discriminator classification for generated images

+
+
+
379                fake_output = self.discriminator(generated_images)
+
+
+
+
+ +

Get generator loss

+
+
+
382                gen_loss = self.generator_loss(fake_output)
+
+
+
+
+ +

Add path length penalty

+
+
+
385                if idx > self.lazy_path_penalty_after and (idx + 1) % self.lazy_path_penalty_interval == 0:
+
+
+
+
+ +

Calculate path length penalty

+
+
+
387                    plp = self.path_length_penalty(w, generated_images)
+
+
+
+
+ +

Ignore if nan

+
+
+
389                    if not torch.isnan(plp):
+390                        tracker.add('loss.plp', plp)
+391                        gen_loss = gen_loss + plp
+
+
+
+
+ +

Calculate gradients

+
+
+
394                gen_loss.backward()
+
+
+
+
+ +

Log generator loss

+
+
+
397                tracker.add('loss.generator', gen_loss)
+398
+399            if (idx + 1) % self.log_generated_interval == 0:
+
+
+
+
+ +

Log discriminator model parameters occasionally

+
+
+
401                tracker.add('generator', self.generator)
+402                tracker.add('mapping_network', self.mapping_network)
+
+
+
+
+ +

Clip gradients for stabilization

+
+
+
405            torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=1.0)
+406            torch.nn.utils.clip_grad_norm_(self.mapping_network.parameters(), max_norm=1.0)
+
+
+
+
+ +

Take optimizer step

+
+
+
409            self.generator_optimizer.step()
+410            self.mapping_network_optimizer.step()
+
+
+
+
+ +

Log generated images

+
+
+
413        if (idx + 1) % self.log_generated_interval == 0:
+414            tracker.add('generated', torch.cat([generated_images[:6], real_images[:3]], dim=0))
+
+
+
+
+ +

Save model checkpoints

+
+
+
416        if (idx + 1) % self.save_checkpoint_interval == 0:
+417            experiment.save_checkpoint()
+
+
+
+
+ +

Flush tracker

+
+
+
420        tracker.save()
+
+
+
+
+ +

Train model

+
+
+
422    def train(self):
+
+
+
+
+ +

Loop for training_steps

+
+
+
428        for i in monit.loop(self.training_steps):
+
+
+
+
+ +

Take a training step

+
+
+
430            self.step(i)
+
+
+
+
+ + +
+
+
432            if (i + 1) % self.log_generated_interval == 0:
+433                tracker.new_line()
+
+
+
+
+ +

Train StyleGAN2

+
+
+
436def main():
+
+
+
+
+ +

Create an experiment

+
+
+
442    experiment.create(name='stylegan2')
+
+
+
+
+ +

Create configurations object

+
+
+
444    configs = Configs()
+
+
+
+
+ +

Set configurations and override some

+
+
+
447    experiment.configs(configs, {
+448        'device.cuda_device': 0,
+449        'image_size': 64,
+450        'log_generated_interval': 200
+451    })
+
+
+
+
+ +

Initialize

+
+
+
454    configs.init()
+
+
+
+
+ +

Set models for saving and loading

+
+
+
456    experiment.add_pytorch_models(mapping_network=configs.mapping_network,
+457                                  generator=configs.generator,
+458                                  discriminator=configs.discriminator)
+
+
+
+
+ +

Start the experiment

+
+
+
461    with experiment.start():
+
+
+
+
+ +

Run the training loop

+
+
+
463        configs.train()
+
+
+
+
+ + +
+
+
466if __name__ == '__main__':
+467    main()
+
+
+
+ + + + + + + \ No newline at end of file diff --git a/docs/gan/stylegan/generated_64.png b/docs/gan/stylegan/generated_64.png new file mode 100644 index 00000000..d57a3f2c Binary files /dev/null and b/docs/gan/stylegan/generated_64.png differ diff --git a/docs/gan/stylegan/generator_block.svg b/docs/gan/stylegan/generator_block.svg new file mode 100644 index 00000000..23c1c3ed --- /dev/null +++ b/docs/gan/stylegan/generator_block.svg @@ -0,0 +1,85 @@ +
3X3 Conv
3X3 Conv
+
+
B
B
A
A
Demod
Demod
Mod
Mod
weights
weights
3X3 Conv
3X3 Conv
+
+
B
B
A
A
Demod
Demod
Mod
Mod
weights
weights
bias
bias
bias
bias
+
+
toRGB
toRGB
feature map
featur...
rgb
rgb
Viewer does not support full SVG 1.1
\ No newline at end of file diff --git a/docs/gan/stylegan/index.html b/docs/gan/stylegan/index.html new file mode 100644 index 00000000..83eae8b1 --- /dev/null +++ b/docs/gan/stylegan/index.html @@ -0,0 +1,2685 @@ + + + + + + + + + + + + + + + + + + + + + + + Style GAN 2 + + + + + + + + +
+
+
+
+

+ home + gan + stylegan +

+

+ + + Github + + Join Slact + + Twitter +

+
+
+
+
+ +

Style GAN 2

+

This is a PyTorch implementation of the paper + Analyzing and Improving the Image Quality of StyleGAN + which introduces Style GAN2. +Style GAN2 is an improvement over Style GAN from the paper + A Style-Based Generator Architecture for Generative Adversarial Networks. +And Style GAN is based on Progressive GAN from the paper + Progressive Growing of GANs for Improved Quality, Stability, and Variation. +All three papers are from the same authors from NVIDIA AI.

+

Our implementation is a minimalistic Style GAN2 model training code. +Only single GPU training is supported to keep the implementation simple. +We managed to shrink it to keep it at less than 500 lines of code, including the training loop.

+

🏃 Here’s the training code: experiment.py.

+

Generated Images

+

These are $64 \times 64$ images generated after training for about 80K steps.

+

We’ll first introduce the three papers at a high level.

+

Generative Adversarial Networks

+

Generative adversarial networks have two components; the generator and the discriminator. +The generator network takes a random latent vector ($z \in \mathcal{Z}$) + and tries to generate a realistic image. +The discriminator network tries to differentiate the real images from generated images. +When we train the two networks together the generator starts generating images indistinguishable from real images.

+

Progressive GAN

+

Progressive GAN generates high-resolution images ($1080 \times 1080$) of size. +It does so by progressively increasing the image size. +First, it trains a network that produces a $4 \times 4$ image, then $8 \times 8$ , + then an $16 \times 16$ image, and so on up to the desired image resolution.

+

At each resolution, the generator network produces an image in latent space which is converted into RGB, +with a $1 \times 1$ convolution. +When we progress from a lower resolution to a higher resolution + (say from $4 \times 4$ to $8 \times 8$ ) we scale the latent image by $2\times$ + and add a new block (two $3 \times 3$ convolution layers) + and a new $1 \times 1$ layer to get RGB. +The transition is done smoothly by adding a residual connection to + the $2\times$ scaled $4 \times 4$ RGB image. +The weight of this residual connection is slowly reduced, to let the new block take over.

+

The discriminator is a mirror image of the generator network. +The progressive growth of the discriminator is done similarly.

+

progressive_gan.svg

+

$2\times$ and $0.5\times$ denote feature map resolution scaling and scaling. +$4\times4$, $8\times4$, … denote feature map resolution at the generator or discriminator block. +Each discriminator and generator block consists of 2 convolution layers with leaky ReLU activations.

+

They use minibatch standard deviation to increase variation and + equalized learning rate which we discussed below in the implementation. +They also use pixel-wise normalization where at each pixel the feature vector is normalized. +They apply this to all the convolution layer outputs (except RGB).

+

Style GAN

+

Style GAN improves the generator of Progressive GAN keeping the discriminator architecture the same.

+

Mapping Network

+

It maps the random latent vector ($z \in \mathcal{Z}$) + into a different latent space ($w \in \mathcal{W}$), + with an 8-layer neural network. +This gives an intermediate latent space $\mathcal{W}$ +where the factors of variations are more linear (disentangled).

+

AdaIN

+

Then $w$ is transformed into two vectors (styles) per layer, + $i$, $y_i = (y_{s,i}, y_{b,i}) = f_{A_i}(w)$ and used for scaling and shifting (biasing) + in each layer with $\text{AdaIN}$ operator (normalize and scale): + +

+

Style Mixing

+

To prevent the generator from assuming adjacent styles are correlated, + they randomly use different styles for different blocks. +That is, they sample two latent vectors $(z_1, z_2)$ and corresponding $(w_1, w_2)$ and + use $w_1$ based styles for some blocks and $w_2$ based styles for some blacks randomly.

+

Stochastic Variation

+

Noise is made available to each block which helps the generator create more realistic images. +Noise is scaled per channel by a learned weight.

+

Bilinear Up and Down Sampling

+

All the up and down-sampling operations are accompanied by bilinear smoothing.

+

style_gan.svg

+

$A$ denotes a linear layer. +$B$ denotes a broadcast and scaling operation (noise is a single channel). +Style GAN also uses progressive growing like Progressive GAN

+

Style GAN 2

+

Style GAN 2 changes both the generator and the discriminator of Style GAN.

+

Weight Modulation and Demodulation

+

They remove the $\text{AdaIN}$ operator and replace it with + the weight modulation and demodulation step. +This is supposed to improve what they call droplet artifacts that are present in generated images, + which are caused by the normalization in $\text{AdaIN}$ operator. +Style vector per layer is calculated from $w_i \in \mathcal{W}$ as $s_i = f_{A_i}(w_i)$.

+

Then the convolution weights $w$ are modulated as follows. +($w$ here on refers to weights not intermediate latent space, + we are sticking to the same notation as the paper.)

+

+ +Then it’s demodulated by normalizing, + +where $i$ is the input channel, $j$ is the output channel, and $k$ is the kernel index.

+

Path Length Regularization

+

Path length regularization encourages a fixed-size step in $\mathcal{W}$ to result in a non-zero, + fixed-magnitude change in the generated image.

+

No Progressive Growing

+

StyleGAN2 uses residual connections (with down-sampling) in the discriminator and skip connections + in the generator with up-sampling + (the RGB outputs from each layer are added - no residual connections in feature maps). +They show that with experiments that the contribution of low-resolution layers is higher + at beginning of the training and then high-resolution layers take over.

+
+
+
148import math
+149from typing import Tuple, Optional, List
+150
+151import numpy as np
+152import torch
+153import torch.nn.functional as F
+154import torch.utils.data
+155from torch import nn
+
+
+
+
+ +

+

Mapping Network

+

Mapping Network

+

This is an MLP with 8 linear layers. +The mapping network maps the latent vector $z \in \mathcal{W}$ +to an intermediate latent space $w \in \mathcal{W}$. +$\mathcal{W}$ space will be disentangled from the image space +where the factors of variation become more linear.

+
+
+
158class MappingNetwork(nn.Module):
+
+
+
+
+ +
    +
  • features is the number of features in $z$ and $w$
  • +
  • n_layers is the number of layers in the mapping network.
  • +
+
+
+
172    def __init__(self, features: int, n_layers: int):
+
+
+
+
+ + +
+
+
177        super().__init__()
+
+
+
+
+ +

Create the MLP

+
+
+
180        layers = []
+181        for i in range(n_layers):
+
+
+
+ +
+
183            layers.append(EqualizedLinear(features, features))
+
+
+
+
+ +

Leaky Relu

+
+
+
185            layers.append(nn.LeakyReLU(negative_slope=0.2, inplace=True))
+186
+187        self.net = nn.Sequential(*layers)
+
+
+
+
+ + +
+
+
189    def forward(self, z: torch.Tensor):
+
+
+
+
+ +

Normalize $z$

+
+
+
191        z = F.normalize(z, dim=1)
+
+
+
+
+ +

Map $z$ to $w$

+
+
+
193        return self.net(z)
+
+
+
+
+ +

+

StyleGAN2 Generator

+

Generator

+

$A$ denotes a linear layer. +$B$ denotes a broadcast and scaling operation (noise is a single channel). +toRGB also has a style modulation which is not shown in the diagram to keep it simple.

+

The generator starts with a learned constant. +Then it has a series of blocks. The feature map resolution is doubled at each block +Each block outputs an RGB image and they are scaled up and summed to get the final RGB image.

+
+
+
196class Generator(nn.Module):
+
+
+
+
+ +
    +
  • log_resolution is the $\log_2$ of image resolution
  • +
  • d_latent is the dimensionality of $w$
  • +
  • n_features number of features in the convolution layer at the highest resolution (final block)
  • +
  • max_features maximum number of features in any generator block
  • +
+
+
+
212    def __init__(self, log_resolution: int, d_latent: int, n_features: int = 32, max_features: int = 512):
+
+
+
+
+ + +
+
+
219        super().__init__()
+
+
+
+
+ +

Calculate the number of features for each block

+

Something like [512, 512, 256, 128, 64, 32]

+
+
+
224        features = [min(max_features, n_features * (2 ** i)) for i in range(log_resolution - 2, -1, -1)]
+
+
+
+
+ +

Number of generator blocks

+
+
+
226        self.n_blocks = len(features)
+
+
+
+
+ +

Trainable $4 \times 4$ constant

+
+
+
229        self.initial_constant = nn.Parameter(torch.randn((1, features[0], 4, 4)))
+
+
+
+
+ +

First style block for $4 \times 4$ resolution and layer to get RGB

+
+
+
232        self.style_block = StyleBlock(d_latent, features[0], features[0])
+233        self.to_rgb = ToRGB(d_latent, features[0])
+
+
+
+
+ +

Generator blocks

+
+
+
236        blocks = [GeneratorBlock(d_latent, features[i - 1], features[i]) for i in range(1, self.n_blocks)]
+237        self.blocks = nn.ModuleList(blocks)
+
+
+
+
+ +

$2 \times$ up sampling layer. The feature space is up sampled +at each block

+
+
+
241        self.up_sample = UpSample()
+
+
+
+
+ +
    +
  • w is $w$. In order to mix-styles (use different $w$ for different layers), we provide a separate +$w$ for each generator block. It has shape `[n_blocks, batch_size, d_latent]1.
  • +
  • input_noise is the noise for each block. +It’s a list of pairs of noise sensors because each block (except the initial) has two noise inputs +after each convolution layer (see the diagram).
  • +
+
+
+
243    def forward(self, w: torch.Tensor, input_noise: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]]):
+
+
+
+
+ +

Get batch size

+
+
+
253        batch_size = w.shape[1]
+
+
+
+
+ +

Expand the learned constant to match batch size

+
+
+
256        x = self.initial_constant.expand(batch_size, -1, -1, -1)
+
+
+
+
+ +

The first style block

+
+
+
259        x = self.style_block(x, w[0], input_noise[0][1])
+
+
+
+
+ +

Get first rgb image

+
+
+
261        rgb = self.to_rgb(x, w[0])
+
+
+
+
+ +

Evaluate rest of the blocks

+
+
+
264        for i in range(1, self.n_blocks):
+
+
+
+
+ +

Up sample the feature map

+
+
+
266            x = self.up_sample(x)
+
+
+
+
+ +

Run it through the generator block

+
+
+
268            x, rgb_new = self.blocks[i - 1](x, w[i], input_noise[i])
+
+
+
+
+ +

Up sample the RGB image and add to the rgb from the block

+
+
+
270            rgb = self.up_sample(rgb) + rgb_new
+
+
+
+
+ +

Return the final RGB image

+
+
+
273        return rgb
+
+
+
+
+ +

+

Generator Block

+

Generator block

+

$A$ denotes a linear layer. +$B$ denotes a broadcast and scaling operation (noise is a single channel). +toRGB also has a style modulation which is not shown in the diagram to keep it simple.

+

The generator block consists of two style blocks ($3 \times 3$ convolutions with style modulation) +and an RGB output.

+
+
+
276class GeneratorBlock(nn.Module):
+
+
+
+
+ +
    +
  • d_latent is the dimensionality of $w$
  • +
  • in_features is the number of features in the input feature map
  • +
  • out_features is the number of features in the output feature map
  • +
+
+
+
291    def __init__(self, d_latent: int, in_features: int, out_features: int):
+
+
+
+
+ + +
+
+
297        super().__init__()
+
+
+
+
+ +

First style block changes the feature map size to out_features

+
+
+
300        self.style_block1 = StyleBlock(d_latent, in_features, out_features)
+
+
+
+
+ +

Second style block

+
+
+
302        self.style_block2 = StyleBlock(d_latent, out_features, out_features)
+
+
+
+
+ +

toRGB layer

+
+
+
305        self.to_rgb = ToRGB(d_latent, out_features)
+
+
+
+
+ +
    +
  • x is the input feature map of shape [batch_size, in_features, height, width]
  • +
  • w is $w$ with shape [batch_size, d_latent]
  • +
  • noise is a tuple of two noise tensors of shape [batch_size, 1, height, width]
  • +
+
+
+
307    def forward(self, x: torch.Tensor, w: torch.Tensor, noise: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]):
+
+
+
+
+ +

First style block with first noise tensor. +The output is of shape [batch_size, out_features, height, width]

+
+
+
315        x = self.style_block1(x, w, noise[0])
+
+
+
+
+ +

Second style block with second noise tensor. +The output is of shape [batch_size, out_features, height, width]

+
+
+
318        x = self.style_block2(x, w, noise[1])
+
+
+
+
+ +

Get RGB image

+
+
+
321        rgb = self.to_rgb(x, w)
+
+
+
+
+ +

Return feature map and rgb image

+
+
+
324        return x, rgb
+
+
+
+
+ +

+

Style Block

+

Style block

+

$A$ denotes a linear layer. +$B$ denotes a broadcast and scaling operation (noise is single channel).

+

Style block has a weight modulation convolution layer.

+
+
+
327class StyleBlock(nn.Module):
+
+
+
+
+ +
    +
  • d_latent is the dimensionality of $w$
  • +
  • in_features is the number of features in the input feature map
  • +
  • out_features is the number of features in the output feature map
  • +
+
+
+
340    def __init__(self, d_latent: int, in_features: int, out_features: int):
+
+
+
+
+ + +
+
+
346        super().__init__()
+
+
+
+
+ +

Get style vector from $w$ (denoted by $A$ in the diagram) with +an equalized learning-rate linear layer

+
+
+
349        self.to_style = EqualizedLinear(d_latent, in_features, bias=1.0)
+
+
+
+
+ +

Weight modulated convolution layer

+
+
+
351        self.conv = Conv2dWeightModulate(in_features, out_features, kernel_size=3)
+
+
+
+
+ +

Noise scale

+
+
+
353        self.scale_noise = nn.Parameter(torch.zeros(1))
+
+
+
+
+ +

Bias

+
+
+
355        self.bias = nn.Parameter(torch.zeros(out_features))
+
+
+
+
+ +

Activation function

+
+
+
358        self.activation = nn.LeakyReLU(0.2, True)
+
+
+
+
+ +
    +
  • x is the input feature map of shape [batch_size, in_features, height, width]
  • +
  • w is $w$ with shape [batch_size, d_latent]
  • +
  • noise is a tensor of shape [batch_size, 1, height, width]
  • +
+
+
+
360    def forward(self, x: torch.Tensor, w: torch.Tensor, noise: Optional[torch.Tensor]):
+
+
+
+
+ +

Get style vector $s$

+
+
+
367        s = self.to_style(w)
+
+
+
+
+ +

Weight modulated convolution

+
+
+
369        x = self.conv(x, s)
+
+
+
+
+ +

Scale and add noise

+
+
+
371        if noise is not None:
+372            x = x + self.scale_noise[None, :, None, None] * noise
+
+
+
+
+ +

Add bias and evaluate activation function

+
+
+
374        return self.activation(x + self.bias[None, :, None, None])
+
+
+
+
+ +

+

To RGB

+

To RGB

+

$A$ denotes a linear layer.

+

Generates an RGB image from a feature map using $1 \times 1$ convolution.

+
+
+
377class ToRGB(nn.Module):
+
+
+
+
+ +
    +
  • d_latent is the dimensionality of $w$
  • +
  • features is the number of features in the feature map
  • +
+
+
+
389    def __init__(self, d_latent: int, features: int):
+
+
+
+
+ + +
+
+
394        super().__init__()
+
+
+
+
+ +

Get style vector from $w$ (denoted by $A$ in the diagram) with +an equalized learning-rate linear layer

+
+
+
397        self.to_style = EqualizedLinear(d_latent, features, bias=1.0)
+
+
+
+
+ +

Weight modulated convolution layer without demodulation

+
+
+
400        self.conv = Conv2dWeightModulate(features, 3, kernel_size=1, demodulate=False)
+
+
+
+
+ +

Bias

+
+
+
402        self.bias = nn.Parameter(torch.zeros(1))
+
+
+
+
+ +

Activation function

+
+
+
404        self.activation = nn.LeakyReLU(0.2, True)
+
+
+
+
+ +
    +
  • x is the input feature map of shape [batch_size, in_features, height, width]
  • +
  • w is $w$ with shape [batch_size, d_latent]
  • +
+
+
+
406    def forward(self, x: torch.Tensor, w: torch.Tensor):
+
+
+
+
+ +

Get style vector $s$

+
+
+
412        style = self.to_style(w)
+
+
+
+
+ +

Weight modulated convolution

+
+
+
414        x = self.conv(x, style)
+
+
+
+
+ +

Add bias and evaluate activation function

+
+
+
416        return self.activation(x + self.bias[None, :, None, None])
+
+
+
+
+ +

Convolution with Weight Modulation and Demodulation

+

This layer scales the convolution weights by the style vector and demodulates by normalizing it.

+
+
+
419class Conv2dWeightModulate(nn.Module):
+
+
+
+
+ +
    +
  • in_features is the number of features in the input feature map
  • +
  • out_features is the number of features in the output feature map
  • +
  • kernel_size is the size of the convolution kernel
  • +
  • demodulate is flag whether to normalize weights by its standard deviation
  • +
  • eps is the $\epsilon$ for normalizing
  • +
+
+
+
426    def __init__(self, in_features: int, out_features: int, kernel_size: int,
+427                 demodulate: float = True, eps: float = 1e-8):
+
+
+
+
+ + +
+
+
435        super().__init__()
+
+
+
+
+ +

Number of output features

+
+
+
437        self.out_features = out_features
+
+
+
+
+ +

Whether to normalize weights

+
+
+
439        self.demodulate = demodulate
+
+
+
+
+ +

Padding size

+
+
+
441        self.padding = (kernel_size - 1) // 2
+
+
+
+ +
+
444        self.weight = EqualizedWeight([out_features, in_features, kernel_size, kernel_size])
+
+
+
+
+ +

$\epsilon$

+
+
+
446        self.eps = eps
+
+
+
+
+ +
    +
  • x is the input feature map of shape [batch_size, in_features, height, width]
  • +
  • s is style based scaling tensor of shape [batch_size, in_features]
  • +
+
+
+
448    def forward(self, x: torch.Tensor, s: torch.Tensor):
+
+
+
+
+ +

Get batch size, height and width

+
+
+
455        b, _, h, w = x.shape
+
+
+
+
+ +

Reshape the scales

+
+
+
458        s = s[:, None, :, None, None]
+
+
+
+ +
+
460        weights = self.weight()[None, :, :, :, :]
+
+
+
+
+ +

+ +where $i$ is the input channel, $j$ is the output channel, and $k$ is the kernel index.

+

The result has shape [batch_size, out_features, in_features, kernel_size, kernel_size]

+
+
+
465        weights = weights * s
+
+
+
+
+ +

Demodulate

+
+
+
468        if self.demodulate:
+
+
+
+
+ +

+ +

+
+
+
470            sigma_inv = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + self.eps)
+
+
+
+
+ +

+ +

+
+
+
472            weights = weights * sigma_inv
+
+
+
+
+ +

Reshape x

+
+
+
475        x = x.reshape(1, -1, h, w)
+
+
+
+
+ +

Reshape weights

+
+
+
478        _, _, *ws = weights.shape
+479        weights = weights.reshape(b * self.out_features, *ws)
+
+
+
+
+ +

Use grouped convolution to efficiently calculate the convolution with sample wise kernel. +i.e. we have a different kernel (weights) for each sample in the batch

+
+
+
483        x = F.conv2d(x, weights, padding=self.padding, groups=b)
+
+
+
+
+ +

Reshape x to [batch_size, out_features, height, width] and return

+
+
+
486        return x.reshape(-1, self.out_features, h, w)
+
+
+
+
+ +

+

Style GAN2 Discriminator

+

Discriminator

+

Discriminator first transforms the image to a feature map of the same resolution and then +runs it through a series of blocks with residual connections. +The resolution is down-sampled by $2 \times$ at each block while doubling the +number of features.

+
+
+
489class Discriminator(nn.Module):
+
+
+
+
+ +
    +
  • log_resolution is the $\log_2$ of image resolution
  • +
  • n_features number of features in the convolution layer at the highest resolution (first block)
  • +
  • max_features maximum number of features in any generator block
  • +
+
+
+
502    def __init__(self, log_resolution: int, n_features: int = 64, max_features: int = 512):
+
+
+
+
+ + +
+
+
508        super().__init__()
+
+
+
+
+ +

Layer to convert RGB image to a feature map with n_features number of features.

+
+
+
511        self.from_rgb = nn.Sequential(
+512            EqualizedConv2d(3, n_features, 1),
+513            nn.LeakyReLU(0.2, True),
+514        )
+
+
+
+
+ +

Calculate the number of features for each block.

+

Something like [64, 128, 256, 512, 512, 512].

+
+
+
519        features = [min(max_features, n_features * (2 ** i)) for i in range(log_resolution - 1)]
+
+
+
+
+ +

Number of discirminator blocks

+
+
+
521        n_blocks = len(features) - 1
+
+
+
+
+ +

Discriminator blocks

+
+
+
523        blocks = [DiscriminatorBlock(features[i], features[i + 1]) for i in range(n_blocks)]
+524        self.blocks = nn.Sequential(*blocks)
+
+
+
+ +
+
527        self.std_dev = MiniBatchStdDev()
+
+
+
+
+ +

Number of features after adding the standard deviations map

+
+
+
529        final_features = features[-1] + 1
+
+
+
+
+ +

Final $3 \times 3$ convolution layer

+
+
+
531        self.conv = EqualizedConv2d(final_features, final_features, 3)
+
+
+
+
+ +

Final linear layer to get the classification

+
+
+
533        self.final = EqualizedLinear(2 * 2 * final_features, 1)
+
+
+
+
+ +
    +
  • x is the input image of shape [batch_size, 3, height, width]
  • +
+
+
+
535    def forward(self, x: torch.Tensor):
+
+
+
+
+ +

Try to normalize the image (this is totally optional, but sped up the early training a little)

+
+
+
541        x = x - 0.5
+
+
+
+
+ +

Convert from RGB

+
+
+
543        x = self.from_rgb(x)
+
+
+
+
+ +

Run through the discriminator blocks

+
+
+
545        x = self.blocks(x)
+
+
+
+
+ +

Calculate and append mini-batch standard deviation

+
+
+
548        x = self.std_dev(x)
+
+
+
+
+ +

$3 \times 3$ convolution

+
+
+
550        x = self.conv(x)
+
+
+
+
+ +

Flatten

+
+
+
552        x = x.reshape(x.shape[0], -1)
+
+
+
+
+ +

Return the classification score

+
+
+
554        return self.final(x)
+
+
+
+
+ +

+

Discriminator Block

+

Discriminator block

+

Discriminator block consists of two $3 \times 3$ convolutions with a residual connection.

+
+
+
557class DiscriminatorBlock(nn.Module):
+
+
+
+
+ +
    +
  • in_features is the number of features in the input feature map
  • +
  • out_features is the number of features in the output feature map
  • +
+
+
+
567    def __init__(self, in_features, out_features):
+
+
+
+
+ + +
+
+
572        super().__init__()
+
+
+
+
+ +

Down-sampling and $1 \times 1$ convolution layer for the residual connection

+
+
+
574        self.residual = nn.Sequential(DownSample(),
+575                                      EqualizedConv2d(in_features, out_features, kernel_size=1))
+
+
+
+
+ +

Two $3 \times 3$ convolutions

+
+
+
578        self.block = nn.Sequential(
+579            EqualizedConv2d(in_features, in_features, kernel_size=3, padding=1),
+580            nn.LeakyReLU(0.2, True),
+581            EqualizedConv2d(in_features, out_features, kernel_size=3, padding=1),
+582            nn.LeakyReLU(0.2, True),
+583        )
+
+
+
+
+ +

Down-sampling layer

+
+
+
586        self.down_sample = DownSample()
+
+
+
+
+ +

Scaling factor $\frac{1}{\sqrt 2}$ after adding the residual

+
+
+
589        self.scale = 1 / math.sqrt(2)
+
+
+
+
+ + +
+
+
591    def forward(self, x):
+
+
+
+
+ +

Get the residual connection

+
+
+
593        residual = self.residual(x)
+
+
+
+
+ +

Convolutions

+
+
+
596        x = self.block(x)
+
+
+
+
+ +

Down-sample

+
+
+
598        x = self.down_sample(x)
+
+
+
+
+ +

Add the residual and scale

+
+
+
601        return (x + residual) * self.scale
+
+
+
+
+ +

+

Mini-batch Standard Deviation

+

Mini-batch standard deviation calculates the standard deviation +across a mini-batch (or a subgroups within the mini-batch) +for each feature in the feature map. Then it takes the mean of all +the standard deviations and appends it to the feature map as one extra feature.

+
+
+
604class MiniBatchStdDev(nn.Module):
+
+
+
+
+ +
    +
  • group_size is the number of samples to calculate standard deviation across.
  • +
+
+
+
616    def __init__(self, group_size: int = 4):
+
+
+
+
+ + +
+
+
620        super().__init__()
+621        self.group_size = group_size
+
+
+
+
+ +
    +
  • x is the feature map
  • +
+
+
+
623    def forward(self, x: torch.Tensor):
+
+
+
+
+ +

Check if the batch size is divisible by the group size

+
+
+
628        assert x.shape[0] % self.group_size == 0
+
+
+
+
+ +

Split the samples into groups of group_size, we flatten the feature map to a single dimension +since we want to calculate the standard deviation for each feature.

+
+
+
631        grouped = x.view(self.group_size, -1)
+
+
+
+
+ +

Calculate the standard deviation for each feature among group_size samples + +

+
+
+
635        std = torch.sqrt(grouped.var(dim=0) + 1e-8)
+
+
+
+
+ +

Get the mean standard deviation

+
+
+
637        std = std.mean().view(1, 1, 1, 1)
+
+
+
+
+ +

Expand the standard deviation to append to the feature map

+
+
+
639        b, _, h, w = x.shape
+640        std = std.expand(b, -1, h, w)
+
+
+
+
+ +

Append (concatenate) the standard deviations to the feature map

+
+
+
642        return torch.cat([x, std], dim=1)
+
+
+
+
+ +

+

Down-sample

+

The down-sample operation smoothens each feature channel and + scale $2 \times$ using bilinear interpolation. +This is based on the paper + Making Convolutional Networks Shift-Invariant Again.

+
+
+
645class DownSample(nn.Module):
+
+
+
+
+ + +
+
+
656    def __init__(self):
+657        super().__init__()
+
+
+
+
+ +

Smoothing layer

+
+
+
659        self.smooth = Smooth()
+
+
+
+
+ + +
+
+
661    def forward(self, x: torch.Tensor):
+
+
+
+
+ +

Smoothing or blurring

+
+
+
663        x = self.smooth(x)
+
+
+
+
+ +

Scaled down

+
+
+
665        return F.interpolate(x, (x.shape[2] // 2, x.shape[3] // 2), mode='bilinear', align_corners=False)
+
+
+
+
+ +

+

Up-sample

+

The up-sample operation scales the image up by $2 \times$ and smoothens each feature channel. +This is based on the paper + Making Convolutional Networks Shift-Invariant Again.

+
+
+
668class UpSample(nn.Module):
+
+
+
+
+ + +
+
+
678    def __init__(self):
+679        super().__init__()
+
+
+
+
+ +

Up-sampling layer

+
+
+
681        self.up_sample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
+
+
+
+
+ +

Smoothing layer

+
+
+
683        self.smooth = Smooth()
+
+
+
+
+ + +
+
+
685    def forward(self, x: torch.Tensor):
+
+
+
+
+ +

Up-sample and smoothen

+
+
+
687        return self.smooth(self.up_sample(x))
+
+
+
+
+ +

+

Smoothing Layer

+

This layer blurs each channel

+
+
+
690class Smooth(nn.Module):
+
+
+
+
+ + +
+
+
698    def __init__(self):
+699        super().__init__()
+
+
+
+
+ +

Blurring kernel

+
+
+
701        kernel = [[1, 2, 1],
+702                  [2, 4, 2],
+703                  [1, 2, 1]]
+
+
+
+
+ +

Convert the kernel to a PyTorch tensor

+
+
+
705        kernel = torch.tensor([[kernel]], dtype=torch.float)
+
+
+
+
+ +

Normalize the kernel

+
+
+
707        kernel /= kernel.sum()
+
+
+
+
+ +

Save kernel as a fixed parameter (no gradient updates)

+
+
+
709        self.kernel = nn.Parameter(kernel, requires_grad=False)
+
+
+
+
+ +

Padding layer

+
+
+
711        self.pad = nn.ReplicationPad2d(1)
+
+
+
+
+ + +
+
+
713    def forward(self, x: torch.Tensor):
+
+
+
+
+ +

Get shape of the input feature map

+
+
+
715        b, c, h, w = x.shape
+
+
+
+
+ +

Reshape for smoothening

+
+
+
717        x = x.view(-1, 1, h, w)
+
+
+
+
+ +

Add padding

+
+
+
720        x = self.pad(x)
+
+
+
+
+ +

Smoothen (blur) with the kernel

+
+
+
723        x = F.conv2d(x, self.kernel)
+
+
+
+
+ +

Reshape and return

+
+
+
726        return x.view(b, c, h, w)
+
+
+
+
+ +

+

Learning-rate Equalized Linear Layer

+

This uses learning-rate equalized weights for a linear layer.

+
+
+
729class EqualizedLinear(nn.Module):
+
+
+
+
+ +
    +
  • in_features is the number of features in the input feature map
  • +
  • out_features is the number of features in the output feature map
  • +
  • bias is the bias initialization constant
  • +
+
+
+
737    def __init__(self, in_features: int, out_features: int, bias: float = 0.):
+
+
+
+
+ + +
+
+
744        super().__init__()
+
+
+
+ +
+
746        self.weight = EqualizedWeight([out_features, in_features])
+
+
+
+
+ +

Bias

+
+
+
748        self.bias = nn.Parameter(torch.ones(out_features) * bias)
+
+
+
+
+ + +
+
+
750    def forward(self, x: torch.Tensor):
+
+
+
+
+ +

Linear transformation

+
+
+
752        return F.linear(x, self.weight(), bias=self.bias)
+
+
+
+
+ +

+

Learning-rate Equalized 2D Convolution Layer

+

This uses learning-rate equalized weights for a convolution layer.

+
+
+
755class EqualizedConv2d(nn.Module):
+
+
+
+
+ +
    +
  • in_features is the number of features in the input feature map
  • +
  • out_features is the number of features in the output feature map
  • +
  • kernel_size is the size of the convolution kernel
  • +
  • padding is the padding to be added on both sides of each size dimension
  • +
+
+
+
763    def __init__(self, in_features: int, out_features: int,
+764                 kernel_size: int, padding: int = 0):
+
+
+
+
+ + +
+
+
771        super().__init__()
+
+
+
+
+ +

Padding size

+
+
+
773        self.padding = padding
+
+
+
+ +
+
775        self.weight = EqualizedWeight([out_features, in_features, kernel_size, kernel_size])
+
+
+
+
+ +

Bias

+
+
+
777        self.bias = nn.Parameter(torch.ones(out_features))
+
+
+
+
+ + +
+
+
779    def forward(self, x: torch.Tensor):
+
+
+
+
+ +

Convolution

+
+
+
781        return F.conv2d(x, self.weight(), bias=self.bias, padding=self.padding)
+
+
+
+
+ +

+

Learning-rate Equalized Weights Parameter

+

This is based on equalized learning rate introduced in the Progressive GAN paper. +Instead of initializing weights at $\mathcal{N}(0,c)$ they initialize weights +to $\mathcal{N}(0, 1)$ and then multiply them by $c$ when using it. + +

+

The gradients on stored parameters $\hat{w}$ get multiplied by $c$ but this doesn’t have +an affect since optimizers such as Adam normalize them by a running mean of the squared gradients.

+

The optimizer updates on $\hat{w}$ are proportionate to the learning rate $\lambda$. +But the effective weights $w$ get updated proportionately to $c \lambda$. +Without equalized learning rate, the effective weights will get updated proportionately to just $\lambda$.

+

So we are effectively scaling the learning rate by $c$ for these weight parameters.

+
+
+
784class EqualizedWeight(nn.Module):
+
+
+
+
+ +
    +
  • shape is the shape of the weight parameter
  • +
+
+
+
804    def __init__(self, shape: List[int]):
+
+
+
+
+ + +
+
+
808        super().__init__()
+
+
+
+
+ +

He initialization constant

+
+
+
811        self.c = 1 / math.sqrt(np.prod(shape[1:]))
+
+
+
+
+ +

Initialize the weights with $\mathcal{N}(0, 1)$

+
+
+
813        self.weight = nn.Parameter(torch.randn(shape))
+
+
+
+
+ +

Weight multiplication coefficient

+
+
+
+
+
+
+
+ + +
+
+
816    def forward(self):
+
+
+
+
+ +

Multiply the weights by $c$ and return

+
+
+
818        return self.weight * self.c
+
+
+
+
+ +

+

Gradient Penalty

+

This is the $R_1$ regularization penality from the paper +Which Training Methods for GANs do actually Converge?.

+

+ +

+

That is we try to reduce the L2 norm of gradients of the discriminator with +respect to images, for real images ($P_\mathcal{D}$).

+
+
+
821class GradientPenalty(nn.Module):
+
+
+
+
+ +
    +
  • x is $x \sim \mathcal{D}$
  • +
  • d is $D(x)$
  • +
+
+
+
836    def forward(self, x: torch.Tensor, d: torch.Tensor):
+
+
+
+
+ +

Get batch size

+
+
+
843        batch_size = x.shape[0]
+
+
+
+
+ +

Calculate gradients of $D(x)$ with respect to $x$. +grad_outputs is set to $1$ since we want the gradients of $D(x)$, +and we need to create and retain graph since we have to compute gradients +with respect to weight on this loss.

+
+
+
849        gradients, *_ = torch.autograd.grad(outputs=d,
+850                                            inputs=x,
+851                                            grad_outputs=d.new_ones(d.shape),
+852                                            create_graph=True)
+
+
+
+
+ +

Reshape gradients to calculate the norm

+
+
+
855        gradients = gradients.reshape(batch_size, -1)
+
+
+
+
+ +

Calculate the norm $\Vert \nabla_{x} D(x)^2 \Vert$

+
+
+
857        norm = gradients.norm(2, dim=-1)
+
+
+
+
+ +

Return the loss $\Vert \nabla_x D_\psi(x)^2 \Vert$

+
+
+
859        return torch.mean(norm ** 2)
+
+
+
+
+ +

+

Path Length Penalty

+

This regularization encourages a fixed-size step in $w$ to result in a fixed-magnitude +change in the image.

+

+ +

+

where $\mathbf{J}_w$ is the Jacobian +$\mathbf{J}_w = \frac{\partial g}{\partial w}$, +$w$ are sampled from $w \in \mathcal{W}$ from the mapping network, and +$y$ are images with noise $\mathcal{N}(0, \mathbf{I})$.

+

$a$ is the exponential moving average of $\Vert \mathbf{J}^\top_{w} y \Vert_2$ +as the training progresses.

+

$\mathbf{J}^\top_{w} y$ is calculated without explicitly calculating the Jacobian using + +

+
+
+
862class PathLengthPenalty(nn.Module):
+
+
+
+
+ +
    +
  • beta is the constant $\beta$ used to calculate the exponential moving average $a$
  • +
+
+
+
885    def __init__(self, beta: float):
+
+
+
+
+ + +
+
+
889        super().__init__()
+
+
+
+
+ +

$\beta$

+
+
+
892        self.beta = beta
+
+
+
+
+ +

Number of steps calculated $N$

+
+
+
894        self.steps = nn.Parameter(torch.tensor(0.), requires_grad=False)
+
+
+
+
+ +

Exponential sum of $\mathbf{J}^\top_{w} y$ + +where $[\mathbf{J}^\top_{w} y]_i$ is the value of it at $i$-th step of training

+
+
+
898        self.exp_sum_a = nn.Parameter(torch.tensor(0.), requires_grad=False)
+
+
+
+
+ +
    +
  • w is the batch of $w$ of shape [batch_size, d_latent]
  • +
  • x are the generated images of shape [batch_size, 3, height, width]
  • +
+
+
+
900    def forward(self, w: torch.Tensor, x: torch.Tensor):
+
+
+
+
+ +

Get the device

+
+
+
907        device = x.device
+
+
+
+
+ +

Get number of pixels

+
+
+
909        image_size = x.shape[2] * x.shape[3]
+
+
+
+
+ +

Calculate $y \in \mathcal{N}(0, \mathbf{I})$

+
+
+
911        y = torch.randn(x.shape, device=device)
+
+
+
+
+ +

Calculate $\big(g(w) \cdot y \big)$ and normalize by the square root of image size. +This is scaling is not mentioned in the paper but was present in +their implementation.

+
+
+
915        output = (x * y).sum() / math.sqrt(image_size)
+
+
+
+
+ +

Calculate gradients to get $\mathbf{J}^\top_{w} y$

+
+
+
918        gradients, *_ = torch.autograd.grad(outputs=output,
+919                                            inputs=w,
+920                                            grad_outputs=torch.ones(output.shape, device=device),
+921                                            create_graph=True)
+
+
+
+
+ +

Calculate L2-norm of $\mathbf{J}^\top_{w} y$

+
+
+
924        norm = (gradients ** 2).sum(dim=2).mean(dim=1).sqrt()
+
+
+
+
+ +

Regularize after first step

+
+
+
927        if self.steps > 0:
+
+
+
+
+ +

Calculate $a$ + +

+
+
+
930            a = self.exp_sum_a / (1 - self.beta ** self.steps)
+
+
+
+
+ +

Calculate the penalty + +

+
+
+
934            loss = torch.mean((norm - a) ** 2)
+935        else:
+
+
+
+
+ +

Return a dummy loss if we can’t calculate $a$

+
+
+
937            loss = norm.new_tensor(0)
+
+
+
+
+ +

Calculate the mean of $\Vert \mathbf{J}^\top_{w} y \Vert_2$

+
+
+
940        mean = norm.mean().detach()
+
+
+
+
+ +

Update exponential sum

+
+
+
942        self.exp_sum_a.mul_(self.beta).add_(mean, alpha=1 - self.beta)
+
+
+
+
+ +

Increment $N$

+
+
+
944        self.steps.add_(1.)
+
+
+
+
+ +

Return the penalty

+
+
+
947        return loss
+
+
+
+ + + + + + + \ No newline at end of file diff --git a/docs/gan/stylegan/mapping_network.svg b/docs/gan/stylegan/mapping_network.svg new file mode 100644 index 00000000..69219405 --- /dev/null +++ b/docs/gan/stylegan/mapping_network.svg @@ -0,0 +1,85 @@ +
Norm
Norm
Linear
Linear
z
z
Linear
Linear
Linear
Linear
More layers
More layers
w
w
Viewer does not support full SVG 1.1
\ No newline at end of file diff --git a/docs/gan/stylegan/progressive_gan.svg b/docs/gan/stylegan/progressive_gan.svg new file mode 100644 index 00000000..cc150e61 --- /dev/null +++ b/docs/gan/stylegan/progressive_gan.svg @@ -0,0 +1,87 @@ +
4x4
4x4
2x
2x
fromRGB
fromRGB
8x8
8x8
8x8
8x8 +
1-⍺
1-⍺
toRGB
toRGB
toRGB
toRGB
+
+
4x4
4x4
toRGB
toRGB
fromRGB
fromRGB
4x4
4x4
0.5x
0.5x
1-⍺
1-⍺
fromRGB
fromRGB
0.5x
0.5x
+
+
4x4
4x4
4x4
4x4
2x
2x
fromRGB
fromRGB
8x8
8x8
8x8
8x8 +
toRGB
toRGB
+
+


...

0.5x
0.5x
+
+
4x4
4x4
Generator
Genera...
Discriminator
Discri...
Progressive Growing
Progressive Growing
Viewer does not support full SVG 1.1
\ No newline at end of file diff --git a/docs/gan/stylegan/style_block.svg b/docs/gan/stylegan/style_block.svg new file mode 100644 index 00000000..01e57797 --- /dev/null +++ b/docs/gan/stylegan/style_block.svg @@ -0,0 +1,85 @@ +
3X3 Conv
3X3 Conv
+
+
B
B
A
A
Demod
Demod
Mod
Mod
weights
weights
bias
bias
feature map
featur...
Viewer does not support full SVG 1.1
\ No newline at end of file diff --git a/docs/gan/stylegan/style_gan.svg b/docs/gan/stylegan/style_gan.svg new file mode 100644 index 00000000..dfe963b4 --- /dev/null +++ b/docs/gan/stylegan/style_gan.svg @@ -0,0 +1,85 @@ +
512x4x4 Const
512x4x4 Const
3x3 Conv
3x3 Conv
+
+
B
B
noise
noise
A
A
AdaIN
AdaIN
+
+
B
B
noise
noise
AdaIN
AdaIN
A
A
4x4
4x4
3X3 Conv
3X3 Conv
3x3 Conv
3x3 Conv
+
+
B
B
noise
noise
A
A
AdaIN
AdaIN
+
+
B
B
noise
noise
AdaIN
AdaIN
A
A
8x8
8x8
Upsample
Upsample
Norm
Norm
Linear
Linear
z
z
Linear
Linear
Linear
Linear
More layers
More layers
w
w
Mapping Network
Mapping Ne...
more layers
more layers
toRGB
toRGB
Viewer does not support full SVG 1.1
\ No newline at end of file diff --git a/docs/gan/stylegan/style_gan2.svg b/docs/gan/stylegan/style_gan2.svg new file mode 100644 index 00000000..6eef9967 --- /dev/null +++ b/docs/gan/stylegan/style_gan2.svg @@ -0,0 +1,85 @@ +
3X3 Conv
3X3 Conv
+
+
B
B
A
A
Upsample
Upsample
3x3 Conv
3x3 Conv
A
A
+
+
4x4
4x4
bias
bias
Demod
Demod
Mod
Mod
weights
weights
toRGB
toRGB
B
B
Demod
Demod
Mod
Mod
weights
weights
3X3 Conv
3X3 Conv
+
+
B
B
A
A
Demod
Demod
Mod
Mod
weights
weights
bias
bias
bias
bias
8x8
8x8
Upsample
Upsample
+
+
more layers
more layers
toRGB
toRGB
512x4x4 Const
512x4x4 Const
Viewer does not support full SVG 1.1
\ No newline at end of file diff --git a/docs/gan/stylegan/style_gan2_disc.svg b/docs/gan/stylegan/style_gan2_disc.svg new file mode 100644 index 00000000..8ff194d8 --- /dev/null +++ b/docs/gan/stylegan/style_gan2_disc.svg @@ -0,0 +1,85 @@ +
Downsample
Downsample
3x3 Conv
3x3 Conv
3x3 Conv
3x3 Conv
Downsample
Downsample
1x1 Conv
1x1 Conv
+
+
1024x1024
1024x1024
fromRGB
fromRGB
Downsample
Downsample
3x3 Conv
3x3 Conv
3x3 Conv
3x3 Conv
Downsample
Downsample
1x1 Conv
1x1 Conv
+
+
8x8
8x8
More layers
More layers
3x3 Conv
3x3 Conv
Minibatch StdDev
Minibatch Std...
4x4
4x4
Classify
Classify
2x2
2x2
Flatten
Flatten
Viewer does not support full SVG 1.1
\ No newline at end of file diff --git a/docs/gan/stylegan/to_rgb.svg b/docs/gan/stylegan/to_rgb.svg new file mode 100644 index 00000000..32d7ae2d --- /dev/null +++ b/docs/gan/stylegan/to_rgb.svg @@ -0,0 +1,85 @@ +
3X3 Conv
3X3 Conv
+
+
A
A
bias
bias
feature map
featur...
Mod
Mod
weights
weights
Viewer does not support full SVG 1.1
\ No newline at end of file diff --git a/docs/gan/wasserstein/experiment.html b/docs/gan/wasserstein/experiment.html index c0f76c87..763c6a65 100644 --- a/docs/gan/wasserstein/experiment.html +++ b/docs/gan/wasserstein/experiment.html @@ -199,19 +199,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/gan/wasserstein/gradient_penalty/experiment.html b/docs/gan/wasserstein/gradient_penalty/experiment.html index 40676390..84917eb7 100644 --- a/docs/gan/wasserstein/gradient_penalty/experiment.html +++ b/docs/gan/wasserstein/gradient_penalty/experiment.html @@ -332,19 +332,46 @@ includes gradient penalty.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/gan/wasserstein/gradient_penalty/index.html b/docs/gan/wasserstein/gradient_penalty/index.html index 0c67e878..6270f7c9 100644 --- a/docs/gan/wasserstein/gradient_penalty/index.html +++ b/docs/gan/wasserstein/gradient_penalty/index.html @@ -221,19 +221,46 @@ with respect to weight on this loss.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/gan/wasserstein/gradient_penalty/readme.html b/docs/gan/wasserstein/gradient_penalty/readme.html index 11202230..f2b36fba 100644 --- a/docs/gan/wasserstein/gradient_penalty/readme.html +++ b/docs/gan/wasserstein/gradient_penalty/readme.html @@ -110,19 +110,46 @@ proposal a better way to improve Lipschitz constraint, a gradient penalty.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/gan/wasserstein/index.html b/docs/gan/wasserstein/index.html index 6107fd28..9261fcf3 100644 --- a/docs/gan/wasserstein/index.html +++ b/docs/gan/wasserstein/index.html @@ -251,19 +251,46 @@ so we minimize, displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/gan/wasserstein/readme.html b/docs/gan/wasserstein/readme.html index 5ce53e36..6560d224 100644 --- a/docs/gan/wasserstein/readme.html +++ b/docs/gan/wasserstein/readme.html @@ -98,19 +98,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/hypernetworks/experiment.html b/docs/hypernetworks/experiment.html index 20e0bd2f..9048fd4e 100644 --- a/docs/hypernetworks/experiment.html +++ b/docs/hypernetworks/experiment.html @@ -343,19 +343,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/hypernetworks/hyper_lstm.html b/docs/hypernetworks/hyper_lstm.html index 2196d261..591be297 100644 --- a/docs/hypernetworks/hyper_lstm.html +++ b/docs/hypernetworks/hyper_lstm.html @@ -741,19 +741,46 @@ Rest of the layers get the input from the layer below

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/hypernetworks/index.html b/docs/hypernetworks/index.html index c6d59acb..a3099fd4 100644 --- a/docs/hypernetworks/index.html +++ b/docs/hypernetworks/index.html @@ -95,19 +95,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/index.html b/docs/index.html index 498059d7..7938bf98 100644 --- a/docs/index.html +++ b/docs/index.html @@ -172,19 +172,46 @@ implementations.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/lstm/index.html b/docs/lstm/index.html index 16708835..be066378 100644 --- a/docs/lstm/index.html +++ b/docs/lstm/index.html @@ -471,19 +471,46 @@ Rest of the layers get the input from the layer below

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/normalization/batch_channel_norm/index.html b/docs/normalization/batch_channel_norm/index.html index 7332e8dc..ff1649f1 100644 --- a/docs/normalization/batch_channel_norm/index.html +++ b/docs/normalization/batch_channel_norm/index.html @@ -645,19 +645,46 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}] displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/normalization/batch_norm/cifar10.html b/docs/normalization/batch_norm/cifar10.html index 2ab0fbb8..0f4be662 100644 --- a/docs/normalization/batch_norm/cifar10.html +++ b/docs/normalization/batch_norm/cifar10.html @@ -246,19 +246,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/normalization/batch_norm/index.html b/docs/normalization/batch_norm/index.html index 5a7cdd58..6113f695 100644 --- a/docs/normalization/batch_norm/index.html +++ b/docs/normalization/batch_norm/index.html @@ -472,19 +472,46 @@ i.e. the means for each feature $\mathbb{E}[(x^{(k)})^2]$

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/normalization/batch_norm/mnist.html b/docs/normalization/batch_norm/mnist.html index edd68833..9a5f5889 100644 --- a/docs/normalization/batch_norm/mnist.html +++ b/docs/normalization/batch_norm/mnist.html @@ -318,19 +318,46 @@ and set a new function to calculate the model.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/normalization/batch_norm/readme.html b/docs/normalization/batch_norm/readme.html index 8d490274..b119a47f 100644 --- a/docs/normalization/batch_norm/readme.html +++ b/docs/normalization/batch_norm/readme.html @@ -162,19 +162,46 @@ a CNN classifier that uses batch normalization for MNIST dataset.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/normalization/group_norm/experiment.html b/docs/normalization/group_norm/experiment.html index 6e69db00..fc6d942f 100644 --- a/docs/normalization/group_norm/experiment.html +++ b/docs/normalization/group_norm/experiment.html @@ -356,19 +356,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/normalization/group_norm/index.html b/docs/normalization/group_norm/index.html index 69b70cca..927c0eb3 100644 --- a/docs/normalization/group_norm/index.html +++ b/docs/normalization/group_norm/index.html @@ -395,19 +395,46 @@ $Var[x_{(i_N, i_G)}] = \mathbb{E}[x^2_{(i_N, i_G)}] - \mathbb{E}[x_{(i_N, i_G)}] displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/normalization/group_norm/readme.html b/docs/normalization/group_norm/readme.html index 743354c1..fbadaf09 100644 --- a/docs/normalization/group_norm/readme.html +++ b/docs/normalization/group_norm/readme.html @@ -112,19 +112,46 @@ all channels within each group.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/normalization/index.html b/docs/normalization/index.html index 6014ae8f..af176eaf 100644 --- a/docs/normalization/index.html +++ b/docs/normalization/index.html @@ -103,19 +103,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/normalization/instance_norm/experiment.html b/docs/normalization/instance_norm/experiment.html index 51af5ad8..37d5cf60 100644 --- a/docs/normalization/instance_norm/experiment.html +++ b/docs/normalization/instance_norm/experiment.html @@ -337,19 +337,46 @@ style transfer and this is only a demo.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/normalization/instance_norm/index.html b/docs/normalization/instance_norm/index.html index 44165a2f..34d5cbf2 100644 --- a/docs/normalization/instance_norm/index.html +++ b/docs/normalization/instance_norm/index.html @@ -350,19 +350,46 @@ i.e. the means for each feature $\mathbb{E}[(x_{t,i}^2]$

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/normalization/instance_norm/readme.html b/docs/normalization/instance_norm/readme.html index 47df63d5..fb7400c2 100644 --- a/docs/normalization/instance_norm/readme.html +++ b/docs/normalization/instance_norm/readme.html @@ -102,19 +102,46 @@ introduces instance normalization which does that.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/normalization/layer_norm/index.html b/docs/normalization/layer_norm/index.html index 9bb9c800..3a8888a1 100644 --- a/docs/normalization/layer_norm/index.html +++ b/docs/normalization/layer_norm/index.html @@ -352,19 +352,46 @@ i.e. the means for each element $\mathbb{E}[X^2]$

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/normalization/layer_norm/readme.html b/docs/normalization/layer_norm/readme.html index 7df71ed8..6dea7db6 100644 --- a/docs/normalization/layer_norm/readme.html +++ b/docs/normalization/layer_norm/readme.html @@ -116,19 +116,46 @@ Layer normalization does it for each batch across all elements.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/normalization/weight_standardization/conv2d.html b/docs/normalization/weight_standardization/conv2d.html index 3f7504ae..96d884f7 100644 --- a/docs/normalization/weight_standardization/conv2d.html +++ b/docs/normalization/weight_standardization/conv2d.html @@ -182,19 +182,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/normalization/weight_standardization/experiment.html b/docs/normalization/weight_standardization/experiment.html index be25927f..9b253b78 100644 --- a/docs/normalization/weight_standardization/experiment.html +++ b/docs/normalization/weight_standardization/experiment.html @@ -249,19 +249,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/normalization/weight_standardization/index.html b/docs/normalization/weight_standardization/index.html index 1fd20786..58774fe2 100644 --- a/docs/normalization/weight_standardization/index.html +++ b/docs/normalization/weight_standardization/index.html @@ -213,19 +213,46 @@ and $I$ is the number of input channels times the kernel size ($I = C_{in} \time displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/normalization/weight_standardization/readme.html b/docs/normalization/weight_standardization/readme.html index 09f5e947..a8aa3c59 100644 --- a/docs/normalization/weight_standardization/readme.html +++ b/docs/normalization/weight_standardization/readme.html @@ -100,19 +100,46 @@ We also have an displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/optimizers/ada_belief.html b/docs/optimizers/ada_belief.html index 2b635e1f..a38a871c 100644 --- a/docs/optimizers/ada_belief.html +++ b/docs/optimizers/ada_belief.html @@ -448,19 +448,46 @@ $\color{cyan}{s_t} + \color{red}{\epsilon}$ in place of $v_t$.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/optimizers/adam.html b/docs/optimizers/adam.html index d20cc3a0..0c7cc7d2 100644 --- a/docs/optimizers/adam.html +++ b/docs/optimizers/adam.html @@ -558,19 +558,46 @@ is what we should specify as the hyper-parameter.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/optimizers/adam_warmup.html b/docs/optimizers/adam_warmup.html index 53fabc0b..bfa65785 100644 --- a/docs/optimizers/adam_warmup.html +++ b/docs/optimizers/adam_warmup.html @@ -199,19 +199,46 @@ where $w$ is the number of warmup steps.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/optimizers/adam_warmup_cosine_decay.html b/docs/optimizers/adam_warmup_cosine_decay.html index f7de4b9c..6cb5783a 100644 --- a/docs/optimizers/adam_warmup_cosine_decay.html +++ b/docs/optimizers/adam_warmup_cosine_decay.html @@ -249,19 +249,46 @@ where $w$ is the number of warmup steps.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/optimizers/amsgrad.html b/docs/optimizers/amsgrad.html index 37925234..49bd1f02 100644 --- a/docs/optimizers/amsgrad.html +++ b/docs/optimizers/amsgrad.html @@ -526,19 +526,46 @@ You can see that AMSGrad converges to true optimal $x = -1$

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/optimizers/configs.html b/docs/optimizers/configs.html index ec4b2a2f..0c650059 100644 --- a/docs/optimizers/configs.html +++ b/docs/optimizers/configs.html @@ -395,19 +395,46 @@ i.e. weight decay is not added to gradients

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/optimizers/index.html b/docs/optimizers/index.html index 3022a8b6..d9e1ae76 100644 --- a/docs/optimizers/index.html +++ b/docs/optimizers/index.html @@ -531,19 +531,46 @@ when the decay is performed directly on the parameter. If this is false the actu displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/optimizers/mnist_experiment.html b/docs/optimizers/mnist_experiment.html index 40544253..0f6f2213 100644 --- a/docs/optimizers/mnist_experiment.html +++ b/docs/optimizers/mnist_experiment.html @@ -414,19 +414,46 @@ We can change the optimizer type and hyper-parameters using configurations.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/optimizers/noam.html b/docs/optimizers/noam.html index dc574bf1..9fdfa52a 100644 --- a/docs/optimizers/noam.html +++ b/docs/optimizers/noam.html @@ -233,19 +233,46 @@ where $w$ is the number of warmup steps.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/optimizers/performance_test.html b/docs/optimizers/performance_test.html index 5ceae84a..eba217ae 100644 --- a/docs/optimizers/performance_test.html +++ b/docs/optimizers/performance_test.html @@ -146,19 +146,46 @@ MyAdam...[DONE] 1,192.89ms displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/optimizers/radam.html b/docs/optimizers/radam.html index 7f9ec311..232d9d0a 100644 --- a/docs/optimizers/radam.html +++ b/docs/optimizers/radam.html @@ -629,19 +629,46 @@ $\theta_t \leftarrow \theta_{t-1} - \alpha \cdot \hat{m}_t$

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/optimizers/readme.html b/docs/optimizers/readme.html index 48879226..31ea6c0d 100644 --- a/docs/optimizers/readme.html +++ b/docs/optimizers/readme.html @@ -104,19 +104,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/pylit.css b/docs/pylit.css index e7ed7044..3b70fffe 100644 --- a/docs/pylit.css +++ b/docs/pylit.css @@ -459,28 +459,77 @@ span.lineno { color: #bd93f9; } -:root { - --blue: #007bff; - --indigo: #6610f2; - --purple: #6f42c1; - --pink: #e83e8c; - --red: #dc3545; - --orange: #fd7e14; - --yellow: #ffc107; - --green: #28a745; - --teal: #20c997; - --cyan: #17a2b8; - --white: #fff; - --gray: #6c757d; - --gray-dark: #343a40; - --primary: #007bff; - --secondary: #6c757d; - --success: #28a745; - --info: #17a2b8; - --warning: #ffc107; - --danger: #dc3545; - --light: #f8f9fa; - --dark: #343a40; +p > img { + max-height: 240px; + max-width: 240px; + border-radius: 5px; + cursor: pointer; + transition: 0.3s; +} +p > img:hover { + opacity: 0.7; +} + +#modal { + position: fixed; + z-index: 1000; + left: 0; + top: 0; + right: 0; + bottom: 0; + overflow: scroll; + background-color: rgba(0, 0, 0, 0.9); +} +#modal > div { + padding: 100px 10px 10px 10px; +} +#modal > div > img { + margin: auto; + display: block; + width: 80%; + max-width: 700px; +} +#modal > div > p { + margin: auto; + display: block; + width: 80%; + max-width: 700px; + text-align: center; + color: #ccc; + padding: 10px 0; + height: 150px; +} +#modal > div > img, #modal > div > p { + animation-name: zoom; + animation-duration: 0.6s; +} +@keyframes zoom { + from { + transform: scale(0); + } + to { + transform: scale(1); + } +} +#modal > span.close { + position: absolute; + top: 15px; + right: 35px; + color: #f1f1f1; + font-size: 40px; + font-weight: bold; + transition: 0.3s; +} +#modal > span.close:hover, #modal > span.close:focus { + color: #bbb; + text-decoration: none; + cursor: pointer; +} + +@media only screen and (max-width: 700px) { + #modal > img { + width: 100%; + } } /*# sourceMappingURL=pylit.css.map */ diff --git a/docs/recurrent_highway_networks/index.html b/docs/recurrent_highway_networks/index.html index 19c22847..0819c0c4 100644 --- a/docs/recurrent_highway_networks/index.html +++ b/docs/recurrent_highway_networks/index.html @@ -460,19 +460,46 @@ Rest of the layers get the input from the layer below

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/resnets/index.html b/docs/resnets/index.html index b90ff0a6..57da1527 100644 --- a/docs/resnets/index.html +++ b/docs/resnets/index.html @@ -84,19 +84,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/resnets/models/index.html b/docs/resnets/models/index.html index 65187660..e1da8625 100644 --- a/docs/resnets/models/index.html +++ b/docs/resnets/models/index.html @@ -85,19 +85,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/resnets/models/mlp.html b/docs/resnets/models/mlp.html index c105f42d..1754bcde 100644 --- a/docs/resnets/models/mlp.html +++ b/docs/resnets/models/mlp.html @@ -304,19 +304,46 @@ Also convert into float for FC layer

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/resnets/models/resnet.html b/docs/resnets/models/resnet.html index 1c068168..c015b100 100644 --- a/docs/resnets/models/resnet.html +++ b/docs/resnets/models/resnet.html @@ -531,19 +531,46 @@ Calculate the output shape after applying a convolution

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/resnets/pretrained_nets.html b/docs/resnets/pretrained_nets.html index 2c4afcf7..71cf11f5 100644 --- a/docs/resnets/pretrained_nets.html +++ b/docs/resnets/pretrained_nets.html @@ -254,19 +254,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/resnets/resnet_net.html b/docs/resnets/resnet_net.html index ee261ff3..1fa29455 100644 --- a/docs/resnets/resnet_net.html +++ b/docs/resnets/resnet_net.html @@ -253,19 +253,46 @@ Calculate the input shape

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/resnets/utils/index.html b/docs/resnets/utils/index.html index 22ebf200..73e81d54 100644 --- a/docs/resnets/utils/index.html +++ b/docs/resnets/utils/index.html @@ -85,19 +85,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/resnets/utils/labelsmoothing.html b/docs/resnets/utils/labelsmoothing.html index f651c6c6..acf7a0de 100644 --- a/docs/resnets/utils/labelsmoothing.html +++ b/docs/resnets/utils/labelsmoothing.html @@ -139,19 +139,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/resnets/utils/train.html b/docs/resnets/utils/train.html index 1d01632a..3f3bba24 100644 --- a/docs/resnets/utils/train.html +++ b/docs/resnets/utils/train.html @@ -376,19 +376,46 @@ from torch.utils.data.sampler import SubsetRandomSampler

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/resnets/utils/utils.html b/docs/resnets/utils/utils.html index 4de8c275..afa01710 100644 --- a/docs/resnets/utils/utils.html +++ b/docs/resnets/utils/utils.html @@ -224,19 +224,46 @@ subplot integers: displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/rl/dqn/experiment.html b/docs/rl/dqn/experiment.html index 8ea78c0d..b7a7814f 100644 --- a/docs/rl/dqn/experiment.html +++ b/docs/rl/dqn/experiment.html @@ -930,19 +930,46 @@ Gradients shouldn’t propagate for these

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/rl/dqn/index.html b/docs/rl/dqn/index.html index ff168d72..ded1f393 100644 --- a/docs/rl/dqn/index.html +++ b/docs/rl/dqn/index.html @@ -337,19 +337,46 @@ mean squared error loss because it is less sensitive to outliers

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/rl/dqn/model.html b/docs/rl/dqn/model.html index c80d2a6c..a7dc0232 100644 --- a/docs/rl/dqn/model.html +++ b/docs/rl/dqn/model.html @@ -322,19 +322,46 @@ $512$ features

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/rl/dqn/replay_buffer.html b/docs/rl/dqn/replay_buffer.html index 2e4bca65..fea0b3d0 100644 --- a/docs/rl/dqn/replay_buffer.html +++ b/docs/rl/dqn/replay_buffer.html @@ -775,19 +775,46 @@ to get the index of actual value

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/rl/game.html b/docs/rl/game.html index e45e712c..a55a9694 100644 --- a/docs/rl/game.html +++ b/docs/rl/game.html @@ -461,19 +461,46 @@ i.e, each channel is a frame.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/rl/index.html b/docs/rl/index.html index 28656992..6aaacbd5 100644 --- a/docs/rl/index.html +++ b/docs/rl/index.html @@ -109,19 +109,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/rl/ppo/experiment.html b/docs/rl/ppo/experiment.html index a826c552..a2037b1c 100644 --- a/docs/rl/ppo/experiment.html +++ b/docs/rl/ppo/experiment.html @@ -1282,19 +1282,46 @@ You can change this while the experiment is running. displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/rl/ppo/gae.html b/docs/rl/ppo/gae.html index cb7ca228..976623f4 100644 --- a/docs/rl/ppo/gae.html +++ b/docs/rl/ppo/gae.html @@ -245,19 +245,46 @@ The performance of the model was improving displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/rl/ppo/index.html b/docs/rl/ppo/index.html index a894a8ea..03444f71 100644 --- a/docs/rl/ppo/index.html +++ b/docs/rl/ppo/index.html @@ -337,19 +337,46 @@ V^{\pi_\theta}_{CLIP}(s_t) displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/rl/ppo/readme.html b/docs/rl/ppo/readme.html index e9e38a9b..c29a21c0 100644 --- a/docs/rl/ppo/readme.html +++ b/docs/rl/ppo/readme.html @@ -110,19 +110,46 @@ The experiment uses Generalized Ad displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/sitemap.xml b/docs/sitemap.xml index c729294a..80a4d1a5 100644 --- a/docs/sitemap.xml +++ b/docs/sitemap.xml @@ -104,6 +104,20 @@ + + https://nn.labml.ai/gan/stylegan/index.html + 2021-05-21T16:30:00+00:00 + 1.00 + + + + + https://nn.labml.ai/gan/stylegan/experiment.html + 2021-05-21T16:30:00+00:00 + 1.00 + + + https://nn.labml.ai/gan/cycle_gan/experiment.html 2021-05-07T16:30:00+00:00 @@ -127,7 +141,7 @@ https://nn.labml.ai/gan/index.html - 2021-05-07T16:30:00+00:00 + 2021-05-09T16:30:00+00:00 1.00 @@ -428,7 +442,7 @@ https://nn.labml.ai/index.html - 2021-05-07T16:30:00+00:00 + 2021-05-09T16:30:00+00:00 1.00 @@ -862,7 +876,7 @@ https://nn.labml.ai/utils.html - 2021-02-17T16:30:00+00:00 + 2021-05-21T16:30:00+00:00 1.00 diff --git a/docs/sketch_rnn/index.html b/docs/sketch_rnn/index.html index ab6a70bd..13554a06 100644 --- a/docs/sketch_rnn/index.html +++ b/docs/sketch_rnn/index.html @@ -2103,19 +2103,46 @@ Paper had suggested 1e-4.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/compressive/experiment.html b/docs/transformers/compressive/experiment.html index 94b8fbe8..4bcba56a 100644 --- a/docs/transformers/compressive/experiment.html +++ b/docs/transformers/compressive/experiment.html @@ -1309,19 +1309,46 @@ Memories that were compressed are needed for the reconstruction loss computation displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/compressive/index.html b/docs/transformers/compressive/index.html index be44b9cb..e1790845 100644 --- a/docs/transformers/compressive/index.html +++ b/docs/transformers/compressive/index.html @@ -927,19 +927,46 @@ The parameters of $f_c^{(i)}$ are the only parameters not detached from gradient displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/compressive/readme.html b/docs/transformers/compressive/readme.html index 66b73b03..92901073 100644 --- a/docs/transformers/compressive/readme.html +++ b/docs/transformers/compressive/readme.html @@ -129,19 +129,46 @@ model on the Tiny Shakespeare dataset.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/configs.html b/docs/transformers/configs.html index e6993833..c4e80918 100644 --- a/docs/transformers/configs.html +++ b/docs/transformers/configs.html @@ -950,19 +950,46 @@ are calculated.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/fast_weights/experiment.html b/docs/transformers/fast_weights/experiment.html index 0e7d4c35..728da48b 100644 --- a/docs/transformers/fast_weights/experiment.html +++ b/docs/transformers/fast_weights/experiment.html @@ -356,19 +356,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/fast_weights/index.html b/docs/transformers/fast_weights/index.html index e18314aa..c24a6d5b 100644 --- a/docs/transformers/fast_weights/index.html +++ b/docs/transformers/fast_weights/index.html @@ -885,19 +885,46 @@ y^{(i)} &= \color{cyan}{W^{(i)}} \color{lightgreen}{\phi'(q^{(i)})} displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/fast_weights/readme.html b/docs/transformers/fast_weights/readme.html index b716be72..341c11ab 100644 --- a/docs/transformers/fast_weights/readme.html +++ b/docs/transformers/fast_weights/readme.html @@ -103,19 +103,46 @@ and a notebook for training a fast weights transformer on the Tiny Shakespeare d displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/fast_weights/token_wise.html b/docs/transformers/fast_weights/token_wise.html index 00dbefa0..2474a782 100644 --- a/docs/transformers/fast_weights/token_wise.html +++ b/docs/transformers/fast_weights/token_wise.html @@ -514,19 +514,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/feed_forward.html b/docs/transformers/feed_forward.html index 9420233f..feccc807 100644 --- a/docs/transformers/feed_forward.html +++ b/docs/transformers/feed_forward.html @@ -305,19 +305,46 @@ depending on whether it is gated

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/feedback/README.html b/docs/transformers/feedback/README.html index d3ac481c..abfcabe6 100644 --- a/docs/transformers/feedback/README.html +++ b/docs/transformers/feedback/README.html @@ -122,19 +122,46 @@ We implemented a custom PyTorch function to improve performance.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/feedback/experiment.html b/docs/transformers/feedback/experiment.html index d9e56278..9346fe70 100644 --- a/docs/transformers/feedback/experiment.html +++ b/docs/transformers/feedback/experiment.html @@ -404,19 +404,46 @@ where the keys and values are precalculated.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/feedback/index.html b/docs/transformers/feedback/index.html index 46d9a107..2249c9f1 100644 --- a/docs/transformers/feedback/index.html +++ b/docs/transformers/feedback/index.html @@ -1695,19 +1695,46 @@ This is the weights parameter for that.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/glu_variants/experiment.html b/docs/transformers/glu_variants/experiment.html index 57e7b31e..73e70aae 100644 --- a/docs/transformers/glu_variants/experiment.html +++ b/docs/transformers/glu_variants/experiment.html @@ -437,19 +437,46 @@ implementation

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/glu_variants/index.html b/docs/transformers/glu_variants/index.html index 83b154ff..06a5da59 100644 --- a/docs/transformers/glu_variants/index.html +++ b/docs/transformers/glu_variants/index.html @@ -100,19 +100,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/glu_variants/simple.html b/docs/transformers/glu_variants/simple.html index 7ed19603..89c9b7da 100644 --- a/docs/transformers/glu_variants/simple.html +++ b/docs/transformers/glu_variants/simple.html @@ -1113,19 +1113,46 @@ a linear layer to generate logits.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/gpt/index.html b/docs/transformers/gpt/index.html index 01a130ef..4bb11bd0 100644 --- a/docs/transformers/gpt/index.html +++ b/docs/transformers/gpt/index.html @@ -849,19 +849,46 @@ per epoch

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/index.html b/docs/transformers/index.html index 243f572e..a9ef14de 100644 --- a/docs/transformers/index.html +++ b/docs/transformers/index.html @@ -133,19 +133,46 @@ It does single GPU training but we implement the concept of switching as describ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/knn/build_index.html b/docs/transformers/knn/build_index.html index f93f18c6..411d26f8 100644 --- a/docs/transformers/knn/build_index.html +++ b/docs/transformers/knn/build_index.html @@ -594,19 +594,46 @@ doesn’t store full vectors.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/knn/eval_knn.html b/docs/transformers/knn/eval_knn.html index f532bfce..fbae5e12 100644 --- a/docs/transformers/knn/eval_knn.html +++ b/docs/transformers/knn/eval_knn.html @@ -554,19 +554,46 @@ each of the weights

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/knn/index.html b/docs/transformers/knn/index.html index 36c9a387..cf8a01ed 100644 --- a/docs/transformers/knn/index.html +++ b/docs/transformers/knn/index.html @@ -119,19 +119,46 @@ of disk space for the index.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/knn/train_model.html b/docs/transformers/knn/train_model.html index 14c0dddb..3e71b8fc 100644 --- a/docs/transformers/knn/train_model.html +++ b/docs/transformers/knn/train_model.html @@ -486,19 +486,46 @@ final token generator from configurable transformer

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/label_smoothing_loss.html b/docs/transformers/label_smoothing_loss.html index 0bd0d15e..42006422 100644 --- a/docs/transformers/label_smoothing_loss.html +++ b/docs/transformers/label_smoothing_loss.html @@ -214,19 +214,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/mha.html b/docs/transformers/mha.html index a4553a63..d749979f 100644 --- a/docs/transformers/mha.html +++ b/docs/transformers/mha.html @@ -574,19 +574,46 @@ $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/models.html b/docs/transformers/models.html index 15f793cb..cbd5f4a8 100644 --- a/docs/transformers/models.html +++ b/docs/transformers/models.html @@ -705,19 +705,46 @@ Initialize parameters with Glorot / fan_avg.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/positional_encoding.html b/docs/transformers/positional_encoding.html index 584ee620..f536f01e 100644 --- a/docs/transformers/positional_encoding.html +++ b/docs/transformers/positional_encoding.html @@ -265,19 +265,46 @@ PE_{p,2i + 1} &= cos\Bigg(\frac{p}{10000^{\frac{2i}{d_{model}}}}\Bigg) displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/relative_mha.html b/docs/transformers/relative_mha.html index fd3f7414..85a9cca3 100644 --- a/docs/transformers/relative_mha.html +++ b/docs/transformers/relative_mha.html @@ -4,7 +4,7 @@ - + @@ -85,19 +85,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/switch/experiment.html b/docs/transformers/switch/experiment.html index ba225c83..bd34fbb2 100644 --- a/docs/transformers/switch/experiment.html +++ b/docs/transformers/switch/experiment.html @@ -806,19 +806,46 @@ set to something small like $\alpha = 0.01$.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/switch/index.html b/docs/transformers/switch/index.html index 279ab732..15ed6649 100644 --- a/docs/transformers/switch/index.html +++ b/docs/transformers/switch/index.html @@ -711,19 +711,46 @@ with handling extra outputs of switch feedforward module.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/switch/readme.html b/docs/transformers/switch/readme.html index ea5d5a4d..5b18542b 100644 --- a/docs/transformers/switch/readme.html +++ b/docs/transformers/switch/readme.html @@ -119,19 +119,46 @@ discusses dropping tokens when routing is not balanced.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/utils.html b/docs/transformers/utils.html index 972ea98b..5f28119e 100644 --- a/docs/transformers/utils.html +++ b/docs/transformers/utils.html @@ -135,19 +135,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/xl/experiment.html b/docs/transformers/xl/experiment.html index 49e9cad1..60ca545c 100644 --- a/docs/transformers/xl/experiment.html +++ b/docs/transformers/xl/experiment.html @@ -988,19 +988,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/xl/index.html b/docs/transformers/xl/index.html index a3ab5c4b..f657c2cd 100644 --- a/docs/transformers/xl/index.html +++ b/docs/transformers/xl/index.html @@ -441,19 +441,46 @@ which will become the memories for the next sequential batch.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/xl/readme.html b/docs/transformers/xl/readme.html index 5c476e39..ef974df3 100644 --- a/docs/transformers/xl/readme.html +++ b/docs/transformers/xl/readme.html @@ -114,19 +114,46 @@ are introduced at the attention calculation.

displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/docs/transformers/xl/relative_mha.html b/docs/transformers/xl/relative_mha.html index 915b7bc6..f68b99f2 100644 --- a/docs/transformers/xl/relative_mha.html +++ b/docs/transformers/xl/relative_mha.html @@ -411,19 +411,46 @@ to get + \ No newline at end of file diff --git a/docs/utils.html b/docs/utils.html index 0177e8c7..a30a2fd6 100644 --- a/docs/utils.html +++ b/docs/utils.html @@ -83,7 +83,8 @@
-

Make a nn.ModuleList with clones of a given layer

+

Clone Module

+

Make a nn.ModuleList with clones of a given module

15def clone_module_list(module: M, n: int) -> TypedModuleList[M]:
@@ -97,7 +98,33 @@
-
19    return TypedModuleList([copy.deepcopy(module) for _ in range(n)])
+
21    return TypedModuleList([copy.deepcopy(module) for _ in range(n)])
+
+ +
+
+ +

+

Cycle Data Loader

+

Infinite loader that recycles the data loader after each epoch

+
+
+
24def cycle_dataloader(data_loader):
+
+
+
+
+ + +
+
+
31    while True:
+32        for batch in data_loader:
+33            yield batch
@@ -118,19 +145,46 @@ displayAlign: 'center', "HTML-CSS": { fonts: ["TeX"] } }); + + \ No newline at end of file diff --git a/labml_nn/gan/stylegan/__init__.py b/labml_nn/gan/stylegan/__init__.py new file mode 100644 index 00000000..570216b0 --- /dev/null +++ b/labml_nn/gan/stylegan/__init__.py @@ -0,0 +1,947 @@ +""" +--- +title: Style GAN 2 +summary: > + An annotated PyTorch implementation of StyleGAN2. +--- + +# Style GAN 2 + +This is a [PyTorch](https://pytorch.org) implementation of the paper + [Analyzing and Improving the Image Quality of StyleGAN](https://arxiv.org/abs/1912.04958) + which introduces **Style GAN2**. +Style GAN2 is an improvement over **Style GAN** from the paper + [A Style-Based Generator Architecture for Generative Adversarial Networks](https://arxiv.org/abs/1812.04948). +And Style GAN is based on **Progressive GAN** from the paper + [Progressive Growing of GANs for Improved Quality, Stability, and Variation](https://arxiv.org/abs/1710.10196). +All three papers are from the same authors from [NVIDIA AI](https://twitter.com/NVIDIAAI). + +*Our implementation is a minimalistic Style GAN2 model training code. +Only single GPU training is supported to keep the implementation simple. +We managed to shrink it to keep it at less than 500 lines of code, including the training loop.* + +**🏃 Here's the training code: [`experiment.py`](experiment.html).** + +![Generated Images](generated_64.png) + +*These are $64 \times 64$ images generated after training for about 80K steps.* + + +We'll first introduce the three papers at a high level. + +## Generative Adversarial Networks + +Generative adversarial networks have two components; the generator and the discriminator. +The generator network takes a random latent vector ($z \in \mathcal{Z}$) + and tries to generate a realistic image. +The discriminator network tries to differentiate the real images from generated images. +When we train the two networks together the generator starts generating images indistinguishable from real images. + +## Progressive GAN + +Progressive GAN generates high-resolution images ($1080 \times 1080$) of size. +It does so by *progressively* increasing the image size. +First, it trains a network that produces a $4 \times 4$ image, then $8 \times 8$ , + then an $16 \times 16$ image, and so on up to the desired image resolution. + +At each resolution, the generator network produces an image in latent space which is converted into RGB, +with a $1 \times 1$ convolution. +When we progress from a lower resolution to a higher resolution + (say from $4 \times 4$ to $8 \times 8$ ) we scale the latent image by $2\times$ + and add a new block (two $3 \times 3$ convolution layers) + and a new $1 \times 1$ layer to get RGB. +The transition is done smoothly by adding a residual connection to + the $2\times$ scaled $4 \times 4$ RGB image. +The weight of this residual connection is slowly reduced, to let the new block take over. + +The discriminator is a mirror image of the generator network. +The progressive growth of the discriminator is done similarly. + +![progressive_gan.svg](progressive_gan.svg) + +*$2\times$ and $0.5\times$ denote feature map resolution scaling and scaling. +$4\times4$, $8\times4$, ... denote feature map resolution at the generator or discriminator block. +Each discriminator and generator block consists of 2 convolution layers with leaky ReLU activations.* + +They use **minibatch standard deviation** to increase variation and + **equalized learning rate** which we discussed below in the implementation. +They also use **pixel-wise normalization** where at each pixel the feature vector is normalized. +They apply this to all the convolution layer outputs (except RGB). + + +## Style GAN + +Style GAN improves the generator of Progressive GAN keeping the discriminator architecture the same. + +#### Mapping Network + +It maps the random latent vector ($z \in \mathcal{Z}$) + into a different latent space ($w \in \mathcal{W}$), + with an 8-layer neural network. +This gives an intermediate latent space $\mathcal{W}$ +where the factors of variations are more linear (disentangled). + +#### AdaIN + +Then $w$ is transformed into two vectors (***styles***) per layer, + $i$, $y_i = (y_{s,i}, y_{b,i}) = f_{A_i}(w)$ and used for scaling and shifting (biasing) + in each layer with $\text{AdaIN}$ operator (normalize and scale): +$$\text{AdaIN}(x_i, y_i) = y_{s, i} \frac{x_i - \mu(x_i)}{\sigma(x_i)} + y_{b,i}$$ + +#### Style Mixing + +To prevent the generator from assuming adjacent styles are correlated, + they randomly use different styles for different blocks. +That is, they sample two latent vectors $(z_1, z_2)$ and corresponding $(w_1, w_2)$ and + use $w_1$ based styles for some blocks and $w_2$ based styles for some blacks randomly. + +#### Stochastic Variation + +Noise is made available to each block which helps the generator create more realistic images. +Noise is scaled per channel by a learned weight. + +#### Bilinear Up and Down Sampling + +All the up and down-sampling operations are accompanied by bilinear smoothing. + +![style_gan.svg](style_gan.svg) + +*$A$ denotes a linear layer. +$B$ denotes a broadcast and scaling operation (noise is a single channel). +Style GAN also uses progressive growing like Progressive GAN* + +## Style GAN 2 + +Style GAN 2 changes both the generator and the discriminator of Style GAN. + +#### Weight Modulation and Demodulation + +They remove the $\text{AdaIN}$ operator and replace it with + the weight modulation and demodulation step. +This is supposed to improve what they call droplet artifacts that are present in generated images, + which are caused by the normalization in $\text{AdaIN}$ operator. +Style vector per layer is calculated from $w_i \in \mathcal{W}$ as $s_i = f_{A_i}(w_i)$. + +Then the convolution weights $w$ are modulated as follows. +($w$ here on refers to weights not intermediate latent space, + we are sticking to the same notation as the paper.) + +$$w'_{i, j, k} = s_i \cdot w_{i, j, k}$$ +Then it's demodulated by normalizing, +$$w''_{i,j,k} = \frac{w'_{i,j,k}}{\sqrt{\sum_{i,k}{w'_{i, j, k}}^2 + \epsilon}}$$ +where $i$ is the input channel, $j$ is the output channel, and $k$ is the kernel index. + +#### Path Length Regularization + +Path length regularization encourages a fixed-size step in $\mathcal{W}$ to result in a non-zero, + fixed-magnitude change in the generated image. + +#### No Progressive Growing + +StyleGAN2 uses residual connections (with down-sampling) in the discriminator and skip connections + in the generator with up-sampling + (the RGB outputs from each layer are added - no residual connections in feature maps). +They show that with experiments that the contribution of low-resolution layers is higher + at beginning of the training and then high-resolution layers take over. +""" + +import math +from typing import Tuple, Optional, List + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.data +from torch import nn + + +class MappingNetwork(nn.Module): + """ + + ## Mapping Network + + ![Mapping Network](mapping_network.svg) + + This is an MLP with 8 linear layers. + The mapping network maps the latent vector $z \in \mathcal{W}$ + to an intermediate latent space $w \in \mathcal{W}$. + $\mathcal{W}$ space will be disentangled from the image space + where the factors of variation become more linear. + """ + + def __init__(self, features: int, n_layers: int): + """ + * `features` is the number of features in $z$ and $w$ + * `n_layers` is the number of layers in the mapping network. + """ + super().__init__() + + # Create the MLP + layers = [] + for i in range(n_layers): + # [Equalized learning-rate linear layers](#equalized_linear) + layers.append(EqualizedLinear(features, features)) + # Leaky Relu + layers.append(nn.LeakyReLU(negative_slope=0.2, inplace=True)) + + self.net = nn.Sequential(*layers) + + def forward(self, z: torch.Tensor): + # Normalize $z$ + z = F.normalize(z, dim=1) + # Map $z$ to $w$ + return self.net(z) + + +class Generator(nn.Module): + """ + + ## StyleGAN2 Generator + + ![Generator](style_gan2.svg) + + *$A$ denotes a linear layer. + $B$ denotes a broadcast and scaling operation (noise is a single channel). + [*toRGB*](#to_rgb) also has a style modulation which is not shown in the diagram to keep it simple.* + + The generator starts with a learned constant. + Then it has a series of blocks. The feature map resolution is doubled at each block + Each block outputs an RGB image and they are scaled up and summed to get the final RGB image. + """ + + def __init__(self, log_resolution: int, d_latent: int, n_features: int = 32, max_features: int = 512): + """ + * `log_resolution` is the $\log_2$ of image resolution + * `d_latent` is the dimensionality of $w$ + * `n_features` number of features in the convolution layer at the highest resolution (final block) + * `max_features` maximum number of features in any generator block + """ + super().__init__() + + # Calculate the number of features for each block + # + # Something like `[512, 512, 256, 128, 64, 32]` + features = [min(max_features, n_features * (2 ** i)) for i in range(log_resolution - 2, -1, -1)] + # Number of generator blocks + self.n_blocks = len(features) + + # Trainable $4 \times 4$ constant + self.initial_constant = nn.Parameter(torch.randn((1, features[0], 4, 4))) + + # First style block for $4 \times 4$ resolution and layer to get RGB + self.style_block = StyleBlock(d_latent, features[0], features[0]) + self.to_rgb = ToRGB(d_latent, features[0]) + + # Generator blocks + blocks = [GeneratorBlock(d_latent, features[i - 1], features[i]) for i in range(1, self.n_blocks)] + self.blocks = nn.ModuleList(blocks) + + # $2 \times$ up sampling layer. The feature space is up sampled + # at each block + self.up_sample = UpSample() + + def forward(self, w: torch.Tensor, input_noise: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]]): + """ + * `w` is $w$. In order to mix-styles (use different $w$ for different layers), we provide a separate + $w$ for each [generator block](#generator_block). It has shape `[n_blocks, batch_size, d_latent]1. + * `input_noise` is the noise for each block. + It's a list of pairs of noise sensors because each block (except the initial) has two noise inputs + after each convolution layer (see the diagram). + """ + + # Get batch size + batch_size = w.shape[1] + + # Expand the learned constant to match batch size + x = self.initial_constant.expand(batch_size, -1, -1, -1) + + # The first style block + x = self.style_block(x, w[0], input_noise[0][1]) + # Get first rgb image + rgb = self.to_rgb(x, w[0]) + + # Evaluate rest of the blocks + for i in range(1, self.n_blocks): + # Up sample the feature map + x = self.up_sample(x) + # Run it through the [generator block](#generator_block) + x, rgb_new = self.blocks[i - 1](x, w[i], input_noise[i]) + # Up sample the RGB image and add to the rgb from the block + rgb = self.up_sample(rgb) + rgb_new + + # Return the final RGB image + return rgb + + +class GeneratorBlock(nn.Module): + """ + + ### Generator Block + + ![Generator block](generator_block.svg) + + *$A$ denotes a linear layer. + $B$ denotes a broadcast and scaling operation (noise is a single channel). + [*toRGB*](#to_rgb) also has a style modulation which is not shown in the diagram to keep it simple.* + + The generator block consists of two [style blocks](#style_block) ($3 \times 3$ convolutions with style modulation) + and an RGB output. + """ + + def __init__(self, d_latent: int, in_features: int, out_features: int): + """ + * `d_latent` is the dimensionality of $w$ + * `in_features` is the number of features in the input feature map + * `out_features` is the number of features in the output feature map + """ + super().__init__() + + # First [style block](#style_block) changes the feature map size to `out_features` + self.style_block1 = StyleBlock(d_latent, in_features, out_features) + # Second [style block](#style_block) + self.style_block2 = StyleBlock(d_latent, out_features, out_features) + + # *toRGB* layer + self.to_rgb = ToRGB(d_latent, out_features) + + def forward(self, x: torch.Tensor, w: torch.Tensor, noise: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]): + """ + * `x` is the input feature map of shape `[batch_size, in_features, height, width]` + * `w` is $w$ with shape `[batch_size, d_latent]` + * `noise` is a tuple of two noise tensors of shape `[batch_size, 1, height, width]` + """ + # First style block with first noise tensor. + # The output is of shape `[batch_size, out_features, height, width]` + x = self.style_block1(x, w, noise[0]) + # Second style block with second noise tensor. + # The output is of shape `[batch_size, out_features, height, width]` + x = self.style_block2(x, w, noise[1]) + + # Get RGB image + rgb = self.to_rgb(x, w) + + # Return feature map and rgb image + return x, rgb + + +class StyleBlock(nn.Module): + """ + + ### Style Block + + ![Style block](style_block.svg) + + *$A$ denotes a linear layer. + $B$ denotes a broadcast and scaling operation (noise is single channel).* + + Style block has a weight modulation convolution layer. + """ + + def __init__(self, d_latent: int, in_features: int, out_features: int): + """ + * `d_latent` is the dimensionality of $w$ + * `in_features` is the number of features in the input feature map + * `out_features` is the number of features in the output feature map + """ + super().__init__() + # Get style vector from $w$ (denoted by $A$ in the diagram) with + # an [equalized learning-rate linear layer](#equalized_linear) + self.to_style = EqualizedLinear(d_latent, in_features, bias=1.0) + # Weight modulated convolution layer + self.conv = Conv2dWeightModulate(in_features, out_features, kernel_size=3) + # Noise scale + self.scale_noise = nn.Parameter(torch.zeros(1)) + # Bias + self.bias = nn.Parameter(torch.zeros(out_features)) + + # Activation function + self.activation = nn.LeakyReLU(0.2, True) + + def forward(self, x: torch.Tensor, w: torch.Tensor, noise: Optional[torch.Tensor]): + """ + * `x` is the input feature map of shape `[batch_size, in_features, height, width]` + * `w` is $w$ with shape `[batch_size, d_latent]` + * `noise` is a tensor of shape `[batch_size, 1, height, width]` + """ + # Get style vector $s$ + s = self.to_style(w) + # Weight modulated convolution + x = self.conv(x, s) + # Scale and add noise + if noise is not None: + x = x + self.scale_noise[None, :, None, None] * noise + # Add bias and evaluate activation function + return self.activation(x + self.bias[None, :, None, None]) + + +class ToRGB(nn.Module): + """ + + ### To RGB + + ![To RGB](to_rgb.svg) + + *$A$ denotes a linear layer.* + + Generates an RGB image from a feature map using $1 \times 1$ convolution. + """ + + def __init__(self, d_latent: int, features: int): + """ + * `d_latent` is the dimensionality of $w$ + * `features` is the number of features in the feature map + """ + super().__init__() + # Get style vector from $w$ (denoted by $A$ in the diagram) with + # an [equalized learning-rate linear layer](#equalized_linear) + self.to_style = EqualizedLinear(d_latent, features, bias=1.0) + + # Weight modulated convolution layer without demodulation + self.conv = Conv2dWeightModulate(features, 3, kernel_size=1, demodulate=False) + # Bias + self.bias = nn.Parameter(torch.zeros(1)) + # Activation function + self.activation = nn.LeakyReLU(0.2, True) + + def forward(self, x: torch.Tensor, w: torch.Tensor): + """ + * `x` is the input feature map of shape `[batch_size, in_features, height, width]` + * `w` is $w$ with shape `[batch_size, d_latent]` + """ + # Get style vector $s$ + style = self.to_style(w) + # Weight modulated convolution + x = self.conv(x, style) + # Add bias and evaluate activation function + return self.activation(x + self.bias[None, :, None, None]) + + +class Conv2dWeightModulate(nn.Module): + """ + ### Convolution with Weight Modulation and Demodulation + + This layer scales the convolution weights by the style vector and demodulates by normalizing it. + """ + + def __init__(self, in_features: int, out_features: int, kernel_size: int, + demodulate: float = True, eps: float = 1e-8): + """ + * `in_features` is the number of features in the input feature map + * `out_features` is the number of features in the output feature map + * `kernel_size` is the size of the convolution kernel + * `demodulate` is flag whether to normalize weights by its standard deviation + * `eps` is the $\epsilon$ for normalizing + """ + super().__init__() + # Number of output features + self.out_features = out_features + # Whether to normalize weights + self.demodulate = demodulate + # Padding size + self.padding = (kernel_size - 1) // 2 + + # [Weights parameter with equalized learning rate](#equalized_weight) + self.weight = EqualizedWeight([out_features, in_features, kernel_size, kernel_size]) + # $\epsilon$ + self.eps = eps + + def forward(self, x: torch.Tensor, s: torch.Tensor): + """ + * `x` is the input feature map of shape `[batch_size, in_features, height, width]` + * `s` is style based scaling tensor of shape `[batch_size, in_features]` + """ + + # Get batch size, height and width + b, _, h, w = x.shape + + # Reshape the scales + s = s[:, None, :, None, None] + # Get [learning rate equalized weights](#equalized_weight) + weights = self.weight()[None, :, :, :, :] + # $$w`_{i,j,k} = s_i * w_{i,j,k}$$ + # where $i$ is the input channel, $j$ is the output channel, and $k$ is the kernel index. + # + # The result has shape `[batch_size, out_features, in_features, kernel_size, kernel_size]` + weights = weights * s + + # Demodulate + if self.demodulate: + # $$\sigma_j = \sqrt{\sum_{i,k} (w'_{i, j, k})^2 + \epsilon}$$ + sigma_inv = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + self.eps) + # $$w''_{i,j,k} = \frac{w'_{i,j,k}}{\sqrt{\sum_{i,k} (w'_{i, j, k})^2 + \epsilon}}$$ + weights = weights * sigma_inv + + # Reshape `x` + x = x.reshape(1, -1, h, w) + + # Reshape weights + _, _, *ws = weights.shape + weights = weights.reshape(b * self.out_features, *ws) + + # Use grouped convolution to efficiently calculate the convolution with sample wise kernel. + # i.e. we have a different kernel (weights) for each sample in the batch + x = F.conv2d(x, weights, padding=self.padding, groups=b) + + # Reshape `x` to `[batch_size, out_features, height, width]` and return + return x.reshape(-1, self.out_features, h, w) + + +class Discriminator(nn.Module): + """ + + ## Style GAN2 Discriminator + + ![Discriminator](style_gan2_disc.svg) + + Discriminator first transforms the image to a feature map of the same resolution and then + runs it through a series of blocks with residual connections. + The resolution is down-sampled by $2 \times$ at each block while doubling the + number of features. + """ + + def __init__(self, log_resolution: int, n_features: int = 64, max_features: int = 512): + """ + * `log_resolution` is the $\log_2$ of image resolution + * `n_features` number of features in the convolution layer at the highest resolution (first block) + * `max_features` maximum number of features in any generator block + """ + super().__init__() + + # Layer to convert RGB image to a feature map with `n_features` number of features. + self.from_rgb = nn.Sequential( + EqualizedConv2d(3, n_features, 1), + nn.LeakyReLU(0.2, True), + ) + + # Calculate the number of features for each block. + # + # Something like `[64, 128, 256, 512, 512, 512]`. + features = [min(max_features, n_features * (2 ** i)) for i in range(log_resolution - 1)] + # Number of [discirminator blocks](#discriminator_block) + n_blocks = len(features) - 1 + # Discriminator blocks + blocks = [DiscriminatorBlock(features[i], features[i + 1]) for i in range(n_blocks)] + self.blocks = nn.Sequential(*blocks) + + # [Mini-batch Standard Deviation](#mini_batch_std_dev) + self.std_dev = MiniBatchStdDev() + # Number of features after adding the standard deviations map + final_features = features[-1] + 1 + # Final $3 \times 3$ convolution layer + self.conv = EqualizedConv2d(final_features, final_features, 3) + # Final linear layer to get the classification + self.final = EqualizedLinear(2 * 2 * final_features, 1) + + def forward(self, x: torch.Tensor): + """ + * `x` is the input image of shape `[batch_size, 3, height, width]` + """ + + # Try to normalize the image (this is totally optional, but sped up the early training a little) + x = x - 0.5 + # Convert from RGB + x = self.from_rgb(x) + # Run through the [discriminator blocks](#discriminator_block) + x = self.blocks(x) + + # Calculate and append [mini-batch standard deviation](#mini_batch_std_dev) + x = self.std_dev(x) + # $3 \times 3$ convolution + x = self.conv(x) + # Flatten + x = x.reshape(x.shape[0], -1) + # Return the classification score + return self.final(x) + + +class DiscriminatorBlock(nn.Module): + """ + + ### Discriminator Block + + ![Discriminator block](discriminator_block.svg) + + Discriminator block consists of two $3 \times 3$ convolutions with a residual connection. + """ + + def __init__(self, in_features, out_features): + """ + * `in_features` is the number of features in the input feature map + * `out_features` is the number of features in the output feature map + """ + super().__init__() + # Down-sampling and $1 \times 1$ convolution layer for the residual connection + self.residual = nn.Sequential(DownSample(), + EqualizedConv2d(in_features, out_features, kernel_size=1)) + + # Two $3 \times 3$ convolutions + self.block = nn.Sequential( + EqualizedConv2d(in_features, in_features, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, True), + EqualizedConv2d(in_features, out_features, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, True), + ) + + # Down-sampling layer + self.down_sample = DownSample() + + # Scaling factor $\frac{1}{\sqrt 2}$ after adding the residual + self.scale = 1 / math.sqrt(2) + + def forward(self, x): + # Get the residual connection + residual = self.residual(x) + + # Convolutions + x = self.block(x) + # Down-sample + x = self.down_sample(x) + + # Add the residual and scale + return (x + residual) * self.scale + + +class MiniBatchStdDev(nn.Module): + """ + + + ### Mini-batch Standard Deviation + + Mini-batch standard deviation calculates the standard deviation + across a mini-batch (or a subgroups within the mini-batch) + for each feature in the feature map. Then it takes the mean of all + the standard deviations and appends it to the feature map as one extra feature. + """ + + def __init__(self, group_size: int = 4): + """ + * `group_size` is the number of samples to calculate standard deviation across. + """ + super().__init__() + self.group_size = group_size + + def forward(self, x: torch.Tensor): + """ + * `x` is the feature map + """ + # Check if the batch size is divisible by the group size + assert x.shape[0] % self.group_size == 0 + # Split the samples into groups of `group_size`, we flatten the feature map to a single dimension + # since we want to calculate the standard deviation for each feature. + grouped = x.view(self.group_size, -1) + # Calculate the standard deviation for each feature among `group_size` samples + # $$\mu_{i} = \frac{1}{N} \sum_g x_{g,i} \\ + # \sigma_{i} = \sqrt{\frac{1}{N} \sum_g (x_{g,i} - \mu_i)^2 + \epsilon}$$ + std = torch.sqrt(grouped.var(dim=0) + 1e-8) + # Get the mean standard deviation + std = std.mean().view(1, 1, 1, 1) + # Expand the standard deviation to append to the feature map + b, _, h, w = x.shape + std = std.expand(b, -1, h, w) + # Append (concatenate) the standard deviations to the feature map + return torch.cat([x, std], dim=1) + + +class DownSample(nn.Module): + """ + + ### Down-sample + + The down-sample operation [smoothens](#smooth) each feature channel and + scale $2 \times$ using bilinear interpolation. + This is based on the paper + [Making Convolutional Networks Shift-Invariant Again](https://arxiv.org/abs/1904.11486). + """ + + def __init__(self): + super().__init__() + # Smoothing layer + self.smooth = Smooth() + + def forward(self, x: torch.Tensor): + # Smoothing or blurring + x = self.smooth(x) + # Scaled down + return F.interpolate(x, (x.shape[2] // 2, x.shape[3] // 2), mode='bilinear', align_corners=False) + + +class UpSample(nn.Module): + """ + + ### Up-sample + + The up-sample operation scales the image up by $2 \times$ and [smoothens](#smooth) each feature channel. + This is based on the paper + [Making Convolutional Networks Shift-Invariant Again](https://arxiv.org/abs/1904.11486). + """ + + def __init__(self): + super().__init__() + # Up-sampling layer + self.up_sample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) + # Smoothing layer + self.smooth = Smooth() + + def forward(self, x: torch.Tensor): + # Up-sample and smoothen + return self.smooth(self.up_sample(x)) + + +class Smooth(nn.Module): + """ + + ### Smoothing Layer + + This layer blurs each channel + """ + + def __init__(self): + super().__init__() + # Blurring kernel + kernel = [[1, 2, 1], + [2, 4, 2], + [1, 2, 1]] + # Convert the kernel to a PyTorch tensor + kernel = torch.tensor([[kernel]], dtype=torch.float) + # Normalize the kernel + kernel /= kernel.sum() + # Save kernel as a fixed parameter (no gradient updates) + self.kernel = nn.Parameter(kernel, requires_grad=False) + # Padding layer + self.pad = nn.ReplicationPad2d(1) + + def forward(self, x: torch.Tensor): + # Get shape of the input feature map + b, c, h, w = x.shape + # Reshape for smoothening + x = x.view(-1, 1, h, w) + + # Add padding + x = self.pad(x) + + # Smoothen (blur) with the kernel + x = F.conv2d(x, self.kernel) + + # Reshape and return + return x.view(b, c, h, w) + + +class EqualizedLinear(nn.Module): + """ + + ## Learning-rate Equalized Linear Layer + + This uses [learning-rate equalized weights]($equalized_weights) for a linear layer. + """ + + def __init__(self, in_features: int, out_features: int, bias: float = 0.): + """ + * `in_features` is the number of features in the input feature map + * `out_features` is the number of features in the output feature map + * `bias` is the bias initialization constant + """ + + super().__init__() + # [Learning-rate equalized weights]($equalized_weights) + self.weight = EqualizedWeight([out_features, in_features]) + # Bias + self.bias = nn.Parameter(torch.ones(out_features) * bias) + + def forward(self, x: torch.Tensor): + # Linear transformation + return F.linear(x, self.weight(), bias=self.bias) + + +class EqualizedConv2d(nn.Module): + """ + + ## Learning-rate Equalized 2D Convolution Layer + + This uses [learning-rate equalized weights]($equalized_weights) for a convolution layer. + """ + + def __init__(self, in_features: int, out_features: int, + kernel_size: int, padding: int = 0): + """ + * `in_features` is the number of features in the input feature map + * `out_features` is the number of features in the output feature map + * `kernel_size` is the size of the convolution kernel + * `padding` is the padding to be added on both sides of each size dimension + """ + super().__init__() + # Padding size + self.padding = padding + # [Learning-rate equalized weights]($equalized_weights) + self.weight = EqualizedWeight([out_features, in_features, kernel_size, kernel_size]) + # Bias + self.bias = nn.Parameter(torch.ones(out_features)) + + def forward(self, x: torch.Tensor): + # Convolution + return F.conv2d(x, self.weight(), bias=self.bias, padding=self.padding) + + +class EqualizedWeight(nn.Module): + """ + + ## Learning-rate Equalized Weights Parameter + + This is based on equalized learning rate introduced in the Progressive GAN paper. + Instead of initializing weights at $\mathcal{N}(0,c)$ they initialize weights + to $\mathcal{N}(0, 1)$ and then multiply them by $c$ when using it. + $$w_i = c \hat{w}_i$$ + + The gradients on stored parameters $\hat{w}$ get multiplied by $c$ but this doesn't have + an affect since optimizers such as Adam normalize them by a running mean of the squared gradients. + + The optimizer updates on $\hat{w}$ are proportionate to the learning rate $\lambda$. + But the effective weights $w$ get updated proportionately to $c \lambda$. + Without equalized learning rate, the effective weights will get updated proportionately to just $\lambda$. + + So we are effectively scaling the learning rate by $c$ for these weight parameters. + """ + + def __init__(self, shape: List[int]): + """ + * `shape` is the shape of the weight parameter + """ + super().__init__() + + # He initialization constant + self.c = 1 / math.sqrt(np.prod(shape[1:])) + # Initialize the weights with $\mathcal{N}(0, 1)$ + self.weight = nn.Parameter(torch.randn(shape)) + # Weight multiplication coefficient + + def forward(self): + # Multiply the weights by $c$ and return + return self.weight * self.c + + +class GradientPenalty(nn.Module): + """ + + ## Gradient Penalty + + This is the $R_1$ regularization penality from the paper + [Which Training Methods for GANs do actually Converge?](https://arxiv.org/abs/1801.04406). + + $$R_1(\psi) = \frac{\gamma}{2} \mathbb{E}_{p_\mathcal{D}(x)} + \Big[\Vert \nabla_x D_\psi(x)^2 \Vert\Big]$$ + + That is we try to reduce the L2 norm of gradients of the discriminator with + respect to images, for real images ($P_\mathcal{D}$). + """ + + def forward(self, x: torch.Tensor, d: torch.Tensor): + """ + * `x` is $x \sim \mathcal{D}$ + * `d` is $D(x)$ + """ + + # Get batch size + batch_size = x.shape[0] + + # Calculate gradients of $D(x)$ with respect to $x$. + # `grad_outputs` is set to $1$ since we want the gradients of $D(x)$, + # and we need to create and retain graph since we have to compute gradients + # with respect to weight on this loss. + gradients, *_ = torch.autograd.grad(outputs=d, + inputs=x, + grad_outputs=d.new_ones(d.shape), + create_graph=True) + + # Reshape gradients to calculate the norm + gradients = gradients.reshape(batch_size, -1) + # Calculate the norm $\Vert \nabla_{x} D(x)^2 \Vert$ + norm = gradients.norm(2, dim=-1) + # Return the loss $\Vert \nabla_x D_\psi(x)^2 \Vert$ + return torch.mean(norm ** 2) + + +class PathLengthPenalty(nn.Module): + """ + + ## Path Length Penalty + + This regularization encourages a fixed-size step in $w$ to result in a fixed-magnitude + change in the image. + + $$\mathbb{E}_{w \sim f(z), y \sim \mathcal{N}(0, \mathbf{I})} + \Big(\Vert \mathbf{J}^\top_{w} y \Vert_2 - a \Big)^2$$ + + where $\mathbf{J}_w$ is the Jacobian + $\mathbf{J}_w = \frac{\partial g}{\partial w}$, + $w$ are sampled from $w \in \mathcal{W}$ from the mapping network, and + $y$ are images with noise $\mathcal{N}(0, \mathbf{I})$. + + $a$ is the exponential moving average of $\Vert \mathbf{J}^\top_{w} y \Vert_2$ + as the training progresses. + + $\mathbf{J}^\top_{w} y$ is calculated without explicitly calculating the Jacobian using + $$\mathbf{J}^\top_{w} y = \nabla_w \big(g(w) \cdot y \big)$$ + """ + + def __init__(self, beta: float): + """ + * `beta` is the constant $\beta$ used to calculate the exponential moving average $a$ + """ + super().__init__() + + # $\beta$ + self.beta = beta + # Number of steps calculated $N$ + self.steps = nn.Parameter(torch.tensor(0.), requires_grad=False) + # Exponential sum of $\mathbf{J}^\top_{w} y$ + # $$\sum^N_{i=1} \beta^{(N - i)}[\mathbf{J}^\top_{w} y]_i$$ + # where $[\mathbf{J}^\top_{w} y]_i$ is the value of it at $i$-th step of training + self.exp_sum_a = nn.Parameter(torch.tensor(0.), requires_grad=False) + + def forward(self, w: torch.Tensor, x: torch.Tensor): + """ + * `w` is the batch of $w$ of shape `[batch_size, d_latent]` + * `x` are the generated images of shape `[batch_size, 3, height, width]` + """ + + # Get the device + device = x.device + # Get number of pixels + image_size = x.shape[2] * x.shape[3] + # Calculate $y \in \mathcal{N}(0, \mathbf{I})$ + y = torch.randn(x.shape, device=device) + # Calculate $\big(g(w) \cdot y \big)$ and normalize by the square root of image size. + # This is scaling is not mentioned in the paper but was present in + # [their implementation](https://github.com/NVlabs/stylegan2/blob/master/training/loss.py#L167). + output = (x * y).sum() / math.sqrt(image_size) + + # Calculate gradients to get $\mathbf{J}^\top_{w} y$ + gradients, *_ = torch.autograd.grad(outputs=output, + inputs=w, + grad_outputs=torch.ones(output.shape, device=device), + create_graph=True) + + # Calculate L2-norm of $\mathbf{J}^\top_{w} y$ + norm = (gradients ** 2).sum(dim=2).mean(dim=1).sqrt() + + # Regularize after first step + if self.steps > 0: + # Calculate $a$ + # $$\frac{1}{1 - \beta^N} \sum^N_{i=1} \beta^{(N - i)}[\mathbf{J}^\top_{w} y]_i$$ + a = self.exp_sum_a / (1 - self.beta ** self.steps) + # Calculate the penalty + # $$\mathbb{E}_{w \sim f(z), y \sim \mathcal{N}(0, \mathbf{I})} + # \Big(\Vert \mathbf{J}^\top_{w} y \Vert_2 - a \Big)^2$$ + loss = torch.mean((norm - a) ** 2) + else: + # Return a dummy loss if we can't calculate $a$ + loss = norm.new_tensor(0) + + # Calculate the mean of $\Vert \mathbf{J}^\top_{w} y \Vert_2$ + mean = norm.mean().detach() + # Update exponential sum + self.exp_sum_a.mul_(self.beta).add_(mean, alpha=1 - self.beta) + # Increment $N$ + self.steps.add_(1.) + + # Return the penalty + return loss diff --git a/labml_nn/gan/stylegan/experiment.py b/labml_nn/gan/stylegan/experiment.py new file mode 100644 index 00000000..4eec2d8a --- /dev/null +++ b/labml_nn/gan/stylegan/experiment.py @@ -0,0 +1,467 @@ +""" +--- +title: Style GAN 2 Model Training +summary: > + An annotated PyTorch implementation of StyleGAN2 model training code. +--- + +# [Style GAN 2](index.html) Model Training + +This is the training code for [Style GAN 2](index.html) model. + +![Generated Images](generated_64.png) + +*These are $64 \times 64$ images generated after training for about 80K steps.* + +*Our implementation is a minimalistic Style GAN2 model training code. +Only single GPU training is supported to keep the implementation simple. +We managed to shrink it to keep it at less than 500 lines of code, including the training loop.* + +*Without DDP (distributed data parallel) and multi-gpu training it will not be possible to train the model +for large resolutions (128+). +If you want training code with fp16 and DDP take a look at +[lucidrains/stylegan2-pytorch](https://github.com/lucidrains/stylegan2-pytorch).* + +We trained this on [CelebA-HQ dataset](https://github.com/tkarras/progressive_growing_of_gans). +You can find the download instruction in this +[discussion on fast.ai](https://forums.fast.ai/t/download-celeba-hq-dataset/45873/3). +Save the images inside [`data/stylegan` folder](#dataset_path). +""" + +import math +from pathlib import Path +from typing import Iterator, Tuple + +import torch +import torch.utils.data +import torchvision +from PIL import Image + +from labml import tracker, lab, monit, experiment +from labml.configs import BaseConfigs +from labml_helpers.device import DeviceConfigs +from labml_helpers.train_valid import ModeState, hook_model_outputs +from labml_nn.gan.stylegan import Discriminator, Generator, MappingNetwork, GradientPenalty, PathLengthPenalty +from labml_nn.gan.wasserstein import DiscriminatorLoss, GeneratorLoss +from labml_nn.utils import cycle_dataloader + + +class Dataset(torch.utils.data.Dataset): + """ + ## Dataset + + This loads the training dataset and resize it to the give image size. + """ + + def __init__(self, path: str, image_size: int): + """ + * `path` path to the folder containing the images + * `image_size` size of the image + """ + super().__init__() + + # Get the paths of all `jpg` files + self.paths = [p for p in Path(path).glob(f'**/*.jpg')] + + # Transformation + self.transform = torchvision.transforms.Compose([ + # Resize the image + torchvision.transforms.Resize(image_size), + # Convert to PyTorch tensor + torchvision.transforms.ToTensor(), + ]) + + def __len__(self): + """Number of images""" + return len(self.paths) + + def __getitem__(self, index): + """Get the the `index`-th image""" + path = self.paths[index] + img = Image.open(path) + return self.transform(img) + + +class Configs(BaseConfigs): + """ + ## Configurations + """ + + # Device to train the model on. + # [`DeviceConfigs`](https://github.com/lab-ml/helpers/blob/master/labml_helpers/device.py) + # picks up an available CUDA device or defaults to CPU. + device: torch.device = DeviceConfigs() + + # [StyleGAN2 Discriminator](index.html#discriminator) + discriminator: Discriminator + # [StyleGAN2 Generator](index.html#generator) + generator: Generator + # [Mapping network](index.html#mapping_network) + mapping_network: MappingNetwork + + # Discriminator and generator loss functions. + # We use [Wasserstein loss](../wasserstein/index.html) + discriminator_loss: DiscriminatorLoss + generator_loss: GeneratorLoss + + # Optimizers + generator_optimizer: torch.optim.Adam + discriminator_optimizer: torch.optim.Adam + mapping_network_optimizer: torch.optim.Adam + + # [Gradient Penalty Regularization Loss](index.html#gradient_penalty) + gradient_penalty = GradientPenalty() + # Gradient penalty coefficient $\gamma$ + gradient_penalty_coefficient: float = 10. + + # [Path length penalty](index.html#path_length_penalty) + path_length_penalty: PathLengthPenalty + + # Data loader + loader: Iterator + + # Batch size + batch_size: int = 32 + # Dimensionality of $z$ and $w$ + d_latent: int = 512 + # Height/width of the image + image_size: int = 32 + # Number of layers in the mapping network + mapping_network_layers: int = 8 + # Generator & Discriminator learning rate + learning_rate: float = 1e-3 + # Mapping network learning rate ($100 \times$ lower than the others) + mapping_network_learning_rate: float = 1e-5 + # Number of steps to accumulate gradients on. Use this to increase the effective batch size. + gradient_accumulate_steps: int = 1 + # $\beta_1$ and $\beta_2$ for Adam optimizer + adam_betas: Tuple[float, float] = (0.0, 0.99) + # Probability of mixing styles + style_mixing_prob: float = 0.9 + + # Total number of training steps + training_steps: int = 150_000 + + # Number of blocks in the generator (calculated based on image resolution) + n_gen_blocks: int + + # ### Lazy regularization + # Instead of calculating the regularization losses, the paper proposes lazy regularization + # where the regularization terms are calculated once in a while. + # This improves the training efficiency a lot. + + # The interval at which to compute gradient penalty + lazy_gradient_penalty_interval: int = 4 + # Path length penalty calculation interval + lazy_path_penalty_interval: int = 32 + # Skip calculating path length penalty during the initial phase of training + lazy_path_penalty_after: int = 5_000 + + # How often to log generated images + log_generated_interval: int = 500 + # How often to save model checkpoints + save_checkpoint_interval: int = 2_000 + + # Training mode state for logging activations + mode: ModeState + # Whether to log model layer outputs + log_layer_outputs: bool = False + + # + # We trained this on [CelebA-HQ dataset](https://github.com/tkarras/progressive_growing_of_gans). + # You can find the download instruction in this + # [discussion on fast.ai](https://forums.fast.ai/t/download-celeba-hq-dataset/45873/3). + # Save the images inside `data/stylegan` folder. + dataset_path: str = str(lab.get_data_path() / 'stylegan2') + + def init(self): + """ + ### Initialize + """ + # Create dataset + dataset = Dataset(self.dataset_path, self.image_size) + # Create data loader + dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=32, + shuffle=True, drop_last=True, pin_memory=True) + # Continuous [cyclic loader](../../utils.html#cycle_dataloader) + self.loader = cycle_dataloader(dataloader) + + # $\log_2$ of image resolution + log_resolution = int(math.log2(self.image_size)) + + # Create discriminator and generator + self.discriminator = Discriminator(log_resolution).to(self.device) + self.generator = Generator(log_resolution, self.d_latent).to(self.device) + # Get number of generator blocks for creating style and noise inputs + self.n_gen_blocks = self.generator.n_blocks + # Create mapping network + self.mapping_network = MappingNetwork(self.d_latent, self.mapping_network_layers).to(self.device) + # Create path length penalty loss + self.path_length_penalty = PathLengthPenalty(0.99).to(self.device) + + # Add model hooks to monitor layer outputs + if self.log_layer_outputs: + hook_model_outputs(self.mode, self.discriminator, 'discriminator') + hook_model_outputs(self.mode, self.generator, 'generator') + hook_model_outputs(self.mode, self.mapping_network, 'mapping_network') + + # Discriminator and generator losses + self.discriminator_loss = DiscriminatorLoss().to(self.device) + self.generator_loss = GeneratorLoss().to(self.device) + + # Create optimizers + self.discriminator_optimizer = torch.optim.Adam( + self.discriminator.parameters(), + lr=self.learning_rate, betas=self.adam_betas + ) + self.generator_optimizer = torch.optim.Adam( + self.generator.parameters(), + lr=self.learning_rate, betas=self.adam_betas + ) + self.mapping_network_optimizer = torch.optim.Adam( + self.mapping_network.parameters(), + lr=self.mapping_network_learning_rate, betas=self.adam_betas + ) + + # Set tracker configurations + tracker.set_image("generated", True) + + def get_w(self, batch_size: int): + """ + ### Sample $w$ + + This samples $z$ randomly and get $w$ from the mapping network. + + We also apply style mixing sometimes where we generate two latent variables + $z_1$ and $z_2$ and get corresponding $w_1$ and $w_2$. + Then we randomly sample a cross-over point and apply $w_1$ to + the generator blocks before the cross-over point and + $w_2$ to the blocks after. + """ + + # Mix styles + if torch.rand(()).item() < self.style_mixing_prob: + # Random cross-over point + cross_over_point = int(torch.rand(()).item() * self.n_gen_blocks) + # Sample $z_1$ and $z_2$ + z2 = torch.randn(batch_size, self.d_latent).to(self.device) + z1 = torch.randn(batch_size, self.d_latent).to(self.device) + # Get $w_1$ and $w_2$ + w1 = self.mapping_network(z1) + w2 = self.mapping_network(z2) + # Expand $w_1$ and $w_2$ for the generator blocks and concatenate + w1 = w1[None, :, :].expand(cross_over_point, -1, -1) + w2 = w2[None, :, :].expand(self.n_gen_blocks - cross_over_point, -1, -1) + return torch.cat((w1, w2), dim=0) + # Without mixing + else: + # Sample $z$ and $z$ + z = torch.randn(batch_size, self.d_latent).to(self.device) + # Get $w$ and $w$ + w = self.mapping_network(z) + # Expand $w$ for the generator blocks + return w[None, :, :].expand(self.n_gen_blocks, -1, -1) + + def get_noise(self, batch_size: int): + """ + ### Generate noise + + This generates noise for each [generator block](index.html#generator_block) + """ + # List to store noise + noise = [] + # Noise resolution starts from $4$ + resolution = 4 + + # Generate noise for each generator block + for i in range(self.n_gen_blocks): + # The first block has only one $3 \times 3$ convolution + if i == 0: + n1 = None + # Generate noise to add after the first convolution layer + else: + n1 = torch.randn(batch_size, 1, resolution, resolution, device=self.device) + # Generate noise to add after the second convolution layer + n2 = torch.randn(batch_size, 1, resolution, resolution, device=self.device) + + # Add noise tensors to the list + noise.append((n1, n2)) + + # Next block has $2 \times$ resolution + resolution *= 2 + + # Return noise tensors + return noise + + def generate_images(self, batch_size: int): + """ + ### Generate images + + This generate images using the generator + """ + + # Get $w$ + w = self.get_w(batch_size) + # Get noise + noise = self.get_noise(batch_size) + + # Generate images + images = self.generator(w, noise) + + # Return images and $w$ + return images, w + + def step(self, idx: int): + """ + ### Training Step + """ + + # Train the discriminator + with monit.section('Discriminator'): + # Reset gradients + self.discriminator_optimizer.zero_grad() + + # Accumulate gradients for `gradient_accumulate_steps` + for i in range(self.gradient_accumulate_steps): + # Update `mode`. Set whether to log activation + with self.mode.update(is_log_activations=(idx + 1) % self.log_generated_interval == 0): + # Sample images from generator + generated_images, _ = self.generate_images(self.batch_size) + # Discriminator classification for generated images + fake_output = self.discriminator(generated_images.detach()) + + # Get real images from the data loader + real_images = next(self.loader).to(self.device) + # We need to calculate gradients w.r.t. real images for gradient penalty + if (idx + 1) % self.lazy_gradient_penalty_interval == 0: + real_images.requires_grad_() + # Discriminator classification for real images + real_output = self.discriminator(real_images) + + # Get discriminator loss + real_loss, fake_loss = self.discriminator_loss(real_output, fake_output) + disc_loss = real_loss + fake_loss + + # Add gradient penalty + if (idx + 1) % self.lazy_gradient_penalty_interval == 0: + # Calculate and log gradient penalty + gp = self.gradient_penalty(real_images, real_output) + tracker.add('loss.gp', gp) + # Multiply by coefficient and add gradient penalty + disc_loss = disc_loss + 0.5 * self.gradient_penalty_coefficient * gp * self.lazy_gradient_penalty_interval + + # Compute gradients + disc_loss.backward() + + # Log discriminator loss + tracker.add('loss.discriminator', disc_loss) + + if (idx + 1) % self.log_generated_interval == 0: + # Log discriminator model parameters occasionally + tracker.add('discriminator', self.discriminator) + + # Clip gradients for stabilization + torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=1.0) + # Take optimizer step + self.discriminator_optimizer.step() + + # Train the generator + with monit.section('Generator'): + # Reset gradients + self.generator_optimizer.zero_grad() + self.mapping_network_optimizer.zero_grad() + + # Accumulate gradients for `gradient_accumulate_steps` + for i in range(self.gradient_accumulate_steps): + # Sample images from generator + generated_images, w = self.generate_images(self.batch_size) + # Discriminator classification for generated images + fake_output = self.discriminator(generated_images) + + # Get generator loss + gen_loss = self.generator_loss(fake_output) + + # Add path length penalty + if idx > self.lazy_path_penalty_after and (idx + 1) % self.lazy_path_penalty_interval == 0: + # Calculate path length penalty + plp = self.path_length_penalty(w, generated_images) + # Ignore if `nan` + if not torch.isnan(plp): + tracker.add('loss.plp', plp) + gen_loss = gen_loss + plp + + # Calculate gradients + gen_loss.backward() + + # Log generator loss + tracker.add('loss.generator', gen_loss) + + if (idx + 1) % self.log_generated_interval == 0: + # Log discriminator model parameters occasionally + tracker.add('generator', self.generator) + tracker.add('mapping_network', self.mapping_network) + + # Clip gradients for stabilization + torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=1.0) + torch.nn.utils.clip_grad_norm_(self.mapping_network.parameters(), max_norm=1.0) + + # Take optimizer step + self.generator_optimizer.step() + self.mapping_network_optimizer.step() + + # Log generated images + if (idx + 1) % self.log_generated_interval == 0: + tracker.add('generated', torch.cat([generated_images[:6], real_images[:3]], dim=0)) + # Save model checkpoints + if (idx + 1) % self.save_checkpoint_interval == 0: + experiment.save_checkpoint() + + # Flush tracker + tracker.save() + + def train(self): + """ + ## Train model + """ + + # Loop for `training_steps` + for i in monit.loop(self.training_steps): + # Take a training step + self.step(i) + # + if (i + 1) % self.log_generated_interval == 0: + tracker.new_line() + + +def main(): + """ + ### Train StyleGAN2 + """ + + # Create an experiment + experiment.create(name='stylegan2') + # Create configurations object + configs = Configs() + + # Set configurations and override some + experiment.configs(configs, { + 'device.cuda_device': 0, + 'image_size': 64, + 'log_generated_interval': 200 + }) + + # Initialize + configs.init() + # Set models for saving and loading + experiment.add_pytorch_models(mapping_network=configs.mapping_network, + generator=configs.generator, + discriminator=configs.discriminator) + + # Start the experiment + with experiment.start(): + # Run the training loop + configs.train() + +# +if __name__ == '__main__': + main() diff --git a/labml_nn/utils.py b/labml_nn/utils.py index 1734c974..5267dc5e 100644 --- a/labml_nn/utils.py +++ b/labml_nn/utils.py @@ -14,6 +14,20 @@ from labml_helpers.module import M, TypedModuleList def clone_module_list(module: M, n: int) -> TypedModuleList[M]: """ - ## Make a `nn.ModuleList` with clones of a given layer + ## Clone Module + + Make a `nn.ModuleList` with clones of a given module """ return TypedModuleList([copy.deepcopy(module) for _ in range(n)]) + + +def cycle_dataloader(data_loader): + """ + + ## Cycle Data Loader + + Infinite loader that recycles the data loader after each epoch + """ + while True: + for batch in data_loader: + yield batch diff --git a/requirements.txt b/requirements.txt index 6fdc75b3..579efdb2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ matplotlib>=3.0.3 einops>=0.3.0 gym[atari] opencv-python +Pillow>=6.2.1 diff --git a/utils/diagrams.py b/utils/diagrams.py new file mode 100644 index 00000000..07605158 --- /dev/null +++ b/utils/diagrams.py @@ -0,0 +1,260 @@ +import shutil +from pathlib import Path +from typing import List +from xml.dom import minidom + +from labml import monit + +HOME = Path('.').absolute() + +STYLES = """ +.black-stroke { + stroke: #aaa; +} + +rect.black-stroke { + stroke: #444; +} + +.black-fill { + fill: #ddd; +} + +.white-fill { + fill: #333; +} + +.blue-stroke { + stroke: #5b8fab; +} + +.blue-fill { + fill: #356782; +} + +.yellow-stroke { + stroke: #bbab52; +} + +.yellow-fill { + fill: #a7942b; +} + +.grey-stroke { + stroke: #484d5a; +} + +.grey-fill { + fill: #2e323c; +} + +.red-stroke { + stroke: #bb3232; +} + +.red-fill { + fill: #901c1c; +} + +.orange-stroke { + stroke: #a5753f; +} + +.orange-fill { + fill: #82531e; +} + +.purple-stroke { + stroke: #a556a5; +} + +.purple-fill { + fill: #8a308a; +} + +.green-stroke { + stroke: #80cc92; +} + +.green-fill { + fill: #499e5d; +} + +switch foreignObject div div div { + color: #ddd !important; +} + +switch foreignObject div div div span { + color: #ddd !important; +} + +.has-background { + background-color: #1d2127 !important; +} +""" + +STROKES = { + '#000000': 'black', + '#6c8ebf': 'blue', + '#d6b656': 'yellow', + '#666666': 'grey', + '#b85450': 'red', + '#d79b00': 'orange', + '#9673a6': 'purple', + '#82b366': 'green', +} + +FILLS = { + '#000000': 'black', + '#ffffff': 'white', + '#dae8fc': 'blue', + '#fff2cc': 'yellow', + '#f5f5f5': 'grey', + '#f8cecc': 'red', + '#ffe6cc': 'orange', + '#e1d5e7': 'purple', + '#d5e8d4': 'green', +} + + +def clear_switches(doc: minidom.Document): + switches = doc.getElementsByTagName('switch') + for s in switches: + children = s.childNodes + assert len(children) == 2 + if children[0].tagName == 'g' and 'requiredFeatures' in children[0].attributes: + s.parentNode.removeChild(s) + s.unlink() + continue + assert children[0].tagName == 'foreignObject' + assert children[1].tagName == 'text' + c = children[1] + s.removeChild(c) + s.parentNode.insertBefore(c, s) + s.parentNode.removeChild(s) + + +def add_class(node: minidom.Node, class_name: str): + if 'class' not in node.attributes: + node.attributes['class'] = class_name + return + + node.attributes['class'] = node.attributes['class'].value + f' {class_name}' + + +def add_bg_classes(nodes: List[minidom.Node]): + for node in nodes: + if 'style' in node.attributes: + s = node.attributes['style'].value + if s.count('background-color'): + add_class(node, 'has-background') + + +def add_stroke_classes(nodes: List[minidom.Node]): + for node in nodes: + if 'stroke' in node.attributes: + stroke = node.attributes['stroke'].value + if stroke not in STROKES: + continue + + node.removeAttribute('stroke') + add_class(node, f'{STROKES[stroke]}-stroke') + + +def add_fill_classes(nodes: List[minidom.Node]): + for node in nodes: + if 'fill' in node.attributes: + fill = node.attributes['fill'].value + if fill not in FILLS: + continue + + node.removeAttribute('fill') + add_class(node, f'{FILLS[fill]}-fill') + + +def add_classes(doc: minidom.Document): + paths = doc.getElementsByTagName('path') + add_stroke_classes(paths) + add_fill_classes(paths) + + rects = doc.getElementsByTagName('rect') + add_stroke_classes(rects) + add_fill_classes(rects) + + ellipse = doc.getElementsByTagName('ellipse') + add_stroke_classes(ellipse) + add_fill_classes(ellipse) + + text = doc.getElementsByTagName('text') + add_fill_classes(text) + + div = doc.getElementsByTagName('div') + add_bg_classes(div) + + span = doc.getElementsByTagName('span') + add_bg_classes(span) + + +def parse(source: Path, dest: Path): + doc: minidom.Document = minidom.parse(str(source)) + + svg = doc.getElementsByTagName('svg') + + assert len(svg) == 1 + svg = svg[0] + + if 'content' in svg.attributes: + svg.removeAttribute('content') + # svg.attributes['height'] = str(int(svg.attributes['height'].value[:-2]) + 30) + 'px' + # svg.attributes['width'] = str(int(svg.attributes['width'].value[:-2]) + 30) + 'px' + + view_box = svg.attributes['viewBox'].value.split(' ') + view_box = [float(v) for v in view_box] + view_box[0] -= 10 + view_box[1] -= 10 + view_box[2] += 20 + view_box[3] += 20 + svg.attributes['viewBox'] = ' '.join([str(v) for v in view_box]) + + svg.attributes['style'] = 'background: #1d2127;' # padding: 10px;' + + # clear_switches(doc) + + style = doc.createElement('style') + style.appendChild(doc.createTextNode(STYLES)) + svg.insertBefore(style, svg.childNodes[0]) + add_classes(doc) + + with open(str(dest), 'w') as f: + doc.writexml(f) + + +def recurse(path: Path): + files = [] + if path.is_file(): + files.append(path) + return files + + for f in path.iterdir(): + files += recurse(f) + + return files + + +def main(): + diagrams_path = HOME / 'diagrams' + docs_path = HOME / 'docs' + + for p in recurse(diagrams_path): + source_path = p + p = p.relative_to(diagrams_path) + dest_path = docs_path / p + if not dest_path.parent.exists(): + dest_path.parent.mkdir(parents=True) + + with monit.section(str(p)): + shutil.copy(str(source_path), str(dest_path)) + + +if __name__ == '__main__': + main()