28from functools import partial
+29from typing import Dict, NamedTuple, Tuple, Any, Callable
+30from typing import List, TypeVar, Generic
+31from typing import Union, Optional
+32
+33import jax
+34import jax.numpy as jnp
+35import numpy as np
+36
+37from labml import lab, monit, experiment, tracker
+38from labml import logger
+39from labml.logger import Text
+40from labml.utils.download import download_fileThis is a base class for all modules. It handles parameters and transforms methods to pure functions for JAX to compile and differentiate.
+You can skip these modules to get into the models directly.
+The modules stores parameters and sub-modules separately. When we want to transform any method to a pure function, we pass the parameters of the module and the sub-module as an argument and assign the passed values to class.
+This is based on a blog post: From PyTorch to JAX: towards neural net frameworks that purify stateful code.
+ +43class Module:Store all parameters and sub-modules in dictionaries
+ +63 _submodules: Dict[str, 'Module']
+64 _params: Dict[str, jnp.ndarray]Initialize
+ +66 def __init__(self):68 self._params = {}
+69 self._submodules = {}We override the get attribute operation. So when you reference an attribute with model.attribute
+ this function gets called.
Read this guide if you are not familiar with Python magic methods.
+ +71 def __getattr__(self, attr_name: str):If the attribute is a parameter
+ +83 if attr_name in self._params:
+84 return self._params[attr_name]If the attribute is a sub-module
+ +86 elif attr_name in self._submodules:
+87 return self._submodules[attr_name]Otherwise fallback to normal attributes. The attributes are stored in __dict__
+ by Python.
90 else:
+91 return self.__dict__[attr_name]We override the set attribute operation. So when you assign an attribute with model.attribute
+ this function gets called.
93 def __setattr__(self, key: str, value: Any):If the value is also a module
+ +102 if isinstance(value, Module):
+103 self._submodules[key] = valueIf the value is a JAX array
+ +105 elif isinstance(value, jnp.ndarray):
+106 self._params[key] = valueOtherwise add it to __dict__
+
108 else:
+109 self.__dict__[key] = valueThese clears out all the parameters. This is used when a method is called as a pure function. We first clears out all the parameters and assigns the parameters passed to the pure function.
+ +111 def _clear_params(self):Clear parameters of the module
+ +120 self._params = {}Recursively clear parameters of submodules
+ +122 for sm in self._submodules.values():
+123 sm._clear_params()This recursively collects all the parameters of the module and sub-modules into a dictionary.
+ +125 def get_params(self) -> Dict[str, jnp.ndarray]:Parameters of the model
+ +133 params = self._params.copy()Parameters of the submodules
+ +135 for sm_name, sm in self._submodules.items():
+136 for name, value in sm.get_params().items():The dictionary keys are of the form module_name/module_name/param_name
+
138 params[sm_name + "/" + name] = value+ +
140 return params142 def _set_params(self, params: Dict[str, jnp.ndarray]):Iterate through parameters. Their names have the form module_name/module_name/param_name
+
149 for name, value in params.items():Split to get module names and parameter name
+ +151 self._set_param(name.split("/"), value)153 def _set_param(self, param_path: List[str], value: jnp.ndarray):No module names; i.e. a parameter of this module
+ +160 if len(param_path) == 1:
+161 self._params[param_path[0]] = valueParameter of a submodule
+ +163 else:
+164 self._submodules[param_path[0]]._set_param(param_path[1:], value)This transforms a member method to a pure function that accepts a dictionary of parameters as an argument.
+For example,
+params = model.get_params()
+pure_function = model.purify(model.calculate_loss)
+output = pure_function(params, data)
+
+ 166 def purify(self, method: Callable) -> Callable:182 def pure_method(params: Dict[str, jnp.array], *args):Clear parameters in the object
+ +184 self._clear_params()Assign the passed parameters
+ +186 self._set_params(params)Invoke the method
+ +188 result = method(*args)Return the result
+ +190 return result+ +
193 return pure_methodType for generics in the module list class
+ +197M = TypeVar('M', bound=Module)This stores a list of modules. We needed this for transformer decoder to hold the list of transformer layers.
+ +200class ModuleList(Module, Generic[M]):For list of modules
+ +209 _submodules: List[M]Initialize with a list of modules.
+ +211 def __init__(self, modules: List[M]):215 super().__init__()
+216 self._submodules = modulesidx
+-th module218 def __getitem__(self, idx: int) -> M:222 return self._submodules[idx]This is not supported
+ +224 def __setitem__(self, key, value):228 raise NotImplementedError230 def __len__(self):234 return len(self._submodules) Override __getattr__
+ of Module
+
236 def __getattr__(self, item):240 return self.__dict__[item] Override __setattr__
+ of Module
+
242 def __setattr__(self, key, value):246 self.__dict__[key] = value248 def _clear_params(self):252 self._params = {}
+253 for sm in self._submodules:
+254 sm._clear_params()256 def get_params(self):260 params = self._params
+261 for i, sm in enumerate(self._submodules):
+262 for name, value in sm.get_params().items():
+263 params[f'{i}/{name}'] = value
+264 return params266 def _set_param(self, param_path: List[str], value: jnp.ndarray):270 self._submodules[int(param_path[0])]._set_param(param_path[1:], value)273class Embedding(Module):rnd_key
+ is the PRNG state n_embeddings
+ is the number of embeddings n_dim
+ is the size of an embedding282 def __init__(self, rnd_key: jax.random.PRNGKey, n_embeddings: int, n_dim: int):288 super().__init__()Embeddings are initialized from
+ +290 self.embeddings = jax.random.normal(rnd_key, (n_embeddings, n_dim))Return the embeddings for the given ids
+ +292 def __call__(self, ids: jnp.ndarray):296 return self.embeddings[ids, :]This is based on our PyTorch implementation.
+ +299class EmbeddingsWithLearnedPositionalEncoding(Module):rnd_key
+ is the PRNG state n_vocab
+ is the vocabulary size d_model
+ is the embedding size max_len
+ is the maximum sequence length (to initialize positional encodings)309 def __init__(self, rnd_key: jax.random.PRNGKey, n_vocab: int, d_model: int, max_len: int = 4096):316 super().__init__()Embeddings
+ +318 self.embeddings = Embedding(rnd_key, n_vocab, d_model)Positional encodings coefficient
+ +320 self.pe_coef = 1 / d_model ** 0.5Positional encodings initialized to zeros
+ +322 self.positional_encodings = jnp.zeros((max_len, d_model))324 def __call__(self, x: jnp.ndarray):Get positional encodings
+ +326 pe = self.positional_encodings[:x.shape[0]]Get embeddings and add positional encodings
+ +328 return self.embeddings(x) * self.pe_coef + pe331class Linear(Module):rnd_key
+ is the PRNG state in_features
+ is the number of features in the input out_features
+ is the number of features in the output340 def __init__(self, rnd_key: jax.random.PRNGKey, in_features: int, out_features: int):346 super().__init__()Initialize weights to
+ +349 rnd_range = 1 / in_features ** 0.5
+350 self.weight = jax.random.uniform(rnd_key, (in_features, out_features),
+351 minval=-rnd_range, maxval=rnd_range)Initialize the biases to
+ +353 self.bias = jnp.zeros((out_features,))355 def __call__(self, x: jnp.ndarray):Multiply by weights and add the bias
+ +357 return jnp.matmul(x, self.weight) + self.biasThis implements the the layer normalization from the paper Layer Normalization.
+When input is a sequence of embeddings, where is the number of channels, is the length of the sequence. and .
+This is based on our PyTorch implementation.
+ +360class LayerNorm(Module):normalized_shape
+ is the shape of the elements (except the batch). The input should then be eps
+ is , used in for numerical stability elementwise_affine
+ is whether to scale and shift the normalized value380 def __init__(self, normalized_shape: Union[Tuple[int], List[int]], *,
+381 eps: float = 1e-5, elementwise_affine: bool = True):389 super().__init__()
+390
+391 self.eps = eps
+392 self.elementwise_affine = elementwise_affine
+393 self.normalized_shape = tuple(normalized_shape)Create parameters for and for gain and bias
+ +396 if elementwise_affine:
+397 self.gain = jnp.ones(normalized_shape)
+398 self.bias = jnp.zeros(normalized_shape)400 def __call__(self, x: jnp.ndarray):Sanity check to make sure the shapes match
+ +402 assert self.normalized_shape == x.shape[-len(self.normalized_shape):]The exes to calculate the mean and variance on
+ +405 axes = [-(i + 1) for i in range(len(self.normalized_shape))]Calculate the mean of all elements; i.e. the means for each element
+ +408 mean = x.mean(axis=axes, keepdims=True)Calculate the squared mean of all elements; i.e. the means for each element
+ +411 mean_2 = (x ** 2).mean(axis=axes, keepdims=True)Variance of all element
+ +413 var = mean_2 - mean ** 2Normalize
+ +415 x_norm = (x - mean) / (var + self.eps) ** 0.5Scale and shift
+ +418 if self.elementwise_affine:
+419 x_norm = self.gain * x_norm + self.bias+ +
422 return x_normThis computes scaled multi-headed attention from the paper Attention Is All You Need for given query
+, key
+ and value
+ vectors.
+
In simple terms, it finds keys that matches the query, and gets the values of those keys.
+It uses dot-product of query and key as the indicator of how matching they are. Before taking the the dot-products are scaled by . This is done to avoid large dot-product values causing softmax to give very small gradients when is large.
+Softmax is calculated along the axis of of the sequence (or time) for keys.
+This is based on our PyTorch implementation.
+ +425class MultiHeadAttention(Module):rnd_key
+ is the PRNG state heads
+ is the number of heads. d_model
+ is the number of features in the query
+, key
+ and value
+ vectors.451 def __init__(self, rnd_key: jax.random.PRNGKey, heads: int, d_model: int):458 super().__init__()Split the PRNG state
+ +461 _, *rnd_keys = jax.random.split(rnd_key, 5)Number of features per head
+ +464 self.d_k = d_model // headsNumber of heads
+ +466 self.heads = headsThese transform the query
+, key
+ and value
+ vectors for multi-headed attention.
469 self.query = Linear(rnd_keys[0], d_model, d_model)
+470 self.key = Linear(rnd_keys[1], d_model, d_model)
+471 self.value = Linear(rnd_keys[2], d_model, d_model)Output layer
+ +474 self.output = Linear(rnd_keys[3], d_model, d_model)Scaling factor before the softmax
+ +476 self.scale = 1 / self.d_k ** 0.5 query
+, key
+ and value
+ are the tensors that store collection of query, key and value vectors. They have shape [seq_len, d_model]
+.
mask
+ has shape [seq_len, seq_len]
+ and mask[i, j]
+ indicates whether query at position i
+ can see key-value at position j
+.
478 def __call__(self, *,
+479 query: jnp.ndarray,
+480 key: jnp.ndarray,
+481 value: jnp.ndarray,
+482 mask: Optional[jnp.ndarray] = None):Get sequence length
+ +493 seq_len = len(query)
+494
+495 if mask is not None:Check mask shape
+ +497 assert mask.shape[0] == query.shape[0]
+498 assert mask.shape[1] == key.shape[0]Same mask applied to all heads.
+ +501 mask = mask[:, :, None]Apply linear transformations
+ +504 query = self.query(query)
+505 key = self.key(key)
+506 value = self.value(value)Reshape to split into heads Input has shape [seq_len, batch_size, d_model]
+. We split the last dimension into heads
+ and d_k
+.
511 query = query.reshape(*query.shape[:-1], self.heads, self.d_k)
+512 key = key.reshape(*key.shape[:-1], self.heads, self.d_k)
+513 value = value.reshape(*value.shape[:-1], self.heads, self.d_k)Compute attention scores . This gives a tensor of shape [seq_len, seq_len, heads]
+.
518 scores = jnp.einsum('ihd,jhd->ijh', query, key)Scale scores
+ +521 scores *= self.scaleApply mask
+ +524 if mask is not None:
+525 scores = scores + (mask == 0) * float('-inf')attention along the key sequence dimension
+ +529 attn = jax.nn.softmax(scores, axis=1)Multiply by values
+ +533 x = jnp.einsum("ijh,jhd->ihd", attn, value)Concatenate multiple heads
+ +536 x = x.reshape(seq_len, -1)Output layer
+ +539 return self.output(x)542class FeedForward(Module):rnd_key
+ is the PRNG state d_model
+ is the number of features in a token embedding d_ff
+ is the number of features in the hidden layer of the FFN activation
+ is the activation function 552 def __init__(self, rnd_key: jax.random.PRNGKey, d_model: int, d_ff: int,
+553 activation=jax.nn.relu):560 super().__init__()Split the PRNG state
+ +562 _, *rnd_keys = jax.random.split(rnd_key, 5)Layer one parameterized by weight and bias
+ +565 self.layer1 = Linear(rnd_keys[0], d_model, d_ff)Layer one parameterized by weight and bias
+ +567 self.layer2 = Linear(rnd_keys[1], d_ff, d_model)Activation function
+ +569 self.activation = activation571 def __call__(self, x: jnp.ndarray):+ +
573 x = self.activation(self.layer1(x))+ +
575 return self.layer2(x)This is a transformer layer with multi-head attention and a position-wise feed-forward layer. We use pre-layer layer normalization.
+ +578class TransformerLayer(Module):d_model
+ is the token embedding size self_attn
+ is the self attention module feed_forward
+ is the feed forward module588 def __init__(self,
+589 d_model: int,
+590 self_attn: MultiHeadAttention,
+591 feed_forward: FeedForward):597 super().__init__()
+598 self.size = d_model
+599 self.self_attn = self_attn
+600 self.feed_forward = feed_forward
+601 self.norm_self_attn = LayerNorm([d_model])
+602 self.norm_ff = LayerNorm([d_model])604 def __call__(self, x: jnp.ndarray, mask: jnp.ndarray):Normalize the vectors before doing self attention
+ +606 z = self.norm_self_attn(x)Run through self attention, i.e. keys and values are from self
+ +608 self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)
+609 x = x + self_attnNormalize for feed-forward
+ +612 z = self.norm_ff(x)Pass through the feed-forward network
+ +614 ff = self.feed_forward(z)Add the feed-forward results
+ +616 x = x + ff+ +
618 return x621class CrossEntropyLoss(Module):628 def __init__(self):
+629 super().__init__()Use jax.vmap
+ to vectorize the loss function
632 self._loss_vmap = jax.vmap(self._loss, in_axes=(0, 0,))634 def _loss(self, output: jnp.ndarray, target: jnp.ndarray):+ +
636 return -jax.nn.log_softmax(output)[target]output
+ is the model outputs of shape [seq_len, n_vocab]
+ target
+ is the target of shape [seq_len]
+638 def __call__(self, output: jnp.ndarray, target: jnp.ndarray):Use the vectorized loss function and calculate the mean.
+We could have used a for loop to calculate the losses but using vmap is about 10X faster
+ +647 return self._loss_vmap(output, target).mean()This is the transformer decode with embedding and output layers.
+ +650class AutoregressiveTransformer(Module):658 layers: ModuleList[TransformerLayer]rnd_key
+ is the PRNG state n_vocab
+ is the vocabulary size d_model
+ is the number of features in a token embedding n_layers
+ is the number of transformer layers heads
+ is the number of attention heads d_ff
+ is the number of features in the hidden layer of the FFN660 def __init__(self, rnd_key: jax.random.PRNGKey, n_vocab: int, d_model: int, n_layers: int, heads: int, d_ff: int):669 super().__init__()
+670 self.n_vocab = n_vocab
+671 self.d_model = d_model
+672 self.loss_func = CrossEntropyLoss()For transformer layers
+ +675 layers = []
+676 for i in range(n_layers):Split PRNG state
+ +678 rnd_key, mha_key, ffn_key = jax.random.split(rnd_key, 3)Create a transformer layer
+ +680 attn = MultiHeadAttention(mha_key, heads, d_model)
+681 ffn = FeedForward(ffn_key, d_model, d_ff)
+682 layers.append(TransformerLayer(d_model, attn, ffn))Make a module list
+ +684 self.layers = ModuleList(layers)Split PRNG state
+ +687 rnd_key, emb_key, out_key = jax.random.split(rnd_key, 3)Create embedding layer
+ +689 self.embeddings = EmbeddingsWithLearnedPositionalEncoding(emb_key, n_vocab, d_model)Final normalization and output layer
+ +691 self.norm = LayerNorm([d_model])
+692 self.output = Linear(out_key, d_model, n_vocab)694 def __call__(self, x: jnp.ndarray):Get sequence length
+ +696 seq_len = len(x)A mask for attention so that a token can only see tokens before that
+ +698 mask = jnp.tril(jnp.ones((seq_len, seq_len), bool))Get embeddings with positional encodings
+ +700 x = self.embeddings(x)Apply the transformer layers
+ +702 for i in range(len(self.layers)):
+703 x = self.layers[i](x, mask)Final normalization and linear transformation to get the logits
+ +706 return self.output(self.norm(x))708 def get_loss(self, x: jnp.ndarray):Get model outputs
+ +713 output = self(x)Cross entropy loss
+ +715 return self.loss_func(output[:-1], x[1:])The starting sequence is given by seq
+ and we greedily sample `length1 tokens
717 def sample(self, seq: jnp.ndarray, length: int = 20):723 for i in range(length):Sample the highest probability token
+ +725 idx = jnp.argmax(self(seq)[-1])Add it to the sequence
+ +727 seq = jnp.concatenate((seq, idx[None]))Return the sampled sequence
+ +730 return seqThis is a named tuple for storing Adam optimizer state for a parameter
+ +733class AdamState(NamedTuple):737 m: jnp.ndarray
+738 v: jnp.ndarrayThis is from paper Adam: A Method for Stochastic Optimization.
+For parameter and gradient at step , the Adam update is,
+where , , and are scalar hyper parameters. and are first and second order moments. and are biased corrected moments. is used as a fix for division by zero error, but also acts as a form of a hyper-parameter that acts against variance in gradients.
+ +741class Adam:params
+ is the tree-map of parameters lr
+ is the learning rate betas
+ is a tuple of (, ) eps
+ is `767 def __init__(self, params: Dict,
+768 lr: float = 0.001, betas: Tuple[float, float] = (0.9, 0.999),
+769 eps: float = 1e-16, ):777 super().__init__()
+778 self.lr = lr
+779 self.betas = betas
+780 self.eps = epsStates for each parameter
+ +783 self.states = jax.tree.map(self._init_state, params)Optimized step function
+ +785 self._step_jit = jax.jit(self._step)Number of steps taken
+ +787 self._n_steps = 0Optimized update state function
+ +789 self._update_state_jit = jax.jit(self._update_state)Initialize the state for a given parameter
+ +791 def _init_state(self, param: jnp.ndarray):795 return AdamState(jnp.zeros_like(param), jnp.zeros_like(param))params
+ is a tree-map of parameters grads
+ is a tree-map of gradients797 def step(self, params: Dict, grads: Dict):Increment step
+ +805 self._n_steps += 1Update states for each parameter
+ +807 self.states = jax.tree.map(self._update_state_jit, grads, self.states)Return updated parameters
+ +809 return jax.tree.map(partial(self._step_jit, self._n_steps), params, self.states)811 def _step(self, n_steps: int, param: jnp.ndarray, state: AdamState):Bias corrections for : and for :
+ +819 bias_correction = [1 - beta ** n_steps for beta in self.betas]Uncorrected first and second moments and
+ +821 m, v = state+ +
824 step_size = self.lr * (bias_correction[1] ** 0.5) / bias_correction[0]+ +
826 den = (v ** 0.5) + self.eps+ +
830 return param - step_size * m / den832 def _update_state(self, grad, state: AdamState):Uncorrected first and second moments and
+ +839 m, v = stateClip gradients
+ +841 grad = jnp.clip(grad, -1, 1)+ +
843 m = self.betas[0] * m + grad * (1 - self.betas[0])+ +
845 v = self.betas[1] * v + (grad ** 2) * (1 - self.betas[1])Return the new state
+ +848 return AdamState(m, v)851class TinyShakespeare:rnd_key
+ is the PRNG state seq_len
+ is the sequence length of a sample batch_size
+ is the batch size858 def __init__(self, rnd_key: jax.random.PRNGKey, seq_len: int, batch_size: int):865 self.batch_size = batch_sizePRNG key for shuffling the samples
+ +867 _, self.rnd_key = jax.random.split(rnd_key)Local path of the text file
+ +870 path = lab.get_data_path() / 'tiny_shakespeare.txt'Download if it doesn't exist
+ +872 url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
+873 if not path.exists():
+874 download_file(url, path)Read the file
+ +877 with open(str(path), 'r') as f:
+878 self.text = f.read()Get the characters/tokens
+ +881 tokens = sorted(list(set(self.text)))Number of tokens
+ +884 self.n_tokens = len(tokens)Map tokens to ids
+ +886 self.stoi = {t: i for i, t in enumerate(tokens)}Id to token/character
+ +888 self.itos = tokensAs a list of ids
+ +891 data = jnp.array([self.stoi[s] for s in list(self.text)])Number of batches
+ +893 self.n_batches = len(data) // (seq_len * batch_size)Truncate
+ +895 data = data[:self.n_batches * seq_len * batch_size]Reshape into a samples (better to use random offsets, but lets ignore that here)
+ +897 self.data = data.reshape((-1, seq_len))List of sample indexes
+ +899 self.idx = jnp.arange(len(self.data))Setup for iteration
+ +901 def __iter__(self):Iteration step
+ +906 self._iter_idx = 0Split PRNG key
+ +908 self.rnd_key, rnd_key = jax.random.split(self.rnd_key)Shuffle sample indexes
+ +910 self.idx = jax.random.permutation(rnd_key, self.idx)+ +
913 return selfNumber of batches
+ +915 def __len__(self):919 return self.n_batchesGet next batch
+ +921 def __next__(self):Stop iteration after iterating through all batches
+ +927 if self._iter_idx >= self.n_batches:
+928 raise StopIteration()Sample indexes for the batch
+ +931 idx = self.idx[self._iter_idx * self.batch_size:(self._iter_idx + 1) * self.batch_size]Increment iteration step
+ +933 self._iter_idx += 1Return samples
+ +936 return self.data[idx]939def main():Create experiment
+ +947 experiment.create(name='jax')Create PRNG key
+ +949 rnd_key = jax.random.PRNGKey(0)Create dataset
+ +951 dataset = TinyShakespeare(rnd_key, seq_len=32, batch_size=128)Create the model
+ +954 model = AutoregressiveTransformer(rnd_key, dataset.n_tokens,
+955 d_model=128, n_layers=3, heads=8, d_ff=512)Get model parameters
+ +957 params = model.get_params()JAX compiled pure sampling function
+ +960 pure_sample_fn = jax.jit(model.purify(model.sample))JAX compiled pure function to get logits for a batch. First we transform model.__call__
+ to a pure function which accepts two arguments: parameters, and input sequence. Next we vectorize the function to process a batch of samples. in_axes
+ specifies which arguments to parallelize and along which axis. (None, 0)
+ means we have the same parameters but parallelize the inputs across the first axis. out_axes
+ specifies along which axis to merge the results.
968 pure_forward_fn = jax.jit(jax.vmap(model.purify(model.__call__),
+969 in_axes=(None, 0), out_axes=0))Similarly we vectorize loss computation
+ +971 pure_loss_fn = jax.jit(jax.vmap(model.purify(model.get_loss),
+972 in_axes=(None, 0), out_axes=0))A function to get mean loss
+ +975 def get_loss(params, seq):
+976 return pure_loss_fn(params, seq).mean()A function to compute gradients for the first argument (parameters)
+ +979 grad_loss_fn = jax.jit(jax.grad(get_loss, argnums=0))Create optimizer
+ +982 optimizer = Adam(params)Start the experiment
+ +985 with experiment.start():Iterate for 32 epochs
+ +987 for epoch in monit.loop(32):Iterate through batches
+ +989 for data in monit.iterate('Train', dataset):Compute and log the loss
+ +991 loss = get_loss(params, data)
+992 tracker.save('loss', np.asarray(loss))Get the gradients
+ +994 grads = grad_loss_fn(params, data)Update parameters
+ +996 params = optimizer.step(params, grads)+ +
999 tracker.new_line()Log a sample after each epoch
+ +1001 prompt = [dataset.stoi[c] for c in 'It ']
+1002 sampled = pure_sample_fn(params, jnp.array(prompt))[len(prompt):]
+1003 sampled = ''.join([dataset.itos[i] for i in sampled])
+1004 sampled = sampled.replace('\n', '\\n')
+1005 logger.log(('It ', Text.meta), (sampled, Text.value))+ +
1009if __name__ == '__main__':
+1010 main()