From 6a41c82b30157fb146ac4d7f455e57a8e7aa7565 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Mon, 23 May 2022 22:26:39 +0530 Subject: [PATCH] FTA (#115) --- docs/activations/fta/experiment.html | 833 ++++++++++++++++++ docs/activations/fta/index.html | 402 +++++++++ docs/activations/index.html | 23 +- docs/index.html | 2 + docs/normalization/deep_norm/experiment.html | 5 +- docs/sitemap.xml | 22 +- docs/transformers/rope/index.html | 2 +- labml_nn/__init__.py | 4 + labml_nn/activations/__init__.py | 13 + labml_nn/activations/fta/__init__.py | 132 +++ labml_nn/activations/fta/experiment.ipynb | 299 +++++++ labml_nn/activations/fta/experiment.py | 220 +++++ .../normalization/deep_norm/experiment.py | 2 +- labml_nn/transformers/rope/__init__.py | 2 +- readme.md | 5 + setup.py | 2 +- 16 files changed, 1947 insertions(+), 21 deletions(-) create mode 100644 docs/activations/fta/experiment.html create mode 100644 docs/activations/fta/index.html create mode 100644 labml_nn/activations/fta/__init__.py create mode 100644 labml_nn/activations/fta/experiment.ipynb create mode 100644 labml_nn/activations/fta/experiment.py diff --git a/docs/activations/fta/experiment.html b/docs/activations/fta/experiment.html new file mode 100644 index 00000000..3d0d2ce3 --- /dev/null +++ b/docs/activations/fta/experiment.html @@ -0,0 +1,833 @@ + + + + + + + + + + + + + + + + + + + + + + + Fuzzy Tiling Activation Experiment + + + + + + + + + + +
+
+
+
+

+ home + activations + fta +

+

+ + + Github + + Twitter +

+
+
+
+
+ +

Fuzzy Tiling Activation Experiment

+

Here we train a transformer that uses Fuzzy Tiling Activation in the Feed-Forward Network. We use it for a language model and train it on Tiny Shakespeare dataset for demonstration.

+

However, this is probably not the ideal task for FTA, and we believe FTA is more suitable for modeling data with continuous variables.

+

Open In Colab Open In Comet

+ +
+
+
22import copy
+23
+24import torch
+25import torch.nn as nn
+26
+27from labml import experiment
+28from labml.configs import option
+29from labml_helpers.module import Module
+30from labml_nn.activations.fta import FTA
+31from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
+32from labml_nn.transformers import MultiHeadAttention, TransformerLayer
+33from labml_nn.transformers.utils import subsequent_mask
+
+
+
+
+ +

FFN module with FTA activation

+ +
+
+
36class FeedForwardFTA(nn.Module):
+
+
+
+
+ +
  • d_model + is the number of features in a token embedding
  • +
  • d_ff + is the number of features in the hidden layer of the FFN
  • +
  • activation + is FTA activation module
  • +
  • dropout + is dropout probability for the hidden layer
+ +
+
+
41    def __init__(self, d_model: int, d_ff: int,
+42                 activation: FTA,
+43                 dropout: float = 0.1):
+
+
+
+
+ + +
+
+
50        super().__init__()
+
+
+
+
+ +

Layer one parameterized by weight and bias

+ +
+
+
52        self.layer1 = nn.Linear(d_model, d_ff)
+
+
+
+
+ +

Layer two parameterized by weight and bias

+ +
+
+
54        self.layer2 = nn.Linear(d_ff * activation.expansion_factor, d_model)
+
+
+
+
+ +

Hidden layer dropout

+ +
+
+
56        self.dropout = nn.Dropout(dropout)
+
+
+
+
+ +

Activation function

+ +
+
+
58        self.activation = activation
+
+
+
+
+ + +
+
+
60    def forward(self, x: torch.Tensor):
+
+
+
+
+ +

+ +
+
+
62        x = self.activation(self.layer1(x))
+
+
+
+
+ +

Apply dropout

+ +
+
+
64        x = self.dropout(x)
+
+
+
+
+ +

+ +
+
+
66        return self.layer2(x)
+
+
+
+
+ +

Auto-Regressive model

+

This is an autoregressive transformer model that uses Feed-Forward Networks with (Fuzzy Tiling Activations)(index.html).

+ +
+
+
69class AutoregressiveTransformer(Module):
+
+
+
+
+ +
  • n_tokens is the number of tokens in the vocabulary
  • +
  • d_model is the embedding size
  • +
  • n_layers is the number of transformer layers
  • +
  • layer is the layer. We use n_layers + copies of this for the transformer.
+ +
+
+
77    def __init__(self, n_tokens: int, d_model: int, n_layers: int, layer: TransformerLayer):
+
+
+
+
+ + +
+
+
84        super().__init__()
+
+
+
+
+ +

Transformer with n_layers + layers

+ +
+
+
86        self.transformer_layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layers)])
+
+
+
+
+ +

