From e6b3c8a6a2cdae4071bcde1abeaaf9a2bc37bd53 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Thu, 21 Aug 2025 12:34:04 +0530 Subject: [PATCH] jax docs --- docs/sitemap.xml | 7 + docs/transformers/jax_transformer/index.html | 3402 ++++++++++++++++++ 2 files changed, 3409 insertions(+) create mode 100644 docs/transformers/jax_transformer/index.html diff --git a/docs/sitemap.xml b/docs/sitemap.xml index c713c598..63627e5f 100644 --- a/docs/sitemap.xml +++ b/docs/sitemap.xml @@ -1042,6 +1042,13 @@ + + https://nn.labml.ai/transformers/jax_transformer/index.html + 2025-08-21T16:30:00+00:00 + 1.00 + + + https://nn.labml.ai/transformers/feedback/index.html 2025-07-18T16:30:00+00:00 diff --git a/docs/transformers/jax_transformer/index.html b/docs/transformers/jax_transformer/index.html new file mode 100644 index 00000000..fb3e3c0f --- /dev/null +++ b/docs/transformers/jax_transformer/index.html @@ -0,0 +1,3402 @@ + + + + + + + + + + + + + + + + + + + + + + + Autoregressive Transformer Decoder in JAX from scratch + + + + + + + + + + +
+
+ +
+ +
+
28from functools import partial
+29from typing import Dict, NamedTuple, Tuple, Any, Callable
+30from typing import List, TypeVar, Generic
+31from typing import Union, Optional
+32
+33import jax
+34import jax.numpy as jnp
+35import numpy as np
+36
+37from labml import lab, monit, experiment, tracker
+38from labml import logger
+39from labml.logger import Text
+40from labml.utils.download import download_file
+
+
+
+
+ +

+

Module

+

This is a base class for all modules. It handles parameters and transforms methods to pure functions for JAX to compile and differentiate.

+

You can skip these modules to get into the models directly.

+

The modules stores parameters and sub-modules separately. When we want to transform any method to a pure function, we pass the parameters of the module and the sub-module as an argument and assign the passed values to class.

+

This is based on a blog post: From PyTorch to JAX: towards neural net frameworks that purify stateful code.

+ +
+
+
43class Module:
+
+
+
+
+ +

Store all parameters and sub-modules in dictionaries

+ +
+
+
63    _submodules: Dict[str, 'Module']
+64    _params: Dict[str, jnp.ndarray]
+
+
+
+
+ +

Initialize

+ +
+
+
66    def __init__(self):
+
+
+
+
+ + +
+
+
68        self._params = {}
+69        self._submodules = {}
+
+
+
+
+ +

Get attribute

+

We override the get attribute operation. So when you reference an attribute with model.attribute + this function gets called.

+

Read this guide if you are not familiar with Python magic methods.

+ +
+
+
71    def __getattr__(self, attr_name: str):
+
+
+
+
+ +

If the attribute is a parameter

+ +
+
+
83        if attr_name in self._params:
+84            return self._params[attr_name]
+
+
+
+
+ +

If the attribute is a sub-module

+ +
+
+
86        elif attr_name in self._submodules:
+87            return self._submodules[attr_name]
+
+
+
+
+ +

Otherwise fallback to normal attributes. The attributes are stored in __dict__ + by Python.

+ +
+
+
90        else:
+91            return self.__dict__[attr_name]
+
+
+
+
+ +

Set attribute

+

We override the set attribute operation. So when you assign an attribute with model.attribute + this function gets called.

+ +
+
+
93    def __setattr__(self, key: str, value: Any):
+
+
+
+
+ +

If the value is also a module

+ +
+
+
102        if isinstance(value, Module):
+103            self._submodules[key] = value
+
+
+
+
+ +

If the value is a JAX array

+ +
+
+
105        elif isinstance(value, jnp.ndarray):
+106            self._params[key] = value
+
+
+
+
+ +

Otherwise add it to __dict__ +

+ +
+
+
108        else:
+109            self.__dict__[key] = value
+
+
+
+
+ +

Clear parameters

+

These clears out all the parameters. This is used when a method is called as a pure function. We first clears out all the parameters and assigns the parameters passed to the pure function.

+ +
+
+
111    def _clear_params(self):
+
+
+
+
+ +

Clear parameters of the module

+ +
+
+
120        self._params = {}
+
+
+
+
+ +

Recursively clear parameters of submodules

+ +
+
+
122        for sm in self._submodules.values():
+123            sm._clear_params()
+
+
+
+
+ +

Collect all the parameters

+

This recursively collects all the parameters of the module and sub-modules into a dictionary.

+ +
+
+
125    def get_params(self) -> Dict[str, jnp.ndarray]:
+
+
+
+
+ +

Parameters of the model

+ +
+
+
133        params = self._params.copy()
+
+
+
+
+ +

Parameters of the submodules

