mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-06 15:46:15 +08:00
✨ batch norm
This commit is contained in:
0
labml_nn/normalization/__init__.py
Normal file
0
labml_nn/normalization/__init__.py
Normal file
47
labml_nn/normalization/batch_norm.py
Normal file
47
labml_nn/normalization/batch_norm.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from labml_helpers.module import Module
|
||||||
|
|
||||||
|
|
||||||
|
class BatchNorm(Module):
|
||||||
|
def __init__(self, channels: int, *,
|
||||||
|
eps: float = 1e-5, momentum: float = 0.1,
|
||||||
|
affine: bool = True, track_running_stats: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.channels = channels
|
||||||
|
|
||||||
|
self.eps = eps
|
||||||
|
self.momentum = momentum
|
||||||
|
self.affine = affine
|
||||||
|
self.track_running_stats = track_running_stats
|
||||||
|
if self.affine:
|
||||||
|
self.weight = nn.Parameter(torch.ones(channels))
|
||||||
|
self.bias = nn.Parameter(torch.zeros(channels))
|
||||||
|
if self.track_running_stats:
|
||||||
|
self.register_buffer('running_mean', torch.zeros(channels))
|
||||||
|
self.register_buffer('running_var', torch.ones(channels))
|
||||||
|
|
||||||
|
def __call__(self, x: torch.Tensor):
|
||||||
|
x_shape = x.shape
|
||||||
|
batch_size = x_shape[0]
|
||||||
|
|
||||||
|
x = x.view(batch_size, self.channels, -1)
|
||||||
|
if self.training or not self.track_running_stats:
|
||||||
|
mean = x.mean(dim=[0, 2])
|
||||||
|
mean_x2 = (x ** 2).mean(dim=[0, 2])
|
||||||
|
var = mean_x2 - mean ** 2
|
||||||
|
|
||||||
|
if self.training and self.track_running_stats:
|
||||||
|
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
|
||||||
|
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
|
||||||
|
else:
|
||||||
|
mean = self.running_mean
|
||||||
|
var = self.running_var
|
||||||
|
|
||||||
|
x_norm = (x - mean.view(1, -1, 1)) / torch.sqrt(var + self.eps).view(1, -1, 1)
|
||||||
|
if self.affine:
|
||||||
|
x_norm = self.weight.view(1, -1, 1) * x_norm + self.bias.view(1, -1, 1)
|
||||||
|
|
||||||
|
return x_norm.view(x_shape)
|
||||||
105
labml_nn/normalization/mnist.py
Normal file
105
labml_nn/normalization/mnist.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.utils.data
|
||||||
|
|
||||||
|
from labml import experiment, tracker
|
||||||
|
from labml.configs import option
|
||||||
|
from labml_helpers.datasets.mnist import MNISTConfigs
|
||||||
|
from labml_helpers.device import DeviceConfigs
|
||||||
|
from labml_helpers.metrics.accuracy import Accuracy
|
||||||
|
from labml_helpers.module import Module
|
||||||
|
from labml_helpers.seed import SeedConfigs
|
||||||
|
from labml_helpers.train_valid import TrainValidConfigs, BatchIndex, hook_model_outputs
|
||||||
|
from labml_nn.normalization.batch_norm import BatchNorm
|
||||||
|
|
||||||
|
|
||||||
|
class Net(Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = nn.Conv2d(1, 20, 5, 1)
|
||||||
|
self.bn1 = BatchNorm(20)
|
||||||
|
self.conv2 = nn.Conv2d(20, 50, 5, 1)
|
||||||
|
self.bn2 = BatchNorm(50)
|
||||||
|
self.fc1 = nn.Linear(4 * 4 * 50, 500)
|
||||||
|
self.bn3 = BatchNorm(500)
|
||||||
|
self.fc2 = nn.Linear(500, 10)
|
||||||
|
|
||||||
|
def __call__(self, x: torch.Tensor):
|
||||||
|
x = F.relu(self.bn1(self.conv1(x)))
|
||||||
|
x = F.max_pool2d(x, 2, 2)
|
||||||
|
x = F.relu(self.bn2(self.conv2(x)))
|
||||||
|
x = F.max_pool2d(x, 2, 2)
|
||||||
|
x = x.view(-1, 4 * 4 * 50)
|
||||||
|
x = F.relu(self.bn3(self.fc1(x)))
|
||||||
|
return self.fc2(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Configs(MNISTConfigs, TrainValidConfigs):
|
||||||
|
optimizer: torch.optim.Adam
|
||||||
|
model: nn.Module
|
||||||
|
set_seed = SeedConfigs()
|
||||||
|
device: torch.device = DeviceConfigs()
|
||||||
|
epochs: int = 10
|
||||||
|
|
||||||
|
is_save_models = True
|
||||||
|
model: nn.Module
|
||||||
|
inner_iterations = 10
|
||||||
|
|
||||||
|
accuracy_func = Accuracy()
|
||||||
|
loss_func = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
def init(self):
|
||||||
|
tracker.set_queue("loss.*", 20, True)
|
||||||
|
tracker.set_scalar("accuracy.*", True)
|
||||||
|
hook_model_outputs(self.mode, self.model, 'model')
|
||||||
|
self.state_modules = [self.accuracy_func]
|
||||||
|
|
||||||
|
def step(self, batch: any, batch_idx: BatchIndex):
|
||||||
|
data, target = batch[0].to(self.device), batch[1].to(self.device)
|
||||||
|
|
||||||
|
if self.mode.is_train:
|
||||||
|
tracker.add_global_step(len(data))
|
||||||
|
|
||||||
|
with self.mode.update(is_log_activations=batch_idx.is_last):
|
||||||
|
output = self.model(data)
|
||||||
|
|
||||||
|
loss = self.loss_func(output, target)
|
||||||
|
self.accuracy_func(output, target)
|
||||||
|
tracker.add("loss.", loss)
|
||||||
|
|
||||||
|
if self.mode.is_train:
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
self.optimizer.step()
|
||||||
|
if batch_idx.is_last:
|
||||||
|
tracker.add('model', self.model)
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
tracker.save()
|
||||||
|
|
||||||
|
|
||||||
|
@option(Configs.model)
|
||||||
|
def model(c: Configs):
|
||||||
|
return Net().to(c.device)
|
||||||
|
|
||||||
|
|
||||||
|
@option(Configs.optimizer)
|
||||||
|
def _optimizer(c: Configs):
|
||||||
|
from labml_helpers.optimizer import OptimizerConfigs
|
||||||
|
opt_conf = OptimizerConfigs()
|
||||||
|
opt_conf.parameters = c.model.parameters()
|
||||||
|
return opt_conf
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
conf = Configs()
|
||||||
|
experiment.create(name='mnist_labml_helpers')
|
||||||
|
experiment.configs(conf, {'optimizer.optimizer': 'Adam'})
|
||||||
|
conf.set_seed.set()
|
||||||
|
experiment.add_pytorch_models(dict(model=conf.model))
|
||||||
|
with experiment.start():
|
||||||
|
conf.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user