Token embedding layer

+ +
+
+
89        self.emb = nn.Embedding(n_tokens, d_model)
+
+
+
+
+ +

Readout layer

+ +
+
+
91        self.readout = nn.Linear(d_model, n_tokens)
+
+
+
+
+ +

The mask will be initialized on the first call

+ +
+
+
94        self.mask = None
+
+
+
+
+ +
  • x are the input tokens of shape [seq_len, batch_size] +
+ +
+
+
96    def forward(self, x: torch.Tensor):
+
+
+
+
+ +

Create auto-regressive mask

+ +
+
+
101        if self.mask is None or self.mask.size(0) != len(x):
+
+
+
+
+ +

Subsequent mask, will mask out tokens from seeing future tokens

+ +
+
+
103            self.mask = subsequent_mask(len(x)).to(x.device)
+
+
+
+
+ +

Get the token embeddings

+ +
+
+
106        x = self.emb(x)
+
+
+
+
+ +

Transformer encoder

+ +
+
+
108        for layer in self.transformer_layers:
+109            x = layer(x=x, mask=self.mask)
+
+
+
+
+ +

Get logits

+ +
+
+
111        x = self.readout(x)
+
+
+
+
+ +

Return results

+ +
+
+
114        return x, None
+
+
+
+
+ +

Configurations

+

This inherits from NLPAutoRegressionConfigs +

+ +
+
+
117class Configs(NLPAutoRegressionConfigs):
+
+
+
+
+ +

Model

+ +
+
+
126    model: AutoregressiveTransformer
+
+
+
+
+ +

Number of layers

+ +
+
+
129    n_layers: int = 4
+
+
+
+
+ +

and for DeepNorm

+ +
+
+
132    deep_norm_alpha: float
+133    deep_norm_beta: float
+
+
+
+
+ +

Number of heads in the attention

+ +
+
+
136    n_heads: int = 4
+
+
+
+
+ +

Embedding size

+ +
+
+
138    d_model: int = 256
+
+
+
+
+ +

Size of each attention head

+ +
+
+
140    d_k: int = 16
+
+
+
+
+ +

Feed forward layer size

+ +
+
+
142    d_ff: int = 256
+
+
+
+
+ +

FTA

+ +
+
+
145    fta_lower_limit: float = -1.
+146    fta_upper_limit: float = +1.
+147    fta_delta: float = 0.2
+148    fta_eta: float = 0.05
+
+
+
+
+ +

Initialize the model

+ +
+
+
151@option(Configs.model)
+152def _model(c: Configs):
+
+
+
+
+ +

Create FTA activation module

+ +
+
+
158    fta = FTA(c.fta_lower_limit, c.fta_upper_limit, c.fta_delta, c.fta_eta)
+
+
+
+
+ +

Create the transformer. We re-use TransformerLayer + and MultiHeadAttention + implementations.

+ +
+
+
162    m = AutoregressiveTransformer(c.n_tokens, c.d_model, c.n_layers,
+163                                  TransformerLayer(d_model=c.d_model,
+164                                                   feed_forward=FeedForwardFTA(d_model=c.d_model,
+165                                                                               d_ff=c.d_ff,
+166                                                                               activation=fta,
+167                                                                               dropout=0.1),
+168                                                   self_attn=MultiHeadAttention(c.n_heads, c.d_model,
+169                                                                                dropout_prob=0.0),
+170                                                   dropout_prob=0.0))
+
+
+
+
+ +

Move to the device

+ +
+
+
173    return m.to(c.device)
+
+
+
+
+ +

Create and run the experiment

+ +
+
+
176def main():
+
+
+
+
+ +

Create experiment

+ +
+
+
181    experiment.create(name="fta", writers={'screen',  'comet', 'labml'})
+
+
+
+
+ +

Create configs

+ +
+
+
183    conf = Configs()
+
+
+
+
+ +

Override configurations