+ +
+
+
135        for sm_name, sm in self._submodules.items():
+136            for name, value in sm.get_params().items():
+
+
+
+
+ +

The dictionary keys are of the form module_name/module_name/param_name +

+ +
+
+
138                params[sm_name + "/" + name] = value
+
+
+
+
+ +

+ +
+
+
140        return params
+
+
+
+
+ +

Set all the parameters

+ +
+
+
142    def _set_params(self, params: Dict[str, jnp.ndarray]):
+
+
+
+
+ +

Iterate through parameters. Their names have the form module_name/module_name/param_name +

+ +
+
+
149        for name, value in params.items():
+
+
+
+
+ +

Split to get module names and parameter name

+ +
+
+
151            self._set_param(name.split("/"), value)
+
+
+
+
+ +

Set a single parameter

+

This is called by _set_params +

+ +
+
+
153    def _set_param(self, param_path: List[str], value: jnp.ndarray):
+
+
+
+
+ +

No module names; i.e. a parameter of this module

+ +
+
+
160        if len(param_path) == 1:
+161            self._params[param_path[0]] = value
+
+
+
+
+ +

Parameter of a submodule

+ +
+
+
163        else:
+164            self._submodules[param_path[0]]._set_param(param_path[1:], value)
+
+
+
+
+ +

Transform a member method to a pure function

+

This transforms a member method to a pure function that accepts a dictionary of parameters as an argument.

+

For example,

+
params = model.get_params()
+pure_function = model.purify(model.calculate_loss)
+output = pure_function(params, data)
+ +
+
+
166    def purify(self, method: Callable) -> Callable:
+
+
+
+
+ + +
+
+
182        def pure_method(params: Dict[str, jnp.array], *args):
+
+
+
+
+ +

Clear parameters in the object

+ +
+
+
184            self._clear_params()
+
+
+
+
+ +

Assign the passed parameters

+ +
+
+
186            self._set_params(params)
+
+
+
+
+ +

Invoke the method

+ +
+
+
188            result = method(*args)
+
+
+
+
+ +

Return the result

+ +
+
+
190            return result
+
+
+
+
+ +

+ +
+
+
193        return pure_method
+
+
+
+
+ +

Type for generics in the module list class

+ +
+
+
197M = TypeVar('M', bound=Module)
+
+
+
+
+ +

Module list

+

This stores a list of modules. We needed this for transformer decoder to hold the list of transformer layers.

+ +
+
+
200class ModuleList(Module, Generic[M]):
+
+
+
+
+ +

For list of modules

+ +
+
+
209    _submodules: List[M]
+
+
+
+
+ +

Initialize with a list of modules.

+ +
+
+
211    def __init__(self, modules: List[M]):
+
+
+
+
+ + +
+
+
215        super().__init__()
+216        self._submodules = modules
+
+
+
+
+ +

Get the idx +-th module

+ +
+
+
218    def __getitem__(self, idx: int) -> M:
+
+
+
+
+ + +
+
+
222        return self._submodules[idx]
+
+
+
+
+ +

This is not supported

+ +
+
+
224    def __setitem__(self, key, value):
+
+
+
+
+ + +
+
+
228        raise NotImplementedError
+
+
+
+
+ +

Number of modules

+ +
+
+
230    def __len__(self):
+
+
+
+
+ + +
+
+
234        return len(self._submodules)
+
+
+
+
+ +

Override __getattr__ + of Module +

+ +
+
+
236    def __getattr__(self, item):
+
+
+
+
+ + +
+
+
240        return self.__dict__[item]
+
+
+
+
+ +

Override __setattr__ + of Module +

+ +
+
+
242    def __setattr__(self, key, value):
+
+
+
+
+ + +
+
+
246        self.__dict__[key] = value
+
+
+
+
+ +

Clear all parameters

+ +
+
+
248    def _clear_params(self):
+
+
+
+
+ + +
+
+
252        self._params = {}
+253        for sm in self._submodules:
+254            sm._clear_params()
+
+
+
+
+ +

Get all parameters

+ +
+
+
256    def get_params(self):
+
+
+
+
+ + +
+
+
260        params = self._params
+261        for i, sm in enumerate(self._submodules):
+262            for name, value in sm.get_params().items():
+263                params[f'{i}/{name}'] = value
+264        return params
+
+
+
+
+ +

Set a parameter

+ +
+
+
266    def _set_param(self, param_path: List[str], value: jnp.ndarray):
+
+
+
+
+ + +
+
+
270        self._submodules[int(param_path[0])]._set_param(param_path[1:], value)
+
+
+
+
+ +

+

Embedding layer

+

This maintains embeddings by id.

+ +
+
+
273class Embedding(Module):
+
+
+
+
+ +
  • rnd_key + is the PRNG state
  • +
  • n_embeddings + is the number of embeddings
  • +
  • n_dim + is the size of an embedding
