diff --git a/docs/index.html b/docs/index.html index 7e5e6b9d..c68af898 100644 --- a/docs/index.html +++ b/docs/index.html @@ -84,7 +84,10 @@ implementations.

-
54    def __init__(self, *,
-55                 d_model: int,
-56                 self_attn: RelativeMultiHeadAttention,
-57                 feed_forward: FeedForward,
-58                 dropout_prob: float):
+
53    def __init__(self, *,
+54                 d_model: int,
+55                 self_attn: RelativeMultiHeadAttention,
+56                 feed_forward: FeedForward,
+57                 dropout_prob: float):
@@ -144,13 +144,13 @@ are introduced at the attention calculation.

-
65        super().__init__()
-66        self.size = d_model
-67        self.self_attn = self_attn
-68        self.feed_forward = feed_forward
-69        self.dropout = nn.Dropout(dropout_prob)
-70        self.norm_self_attn = nn.LayerNorm([d_model])
-71        self.norm_ff = nn.LayerNorm([d_model])
+
64        super().__init__()
+65        self.size = d_model
+66        self.self_attn = self_attn
+67        self.feed_forward = feed_forward
+68        self.dropout = nn.Dropout(dropout_prob)
+69        self.norm_self_attn = nn.LayerNorm([d_model])
+70        self.norm_ff = nn.LayerNorm([d_model])
@@ -166,10 +166,10 @@ are introduced at the attention calculation.

-
73    def forward(self, *,
-74                x: torch.Tensor,
-75                mem: Optional[torch.Tensor],
-76                mask: torch.Tensor):
+
72    def forward(self, *,
+73                x: torch.Tensor,
+74                mem: Optional[torch.Tensor],
+75                mask: torch.Tensor):
@@ -180,7 +180,7 @@ are introduced at the attention calculation.

Normalize the vectors before doing self attention

-
84        z = self.norm_self_attn(x)
+
83        z = self.norm_self_attn(x)
@@ -191,7 +191,7 @@ are introduced at the attention calculation.

If there is memory

-
86        if mem is not None:
+
85        if mem is not None:
@@ -202,7 +202,7 @@ are introduced at the attention calculation.

Normalize it

-
88            mem = self.norm_self_attn(mem)
+
87            mem = self.norm_self_attn(mem)
@@ -213,7 +213,7 @@ are introduced at the attention calculation.

Concatenate with z

-
90            m_z = torch.cat((mem, z), dim=0)
+
89            m_z = torch.cat((mem, z), dim=0)
@@ -224,8 +224,8 @@ are introduced at the attention calculation.

Ignore if there is no memory

-
92        else:
-93            m_z = z
+
91        else:
+92            m_z = z
@@ -236,7 +236,7 @@ are introduced at the attention calculation.

Attention

-
95        self_attn = self.self_attn(query=z, key=m_z, value=m_z, mask=mask)
+
94        self_attn = self.self_attn(query=z, key=m_z, value=m_z, mask=mask)
@@ -247,7 +247,7 @@ are introduced at the attention calculation.

Add the attention results

-
97        x = x + self.dropout(self_attn)
+
96        x = x + self.dropout(self_attn)
@@ -258,7 +258,7 @@ are introduced at the attention calculation.

Normalize for feed-forward

-
100        z = self.norm_ff(x)
+
99        z = self.norm_ff(x)
@@ -269,7 +269,7 @@ are introduced at the attention calculation.

Pass through the feed-forward network

-
102        ff = self.feed_forward(z)
+
101        ff = self.feed_forward(z)
@@ -280,7 +280,7 @@ are introduced at the attention calculation.

Add the feed-forward results back

-
104        x = x + self.dropout(ff)
+
103        x = x + self.dropout(ff)
@@ -291,7 +291,7 @@ are introduced at the attention calculation.

-
107        return x
+
106        return x
@@ -303,7 +303,7 @@ are introduced at the attention calculation.

This consists of multiple transformer XL layers

-
110class TransformerXL(Module):
+
109class TransformerXL(Module):
@@ -314,8 +314,8 @@ are introduced at the attention calculation.

-
117    def __init__(self, layer: TransformerXLLayer, n_layers: int):
-118        super().__init__()
+
116    def __init__(self, layer: TransformerXLLayer, n_layers: int):
+117        super().__init__()
@@ -326,7 +326,7 @@ are introduced at the attention calculation.

Make copies of the transformer layer

-
120        self.layers = clone_module_list(layer, n_layers)
+
119        self.layers = clone_module_list(layer, n_layers)
@@ -337,7 +337,7 @@ are introduced at the attention calculation.

Final normalization layer

-
122        self.norm = nn.LayerNorm([layer.size])
+
121        self.norm = nn.LayerNorm([layer.size])
@@ -352,7 +352,7 @@ are introduced at the attention calculation.

-
124    def forward(self, x: torch.Tensor, mem: List[torch.Tensor], mask: torch.Tensor):
+
123    def forward(self, x: torch.Tensor, mem: List[torch.Tensor], mask: torch.Tensor):
@@ -364,7 +364,7 @@ are introduced at the attention calculation.

which will be the memories for the next sequential batch.