+ +
+
+
185    experiment.configs(conf, {
+
+
+
+
+ +

Use character level tokenizer

+ +
+
+
187        'tokenizer': 'character',
+
+
+
+
+ +

Prompt separator is blank

+ +
+
+
189        'prompt_separator': '',
+
+
+
+
+ +

Starting prompt for sampling

+ +
+
+
191        'prompt': 'It is ',
+
+
+
+
+ +

Use Tiny Shakespeare dataset

+ +
+
+
193        'text': 'tiny_shakespeare',
+
+
+
+
+ +

Use a context size of

+ +
+
+
196        'seq_len': 256,
+
+
+
+
+ +

Train for 32 epochs

+ +
+
+
198        'epochs': 32,
+
+
+
+
+ +

Batch size

+ +
+
+
200        'batch_size': 16,
+
+
+
+
+ +

Switch between training and validation for times per epoch

+ +
+
+
202        'inner_iterations': 10,
+
+
+
+
+ +

Adam optimizer with no warmup

+ +
+
+
205        'optimizer.optimizer': 'Adam',
+206        'optimizer.learning_rate': 3e-4,
+207    })
+
+
+
+
+ +

Set model(s) for saving and loading

+ +
+
+
210    experiment.add_pytorch_models({'model': conf.model})
+
+
+
+
+ +

Start the experiment

+ +
+
+
213    with experiment.start():
+
+
+
+
+ +

Run training

+ +
+
+
215        conf.run()
+
+
+
+
+ +

+ +
+
+
219if __name__ == '__main__':
+220    main()
+
+
+ +
+ + + + \ No newline at end of file diff --git a/docs/activations/fta/index.html b/docs/activations/fta/index.html new file mode 100644 index 00000000..32bee1ea --- /dev/null +++ b/docs/activations/fta/index.html @@ -0,0 +1,402 @@ + + + + + + + + + + + + + + + + + + + + + + + Fuzzy Tiling Activations + + + + + + + + + + +
+
+
+
+

+ home + activations + fta +

+

+ + + Github + + Twitter +

+
+
+
+
+ +

Fuzzy Tiling Activations (FTA)

+

This is a PyTorch implementation/tutorial of Fuzzy Tiling Activations: A Simple Approach to Learning Sparse Representations Online.

+

Fuzzy tiling activations are a form of sparse activations based on binning.

+

Binning is classification of a scalar value into a bin based on intervals. One problem with binning is that it gives zero gradients for most values (except at the boundary of bins). The other is that binning loses precision if the bin intervals are large.

+

FTA overcomes these disadvantages. Instead of hard boundaries like in Tiling Activations, FTA uses soft boundaries between bins. This gives non-zero gradients for all or a wide range of values. And also doesn't lose precision since it's captured in partial values.

+

Tiling Activations

+

is the tiling vector,

+

+

where is the input range, is the bin size, and is divisible by .

+

Tiling activation is,

+

+

where is the indicator function which gives if the input is positive and otherwise.

+

Note that tiling activation gives zero gradients because it has hard boundaries.

+

Fuzzy Tiling Activations

+

The fuzzy indicator function,

+

+

which increases linearly from to when and is equal to for . is a hyper-parameter.

+

FTA uses this to create soft boundaries between bins.

+

+

Here's a simple experiment that uses FTA in a transformer.

+

Open In Colab Open In Comet

+ +
+
+
62import torch
+63from torch import nn
+
+
+
+
+ +

Fuzzy Tiling Activations (FTA)

+ +
+
+
66class FTA(nn.Module):
+
+
+
+
+ +
  • lower_limit is the lower limit
  • +
  • upper_limit is the upper limit
  • +
  • delta is the bin size
  • +
  • eta is the parameter that detemines the softness of the boundaries.
+ +
+
+
71    def __init__(self, lower_limit: float, upper_limit: float, delta: float, eta: float):
+
+
+
+
+ + +
+
+
78        super().__init__()
+
+
+
+
+ +

Initialize tiling vector

+ +
+
+
81        self.c = nn.Parameter(torch.arange(lower_limit, upper_limit, delta), requires_grad=False)
+
+
+
+
+ +

The input vector expands by a factor equal to the number of bins

+ +
+
+
83        self.expansion_factor = len(self.c)
+
+
+
+
+ +

+ +
+
+
85        self.delta = delta
+
+
+
+
+ +

+ +
+
+
87        self.eta = eta
+
+
+
+
+ +

Fuzzy indicator function

+

+ +
+
+
89    def fuzzy_i_plus(self, x: torch.Tensor):
+
+
+
+
+ + +
+
+
95        return (x <= self.eta) * x + (x > self.eta)
+
+
+
+
+ + +
+
+
97    def forward(self, z: torch.Tensor):
+
+
+
+
+ +

Add another dimension of size . We will expand this into bins.

+ +
+
+
100        z = z.view(*z.shape, 1)
+
+
+
+
+ +