+ +
+
+
282    def __init__(self, rnd_key: jax.random.PRNGKey, n_embeddings: int, n_dim: int):
+
+
+
+
+ + +
+
+
288        super().__init__()
+
+
+
+
+ +

Embeddings are initialized from

+ +
+
+
290        self.embeddings = jax.random.normal(rnd_key, (n_embeddings, n_dim))
+
+
+
+
+ +

Return the embeddings for the given ids

+ +
+
+
292    def __call__(self, ids: jnp.ndarray):
+
+
+
+
+ + +
+
+
296        return self.embeddings[ids, :]
+
+
+
+
+ +

+

Embed tokens and add parameterized positional encodings

+

This is based on our PyTorch implementation.

+ +
+
+
299class EmbeddingsWithLearnedPositionalEncoding(Module):
+
+
+
+
+ +
  • rnd_key + is the PRNG state
  • +
  • n_vocab + is the vocabulary size
  • +
  • d_model + is the embedding size
  • +
  • max_len + is the maximum sequence length (to initialize positional encodings)
+ +
+
+
309    def __init__(self, rnd_key: jax.random.PRNGKey, n_vocab: int, d_model: int, max_len: int = 4096):
+
+
+
+
+ + +
+
+
316        super().__init__()
+
+
+
+
+ +

Embeddings

+ +
+
+
318        self.embeddings = Embedding(rnd_key, n_vocab, d_model)
+
+
+
+
+ +

Positional encodings coefficient

+ +
+
+
320        self.pe_coef = 1 / d_model ** 0.5
+
+
+
+
+ +

Positional encodings initialized to zeros

+ +
+
+
322        self.positional_encodings = jnp.zeros((max_len, d_model))
+
+
+
+
+ + +
+
+
324    def __call__(self, x: jnp.ndarray):
+
+
+
+
+ +

Get positional encodings

+ +
+
+
326        pe = self.positional_encodings[:x.shape[0]]
+
+
+
+
+ +

Get embeddings and add positional encodings

+ +
+
+
328        return self.embeddings(x) * self.pe_coef + pe
+
+
+
+
+ +

+

Linear Layer

+

This is a simple linear layer with a weight matrix and a bias vector

+ +
+
+
331class Linear(Module):
+
+
+
+
+ +
  • rnd_key + is the PRNG state
  • +
  • in_features + is the number of features in the input
  • +
  • out_features + is the number of features in the output
+ +
+
+
340    def __init__(self, rnd_key: jax.random.PRNGKey, in_features: int, out_features: int):
+
+
+
+
+ + +
+
+
346        super().__init__()
+
+
+
+
+ +

Initialize weights to

+ +
+
+
349        rnd_range = 1 / in_features ** 0.5
+350        self.weight = jax.random.uniform(rnd_key, (in_features, out_features),
+351                                         minval=-rnd_range, maxval=rnd_range)
+
+
+
+
+ +

Initialize the biases to

+ +
+
+
353        self.bias = jnp.zeros((out_features,))
+
+
+
+
+ + +
+
+
355    def __call__(self, x: jnp.ndarray):
+
+
+
+
+ +

Multiply by weights and add the bias

+ +
+
+
357        return jnp.matmul(x, self.weight) + self.bias
+
+
+
+
+ +

+

Layer Normalization

+

This implements the the layer normalization from the paper Layer Normalization.

+

When input is a sequence of embeddings, where is the number of channels, is the length of the sequence. and .

+

This is based on our PyTorch implementation.

+ +
+
+
360class LayerNorm(Module):
+
+
+
+
+ +
  • normalized_shape + is the shape of the elements (except the batch). The input should then be
  • +
  • eps + is , used in for numerical stability
  • +
  • elementwise_affine + is whether to scale and shift the normalized value
+ +
+
+
380    def __init__(self, normalized_shape: Union[Tuple[int], List[int]], *,
+381                 eps: float = 1e-5, elementwise_affine: bool = True):
+
+
+
+
+ + +
+
+
389        super().__init__()
+390
+391        self.eps = eps
+392        self.elementwise_affine = elementwise_affine
+393        self.normalized_shape = tuple(normalized_shape)
+
+
+
+
+ +

Create parameters for and for gain and bias

+ +
+
+
396        if elementwise_affine:
+397            self.gain = jnp.ones(normalized_shape)
+398            self.bias = jnp.zeros(normalized_shape)
+
+
+
+
+ + +
+
+
400    def __call__(self, x: jnp.ndarray):
+
+
+
+
+ +

Sanity check to make sure the shapes match

+ +
+
+
402        assert self.normalized_shape == x.shape[-len(self.normalized_shape):]
+
+
+
+
+ +

The exes to calculate the mean and variance on

+ +
+
+
405        axes = [-(i + 1) for i in range(len(self.normalized_shape))]
+
+
+
+
+ +

Calculate the mean of all elements; i.e. the means for each element

+ +
+
+
408        mean = x.mean(axis=axes, keepdims=True)
+
+
+
+
+ +