-
132        new_mem = []
+
131        new_mem = []
@@ -375,7 +375,7 @@ which will be the memories for the next sequential batch.

Run through each transformer layer

-
134        for i, layer in enumerate(self.layers):
+
133        for i, layer in enumerate(self.layers):
@@ -386,7 +386,7 @@ which will be the memories for the next sequential batch.

Add to the list of feature vectors

-
136            new_mem.append(x.detach())
+
135            new_mem.append(x.detach())
@@ -397,7 +397,7 @@ which will be the memories for the next sequential batch.

Memory

-
138            m = mem[i] if mem else None
+
137            m = mem[i] if mem else None
@@ -408,7 +408,7 @@ which will be the memories for the next sequential batch.

Run through the transformer XL layer

-
140            x = layer(x=x, mem=m, mask=mask)
+
139            x = layer(x=x, mem=m, mask=mask)
@@ -419,7 +419,7 @@ which will be the memories for the next sequential batch.

Finally, normalize the vectors

-
142        return self.norm(x), new_mem
+
141        return self.norm(x), new_mem
diff --git a/labml_nn/__init__.py b/labml_nn/__init__.py index 676c5a07..84f8fe82 100644 --- a/labml_nn/__init__.py +++ b/labml_nn/__init__.py @@ -17,7 +17,8 @@ implementations. * [Multi-headed attention](transformers/mha.html) * [Transformer building blocks](transformers/models.html) -* [Relative multi-headed attention](transformers/xl/relative_mha.html). +* [Transformer XL](transformers/xl/index.html) + * [Relative multi-headed attention](transformers/xl/relative_mha.html) * [GPT Architecture](transformers/gpt/index.html) * [GLU Variants](transformers/glu_variants/simple.html) * [kNN-LM: Generalization through Memorization](transformers/knn/index.html) diff --git a/labml_nn/transformers/__init__.py b/labml_nn/transformers/__init__.py index db916987..d699bb0e 100644 --- a/labml_nn/transformers/__init__.py +++ b/labml_nn/transformers/__init__.py @@ -14,10 +14,13 @@ from paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762), and derivatives and enhancements of it. * [Multi-head attention](mha.html) -* [Relative multi-head attention](xl/relative_mha.html) * [Transformer Encoder and Decoder Models](models.html) * [Fixed positional encoding](positional_encoding.html) +## [Transformer XL](xl/index.html) +This implements Transformer XL model using +[relative multi-head attention](xl/relative_mha.html) + ## [GPT Architecture](gpt) This is an implementation of GPT-2 architecture. diff --git a/labml_nn/transformers/xl/__init__.py b/labml_nn/transformers/xl/__init__.py index 257bf1f1..85cf0da5 100644 --- a/labml_nn/transformers/xl/__init__.py +++ b/labml_nn/transformers/xl/__init__.py @@ -30,7 +30,6 @@ Here's [the training code](experiment.html) and a notebook for training a transf [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/xl/experiment.ipynb) [![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://web.lab-ml.com/run?uuid=d3b6760c692e11ebb6a70242ac1c0002) - """ diff --git a/labml_nn/transformers/xl/readme.md b/labml_nn/transformers/xl/readme.md new file mode 100644 index 00000000..f8f17a23 --- /dev/null +++ b/labml_nn/transformers/xl/readme.md @@ -0,0 +1,24 @@ +# Transformer XL + +This is an implementation of +[Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) +in [PyTorch](https://pytorch.org). + +Transformer has a limited attention span, +equal to the length of the sequence trained in parallel. +All these positions have a fixed positional encoding. +Transformer XL increases this attention span by letting +each of the positions pay attention to precalculated past embeddings. +For instance if the context length is $l$ it will keep the embeddings of +all layers for previous batch of length $l$ and feed them to current step. +If we use fixed-positional encodings these pre-calculated embeddings will have +the same positions as the current context. +They introduce relative positional encoding, where the positional encodings +are introduced at the attention calculation. + +Annotated implementation of relative multi-headed attention is in [`relative_mha.py`](relative_mha.html). + +Here's [the training code](experiment.html) and a notebook for training a transformer XL model on Tiny Shakespeare dataset. + +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/transformers/xl/experiment.ipynb) +[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://web.lab-ml.com/run?uuid=d3b6760c692e11ebb6a70242ac1c0002) diff --git a/readme.md b/readme.md index 61c4ac85..5c18945f 100644 --- a/readme.md +++ b/readme.md @@ -23,7 +23,8 @@ implementations almost weekly. * [Multi-headed attention](https://nn.labml.ai/transformers/mha.html) * [Transformer building blocks](https://nn.labml.ai/transformers/models.html) -* [Relative multi-headed attention](https://nn.labml.ai/transformers/xl/relative_mha.html). +* [Transformer XL](https://nn.labml.ai/transformers/xl/index.html) + * [Relative multi-headed attention](https://nn.labml.ai/transformers/xl/relative_mha.html) * [GPT Architecture](https://nn.labml.ai/transformers/gpt/index.html) * [GLU Variants](https://nn.labml.ai/transformers/glu_variants/simple.html) * [kNN-LM: Generalization through Memorization](https://nn.labml.ai/transformers/knn)