mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-02 13:00:17 +08:00
✨ batch norm
This commit is contained in:
0
labml_nn/normalization/__init__.py
Normal file
0
labml_nn/normalization/__init__.py
Normal file
47
labml_nn/normalization/batch_norm.py
Normal file
47
labml_nn/normalization/batch_norm.py
Normal file
@ -0,0 +1,47 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from labml_helpers.module import Module
|
||||
|
||||
|
||||
class BatchNorm(Module):
|
||||
def __init__(self, channels: int, *,
|
||||
eps: float = 1e-5, momentum: float = 0.1,
|
||||
affine: bool = True, track_running_stats: bool = True):
|
||||
super().__init__()
|
||||
|
||||
self.channels = channels
|
||||
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.affine = affine
|
||||
self.track_running_stats = track_running_stats
|
||||
if self.affine:
|
||||
self.weight = nn.Parameter(torch.ones(channels))
|
||||
self.bias = nn.Parameter(torch.zeros(channels))
|
||||
if self.track_running_stats:
|
||||
self.register_buffer('running_mean', torch.zeros(channels))
|
||||
self.register_buffer('running_var', torch.ones(channels))
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
x_shape = x.shape
|
||||
batch_size = x_shape[0]
|
||||
|
||||
x = x.view(batch_size, self.channels, -1)
|
||||
if self.training or not self.track_running_stats:
|
||||
mean = x.mean(dim=[0, 2])
|
||||
mean_x2 = (x ** 2).mean(dim=[0, 2])
|
||||
var = mean_x2 - mean ** 2
|
||||
|
||||
if self.training and self.track_running_stats:
|
||||
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
|
||||
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
|
||||
else:
|
||||
mean = self.running_mean
|
||||
var = self.running_var
|
||||
|
||||
x_norm = (x - mean.view(1, -1, 1)) / torch.sqrt(var + self.eps).view(1, -1, 1)
|
||||
if self.affine:
|
||||
x_norm = self.weight.view(1, -1, 1) * x_norm + self.bias.view(1, -1, 1)
|
||||
|
||||
return x_norm.view(x_shape)
|
||||
105
labml_nn/normalization/mnist.py
Normal file
105
labml_nn/normalization/mnist.py
Normal file
@ -0,0 +1,105 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
|
||||
from labml import experiment, tracker
|
||||
from labml.configs import option
|
||||
from labml_helpers.datasets.mnist import MNISTConfigs
|
||||
from labml_helpers.device import DeviceConfigs
|
||||
from labml_helpers.metrics.accuracy import Accuracy
|
||||
from labml_helpers.module import Module
|
||||
from labml_helpers.seed import SeedConfigs
|
||||
from labml_helpers.train_valid import TrainValidConfigs, BatchIndex, hook_model_outputs
|
||||
from labml_nn.normalization.batch_norm import BatchNorm
|
||||
|
||||
|
||||
class Net(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(1, 20, 5, 1)
|
||||
self.bn1 = BatchNorm(20)
|
||||
self.conv2 = nn.Conv2d(20, 50, 5, 1)
|
||||
self.bn2 = BatchNorm(50)
|
||||
self.fc1 = nn.Linear(4 * 4 * 50, 500)
|
||||
self.bn3 = BatchNorm(500)
|
||||
self.fc2 = nn.Linear(500, 10)
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
x = F.relu(self.bn1(self.conv1(x)))
|
||||
x = F.max_pool2d(x, 2, 2)
|
||||
x = F.relu(self.bn2(self.conv2(x)))
|
||||
x = F.max_pool2d(x, 2, 2)
|
||||
x = x.view(-1, 4 * 4 * 50)
|
||||
x = F.relu(self.bn3(self.fc1(x)))
|
||||
return self.fc2(x)
|
||||
|
||||
|
||||
class Configs(MNISTConfigs, TrainValidConfigs):
|
||||
optimizer: torch.optim.Adam
|
||||
model: nn.Module
|
||||
set_seed = SeedConfigs()
|
||||
device: torch.device = DeviceConfigs()
|
||||
epochs: int = 10
|
||||
|
||||
is_save_models = True
|
||||
model: nn.Module
|
||||
inner_iterations = 10
|
||||
|
||||
accuracy_func = Accuracy()
|
||||
loss_func = nn.CrossEntropyLoss()
|
||||
|
||||
def init(self):
|
||||
tracker.set_queue("loss.*", 20, True)
|
||||
tracker.set_scalar("accuracy.*", True)
|
||||
hook_model_outputs(self.mode, self.model, 'model')
|
||||
self.state_modules = [self.accuracy_func]
|
||||
|
||||
def step(self, batch: any, batch_idx: BatchIndex):
|
||||
data, target = batch[0].to(self.device), batch[1].to(self.device)
|
||||
|
||||
if self.mode.is_train:
|
||||
tracker.add_global_step(len(data))
|
||||
|
||||
with self.mode.update(is_log_activations=batch_idx.is_last):
|
||||
output = self.model(data)
|
||||
|
||||
loss = self.loss_func(output, target)
|
||||
self.accuracy_func(output, target)
|
||||
tracker.add("loss.", loss)
|
||||
|
||||
if self.mode.is_train:
|
||||
loss.backward()
|
||||
|
||||
self.optimizer.step()
|
||||
if batch_idx.is_last:
|
||||
tracker.add('model', self.model)
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
tracker.save()
|
||||
|
||||
|
||||
@option(Configs.model)
|
||||
def model(c: Configs):
|
||||
return Net().to(c.device)
|
||||
|
||||
|
||||
@option(Configs.optimizer)
|
||||
def _optimizer(c: Configs):
|
||||
from labml_helpers.optimizer import OptimizerConfigs
|
||||
opt_conf = OptimizerConfigs()
|
||||
opt_conf.parameters = c.model.parameters()
|
||||
return opt_conf
|
||||
|
||||
|
||||
def main():
|
||||
conf = Configs()
|
||||
experiment.create(name='mnist_labml_helpers')
|
||||
experiment.configs(conf, {'optimizer.optimizer': 'Adam'})
|
||||
conf.set_seed.set()
|
||||
experiment.add_pytorch_models(dict(model=conf.model))
|
||||
with experiment.start():
|
||||
conf.run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user