Calculate the squared mean of all elements; i.e. the means for each element

+ +
+
+
411        mean_2 = (x ** 2).mean(axis=axes, keepdims=True)
+
+
+
+
+ +

Variance of all element

+ +
+
+
413        var = mean_2 - mean ** 2
+
+
+
+
+ +

Normalize

+ +
+
+
415        x_norm = (x - mean) / (var + self.eps) ** 0.5
+
+
+
+
+ +

Scale and shift

+ +
+
+
418        if self.elementwise_affine:
+419            x_norm = self.gain * x_norm + self.bias
+
+
+
+
+ +

+ +
+
+
422        return x_norm
+
+
+
+
+ +

+

Multi-Head Attention Module

+

This computes scaled multi-headed attention from the paper Attention Is All You Need for given query +, key + and value + vectors.

+

+

In simple terms, it finds keys that matches the query, and gets the values of those keys.

+

It uses dot-product of query and key as the indicator of how matching they are. Before taking the the dot-products are scaled by . This is done to avoid large dot-product values causing softmax to give very small gradients when is large.

+

Softmax is calculated along the axis of of the sequence (or time) for keys.

+

This is based on our PyTorch implementation.

+ +
+
+
425class MultiHeadAttention(Module):
+
+
+
+
+ +
  • rnd_key + is the PRNG state
  • +
  • heads + is the number of heads.
  • +
  • d_model + is the number of features in the query +, key + and value + vectors.
+ +
+
+
451    def __init__(self, rnd_key: jax.random.PRNGKey, heads: int, d_model: int):
+
+
+
+
+ + +
+
+
458        super().__init__()
+
+
+
+
+ +

Split the PRNG state

+ +
+
+
461        _, *rnd_keys = jax.random.split(rnd_key, 5)
+
+
+
+
+ +

Number of features per head

+ +
+
+
464        self.d_k = d_model // heads
+
+
+
+
+ +

Number of heads

+ +
+
+
466        self.heads = heads
+
+
+
+
+ +

These transform the query +, key + and value + vectors for multi-headed attention.

+ +
+
+
469        self.query = Linear(rnd_keys[0], d_model, d_model)
+470        self.key = Linear(rnd_keys[1], d_model, d_model)
+471        self.value = Linear(rnd_keys[2], d_model, d_model)
+
+
+
+
+ +

Output layer

+ +
+
+
474        self.output = Linear(rnd_keys[3], d_model, d_model)
+
+
+
+
+ +

Scaling factor before the softmax

+ +
+
+
476        self.scale = 1 / self.d_k ** 0.5
+
+
+
+
+ +

query +, key + and value + are the tensors that store collection of query, key and value vectors. They have shape [seq_len, d_model] +.

+

mask + has shape [seq_len, seq_len] + and mask[i, j] + indicates whether query at position i + can see key-value at position j +.

+ +
+
+
478    def __call__(self, *,
+479                 query: jnp.ndarray,
+480                 key: jnp.ndarray,
+481                 value: jnp.ndarray,
+482                 mask: Optional[jnp.ndarray] = None):
+
+
+
+
+ +

Get sequence length

+ +
+
+
493        seq_len = len(query)
+494
+495        if mask is not None:
+
+
+
+
+ +

Check mask shape

+ +
+
+
497            assert mask.shape[0] == query.shape[0]
+498            assert mask.shape[1] == key.shape[0]
+
+
+
+
+ +

Same mask applied to all heads.

+ +
+
+
501            mask = mask[:, :, None]
+
+
+
+
+ +

Apply linear transformations

+ +
+
+
504        query = self.query(query)
+505        key = self.key(key)
+506        value = self.value(value)
+
+
+
+
+ +

Reshape to split into heads Input has shape [seq_len, batch_size, d_model] +. We split the last dimension into heads + and d_k +.

+ +
+
+
511        query = query.reshape(*query.shape[:-1], self.heads, self.d_k)
+512        key = key.reshape(*key.shape[:-1], self.heads, self.d_k)
+513        value = value.reshape(*value.shape[:-1], self.heads, self.d_k)
+
+
+
+
+ +

Compute attention scores . This gives a tensor of shape [seq_len, seq_len, heads] +.

+ +
+
+
518        scores = jnp.einsum('ihd,jhd->ijh', query, key)
+
+
+
+
+ +

Scale scores

+ +
+
+
521        scores *= self.scale
+
+
+
+
+ +

Apply mask

+ +
+
+
524        if mask is not None:
+525            scores = scores + (mask == 0) * float('-inf')
+
+
+
+
+ +

attention along the key sequence dimension

+ +
+
+
529        attn = jax.nn.softmax(scores, axis=1)
+
+
+
+
+ +

Multiply by values

+ +
+
+
533        x = jnp.einsum("ijh,jhd->ihd", attn, value)
+
+
+
+
+ +

Concatenate multiple heads