+ +
+
+
103        z = 1. - self.fuzzy_i_plus(torch.clip(self.c - z, min=0.) + torch.clip(z - self.delta - self.c, min=0.))
+
+
+
+
+ +

Reshape back to original number of dimensions. The last dimension size gets expanded by the number of bins, .

+ +
+
+
107        return z.view(*z.shape[:-2], -1)
+
+
+
+
+ +

Code to test the FTA module

+ +
+
+
110def _test():
+
+
+
+
+ + +
+
+
114    from labml.logger import inspect
+
+
+
+
+ +

Initialize

+ +
+
+
117    a = FTA(-10, 10, 2., 0.5)
+
+
+
+
+ +

Print

+ +
+
+
119    inspect(a.c)
+
+
+
+
+ +

Print number of bins

+ +
+
+
121    inspect(a.expansion_factor)
+
+
+
+
+ +

Input

+ +
+
+
124    z = torch.tensor([1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9., 10., 11.])
+
+
+
+
+ +

Print

+ +
+
+
126    inspect(z)
+
+
+
+
+ +

Print

+ +
+
+
128    inspect(a(z))
+129
+130
+131if __name__ == '__main__':
+132    _test()
+
+
+ +
+ + + + \ No newline at end of file diff --git a/docs/activations/index.html b/docs/activations/index.html index 6f9e2821..cd4d7b0f 100644 --- a/docs/activations/index.html +++ b/docs/activations/index.html @@ -3,24 +3,24 @@ - + - - + + - + - - + + - __init__.py + Neural Network Activation Functions @@ -64,14 +64,17 @@
-
+
- +

Neural Networks Activations

+ +
-
1from .swish import Swish
+
14from .swish import Swish
15import copy
diff --git a/docs/sitemap.xml b/docs/sitemap.xml
index 74a8e125..42cf8562 100644
--- a/docs/sitemap.xml
+++ b/docs/sitemap.xml
@@ -85,7 +85,21 @@
 
     
       https://nn.labml.ai/activations/index.html
-      2021-01-25T16:30:00+00:00
+      2022-05-23T16:30:00+00:00
+      1.00
+    
+    
+
+    
+      https://nn.labml.ai/activations/fta/index.html
+      2022-05-23T16:30:00+00:00
+      1.00
+    
+    
+
+    
+      https://nn.labml.ai/activations/fta/experiment.html
+      2022-05-23T16:30:00+00:00
       1.00
     
     
@@ -197,7 +211,7 @@
 
     
       https://nn.labml.ai/normalization/deep_norm/experiment.html
-      2022-04-23T16:30:00+00:00
+      2022-05-23T16:30:00+00:00
       1.00
     
     
@@ -295,7 +309,7 @@
 
     
       https://nn.labml.ai/index.html
-      2022-05-03T16:30:00+00:00
+      2022-05-23T16:30:00+00:00
       1.00
     
     
@@ -589,7 +603,7 @@
 
     
       https://nn.labml.ai/transformers/rope/index.html
-      2022-02-23T16:30:00+00:00
+      2022-04-05T16:30:00+00:00
       1.00
     
     
diff --git a/docs/transformers/rope/index.html b/docs/transformers/rope/index.html
index 5b40fd74..06512a47 100644
--- a/docs/transformers/rope/index.html
+++ b/docs/transformers/rope/index.html
@@ -214,7 +214,7 @@
             
-            

Concatenate so that for row we have

+

Concatenate so that for row we have

