From 7e54d43c0cf541ad3e2330bb4168af22415b3313 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Mon, 1 Feb 2021 10:25:40 +0530 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20batch=20norm?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- labml_nn/normalization/__init__.py | 0 labml_nn/normalization/batch_norm.py | 47 ++++++++++++ labml_nn/normalization/mnist.py | 105 +++++++++++++++++++++++++++ 3 files changed, 152 insertions(+) create mode 100644 labml_nn/normalization/__init__.py create mode 100644 labml_nn/normalization/batch_norm.py create mode 100644 labml_nn/normalization/mnist.py diff --git a/labml_nn/normalization/__init__.py b/labml_nn/normalization/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/labml_nn/normalization/batch_norm.py b/labml_nn/normalization/batch_norm.py new file mode 100644 index 00000000..6a2d5daf --- /dev/null +++ b/labml_nn/normalization/batch_norm.py @@ -0,0 +1,47 @@ +import torch +from torch import nn + +from labml_helpers.module import Module + + +class BatchNorm(Module): + def __init__(self, channels: int, *, + eps: float = 1e-5, momentum: float = 0.1, + affine: bool = True, track_running_stats: bool = True): + super().__init__() + + self.channels = channels + + self.eps = eps + self.momentum = momentum + self.affine = affine + self.track_running_stats = track_running_stats + if self.affine: + self.weight = nn.Parameter(torch.ones(channels)) + self.bias = nn.Parameter(torch.zeros(channels)) + if self.track_running_stats: + self.register_buffer('running_mean', torch.zeros(channels)) + self.register_buffer('running_var', torch.ones(channels)) + + def __call__(self, x: torch.Tensor): + x_shape = x.shape + batch_size = x_shape[0] + + x = x.view(batch_size, self.channels, -1) + if self.training or not self.track_running_stats: + mean = x.mean(dim=[0, 2]) + mean_x2 = (x ** 2).mean(dim=[0, 2]) + var = mean_x2 - mean ** 2 + + if self.training and self.track_running_stats: + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var + else: + mean = self.running_mean + var = self.running_var + + x_norm = (x - mean.view(1, -1, 1)) / torch.sqrt(var + self.eps).view(1, -1, 1) + if self.affine: + x_norm = self.weight.view(1, -1, 1) * x_norm + self.bias.view(1, -1, 1) + + return x_norm.view(x_shape) diff --git a/labml_nn/normalization/mnist.py b/labml_nn/normalization/mnist.py new file mode 100644 index 00000000..cb98f57e --- /dev/null +++ b/labml_nn/normalization/mnist.py @@ -0,0 +1,105 @@ +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data + +from labml import experiment, tracker +from labml.configs import option +from labml_helpers.datasets.mnist import MNISTConfigs +from labml_helpers.device import DeviceConfigs +from labml_helpers.metrics.accuracy import Accuracy +from labml_helpers.module import Module +from labml_helpers.seed import SeedConfigs +from labml_helpers.train_valid import TrainValidConfigs, BatchIndex, hook_model_outputs +from labml_nn.normalization.batch_norm import BatchNorm + + +class Net(Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 20, 5, 1) + self.bn1 = BatchNorm(20) + self.conv2 = nn.Conv2d(20, 50, 5, 1) + self.bn2 = BatchNorm(50) + self.fc1 = nn.Linear(4 * 4 * 50, 500) + self.bn3 = BatchNorm(500) + self.fc2 = nn.Linear(500, 10) + + def __call__(self, x: torch.Tensor): + x = F.relu(self.bn1(self.conv1(x))) + x = F.max_pool2d(x, 2, 2) + x = F.relu(self.bn2(self.conv2(x))) + x = F.max_pool2d(x, 2, 2) + x = x.view(-1, 4 * 4 * 50) + x = F.relu(self.bn3(self.fc1(x))) + return self.fc2(x) + + +class Configs(MNISTConfigs, TrainValidConfigs): + optimizer: torch.optim.Adam + model: nn.Module + set_seed = SeedConfigs() + device: torch.device = DeviceConfigs() + epochs: int = 10 + + is_save_models = True + model: nn.Module + inner_iterations = 10 + + accuracy_func = Accuracy() + loss_func = nn.CrossEntropyLoss() + + def init(self): + tracker.set_queue("loss.*", 20, True) + tracker.set_scalar("accuracy.*", True) + hook_model_outputs(self.mode, self.model, 'model') + self.state_modules = [self.accuracy_func] + + def step(self, batch: any, batch_idx: BatchIndex): + data, target = batch[0].to(self.device), batch[1].to(self.device) + + if self.mode.is_train: + tracker.add_global_step(len(data)) + + with self.mode.update(is_log_activations=batch_idx.is_last): + output = self.model(data) + + loss = self.loss_func(output, target) + self.accuracy_func(output, target) + tracker.add("loss.", loss) + + if self.mode.is_train: + loss.backward() + + self.optimizer.step() + if batch_idx.is_last: + tracker.add('model', self.model) + self.optimizer.zero_grad() + + tracker.save() + + +@option(Configs.model) +def model(c: Configs): + return Net().to(c.device) + + +@option(Configs.optimizer) +def _optimizer(c: Configs): + from labml_helpers.optimizer import OptimizerConfigs + opt_conf = OptimizerConfigs() + opt_conf.parameters = c.model.parameters() + return opt_conf + + +def main(): + conf = Configs() + experiment.create(name='mnist_labml_helpers') + experiment.configs(conf, {'optimizer.optimizer': 'Adam'}) + conf.set_seed.set() + experiment.add_pytorch_models(dict(model=conf.model)) + with experiment.start(): + conf.run() + + +if __name__ == '__main__': + main()