+ +
+
+
536        x = x.reshape(seq_len, -1)
+
+
+
+
+ +

Output layer

+ +
+
+
539        return self.output(x)
+
+
+
+
+ +

+

Position-wise Feed-Forward layer

+

This is based on our PyTorch implementation.

+ +
+
+
542class FeedForward(Module):
+
+
+
+
+ +
  • rnd_key + is the PRNG state
  • +
  • d_model + is the number of features in a token embedding
  • +
  • d_ff + is the number of features in the hidden layer of the FFN
  • +
  • activation + is the activation function
+ +
+
+
552    def __init__(self, rnd_key: jax.random.PRNGKey, d_model: int, d_ff: int,
+553                 activation=jax.nn.relu):
+
+
+
+
+ + +
+
+
560        super().__init__()
+
+
+
+
+ +

Split the PRNG state

+ +
+
+
562        _, *rnd_keys = jax.random.split(rnd_key, 5)
+
+
+
+
+ +

Layer one parameterized by weight and bias

+ +
+
+
565        self.layer1 = Linear(rnd_keys[0], d_model, d_ff)
+
+
+
+
+ +

Layer one parameterized by weight and bias

+ +
+
+
567        self.layer2 = Linear(rnd_keys[1], d_ff, d_model)
+
+
+
+
+ +

Activation function

+ +
+
+
569        self.activation = activation
+
+
+
+
+ + +
+
+
571    def __call__(self, x: jnp.ndarray):
+
+
+
+
+ +

+ +
+
+
573        x = self.activation(self.layer1(x))
+
+
+
+
+ +

+ +
+
+
575        return self.layer2(x)
+
+
+
+
+ +

+

Transformer Layer

+

This is a transformer layer with multi-head attention and a position-wise feed-forward layer. We use pre-layer layer normalization.

+ +
+
+
578class TransformerLayer(Module):
+
+
+
+
+ +
  • d_model + is the token embedding size
  • +
  • self_attn + is the self attention module
  • +
  • feed_forward + is the feed forward module
+ +
+
+
588    def __init__(self,
+589                 d_model: int,
+590                 self_attn: MultiHeadAttention,
+591                 feed_forward: FeedForward):
+
+
+
+
+ + +
+
+
597        super().__init__()
+598        self.size = d_model
+599        self.self_attn = self_attn
+600        self.feed_forward = feed_forward
+601        self.norm_self_attn = LayerNorm([d_model])
+602        self.norm_ff = LayerNorm([d_model])
+
+
+
+
+ + +
+
+
604    def __call__(self, x: jnp.ndarray, mask: jnp.ndarray):
+
+
+
+
+ +

Normalize the vectors before doing self attention

+ +
+
+
606        z = self.norm_self_attn(x)
+
+
+
+
+ +

Run through self attention, i.e. keys and values are from self

+ +
+
+
608        self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)
+609        x = x + self_attn
+
+
+
+
+ +

Normalize for feed-forward

+ +
+
+
612        z = self.norm_ff(x)
+
+
+
+
+ +

Pass through the feed-forward network

+ +
+
+
614        ff = self.feed_forward(z)
+
+
+
+
+ +

Add the feed-forward results

+ +
+
+
616        x = x + ff
+
+
+
+
+ +

+ +
+
+
618        return x
+
+
+
+
+ +

+

Cross Entropy Loss

+ +
+
+
621class CrossEntropyLoss(Module):
+
+
+
+
+ + +
+
+
628    def __init__(self):
+629        super().__init__()
+
+
+
+
+ +

Use jax.vmap + to vectorize the loss function

+ +
+
+
632        self._loss_vmap = jax.vmap(self._loss, in_axes=(0, 0,))
+
+
+
+
+ + +
+
+
634    def _loss(self, output: jnp.ndarray, target: jnp.ndarray):
+
+
+
+
+ +

+ +
+
+
636        return -jax.nn.log_softmax(output)[target]
+
+
+
+
+ +
  • output + is the model outputs of shape [seq_len, n_vocab] +
  • +
  • target + is the target of shape [seq_len] +
+ +
+
+
638    def __call__(self, output: jnp.ndarray, target: jnp.ndarray):
+
+
+
+
+ +

Use the vectorized loss function and calculate the mean.

+

We could have used a for loop to calculate the losses but using vmap is about 10X faster

+ +
+
+
647        return self._loss_vmap(output, target).mean()
+
+
+
+
+ +

+

Autoregressive Transformer

+

This is the transformer decode with embedding and output layers.

+ +
+
+
650class AutoregressiveTransformer(Module):
+
+
+
+
+ + +
+
+
658    layers: ModuleList[TransformerLayer]
+
+
+
+
+ +
  • rnd_key + is the PRNG state
  • +
  • n_vocab + is the vocabulary size
  • +
  • d_model + is the number of features in a token embedding
  • +
  • n_layers + is the number of transformer layers
  • +
  • heads + is the number of attention heads
  • +
  • d_ff + is the number of features in the hidden layer of the FFN