diff --git a/labml_nn/__init__.py b/labml_nn/__init__.py index 229592fe..5688b209 100644 --- a/labml_nn/__init__.py +++ b/labml_nn/__init__.py @@ -112,6 +112,10 @@ Solving games with incomplete information such as poker with CFR. * [Evidential Deep Learning to Quantify Classification Uncertainty](uncertainty/evidence/index.html) +#### ✨ [Activations](activations/index.html) + +* [Fuzzy Tiling Activations](activations/fta/index.html) + ## Highlighted Research Paper PDFs * [Autoregressive Search Engines: Generating Substrings as Document Identifiers](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf) diff --git a/labml_nn/activations/__init__.py b/labml_nn/activations/__init__.py index 0306daa9..3aa9d24e 100644 --- a/labml_nn/activations/__init__.py +++ b/labml_nn/activations/__init__.py @@ -1 +1,14 @@ +""" +--- +title: Neural Network Activation Functions +summary: > + A set of PyTorch implementations/tutorials related to neural network activations +--- + +# Neural Networks Activations + +* [Fuzzy Tiling Activations](fta/index.html) +* 🚧 [Swish](swish/index.html) +""" + from .swish import Swish diff --git a/labml_nn/activations/fta/__init__.py b/labml_nn/activations/fta/__init__.py new file mode 100644 index 00000000..fd24407e --- /dev/null +++ b/labml_nn/activations/fta/__init__.py @@ -0,0 +1,132 @@ +""" +--- +title: Fuzzy Tiling Activations +summary: > + PyTorch implementation and tutorial of Fuzzy Tiling Activations from the + paper Fuzzy Tiling Activations: A Simple Approach to Learning Sparse Representations Online. +--- + +# Fuzzy Tiling Activations (FTA) + +This is a [PyTorch](https://pytorch.org) implementation/tutorial of +[Fuzzy Tiling Activations: A Simple Approach to Learning Sparse Representations Online](https://papers.labml.ai/paper/aca66d8edc8911eba3db37f65e372566). + +Fuzzy tiling activations are a form of sparse activations based on binning. + +Binning is classification of a scalar value into a bin based on intervals. +One problem with binning is that it gives zero gradients for most values (except at the boundary of bins). +The other is that binning loses precision if the bin intervals are large. + +FTA overcomes these disadvantages. +Instead of hard boundaries like in Tiling Activations, FTA uses soft boundaries +between bins. +This gives non-zero gradients for all or a wide range of values. +And also doesn't lose precision since it's captured in partial values. + +#### Tiling Activations + +$\mathbf{c}$ is the tiling vector, + +$$\mathbf{c} = (l, l + \delta, l + 2 \delta, \dots, u - 2 \delta, u - \delta)$$ + +where $[l, u]$ is the input range, $\delta$ is the bin size, and $u - l$ is divisible by $\delta$. + +Tiling activation is, + +$$\phi(z) = 1 - I_+ \big( \max(\mathbf{c} - z, 0) + \max(z - \delta - \mathbf{c}) \big)$$ + +where $I_+(\cdot)$ is the indicator function which gives $1$ if the input is positive and $0$ otherwise. + +Note that tiling activation gives zero gradients because it has hard boundaries. + +#### Fuzzy Tiling Activations + +The fuzzy indicator function, + +$$I_{\eta,+}(x) = I_+(\eta - x) x + I_+ (x - \eta)$$ + +which increases linearly from $0$ to $1$ when $0 \le x \lt \eta$ +and is equal to $1$ for $\eta \le x$. +$\eta$ is a hyper-parameter. + +FTA uses this to create soft boundaries between bins. + +$$\phi_\eta(z) = 1 - I_{\eta,+} \big( \max(\mathbf{c} - z, 0) + \max(z - \delta - \mathbf{c}, 0) \big)$$ + +[Here's a simple experiment](experiment.html) that uses FTA in a transformer. + +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/activations/fta/experiment.ipynb) +[![Open In Comet](https://images.labml.ai/images/comet.svg?experiment=capsule_networks&file=model)](https://www.comet.ml/labml/fta/69be11f83693407f82a86dcbb232bcfe?experiment-tab=chart&showOutliers=true&smoothing=0&transformY=smoothing&viewId=rlJOpXDGtL8zbkcX66R77P5me&xAxis=step) +""" + +import torch +from torch import nn + + +class FTA(nn.Module): + """ + ### Fuzzy Tiling Activations (FTA) + """ + + def __init__(self, lower_limit: float, upper_limit: float, delta: float, eta: float): + """ + :param lower_limit: is the lower limit $l$ + :param upper_limit: is the upper limit $u$ + :param delta: is the bin size $\delta$ + :param eta: is the parameter $\eta$ that detemines the softness of the boundaries. + """ + super().__init__() + # Initialize tiling vector + # $$\mathbf{c} = (l, l + \delta, l + 2 \delta, \dots, u - 2 \delta, u - \delta)$$ + self.c = nn.Parameter(torch.arange(lower_limit, upper_limit, delta), requires_grad=False) + # The input vector expands by a factor equal to the number of bins $\frac{u - l}{\delta}$ + self.expansion_factor = len(self.c) + # $\delta$ + self.delta = delta + # $\eta$ + self.eta = eta + + def fuzzy_i_plus(self, x: torch.Tensor): + """ + #### Fuzzy indicator function + + $$I_{\eta,+}(x) = I_+(\eta - x) x + I_+ (x - \eta)$$ + """ + return (x <= self.eta) * x + (x > self.eta) + + def forward(self, z: torch.Tensor): + # Add another dimension of size $1$. + # We will expand this into bins. + z = z.view(*z.shape, 1) + + # $$\phi_\eta(z) = 1 - I_{\eta,+} \big( \max(\mathbf{c} - z, 0) + \max(z - \delta - \mathbf{c}, 0) \big)$$ + z = 1. - self.fuzzy_i_plus(torch.clip(self.c - z, min=0.) + torch.clip(z - self.delta - self.c, min=0.)) + + # Reshape back to original number of dimensions. + # The last dimension size gets expanded by the number of bins, $\frac{u - l}{\delta}$. + return z.view(*z.shape[:-2], -1) + + +def _test(): + """ + #### Code to test the FTA module + """ + from labml.logger import inspect + + # Initialize + a = FTA(-10, 10, 2., 0.5) + # Print $\mathbf{c}$ + inspect(a.c) + # Print number of bins $\frac{u - l}{\delta}$ + inspect(a.expansion_factor) + + # Input $z$ + z = torch.tensor([1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9., 10., 11.]) + # Print $z$ + inspect(z) + # Print $\phi_\eta(z)$ + inspect(a(z)) + + +if __name__ == '__main__': + _test() diff --git a/labml_nn/activations/fta/experiment.ipynb b/labml_nn/activations/fta/experiment.ipynb new file mode 100644 index 00000000..f929f2e7 --- /dev/null +++ b/labml_nn/activations/fta/experiment.ipynb @@ -0,0 +1,299 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "AYV_dMVDxyc2", + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "[![Github](https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social)](https://github.com/labmlai/annotated_deep_learning_paper_implementations)\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/activations/fta/experiment.ipynb)\n", + "[![Open In Comet](https://images.labml.ai/images/comet.svg?experiment=capsule_networks&file=model)](https://www.comet.ml/labml/fta/69be11f83693407f82a86dcbb232bcfe?experiment-tab=chart&showOutliers=true&smoothing=0&transformY=smoothing&viewId=rlJOpXDGtL8zbkcX66R77P5me&xAxis=step)\n", + "\n", + "## [Fuzzy Tiling Activations](https://nn.labml.ai/activations/fta/index.html)\n", + "\n", + "Here we train a transformer that uses [Fuzzy Tiling Activation](https://nn.labml.ai/activations/fta/index.html) in the\n", + "[Feed-Forward Network](https://nn.labml.ai/transformers/feed_forward.html).\n", + "We use it for a language model and train it on Tiny Shakespeare dataset\n", + "for demonstration.\n", + "However, this is probably not the ideal task for FTA, and we\n", + "believe FTA is more suitable for modeling data with continuous variables." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AahG_i2y5tY9", + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "### Install the packages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ZCzmCrAIVg0L", + "outputId": "cf107fb2-4d50-4c67-af34-367624553421", + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "!pip install labml-nn comet_ml --quiet" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "### Enable [Comet](https://www.comet.ml)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "#@markdown Select in order to enable logging this experiment to [Comet](https://www.comet.ml).\n", + "use_comet = False #@param {type:\"boolean\"}\n", + "\n", + "if use_comet:\n", + " import comet_ml\n", + " comet_ml.init(project_name='fta')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SE2VUQ6L5zxI", + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "\n", + "from labml import experiment\n", + "from labml.configs import option\n", + "from labml_nn.activations.fta.experiment import Configs" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "### Create an experiment" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "experiment.create(name=\"fta\", writers={\"screen\", \"comet\"} if use_comet else {'screen'})" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "### Configurations" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "conf = Configs()" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "Set experiment configurations and assign a configurations dictionary to override configurations" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "experiment.configs(conf, {\n", + " 'tokenizer': 'character',\n", + " 'prompt_separator': '',\n", + " 'prompt': 'It is ',\n", + " 'text': 'tiny_shakespeare',\n", + "\n", + " 'seq_len': 256,\n", + " 'epochs': 32,\n", + " 'batch_size': 16,\n", + " 'inner_iterations': 10,\n", + "\n", + " 'optimizer.optimizer': 'Adam',\n", + " 'optimizer.learning_rate': 3e-4,\n", + "})" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EvI7MtgJ61w5", + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "Set PyTorch models for loading and saving" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 255 + }, + "id": "GDlt7dp-5ALt", + "outputId": "e7548e8f-c541-4618-dc5a-1597cae42003", + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "experiment.add_pytorch_models({'model': conf.model})" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KJZRf8527GxL", + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "### Start the experiment and run the training loop." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "aIAWo7Fw5DR8", + "outputId": "db979785-bfe3-4eda-d3eb-8ccbe61053e5", + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "# Start the experiment\n", + "with experiment.start():\n", + " conf.run()" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "FTA", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.11" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/labml_nn/activations/fta/experiment.py b/labml_nn/activations/fta/experiment.py new file mode 100644 index 00000000..be740a02 --- /dev/null +++ b/labml_nn/activations/fta/experiment.py @@ -0,0 +1,220 @@ +""" +--- +title: Fuzzy Tiling Activation Experiment +summary: > + Training a transformer with FTA in FFN on Tiny Shakespeare. +--- + +# [Fuzzy Tiling Activation](index.html) Experiment + +Here we train a transformer that uses [Fuzzy Tiling Activation](index.html) in the +[Feed-Forward Network](../../transformers/feed_forward.html). +We use it for a language model and train it on Tiny Shakespeare dataset +for demonstration. + +However, this is probably not the ideal task for FTA, and we +believe FTA is more suitable for modeling data with continuous variables. + +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/activations/fta/experiment.ipynb) +[![Open In Comet](https://images.labml.ai/images/comet.svg?experiment=capsule_networks&file=model)](https://www.comet.ml/labml/fta/69be11f83693407f82a86dcbb232bcfe?experiment-tab=chart&showOutliers=true&smoothing=0&transformY=smoothing&viewId=rlJOpXDGtL8zbkcX66R77P5me&xAxis=step) +""" + +import copy + +import torch +import torch.nn as nn + +from labml import experiment +from labml.configs import option +from labml_helpers.module import Module +from labml_nn.activations.fta import FTA +from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs +from labml_nn.transformers import MultiHeadAttention, TransformerLayer +from labml_nn.transformers.utils import subsequent_mask + + +class FeedForwardFTA(nn.Module): + """ + ## FFN module with [FTA](index.html) activation + """ + + def __init__(self, d_model: int, d_ff: int, + activation: FTA, + dropout: float = 0.1): + """ + * `d_model` is the number of features in a token embedding + * `d_ff` is the number of features in the hidden layer of the FFN + * `activation` is FTA activation module + * `dropout` is dropout probability for the hidden layer + """ + super().__init__() + # Layer one parameterized by weight $W_1$ and bias $b_1$ + self.layer1 = nn.Linear(d_model, d_ff) + # Layer two parameterized by weight $W_1$ and bias $b_1$ + self.layer2 = nn.Linear(d_ff * activation.expansion_factor, d_model) + # Hidden layer dropout + self.dropout = nn.Dropout(dropout) + # Activation function $f$ + self.activation = activation + + def forward(self, x: torch.Tensor): + # $f(x W_1 + b_1)$ + x = self.activation(self.layer1(x)) + # Apply dropout + x = self.dropout(x) + # + return self.layer2(x) + + +class AutoregressiveTransformer(Module): + """ + ## Auto-Regressive model + + This is an autoregressive transformer model that uses Feed-Forward Networks with + (Fuzzy Tiling Activations)(index.html). + """ + + def __init__(self, n_tokens: int, d_model: int, n_layers: int, layer: TransformerLayer): + """ + :param n_tokens: is the number of tokens in the vocabulary + :param d_model: is the embedding size + :param n_layers: is the number of transformer layers + :param layer: is the layer. We use `n_layers` copies of this for the transformer. + """ + super().__init__() + # Transformer with `n_layers` layers + self.transformer_layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layers)]) + + # Token embedding layer + self.emb = nn.Embedding(n_tokens, d_model) + # Readout layer + self.readout = nn.Linear(d_model, n_tokens) + + # The mask will be initialized on the first call + self.mask = None + + def forward(self, x: torch.Tensor): + """ + :param x: are the input tokens of shape `[seq_len, batch_size]` + """ + # Create auto-regressive mask + if self.mask is None or self.mask.size(0) != len(x): + # Subsequent mask, will mask out tokens from seeing future tokens + self.mask = subsequent_mask(len(x)).to(x.device) + + # Get the token embeddings + x = self.emb(x) + # Transformer encoder + for layer in self.transformer_layers: + x = layer(x=x, mask=self.mask) + # Get logits + x = self.readout(x) + + # Return results + return x, None + + +class Configs(NLPAutoRegressionConfigs): + """ + ## Configurations + + This inherits from + [`NLPAutoRegressionConfigs`](../../experiments/nlp_autoregression.html#NLPAutoRegressionConfigs) + """ + + # Model + model: AutoregressiveTransformer + + # Number of layers + n_layers: int = 4 + + # $\alpha$ and $\beta$ for DeepNorm + deep_norm_alpha: float + deep_norm_beta: float + + # Number of heads in the attention + n_heads: int = 4 + # Embedding size + d_model: int = 256 + # Size of each attention head + d_k: int = 16 + # Feed forward layer size + d_ff: int = 256 + + # FTA + fta_lower_limit: float = -1. + fta_upper_limit: float = +1. + fta_delta: float = 0.2 + fta_eta: float = 0.05 + + +@option(Configs.model) +def _model(c: Configs): + """ + #### Initialize the model + """ + + # Create FTA activation module + fta = FTA(c.fta_lower_limit, c.fta_upper_limit, c.fta_delta, c.fta_eta) + # Create the transformer. + # We re-use [`TransformerLayer`](../../transformers/models.html#TransformerLayer) and + # [`MultiHeadAttention`](../../transformers/mha.html) implementations. + m = AutoregressiveTransformer(c.n_tokens, c.d_model, c.n_layers, + TransformerLayer(d_model=c.d_model, + feed_forward=FeedForwardFTA(d_model=c.d_model, + d_ff=c.d_ff, + activation=fta, + dropout=0.1), + self_attn=MultiHeadAttention(c.n_heads, c.d_model, + dropout_prob=0.0), + dropout_prob=0.0)) + + # Move to the device + return m.to(c.device) + + +def main(): + """ + #### Create and run the experiment + """ + # Create experiment + experiment.create(name="fta", writers={'screen', 'comet', 'labml'}) + # Create configs + conf = Configs() + # Override configurations + experiment.configs(conf, { + # Use character level tokenizer + 'tokenizer': 'character', + # Prompt separator is blank + 'prompt_separator': '', + # Starting prompt for sampling + 'prompt': 'It is ', + # Use Tiny Shakespeare dataset + 'text': 'tiny_shakespeare', + + # Use a context size of $256$ + 'seq_len': 256, + # Train for 32 epochs + 'epochs': 32, + # Batch size $16$ + 'batch_size': 16, + # Switch between training and validation for $10$ times per epoch + 'inner_iterations': 10, + + # Adam optimizer with no warmup + 'optimizer.optimizer': 'Adam', + 'optimizer.learning_rate': 3e-4, + }) + + # Set model(s) for saving and loading + experiment.add_pytorch_models({'model': conf.model}) + + # Start the experiment + with experiment.start(): + # Run training + conf.run() + + +# +if __name__ == '__main__': + main() diff --git a/labml_nn/normalization/deep_norm/experiment.py b/labml_nn/normalization/deep_norm/experiment.py index fbb008c5..ade34a97 100644 --- a/labml_nn/normalization/deep_norm/experiment.py +++ b/labml_nn/normalization/deep_norm/experiment.py @@ -5,7 +5,7 @@ summary: > Training a DeepNorm transformer on Tiny Shakespeare. --- -# DeepNorm Experiment +# [DeepNorm](index.html) Experiment [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/normalization/deep_norm/experiment.ipynb) [![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/ec8e4dacb7f311ec8d1cd37d50b05c3d) diff --git a/labml_nn/transformers/rope/__init__.py b/labml_nn/transformers/rope/__init__.py index 4674f037..ae93f859 100644 --- a/labml_nn/transformers/rope/__init__.py +++ b/labml_nn/transformers/rope/__init__.py @@ -141,7 +141,7 @@ class RotaryPositionalEmbeddings(nn.Module): idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta) # Concatenate so that for row $m$ we have - # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta 0, m \theta 1, ..., m \theta_{\frac{d}{2}}]$ + # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$ idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., -x^{(\frac{d}{2})}]$ diff --git a/readme.md b/readme.md index f52723a8..8b59adf4 100644 --- a/readme.md +++ b/readme.md @@ -115,6 +115,11 @@ Solving games with incomplete information such as poker with CFR. * [Evidential Deep Learning to Quantify Classification Uncertainty](https://nn.labml.ai/uncertainty/evidence/index.html) +#### ✨ [Activations](https://nn.labml.ai/activations/index.html) + +* [Fuzzy Tiling Activations](https://nn.labml.ai/activations/fta/index.html) + + ## Highlighted Research Paper PDFs * [Autoregressive Search Engines: Generating Substrings as Document Identifiers](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf) diff --git a/setup.py b/setup.py index 4f22edf5..cc48eadc 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ with open("readme.md", "r") as f: setuptools.setup( name='labml-nn', - version='0.4.121', + version='0.4.122', author="Varuna Jayasiri, Nipun Wijerathne", author_email="vpjayasiri@gmail.com, hnipun@gmail.com", description="🧑‍🏫 Implementations/tutorials of deep learning papers with side-by-side notes 📝; including transformers (original, xl, switch, feedback, vit), optimizers (adam, radam, adabelief), gans(dcgan, cyclegan, stylegan2), 🎮 reinforcement learning (ppo, dqn), capsnet, distillation, etc. 🧠",