15import copy
diff --git a/docs/sitemap.xml b/docs/sitemap.xml
index 74a8e125..42cf8562 100644
--- a/docs/sitemap.xml
+++ b/docs/sitemap.xml
@@ -85,7 +85,21 @@
https://nn.labml.ai/activations/index.html
- 2021-01-25T16:30:00+00:00
+ 2022-05-23T16:30:00+00:00
+ 1.00
+
+
+
+
+ https://nn.labml.ai/activations/fta/index.html
+ 2022-05-23T16:30:00+00:00
+ 1.00
+
+
+
+
+ https://nn.labml.ai/activations/fta/experiment.html
+ 2022-05-23T16:30:00+00:00
1.00
@@ -197,7 +211,7 @@
https://nn.labml.ai/normalization/deep_norm/experiment.html
- 2022-04-23T16:30:00+00:00
+ 2022-05-23T16:30:00+00:00
1.00
@@ -295,7 +309,7 @@
https://nn.labml.ai/index.html
- 2022-05-03T16:30:00+00:00
+ 2022-05-23T16:30:00+00:00
1.00
@@ -589,7 +603,7 @@
https://nn.labml.ai/transformers/rope/index.html
- 2022-02-23T16:30:00+00:00
+ 2022-04-05T16:30:00+00:00
1.00
diff --git a/docs/transformers/rope/index.html b/docs/transformers/rope/index.html
index 5b40fd74..06512a47 100644
--- a/docs/transformers/rope/index.html
+++ b/docs/transformers/rope/index.html
@@ -214,7 +214,7 @@
- Concatenate so that for row m we have [mθ0,mθ1,...,mθ2d,mθ0,mθ1,...,mθ2d]
+ Concatenate so that for row m we have [mθ0,mθ1,...,mθ2d,mθ0,mθ1,...,mθ2d]
diff --git a/labml_nn/__init__.py b/labml_nn/__init__.py
index 229592fe..5688b209 100644
--- a/labml_nn/__init__.py
+++ b/labml_nn/__init__.py
@@ -112,6 +112,10 @@ Solving games with incomplete information such as poker with CFR.
* [Evidential Deep Learning to Quantify Classification Uncertainty](uncertainty/evidence/index.html)
+#### ✨ [Activations](activations/index.html)
+
+* [Fuzzy Tiling Activations](activations/fta/index.html)
+
## Highlighted Research Paper PDFs
* [Autoregressive Search Engines: Generating Substrings as Document Identifiers](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf)
diff --git a/labml_nn/activations/__init__.py b/labml_nn/activations/__init__.py
index 0306daa9..3aa9d24e 100644
--- a/labml_nn/activations/__init__.py
+++ b/labml_nn/activations/__init__.py
@@ -1 +1,14 @@
+"""
+---
+title: Neural Network Activation Functions
+summary: >
+ A set of PyTorch implementations/tutorials related to neural network activations
+---
+
+# Neural Networks Activations
+
+* [Fuzzy Tiling Activations](fta/index.html)
+* 🚧 [Swish](swish/index.html)
+"""
+
from .swish import Swish
diff --git a/labml_nn/activations/fta/__init__.py b/labml_nn/activations/fta/__init__.py
new file mode 100644
index 00000000..fd24407e
--- /dev/null
+++ b/labml_nn/activations/fta/__init__.py
@@ -0,0 +1,132 @@
+"""
+---
+title: Fuzzy Tiling Activations
+summary: >
+ PyTorch implementation and tutorial of Fuzzy Tiling Activations from the
+ paper Fuzzy Tiling Activations: A Simple Approach to Learning Sparse Representations Online.
+---
+
+# Fuzzy Tiling Activations (FTA)
+
+This is a [PyTorch](https://pytorch.org) implementation/tutorial of
+[Fuzzy Tiling Activations: A Simple Approach to Learning Sparse Representations Online](https://papers.labml.ai/paper/aca66d8edc8911eba3db37f65e372566).
+
+Fuzzy tiling activations are a form of sparse activations based on binning.
+
+Binning is classification of a scalar value into a bin based on intervals.
+One problem with binning is that it gives zero gradients for most values (except at the boundary of bins).
+The other is that binning loses precision if the bin intervals are large.
+
+FTA overcomes these disadvantages.
+Instead of hard boundaries like in Tiling Activations, FTA uses soft boundaries
+between bins.
+This gives non-zero gradients for all or a wide range of values.
+And also doesn't lose precision since it's captured in partial values.
+
+#### Tiling Activations
+
+$\mathbf{c}$ is the tiling vector,
+
+$$\mathbf{c} = (l, l + \delta, l + 2 \delta, \dots, u - 2 \delta, u - \delta)$$
+
+where $[l, u]$ is the input range, $\delta$ is the bin size, and $u - l$ is divisible by $\delta$.
+
+Tiling activation is,
+
+$$\phi(z) = 1 - I_+ \big( \max(\mathbf{c} - z, 0) + \max(z - \delta - \mathbf{c}) \big)$$
+
+where $I_+(\cdot)$ is the indicator function which gives $1$ if the input is positive and $0$ otherwise.
+
+Note that tiling activation gives zero gradients because it has hard boundaries.
+
+#### Fuzzy Tiling Activations
+
+The fuzzy indicator function,
+
+$$I_{\eta,+}(x) = I_+(\eta - x) x + I_+ (x - \eta)$$
+
+which increases linearly from $0$ to $1$ when $0 \le x \lt \eta$
+and is equal to $1$ for $\eta \le x$.
+$\eta$ is a hyper-parameter.
+
+FTA uses this to create soft boundaries between bins.
+
+$$\phi_\eta(z) = 1 - I_{\eta,+} \big( \max(\mathbf{c} - z, 0) + \max(z - \delta - \mathbf{c}, 0) \big)$$
+
+[Here's a simple experiment](experiment.html) that uses FTA in a transformer.
+
+[](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/activations/fta/experiment.ipynb)
+[](https://www.comet.ml/labml/fta/69be11f83693407f82a86dcbb232bcfe?experiment-tab=chart&showOutliers=true&smoothing=0&transformY=smoothing&viewId=rlJOpXDGtL8zbkcX66R77P5me&xAxis=step)
+"""
+
+import torch
+from torch import nn
+
+
+class FTA(nn.Module):
+ """
+ ### Fuzzy Tiling Activations (FTA)
+ """
+
+ def __init__(self, lower_limit: float, upper_limit: float, delta: float, eta: float):
+ """
+ :param lower_limit: is the lower limit $l$
+ :param upper_limit: is the upper limit $u$
+ :param delta: is the bin size $\delta$
+ :param eta: is the parameter $\eta$ that detemines the softness of the boundaries.
+ """
+ super().__init__()
+ # Initialize tiling vector
+ # $$\mathbf{c} = (l, l + \delta, l + 2 \delta, \dots, u - 2 \delta, u - \delta)$$
+ self.c = nn.Parameter(torch.arange(lower_limit, upper_limit, delta), requires_grad=False)
+ # The input vector expands by a factor equal to the number of bins $\frac{u - l}{\delta}$
+ self.expansion_factor = len(self.c)
+ # $\delta$
+ self.delta = delta
+ # $\eta$
+ self.eta = eta
+
+ def fuzzy_i_plus(self, x: torch.Tensor):
+ """
+ #### Fuzzy indicator function
+
+ $$I_{\eta,+}(x) = I_+(\eta - x) x + I_+ (x - \eta)$$
+ """
+ return (x <= self.eta) * x + (x > self.eta)
+
+ def forward(self, z: torch.Tensor):
+ # Add another dimension of size $1$.
+ # We will expand this into bins.
+ z = z.view(*z.shape, 1)
+
+ # $$\phi_\eta(z) = 1 - I_{\eta,+} \big( \max(\mathbf{c} - z, 0) + \max(z - \delta - \mathbf{c}, 0) \big)$$
+ z = 1. - self.fuzzy_i_plus(torch.clip(self.c - z, min=0.) + torch.clip(z - self.delta - self.c, min=0.))
+
+ # Reshape back to original number of dimensions.
+ # The last dimension size gets expanded by the number of bins, $\frac{u - l}{\delta}$.
+ return z.view(*z.shape[:-2], -1)
+
+
+def _test():
+ """
+ #### Code to test the FTA module
+ """
+ from labml.logger import inspect
+
+ # Initialize
+ a = FTA(-10, 10, 2., 0.5)
+ # Print $\mathbf{c}$
+ inspect(a.c)
+ # Print number of bins $\frac{u - l}{\delta}$
+ inspect(a.expansion_factor)
+
+ # Input $z$
+ z = torch.tensor([1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9., 10., 11.])
+ # Print $z$
+ inspect(z)
+ # Print $\phi_\eta(z)$
+ inspect(a(z))
+
+
+if __name__ == '__main__':
+ _test()
diff --git a/labml_nn/activations/fta/experiment.ipynb b/labml_nn/activations/fta/experiment.ipynb
new file mode 100644
index 00000000..f929f2e7
--- /dev/null
+++ b/labml_nn/activations/fta/experiment.ipynb
@@ -0,0 +1,299 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "AYV_dMVDxyc2",
+ "pycharm": {
+ "name": "#%% md\n"
+ }
+ },
+ "source": [
+ "[](https://github.com/labmlai/annotated_deep_learning_paper_implementations)\n",
+ "[](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/activations/fta/experiment.ipynb)\n",
+ "[](https://www.comet.ml/labml/fta/69be11f83693407f82a86dcbb232bcfe?experiment-tab=chart&showOutliers=true&smoothing=0&transformY=smoothing&viewId=rlJOpXDGtL8zbkcX66R77P5me&xAxis=step)\n",
+ "\n",
+ "## [Fuzzy Tiling Activations](https://nn.labml.ai/activations/fta/index.html)\n",
+ "\n",
+ "Here we train a transformer that uses [Fuzzy Tiling Activation](https://nn.labml.ai/activations/fta/index.html) in the\n",
+ "[Feed-Forward Network](https://nn.labml.ai/transformers/feed_forward.html).\n",
+ "We use it for a language model and train it on Tiny Shakespeare dataset\n",
+ "for demonstration.\n",
+ "However, this is probably not the ideal task for FTA, and we\n",
+ "believe FTA is more suitable for modeling data with continuous variables."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "AahG_i2y5tY9",
+ "pycharm": {
+ "name": "#%% md\n"
+ }
+ },
+ "source": [
+ "### Install the packages"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "ZCzmCrAIVg0L",
+ "outputId": "cf107fb2-4d50-4c67-af34-367624553421",
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "!pip install labml-nn comet_ml --quiet"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "pycharm": {
+ "name": "#%% md\n"
+ }
+ },
+ "source": [
+ "### Enable [Comet](https://www.comet.ml)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "#@markdown Select in order to enable logging this experiment to [Comet](https://www.comet.ml).\n",
+ "use_comet = False #@param {type:\"boolean\"}\n",
+ "\n",
+ "if use_comet:\n",
+ " import comet_ml\n",
+ " comet_ml.init(project_name='fta')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "SE2VUQ6L5zxI",
+ "pycharm": {
+ "name": "#%% md\n"
+ }
+ },
+ "source": [
+ "### Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "\n",
+ "from labml import experiment\n",
+ "from labml.configs import option\n",
+ "from labml_nn.activations.fta.experiment import Configs"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Create an experiment"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%% md\n"
+ }
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "outputs": [],
+ "source": [
+ "experiment.create(name=\"fta\", writers={\"screen\", \"comet\"} if use_comet else {'screen'})"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Configurations"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%% md\n"
+ }
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "outputs": [],
+ "source": [
+ "conf = Configs()"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Set experiment configurations and assign a configurations dictionary to override configurations"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%% md\n"
+ }
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "outputs": [],
+ "source": [
+ "experiment.configs(conf, {\n",
+ " 'tokenizer': 'character',\n",
+ " 'prompt_separator': '',\n",
+ " 'prompt': 'It is ',\n",
+ " 'text': 'tiny_shakespeare',\n",
+ "\n",
+ " 'seq_len': 256,\n",
+ " 'epochs': 32,\n",
+ " 'batch_size': 16,\n",
+ " 'inner_iterations': 10,\n",
+ "\n",
+ " 'optimizer.optimizer': 'Adam',\n",
+ " 'optimizer.learning_rate': 3e-4,\n",
+ "})"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "EvI7MtgJ61w5",
+ "pycharm": {
+ "name": "#%% md\n"
+ }
+ },
+ "source": [
+ "Set PyTorch models for loading and saving"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 255
+ },
+ "id": "GDlt7dp-5ALt",
+ "outputId": "e7548e8f-c541-4618-dc5a-1597cae42003",
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "experiment.add_pytorch_models({'model': conf.model})"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "KJZRf8527GxL",
+ "pycharm": {
+ "name": "#%% md\n"
+ }
+ },
+ "source": [
+ "### Start the experiment and run the training loop."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1000
+ },
+ "id": "aIAWo7Fw5DR8",
+ "outputId": "db979785-bfe3-4eda-d3eb-8ccbe61053e5",
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Start the experiment\n",
+ "with experiment.start():\n",
+ " conf.run()"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "collapsed_sections": [],
+ "name": "FTA",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.11"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
\ No newline at end of file
diff --git a/labml_nn/activations/fta/experiment.py b/labml_nn/activations/fta/experiment.py
new file mode 100644
index 00000000..be740a02
--- /dev/null
+++ b/labml_nn/activations/fta/experiment.py
@@ -0,0 +1,220 @@
+"""
+---
+title: Fuzzy Tiling Activation Experiment
+summary: >
+ Training a transformer with FTA in FFN on Tiny Shakespeare.
+---
+
+# [Fuzzy Tiling Activation](index.html) Experiment
+
+Here we train a transformer that uses [Fuzzy Tiling Activation](index.html) in the
+[Feed-Forward Network](../../transformers/feed_forward.html).
+We use it for a language model and train it on Tiny Shakespeare dataset
+for demonstration.
+
+However, this is probably not the ideal task for FTA, and we
+believe FTA is more suitable for modeling data with continuous variables.
+
+[](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/activations/fta/experiment.ipynb)
+[](https://www.comet.ml/labml/fta/69be11f83693407f82a86dcbb232bcfe?experiment-tab=chart&showOutliers=true&smoothing=0&transformY=smoothing&viewId=rlJOpXDGtL8zbkcX66R77P5me&xAxis=step)
+"""
+
+import copy
+
+import torch
+import torch.nn as nn
+
+from labml import experiment
+from labml.configs import option
+from labml_helpers.module import Module
+from labml_nn.activations.fta import FTA
+from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
+from labml_nn.transformers import MultiHeadAttention, TransformerLayer
+from labml_nn.transformers.utils import subsequent_mask
+
+
+class FeedForwardFTA(nn.Module):
+ """
+ ## FFN module with [FTA](index.html) activation
+ """
+
+ def __init__(self, d_model: int, d_ff: int,
+ activation: FTA,
+ dropout: float = 0.1):
+ """
+ * `d_model` is the number of features in a token embedding
+ * `d_ff` is the number of features in the hidden layer of the FFN
+ * `activation` is FTA activation module
+ * `dropout` is dropout probability for the hidden layer
+ """
+ super().__init__()
+ # Layer one parameterized by weight $W_1$ and bias $b_1$
+ self.layer1 = nn.Linear(d_model, d_ff)
+ # Layer two parameterized by weight $W_1$ and bias $b_1$
+ self.layer2 = nn.Linear(d_ff * activation.expansion_factor, d_model)
+ # Hidden layer dropout
+ self.dropout = nn.Dropout(dropout)
+ # Activation function $f$
+ self.activation = activation
+
+ def forward(self, x: torch.Tensor):
+ # $f(x W_1 + b_1)$
+ x = self.activation(self.layer1(x))
+ # Apply dropout
+ x = self.dropout(x)
+ #
+ return self.layer2(x)
+
+
+class AutoregressiveTransformer(Module):
+ """
+ ## Auto-Regressive model
+
+ This is an autoregressive transformer model that uses Feed-Forward Networks with
+ (Fuzzy Tiling Activations)(index.html).
+ """
+
+ def __init__(self, n_tokens: int, d_model: int, n_layers: int, layer: TransformerLayer):
+ """
+ :param n_tokens: is the number of tokens in the vocabulary
+ :param d_model: is the embedding size
+ :param n_layers: is the number of transformer layers
+ :param layer: is the layer. We use `n_layers` copies of this for the transformer.
+ """
+ super().__init__()
+ # Transformer with `n_layers` layers
+ self.transformer_layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layers)])
+
+ # Token embedding layer
+ self.emb = nn.Embedding(n_tokens, d_model)
+ # Readout layer
+ self.readout = nn.Linear(d_model, n_tokens)
+
+ # The mask will be initialized on the first call
+ self.mask = None
+
+ def forward(self, x: torch.Tensor):
+ """
+ :param x: are the input tokens of shape `[seq_len, batch_size]`
+ """
+ # Create auto-regressive mask
+ if self.mask is None or self.mask.size(0) != len(x):
+ # Subsequent mask, will mask out tokens from seeing future tokens
+ self.mask = subsequent_mask(len(x)).to(x.device)
+
+ # Get the token embeddings
+ x = self.emb(x)
+ # Transformer encoder
+ for layer in self.transformer_layers:
+ x = layer(x=x, mask=self.mask)
+ # Get logits
+ x = self.readout(x)
+
+ # Return results
+ return x, None
+
+
+class Configs(NLPAutoRegressionConfigs):
+ """
+ ## Configurations
+
+ This inherits from
+ [`NLPAutoRegressionConfigs`](../../experiments/nlp_autoregression.html#NLPAutoRegressionConfigs)
+ """
+
+ # Model
+ model: AutoregressiveTransformer
+
+ # Number of layers
+ n_layers: int = 4
+
+ # $\alpha$ and $\beta$ for DeepNorm
+ deep_norm_alpha: float
+ deep_norm_beta: float
+
+ # Number of heads in the attention
+ n_heads: int = 4
+ # Embedding size
+ d_model: int = 256
+ # Size of each attention head
+ d_k: int = 16
+ # Feed forward layer size
+ d_ff: int = 256
+
+ # FTA
+ fta_lower_limit: float = -1.
+ fta_upper_limit: float = +1.
+ fta_delta: float = 0.2
+ fta_eta: float = 0.05
+
+
+@option(Configs.model)
+def _model(c: Configs):
+ """
+ #### Initialize the model
+ """
+
+ # Create FTA activation module
+ fta = FTA(c.fta_lower_limit, c.fta_upper_limit, c.fta_delta, c.fta_eta)
+ # Create the transformer.
+ # We re-use [`TransformerLayer`](../../transformers/models.html#TransformerLayer) and
+ # [`MultiHeadAttention`](../../transformers/mha.html) implementations.
+ m = AutoregressiveTransformer(c.n_tokens, c.d_model, c.n_layers,
+ TransformerLayer(d_model=c.d_model,
+ feed_forward=FeedForwardFTA(d_model=c.d_model,
+ d_ff=c.d_ff,
+ activation=fta,
+ dropout=0.1),
+ self_attn=MultiHeadAttention(c.n_heads, c.d_model,
+ dropout_prob=0.0),
+ dropout_prob=0.0))
+
+ # Move to the device
+ return m.to(c.device)
+
+
+def main():
+ """
+ #### Create and run the experiment
+ """
+ # Create experiment
+ experiment.create(name="fta", writers={'screen', 'comet', 'labml'})
+ # Create configs
+ conf = Configs()
+ # Override configurations
+ experiment.configs(conf, {
+ # Use character level tokenizer
+ 'tokenizer': 'character',
+ # Prompt separator is blank
+ 'prompt_separator': '',
+ # Starting prompt for sampling
+ 'prompt': 'It is ',
+ # Use Tiny Shakespeare dataset
+ 'text': 'tiny_shakespeare',
+
+ # Use a context size of $256$
+ 'seq_len': 256,
+ # Train for 32 epochs
+ 'epochs': 32,
+ # Batch size $16$
+ 'batch_size': 16,
+ # Switch between training and validation for $10$ times per epoch
+ 'inner_iterations': 10,
+
+ # Adam optimizer with no warmup
+ 'optimizer.optimizer': 'Adam',
+ 'optimizer.learning_rate': 3e-4,
+ })
+
+ # Set model(s) for saving and loading
+ experiment.add_pytorch_models({'model': conf.model})
+
+ # Start the experiment
+ with experiment.start():
+ # Run training
+ conf.run()
+
+
+#
+if __name__ == '__main__':
+ main()
diff --git a/labml_nn/normalization/deep_norm/experiment.py b/labml_nn/normalization/deep_norm/experiment.py
index fbb008c5..ade34a97 100644
--- a/labml_nn/normalization/deep_norm/experiment.py
+++ b/labml_nn/normalization/deep_norm/experiment.py
@@ -5,7 +5,7 @@ summary: >
Training a DeepNorm transformer on Tiny Shakespeare.
---
-# DeepNorm Experiment
+# [DeepNorm](index.html) Experiment
[](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/normalization/deep_norm/experiment.ipynb)
[](https://app.labml.ai/run/ec8e4dacb7f311ec8d1cd37d50b05c3d)
diff --git a/labml_nn/transformers/rope/__init__.py b/labml_nn/transformers/rope/__init__.py
index 4674f037..ae93f859 100644
--- a/labml_nn/transformers/rope/__init__.py
+++ b/labml_nn/transformers/rope/__init__.py
@@ -141,7 +141,7 @@ class RotaryPositionalEmbeddings(nn.Module):
idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta)
# Concatenate so that for row $m$ we have
- # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta 0, m \theta 1, ..., m \theta_{\frac{d}{2}}]$
+ # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., -x^{(\frac{d}{2})}]$
diff --git a/readme.md b/readme.md
index f52723a8..8b59adf4 100644
--- a/readme.md
+++ b/readme.md
@@ -115,6 +115,11 @@ Solving games with incomplete information such as poker with CFR.
* [Evidential Deep Learning to Quantify Classification Uncertainty](https://nn.labml.ai/uncertainty/evidence/index.html)
+#### ✨ [Activations](https://nn.labml.ai/activations/index.html)
+
+* [Fuzzy Tiling Activations](https://nn.labml.ai/activations/fta/index.html)
+
+
## Highlighted Research Paper PDFs
* [Autoregressive Search Engines: Generating Substrings as Document Identifiers](https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/papers/2204.10628.pdf)
diff --git a/setup.py b/setup.py
index 4f22edf5..cc48eadc 100644
--- a/setup.py
+++ b/setup.py
@@ -5,7 +5,7 @@ with open("readme.md", "r") as f:
setuptools.setup(
name='labml-nn',
- version='0.4.121',
+ version='0.4.122',
author="Varuna Jayasiri, Nipun Wijerathne",
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
description="🧑🏫 Implementations/tutorials of deep learning papers with side-by-side notes 📝; including transformers (original, xl, switch, feedback, vit), optimizers (adam, radam, adabelief), gans(dcgan, cyclegan, stylegan2), 🎮 reinforcement learning (ppo, dqn), capsnet, distillation, etc. 🧠",