+ +
+
+
660    def __init__(self, rnd_key: jax.random.PRNGKey, n_vocab: int, d_model: int, n_layers: int, heads: int, d_ff: int):
+
+
+
+
+ + +
+
+
669        super().__init__()
+670        self.n_vocab = n_vocab
+671        self.d_model = d_model
+672        self.loss_func = CrossEntropyLoss()
+
+
+
+
+ +

For transformer layers

+ +
+
+
675        layers = []
+676        for i in range(n_layers):
+
+
+
+
+ +

Split PRNG state

+ +
+
+
678            rnd_key, mha_key, ffn_key = jax.random.split(rnd_key, 3)
+
+
+
+
+ +

Create a transformer layer

+ +
+
+
680            attn = MultiHeadAttention(mha_key, heads, d_model)
+681            ffn = FeedForward(ffn_key, d_model, d_ff)
+682            layers.append(TransformerLayer(d_model, attn, ffn))
+
+
+
+
+ +

Make a module list

+ +
+
+
684        self.layers = ModuleList(layers)
+
+
+
+
+ +

Split PRNG state

+ +
+
+
687        rnd_key, emb_key, out_key = jax.random.split(rnd_key, 3)
+
+
+
+
+ +

Create embedding layer

+ +
+
+
689        self.embeddings = EmbeddingsWithLearnedPositionalEncoding(emb_key, n_vocab, d_model)
+
+
+
+
+ +

Final normalization and output layer

+ +
+
+
691        self.norm = LayerNorm([d_model])
+692        self.output = Linear(out_key, d_model, n_vocab)
+
+
+
+
+ + +
+
+
694    def __call__(self, x: jnp.ndarray):
+
+
+
+
+ +

Get sequence length

+ +
+
+
696        seq_len = len(x)
+
+
+
+
+ +

A mask for attention so that a token can only see tokens before that

+ +
+
+
698        mask = jnp.tril(jnp.ones((seq_len, seq_len), bool))
+
+
+
+
+ +

Get embeddings with positional encodings

+ +
+
+
700        x = self.embeddings(x)
+
+
+
+
+ +

Apply the transformer layers

+ +
+
+
702        for i in range(len(self.layers)):
+703            x = self.layers[i](x, mask)
+
+
+
+
+ +

Final normalization and linear transformation to get the logits

+ +
+
+
706        return self.output(self.norm(x))
+
+
+
+
+ +

Calculate the loss

+ +
+
+
708    def get_loss(self, x: jnp.ndarray):
+
+
+
+
+ +

Get model outputs

+ +
+
+
713        output = self(x)
+
+
+
+
+ +

Cross entropy loss

+ +
+
+
715        return self.loss_func(output[:-1], x[1:])
+
+
+
+
+ +

Sample

+

The starting sequence is given by seq + and we greedily sample `length1 tokens

+ +
+
+
717    def sample(self, seq: jnp.ndarray, length: int = 20):
+
+
+
+
+ + +
+
+
723        for i in range(length):
+
+
+
+
+ +

Sample the highest probability token

+ +
+
+
725            idx = jnp.argmax(self(seq)[-1])
+
+
+
+
+ +

Add it to the sequence

+ +
+
+
727            seq = jnp.concatenate((seq, idx[None]))
+
+
+
+
+ +

Return the sampled sequence

+ +
+
+
730        return seq
+
+
+
+
+ +

This is a named tuple for storing Adam optimizer state for a parameter

+ +
+
+
733class AdamState(NamedTuple):
+
+
+
+
+ + +
+
+
737    m: jnp.ndarray
+738    v: jnp.ndarray
+
+
+
+
+ +

+

Adam Optimizer

+

This is from paper Adam: A Method for Stochastic Optimization.

+

For parameter and gradient at step , the Adam update is,

+

where , , and are scalar hyper parameters. and are first and second order moments. and are biased corrected moments. is used as a fix for division by zero error, but also acts as a form of a hyper-parameter that acts against variance in gradients.

+ +
+
+
741class Adam:
+
+
+
+
+ +
  • params + is the tree-map of parameters
  • +
  • lr + is the learning rate
  • +
  • betas + is a tuple of (, )
  • +
  • eps + is `
+ +
+
+
767    def __init__(self, params: Dict,
+768                 lr: float = 0.001, betas: Tuple[float, float] = (0.9, 0.999),
+769                 eps: float = 1e-16, ):
+
+
+
+
+ + +
+
+
777        super().__init__()
+778        self.lr = lr
+779        self.betas = betas
+780        self.eps = eps
+
+
+
+
+ +

States for each parameter

+ +
+
+
783        self.states = jax.tree.map(self._init_state, params)
+
+
+
+
+ +

Optimized step function

+ +
+
+
785        self._step_jit = jax.jit(self._step)
+
+
+
+
+ +

Number of steps taken

+ +
+
+
787        self._n_steps = 0
+
+
+
+
+ +

Optimized update state function

+ +
+
+
789        self._update_state_jit = jax.jit(self._update_state)
+
+
+
+
+ +

Initialize the state for a given parameter

+ +
+
+
791    def _init_state(self, param: jnp.ndarray):
+
+
+
+
+ + +
+
+
795        return AdamState(jnp.zeros_like(param), jnp.zeros_like(param))
+
+
+
+
+ +

Step function

+
  • params + is a tree-map of parameters
  • +
  • grads + is a tree-map of gradients
+ +
+
+
797    def step(self, params: Dict, grads: Dict):
+
+
+
+
+ +

Increment step

+ +
+
+
805        self._n_steps += 1
+
+
+
+
+ +

Update states for each parameter

+ +
+
+
807        self.states = jax.tree.map(self._update_state_jit, grads, self.states)
+
+
+
+
+ +

Return updated parameters

+ +
+
+
809        return jax.tree.map(partial(self._step_jit, self._n_steps), params, self.states)
+
+
+
+
+ +

Update parameters

+

This performs a Adam update on the given parameter

+ +
+
+
811    def _step(self, n_steps: int, param: jnp.ndarray, state: AdamState):
+
+
+
+
+ +

Bias corrections for : and for :

+ +
+
+
819        bias_correction = [1 - beta ** n_steps for beta in self.betas]
+
+
+
+
+ +

Uncorrected first and second moments and

+ +
+
+
821        m, v = state
+
+
+
+
+ +

+ +
+
+
824        step_size = self.lr * (bias_correction[1] ** 0.5) / bias_correction[0]
+
+
+
+
+ +

+ +
+
+
826        den = (v ** 0.5) + self.eps
+
+
+
+
+ +

+ +
+
+
830        return param - step_size * m / den
+
+
+
+
+ +

Update state

+

This updates uncorrected first and second moments and

+ +
+
+
832    def _update_state(self, grad, state: AdamState):
+
+
+
+
+ +

Uncorrected first and second moments and

+ +
+
+
839        m, v = state
+
+
+
+
+ +

Clip gradients

+ +
+
+
841        grad = jnp.clip(grad, -1, 1)
+
+
+
+
+ +

+ +
+
+
843        m = self.betas[0] * m + grad * (1 - self.betas[0])
+
+
+
+
+ +

+ +
+
+
845        v = self.betas[1] * v + (grad ** 2) * (1 - self.betas[1])
+
+
+
+
+ +

Return the new state

+ +
+
+
848        return AdamState(m, v)
+
+
+
+
+ +

+

Tiny Shakespeare dataset

+ +
+
+
851class TinyShakespeare:
+
+
+
+
+ +
  • rnd_key + is the PRNG state
  • +
  • seq_len + is the sequence length of a sample
  • +
  • batch_size + is the batch size
+ +
+
+
858    def __init__(self, rnd_key: jax.random.PRNGKey, seq_len: int, batch_size: int):
+
+
+
+
+ + +
+
+
865        self.batch_size = batch_size
+
+
+
+
+ +

PRNG key for shuffling the samples

+ +
+
+
867        _, self.rnd_key = jax.random.split(rnd_key)
+
+
+
+
+ +

Local path of the text file

+ +
+
+
870        path = lab.get_data_path() / 'tiny_shakespeare.txt'
+
+
+
+
+ +

Download if it doesn't exist

+ +
+
+
872        url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
+873        if not path.exists():
+874            download_file(url, path)
+
+
+
+
+ +

Read the file

+ +
+
+
877        with open(str(path), 'r') as f:
+878            self.text = f.read()
+
+
+
+
+ +

Get the characters/tokens

+ +
+
+
881        tokens = sorted(list(set(self.text)))
+
+
+
+
+ +

Number of tokens

+ +
+
+
884        self.n_tokens = len(tokens)
+
+
+
+
+ +

Map tokens to ids

+ +
+
+
886        self.stoi = {t: i for i, t in enumerate(tokens)}
+
+
+
+
+ +

Id to token/character

+ +
+
+
888        self.itos = tokens
+
+
+
+
+ +

As a list of ids

+ +
+
+
891        data = jnp.array([self.stoi[s] for s in list(self.text)])
+
+
+
+
+ +

Number of batches

+ +
+
+
893        self.n_batches = len(data) // (seq_len * batch_size)
+
+
+
+
+ +

Truncate

+ +
+
+
895        data = data[:self.n_batches * seq_len * batch_size]
+
+
+
+
+ +

Reshape into a samples (better to use random offsets, but lets ignore that here)

+ +
+
+
897        self.data = data.reshape((-1, seq_len))
+
+
+
+
+ +

List of sample indexes

+ +
+
+
899        self.idx = jnp.arange(len(self.data))
+
+
+
+
+ +

Setup for iteration

+ +
+
+
901    def __iter__(self):
+
+
+
+
+ +

Iteration step

+ +
+
+
906        self._iter_idx = 0
+
+
+
+
+ +

Split PRNG key

+ +
+
+
908        self.rnd_key, rnd_key = jax.random.split(self.rnd_key)
+
+
+
+
+ +

Shuffle sample indexes

+ +
+
+
910        self.idx = jax.random.permutation(rnd_key, self.idx)
+
+
+
+
+ +

+ +
+
+
913        return self
+
+
+
+
+ +

Number of batches

+ +
+
+
915    def __len__(self):
+
+
+
+
+ + +
+
+
919        return self.n_batches
+
+
+
+
+ +

Get next batch

+ +
+
+
921    def __next__(self):
+
+
+
+
+ +

Stop iteration after iterating through all batches

+ +
+
+
927        if self._iter_idx >= self.n_batches:
+928            raise StopIteration()
+
+
+
+
+ +

Sample indexes for the batch

+ +
+
+
931        idx = self.idx[self._iter_idx * self.batch_size:(self._iter_idx + 1) * self.batch_size]
+
+
+
+
+ +

Increment iteration step

+ +
+
+
933        self._iter_idx += 1
+
+
+
+
+ +

Return samples

+ +
+
+
936        return self.data[idx]
+
+
+
+
+ +

+

Run the experiment

+ +
+
+
939def main():
+
+
+
+
+ +

Create experiment

+ +
+
+
947    experiment.create(name='jax')
+
+
+
+
+ +

Create PRNG key

+ +
+
+
949    rnd_key = jax.random.PRNGKey(0)
+
+
+
+
+ +

Create dataset

+ +
+
+
951    dataset = TinyShakespeare(rnd_key, seq_len=32, batch_size=128)
+
+
+
+
+ +

Create the model

+ +
+
+
954    model = AutoregressiveTransformer(rnd_key, dataset.n_tokens,
+955                                      d_model=128, n_layers=3, heads=8, d_ff=512)
+
+
+
+
+ +

Get model parameters

+ +
+
+
957    params = model.get_params()
+
+
+
+
+ +

JAX compiled pure sampling function

+ +
+
+
960    pure_sample_fn = jax.jit(model.purify(model.sample))
+
+
+
+
+ +

JAX compiled pure function to get logits for a batch. First we transform model.__call__ + to a pure function which accepts two arguments: parameters, and input sequence. Next we vectorize the function to process a batch of samples. in_axes + specifies which arguments to parallelize and along which axis. (None, 0) + means we have the same parameters but parallelize the inputs across the first axis. out_axes + specifies along which axis to merge the results.

+ +
+
+
968    pure_forward_fn = jax.jit(jax.vmap(model.purify(model.__call__),
+969                                       in_axes=(None, 0), out_axes=0))
+
+
+
+
+ +

Similarly we vectorize loss computation

+ +
+
+
971    pure_loss_fn = jax.jit(jax.vmap(model.purify(model.get_loss),
+972                                    in_axes=(None, 0), out_axes=0))
+
+
+
+
+ +

A function to get mean loss

+ +
+
+
975    def get_loss(params, seq):
+976        return pure_loss_fn(params, seq).mean()
+
+
+
+
+ +

A function to compute gradients for the first argument (parameters)

+ +
+
+
979    grad_loss_fn = jax.jit(jax.grad(get_loss, argnums=0))
+
+
+
+
+ +

Create optimizer

+ +
+
+
982    optimizer = Adam(params)
+
+
+
+
+ +

Start the experiment

+ +
+
+
985    with experiment.start():
+
+
+
+
+ +

Iterate for 32 epochs

+ +
+
+
987        for epoch in monit.loop(32):
+
+
+
+
+ +

Iterate through batches

+ +
+
+
989            for data in monit.iterate('Train', dataset):
+
+
+
+
+ +

Compute and log the loss

+ +
+
+
991                loss = get_loss(params, data)
+992                tracker.save('loss', np.asarray(loss))
+
+
+
+
+ +

Get the gradients

+ +
+
+
994                grads = grad_loss_fn(params, data)
+
+
+
+
+ +

Update parameters

+ +
+
+
996                params = optimizer.step(params, grads)
+
+
+
+
+ +

+ +
+
+
999            tracker.new_line()
+
+
+
+
+ +

Log a sample after each epoch

+ +
+
+
1001            prompt = [dataset.stoi[c] for c in 'It ']
+1002            sampled = pure_sample_fn(params, jnp.array(prompt))[len(prompt):]
+1003            sampled = ''.join([dataset.itos[i] for i in sampled])
+1004            sampled = sampled.replace('\n', '\\n')
+1005            logger.log(('It ', Text.meta), (sampled, Text.value))
+
+
+
+
+ +

+ +
+
+
1009if __name__ == '__main__':
+1010    main()
+
+
+ +
+ + + + \ No newline at end of file