From 16bf5d0b10097ff578bf27b48bd0c6f980095f8d Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Sat, 12 Mar 2022 16:04:05 +0530 Subject: [PATCH] experiment link --- docs/transformers/retro/index.html | 1 + docs/transformers/retro/model.html | 409 ++++++++++++------------ docs/transformers/retro/train.html | 189 +++++------ labml_nn/transformers/retro/__init__.py | 2 + labml_nn/transformers/retro/model.py | 2 + labml_nn/transformers/retro/train.py | 2 + 6 files changed, 307 insertions(+), 298 deletions(-) diff --git a/docs/transformers/retro/index.html b/docs/transformers/retro/index.html index 6739b1e0..0fab694d 100644 --- a/docs/transformers/retro/index.html +++ b/docs/transformers/retro/index.html @@ -80,6 +80,7 @@
  • Model
  • Dataset: Pre-calculate the nearest neighbors
  • Training code
  • +

    View Run

    diff --git a/docs/transformers/retro/model.html b/docs/transformers/retro/model.html index a22b4a53..1133cf2f 100644 --- a/docs/transformers/retro/model.html +++ b/docs/transformers/retro/model.html @@ -71,16 +71,17 @@

    RETRO model

    This is the model definition for RETRO.

    +

    View Run

    -
    14import math
    -15from typing import Set
    -16
    -17import torch
    -18from torch import nn
    -19
    -20from labml.logger import inspect
    +
    16import math
    +17from typing import Set
    +18
    +19import torch
    +20from torch import nn
    +21
    +22from labml.logger import inspect
    @@ -93,7 +94,7 @@
    -
    23class RotaryPositionalEmbeddings(nn.Module):
    +
    25class RotaryPositionalEmbeddings(nn.Module):
    @@ -108,7 +109,7 @@
    -
    34    def __init__(self, d: int, base: int = 10_000):
    +
    36    def __init__(self, d: int, base: int = 10_000):
    @@ -119,7 +120,7 @@
    -
    39        super().__init__()
    +
    41        super().__init__()
    @@ -131,7 +132,7 @@
    -
    41        self.theta = nn.Parameter(1. / (base ** (torch.arange(0, d, 2).float() / d)), requires_grad=False)
    +
    43        self.theta = nn.Parameter(1. / (base ** (torch.arange(0, d, 2).float() / d)), requires_grad=False)
    @@ -145,7 +146,7 @@
    -
    43    def forward(self, x: torch.Tensor):
    +
    45    def forward(self, x: torch.Tensor):
    @@ -157,7 +158,7 @@
    -
    48        batch_size, seq_len, n_heads, d = x.shape
    +
    50        batch_size, seq_len, n_heads, d = x.shape
    @@ -169,7 +170,7 @@
    -
    51        d_2 = d // 2
    +
    53        d_2 = d // 2
    @@ -182,7 +183,7 @@
    -
    54        seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)
    +
    56        seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)
    @@ -194,7 +195,7 @@
    -
    57        idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta)
    +
    59        idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta)
    @@ -206,7 +207,7 @@
    -
    61        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
    +
    63        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
    @@ -218,7 +219,7 @@
    -
    65        neg_half_x = torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
    +
    67        neg_half_x = torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
    @@ -231,7 +232,7 @@
    -
    77        rx = (x * idx_theta2.cos()[None, :, None, :]) + (neg_half_x * idx_theta2.sin()[None, :, None, :])
    +
    79        rx = (x * idx_theta2.cos()[None, :, None, :]) + (neg_half_x * idx_theta2.sin()[None, :, None, :])
    @@ -243,7 +244,7 @@
    -
    80        return rx
    +
    82        return rx
    @@ -256,7 +257,7 @@
    -
    83class SelfAttention(nn.Module):
    +
    85class SelfAttention(nn.Module):
    @@ -275,7 +276,7 @@
    -
    90    def __init__(self, d_model: int, n_heads: int, d_k: int, is_causal: bool):
    +
    92    def __init__(self, d_model: int, n_heads: int, d_k: int, is_causal: bool):
    @@ -286,11 +287,11 @@
    -
    97        super().__init__()
    -98
    -99        self.is_causal = is_causal
    -100        self.n_heads = n_heads
    -101        self.d_k = d_k
    +
    99        super().__init__()
    +100
    +101        self.is_causal = is_causal
    +102        self.n_heads = n_heads
    +103        self.d_k = d_k
    @@ -313,7 +314,7 @@ M834 80h400000v40h-400000z">
    104        self.scale = 1 / math.sqrt(self.d_k)
    +
    106        self.scale = 1 / math.sqrt(self.d_k)
    @@ -325,9 +326,9 @@ M834 80h400000v40h-400000z">
    107        self.query = nn.Linear(d_model, n_heads * d_k)
    -108        self.key = nn.Linear(d_model, n_heads * d_k)
    -109        self.value = nn.Linear(d_model, n_heads * d_k)
    +
    109        self.query = nn.Linear(d_model, n_heads * d_k)
    +110        self.key = nn.Linear(d_model, n_heads * d_k)
    +111        self.value = nn.Linear(d_model, n_heads * d_k)
    @@ -339,7 +340,7 @@ M834 80h400000v40h-400000z">
    112        self.norm = nn.LayerNorm(d_model)
    +
    114        self.norm = nn.LayerNorm(d_model)
    @@ -351,7 +352,7 @@ M834 80h400000v40h-400000z">
    115        self.softmax = nn.Softmax(dim=-1)
    +
    117        self.softmax = nn.Softmax(dim=-1)
    @@ -363,7 +364,7 @@ M834 80h400000v40h-400000z">
    118        self.rotary_pe = RotaryPositionalEmbeddings(self.d_k)
    +
    120        self.rotary_pe = RotaryPositionalEmbeddings(self.d_k)
    @@ -375,7 +376,7 @@ M834 80h400000v40h-400000z">
    121        self.output = nn.Linear(n_heads * d_k, d_model)
    +
    123        self.output = nn.Linear(n_heads * d_k, d_model)
    @@ -390,7 +391,7 @@ M834 80h400000v40h-400000z">
    123    def mask_attention(self, attn: torch.Tensor):
    +
    125    def mask_attention(self, attn: torch.Tensor):
    @@ -402,8 +403,8 @@ M834 80h400000v40h-400000z">
    131        if not self.is_causal:
    -132            return attn
    +
    133        if not self.is_causal:
    +134            return attn
    @@ -415,7 +416,7 @@ M834 80h400000v40h-400000z">
    135        mask = torch.tril(attn.new_ones(attn.shape[-2:]))
    +
    137        mask = torch.tril(attn.new_ones(attn.shape[-2:]))
    @@ -427,7 +428,7 @@ M834 80h400000v40h-400000z">
    137        return attn.masked_fill(mask == 0, float('-inf'))
    +
    139        return attn.masked_fill(mask == 0, float('-inf'))
    @@ -441,7 +442,7 @@ M834 80h400000v40h-400000z">
    139    def forward(self, h: torch.Tensor):
    +
    141    def forward(self, h: torch.Tensor):
    @@ -453,7 +454,7 @@ M834 80h400000v40h-400000z">
    145        h_res = h
    +
    147        h_res = h
    @@ -465,7 +466,7 @@ M834 80h400000v40h-400000z">
    148        h = self.norm(h)
    +
    150        h = self.norm(h)
    @@ -478,10 +479,10 @@ M834 80h400000v40h-400000z">
    152        mh_shape = (*h.shape[:-1], self.n_heads, self.d_k)
    -153        q = self.query(h).view(mh_shape)
    -154        k = self.key(h).view(mh_shape)
    -155        v = self.value(h).view(mh_shape)
    +
    154        mh_shape = (*h.shape[:-1], self.n_heads, self.d_k)
    +155        q = self.query(h).view(mh_shape)
    +156        k = self.key(h).view(mh_shape)
    +157        v = self.value(h).view(mh_shape)
    @@ -493,8 +494,8 @@ M834 80h400000v40h-400000z">
    158        q = self.rotary_pe(q)
    -159        k = self.rotary_pe(k)
    +
    160        q = self.rotary_pe(q)
    +161        k = self.rotary_pe(k)
    @@ -506,7 +507,7 @@ M834 80h400000v40h-400000z">
    162        attn = torch.einsum('bihd,bjhd->bhij', q, k)
    +
    164        attn = torch.einsum('bihd,bjhd->bhij', q, k)
    @@ -529,7 +530,7 @@ M834 80h400000v40h-400000z">
    164        attn = attn * self.scale
    +
    166        attn = attn * self.scale
    @@ -541,7 +542,7 @@ M834 80h400000v40h-400000z">
    167        attn = self.mask_attention(attn)
    +
    169        attn = self.mask_attention(attn)
    @@ -553,7 +554,7 @@ M834 80h400000v40h-400000z">
    170        attn = self.softmax(attn)
    +
    172        attn = self.softmax(attn)
    @@ -565,7 +566,7 @@ M834 80h400000v40h-400000z">
    173        h = torch.einsum("bhij,bjhd->bihd", attn, v)
    +
    175        h = torch.einsum("bhij,bjhd->bihd", attn, v)
    @@ -579,7 +580,7 @@ M834 80h400000v40h-400000z">
    177        h = h.reshape(*h.shape[:-2], -1)
    +
    179        h = h.reshape(*h.shape[:-2], -1)
    @@ -592,7 +593,7 @@ M834 80h400000v40h-400000z">
    181        h = self.output(h)
    +
    183        h = self.output(h)
    @@ -604,7 +605,7 @@ M834 80h400000v40h-400000z">
    184        return h + h_res
    +
    186        return h + h_res
    @@ -619,7 +620,7 @@ M834 80h400000v40h-400000z">
    187class CrossAttention(nn.Module):
    +
    189class CrossAttention(nn.Module):
    @@ -636,7 +637,7 @@ M834 80h400000v40h-400000z">
    201    def __init__(self, d_model: int, n_heads: int, d_k: int):
    +
    203    def __init__(self, d_model: int, n_heads: int, d_k: int):
    @@ -647,10 +648,10 @@ M834 80h400000v40h-400000z">
    207        super().__init__()
    -208
    -209        self.n_heads = n_heads
    -210        self.d_k = d_k
    +
    209        super().__init__()
    +210
    +211        self.n_heads = n_heads
    +212        self.d_k = d_k
    @@ -673,7 +674,7 @@ M834 80h400000v40h-400000z">
    213        self.scale = 1 / math.sqrt(self.d_k)
    +
    215        self.scale = 1 / math.sqrt(self.d_k)
    @@ -685,9 +686,9 @@ M834 80h400000v40h-400000z">
    216        self.query = nn.Linear(d_model, n_heads * d_k)
    -217        self.key = nn.Linear(d_model, n_heads * d_k)
    -218        self.value = nn.Linear(d_model, n_heads * d_k)
    +
    218        self.query = nn.Linear(d_model, n_heads * d_k)
    +219        self.key = nn.Linear(d_model, n_heads * d_k)
    +220        self.value = nn.Linear(d_model, n_heads * d_k)
    @@ -699,7 +700,7 @@ M834 80h400000v40h-400000z">
    221        self.norm = nn.LayerNorm(d_model)
    +
    223        self.norm = nn.LayerNorm(d_model)
    @@ -711,7 +712,7 @@ M834 80h400000v40h-400000z">
    224        self.softmax = nn.Softmax(dim=-1)
    +
    226        self.softmax = nn.Softmax(dim=-1)
    @@ -723,7 +724,7 @@ M834 80h400000v40h-400000z">
    227        self.output = nn.Linear(n_heads * d_k, d_model)
    +
    229        self.output = nn.Linear(n_heads * d_k, d_model)
    @@ -740,7 +741,7 @@ M834 80h400000v40h-400000z">
    229    def forward(self, e: torch.Tensor, h: torch.Tensor):
    +
    231    def forward(self, e: torch.Tensor, h: torch.Tensor):
    @@ -752,7 +753,7 @@ M834 80h400000v40h-400000z">
    238        e_res = e
    +
    240        e_res = e
    @@ -764,7 +765,7 @@ M834 80h400000v40h-400000z">
    241        e = self.norm(e)
    +
    243        e = self.norm(e)
    @@ -776,7 +777,7 @@ M834 80h400000v40h-400000z">
    244        q = self.query(e).view(*e.shape[:-1], self.n_heads, self.d_k)
    +
    246        q = self.query(e).view(*e.shape[:-1], self.n_heads, self.d_k)
    @@ -788,8 +789,8 @@ M834 80h400000v40h-400000z">
    246        k = self.key(h).view(*h.shape[:-1], self.n_heads, self.d_k)
    -247        v = self.value(h).view(*h.shape[:-1], self.n_heads, self.d_k)
    +
    248        k = self.key(h).view(*h.shape[:-1], self.n_heads, self.d_k)
    +249        v = self.value(h).view(*h.shape[:-1], self.n_heads, self.d_k)
    @@ -802,7 +803,7 @@ M834 80h400000v40h-400000z">
    252        attn = torch.einsum('bcnihd,bcjhd->bcnhij', q, k)
    +
    254        attn = torch.einsum('bcnihd,bcjhd->bcnhij', q, k)
    @@ -814,7 +815,7 @@ M834 80h400000v40h-400000z">
    254        attn = attn * self.scale
    +
    256        attn = attn * self.scale
    @@ -826,7 +827,7 @@ M834 80h400000v40h-400000z">
    257        attn = self.softmax(attn)
    +
    259        attn = self.softmax(attn)
    @@ -838,7 +839,7 @@ M834 80h400000v40h-400000z">
    260        e = torch.einsum("bcnhij,bcjhd->bcnihd", attn, v)
    +
    262        e = torch.einsum("bcnhij,bcjhd->bcnihd", attn, v)
    @@ -852,7 +853,7 @@ M834 80h400000v40h-400000z">
    264        e = e.reshape(*e.shape[:-2], -1)
    +
    266        e = e.reshape(*e.shape[:-2], -1)
    @@ -865,7 +866,7 @@ M834 80h400000v40h-400000z">
    268        e = self.output(e)
    +
    270        e = self.output(e)
    @@ -877,7 +878,7 @@ M834 80h400000v40h-400000z">
    271        return e + e_res
    +
    273        return e + e_res
    @@ -892,7 +893,7 @@ M834 80h400000v40h-400000z">
    274class ChunkedCrossAttention(nn.Module):
    +
    276class ChunkedCrossAttention(nn.Module):
    @@ -911,7 +912,7 @@ M834 80h400000v40h-400000z">
    286    def __init__(self, d_model: int, n_heads: int, d_k: int, chunk_len: int):
    +
    288    def __init__(self, d_model: int, n_heads: int, d_k: int, chunk_len: int):
    @@ -922,11 +923,11 @@ M834 80h400000v40h-400000z">
    294        super().__init__()
    -295
    -296        self.chunk_len = chunk_len
    -297        self.n_heads = n_heads
    -298        self.d_k = d_k
    +
    296        super().__init__()
    +297
    +298        self.chunk_len = chunk_len
    +299        self.n_heads = n_heads
    +300        self.d_k = d_k
    @@ -949,7 +950,7 @@ M834 80h400000v40h-400000z">
    301        self.scale = 1 / math.sqrt(self.d_k)
    +
    303        self.scale = 1 / math.sqrt(self.d_k)
    @@ -961,9 +962,9 @@ M834 80h400000v40h-400000z">
    304        self.query = nn.Linear(d_model, n_heads * d_k)
    -305        self.key = nn.Linear(d_model, n_heads * d_k)
    -306        self.value = nn.Linear(d_model, n_heads * d_k)
    +
    306        self.query = nn.Linear(d_model, n_heads * d_k)
    +307        self.key = nn.Linear(d_model, n_heads * d_k)
    +308        self.value = nn.Linear(d_model, n_heads * d_k)
    @@ -975,7 +976,7 @@ M834 80h400000v40h-400000z">
    309        self.norm = nn.LayerNorm(d_model)
    +
    311        self.norm = nn.LayerNorm(d_model)
    @@ -987,7 +988,7 @@ M834 80h400000v40h-400000z">
    312        self.softmax = nn.Softmax(dim=-1)
    +
    314        self.softmax = nn.Softmax(dim=-1)
    @@ -999,7 +1000,7 @@ M834 80h400000v40h-400000z">
    315        self.output = nn.Linear(n_heads * d_k, d_model)
    +
    317        self.output = nn.Linear(n_heads * d_k, d_model)
    @@ -1015,7 +1016,7 @@ M834 80h400000v40h-400000z">
    317    def forward(self, h: torch.Tensor, e: torch.Tensor):
    +
    319    def forward(self, h: torch.Tensor, e: torch.Tensor):
    @@ -1027,7 +1028,7 @@ M834 80h400000v40h-400000z">
    324        batch_size, chunks, neighbors, neighbor_len, d_model = e.shape
    +
    326        batch_size, chunks, neighbors, neighbor_len, d_model = e.shape
    @@ -1039,8 +1040,8 @@ M834 80h400000v40h-400000z">
    327        if chunks == 0:
    -328            return h
    +
    329        if chunks == 0:
    +330            return h
    @@ -1052,7 +1053,7 @@ M834 80h400000v40h-400000z">
    331        h_res = h
    +
    333        h_res = h
    @@ -1066,7 +1067,7 @@ M834 80h400000v40h-400000z">
    339        h = h[:, self.chunk_len - 1:]
    +
    341        h = h[:, self.chunk_len - 1:]
    @@ -1078,7 +1079,7 @@ M834 80h400000v40h-400000z">
    341        h = self.norm(h)
    +
    343        h = self.norm(h)
    @@ -1090,8 +1091,8 @@ M834 80h400000v40h-400000z">
    343        if h.shape[1] < chunks * self.chunk_len:
    -344            h = torch.cat((h, h.new_zeros(batch_size, chunks * self.chunk_len - h.shape[1], d_model)), dim=1)
    +
    345        if h.shape[1] < chunks * self.chunk_len:
    +346            h = torch.cat((h, h.new_zeros(batch_size, chunks * self.chunk_len - h.shape[1], d_model)), dim=1)
    @@ -1103,7 +1104,7 @@ M834 80h400000v40h-400000z">
    346        h = h.reshape(batch_size, chunks, self.chunk_len, d_model)
    +
    348        h = h.reshape(batch_size, chunks, self.chunk_len, d_model)
    @@ -1115,7 +1116,7 @@ M834 80h400000v40h-400000z">
    349        q = self.query(h).view(*h.shape[:-1], self.n_heads, self.d_k)
    +
    351        q = self.query(h).view(*h.shape[:-1], self.n_heads, self.d_k)
    @@ -1127,8 +1128,8 @@ M834 80h400000v40h-400000z">
    351        k = self.key(e).view(*e.shape[:-1], self.n_heads, self.d_k)
    -352        v = self.value(e).view(*e.shape[:-1], self.n_heads, self.d_k)
    +
    353        k = self.key(e).view(*e.shape[:-1], self.n_heads, self.d_k)
    +354        v = self.value(e).view(*e.shape[:-1], self.n_heads, self.d_k)
    @@ -1141,7 +1142,7 @@ M834 80h400000v40h-400000z">
    357        attn = torch.einsum('bcihd,bcnjhd->bchinj', q, k)
    +
    359        attn = torch.einsum('bcihd,bcnjhd->bchinj', q, k)
    @@ -1153,7 +1154,7 @@ M834 80h400000v40h-400000z">
    359        attn = attn * self.scale
    +
    361        attn = attn * self.scale
    @@ -1166,7 +1167,7 @@ M834 80h400000v40h-400000z">
    362        attn = self.softmax(attn.view(*attn.shape[:-2], -1)).view(attn.shape)
    +
    364        attn = self.softmax(attn.view(*attn.shape[:-2], -1)).view(attn.shape)
    @@ -1178,7 +1179,7 @@ M834 80h400000v40h-400000z">
    365        h = torch.einsum("bchinj,bcnjhd->bcihd", attn, v)
    +
    367        h = torch.einsum("bchinj,bcnjhd->bcihd", attn, v)
    @@ -1192,7 +1193,7 @@ M834 80h400000v40h-400000z">
    369        h = h.reshape(batch_size, chunks * self.chunk_len, -1)
    +
    371        h = h.reshape(batch_size, chunks * self.chunk_len, -1)
    @@ -1205,7 +1206,7 @@ M834 80h400000v40h-400000z">
    373        h = self.output(h)
    +
    375        h = self.output(h)
    @@ -1218,7 +1219,7 @@ M834 80h400000v40h-400000z">
    376        h = torch.cat((h.new_zeros(batch_size, self.chunk_len - 1, d_model), h), dim=1)
    +
    378        h = torch.cat((h.new_zeros(batch_size, self.chunk_len - 1, d_model), h), dim=1)
    @@ -1230,7 +1231,7 @@ M834 80h400000v40h-400000z">
    379        return h[:, :h_res.shape[1]] + h_res
    +
    381        return h[:, :h_res.shape[1]] + h_res
    @@ -1243,7 +1244,7 @@ M834 80h400000v40h-400000z">
    382class FeedForward(nn.Module):
    +
    384class FeedForward(nn.Module):
    @@ -1258,7 +1259,7 @@ M834 80h400000v40h-400000z">
    389    def __init__(self, d_model: int, d_ff: int):
    +
    391    def __init__(self, d_model: int, d_ff: int):
    @@ -1269,7 +1270,7 @@ M834 80h400000v40h-400000z">
    395        super().__init__()
    +
    397        super().__init__()
    @@ -1281,8 +1282,8 @@ M834 80h400000v40h-400000z">
    398        self.lin1 = nn.Linear(d_model, d_ff)
    -399        self.lin2 = nn.Linear(d_ff, d_model)
    +
    400        self.lin1 = nn.Linear(d_model, d_ff)
    +401        self.lin2 = nn.Linear(d_ff, d_model)
    @@ -1294,7 +1295,7 @@ M834 80h400000v40h-400000z">
    402        self.act = nn.ReLU()
    +
    404        self.act = nn.ReLU()
    @@ -1306,7 +1307,7 @@ M834 80h400000v40h-400000z">
    405        self.norm = nn.LayerNorm(d_model)
    +
    407        self.norm = nn.LayerNorm(d_model)
    @@ -1320,7 +1321,7 @@ M834 80h400000v40h-400000z">
    407    def forward(self, h: torch.Tensor):
    +
    409    def forward(self, h: torch.Tensor):
    @@ -1332,7 +1333,7 @@ M834 80h400000v40h-400000z">
    413        h_res = h
    +
    415        h_res = h
    @@ -1344,7 +1345,7 @@ M834 80h400000v40h-400000z">
    415        h = self.norm(h)
    +
    417        h = self.norm(h)
    @@ -1356,7 +1357,7 @@ M834 80h400000v40h-400000z">
    417        h = self.lin1(h)
    +
    419        h = self.lin1(h)
    @@ -1368,7 +1369,7 @@ M834 80h400000v40h-400000z">
    419        h = self.act(h)
    +
    421        h = self.act(h)
    @@ -1380,7 +1381,7 @@ M834 80h400000v40h-400000z">
    421        h = self.lin2(h)
    +
    423        h = self.lin2(h)
    @@ -1392,7 +1393,7 @@ M834 80h400000v40h-400000z">
    424        return h + h_res
    +
    426        return h + h_res
    @@ -1405,7 +1406,7 @@ M834 80h400000v40h-400000z">
    427class NearestNeighborEncoder(nn.Module):
    +
    429class NearestNeighborEncoder(nn.Module):
    @@ -1430,8 +1431,8 @@ M834 80h400000v40h-400000z">
    434    def __init__(self, chunk_len: int, n_layers: int, ca_layers: Set[int],
    -435                 d_model: int, n_heads: int, d_k: int, d_ff: int):
    +
    436    def __init__(self, chunk_len: int, n_layers: int, ca_layers: Set[int],
    +437                 d_model: int, n_heads: int, d_k: int, d_ff: int):
    @@ -1442,9 +1443,9 @@ M834 80h400000v40h-400000z">
    446        super().__init__()
    -447        self.ca_layers = ca_layers
    -448        self.chunk_len = chunk_len
    +
    448        super().__init__()
    +449        self.ca_layers = ca_layers
    +450        self.chunk_len = chunk_len
    @@ -1456,7 +1457,7 @@ M834 80h400000v40h-400000z">
    450        self.ca = nn.ModuleList([CrossAttention(d_model, n_heads, d_k) for _ in range(len(ca_layers))])
    +
    452        self.ca = nn.ModuleList([CrossAttention(d_model, n_heads, d_k) for _ in range(len(ca_layers))])
    @@ -1468,7 +1469,7 @@ M834 80h400000v40h-400000z">
    452        self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=False) for _ in range(n_layers)])
    +
    454        self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=False) for _ in range(n_layers)])
    @@ -1480,7 +1481,7 @@ M834 80h400000v40h-400000z">
    454        self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])
    +
    456        self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])
    @@ -1492,7 +1493,7 @@ M834 80h400000v40h-400000z">
    457        self.norm_h = nn.LayerNorm(d_model)
    +
    459        self.norm_h = nn.LayerNorm(d_model)
    @@ -1510,7 +1511,7 @@ M834 80h400000v40h-400000z">
    459    def forward(self, e: torch.Tensor, h: torch.Tensor):
    +
    461    def forward(self, e: torch.Tensor, h: torch.Tensor):
    @@ -1522,7 +1523,7 @@ M834 80h400000v40h-400000z">
    472        batch_size, chunks, neighbors, neighbor_len, d_model = e.shape
    +
    474        batch_size, chunks, neighbors, neighbor_len, d_model = e.shape
    @@ -1534,7 +1535,7 @@ M834 80h400000v40h-400000z">
    475        h_split = h[:, :self.chunk_len * chunks, :].reshape(batch_size, chunks, self.chunk_len, d_model)
    +
    477        h_split = h[:, :self.chunk_len * chunks, :].reshape(batch_size, chunks, self.chunk_len, d_model)
    @@ -1546,7 +1547,7 @@ M834 80h400000v40h-400000z">
    478        h_split = self.norm_h(h_split)
    +
    480        h_split = self.norm_h(h_split)
    @@ -1558,7 +1559,7 @@ M834 80h400000v40h-400000z">
    481        p_ca = 0
    +
    483        p_ca = 0
    @@ -1570,7 +1571,7 @@ M834 80h400000v40h-400000z">
    483        for p in range(len(self.attn)):
    +
    485        for p in range(len(self.attn)):
    @@ -1582,7 +1583,7 @@ M834 80h400000v40h-400000z">
    486            e = self.attn[p](e.view(-1, neighbor_len, d_model)).view(e.shape)
    +
    488            e = self.attn[p](e.view(-1, neighbor_len, d_model)).view(e.shape)
    @@ -1594,7 +1595,7 @@ M834 80h400000v40h-400000z">
    489            if p in self.ca_layers:
    +
    491            if p in self.ca_layers:
    @@ -1606,7 +1607,7 @@ M834 80h400000v40h-400000z">
    491                e = self.ca[p_ca](e, h_split)
    +
    493                e = self.ca[p_ca](e, h_split)
    @@ -1618,7 +1619,7 @@ M834 80h400000v40h-400000z">
    493                p_ca += 1
    +
    495                p_ca += 1
    @@ -1630,7 +1631,7 @@ M834 80h400000v40h-400000z">
    496            e = self.ffw[p](e)
    +
    498            e = self.ffw[p](e)
    @@ -1642,7 +1643,7 @@ M834 80h400000v40h-400000z">
    499        return e
    +
    501        return e
    @@ -1655,7 +1656,7 @@ M834 80h400000v40h-400000z">
    502class RetroModel(nn.Module):
    +
    504class RetroModel(nn.Module):
    @@ -1684,8 +1685,8 @@ M834 80h400000v40h-400000z">
    509    def __init__(self, n_vocab: int, d_model: int, n_layers: int, ca_layers: Set[int], chunk_len: int,
    -510                 n_heads: int, d_k: int, d_ff: int, encoder: NearestNeighborEncoder):
    +
    511    def __init__(self, n_vocab: int, d_model: int, n_layers: int, ca_layers: Set[int], chunk_len: int,
    +512                 n_heads: int, d_k: int, d_ff: int, encoder: NearestNeighborEncoder):
    @@ -1696,10 +1697,10 @@ M834 80h400000v40h-400000z">
    522        super().__init__()
    -523
    -524        self.ca_layers = ca_layers
    -525        self.encoder = encoder
    +
    524        super().__init__()
    +525
    +526        self.ca_layers = ca_layers
    +527        self.encoder = encoder
    @@ -1711,7 +1712,7 @@ M834 80h400000v40h-400000z">
    528        self.emb = nn.Embedding(n_vocab, d_model)
    +
    530        self.emb = nn.Embedding(n_vocab, d_model)
    @@ -1723,8 +1724,8 @@ M834 80h400000v40h-400000z">
    530        self.cca = nn.ModuleList(
    -531            [ChunkedCrossAttention(d_model, n_heads, d_k, chunk_len) for _ in range(len(ca_layers))])
    +
    532        self.cca = nn.ModuleList(
    +533            [ChunkedCrossAttention(d_model, n_heads, d_k, chunk_len) for _ in range(len(ca_layers))])
    @@ -1736,7 +1737,7 @@ M834 80h400000v40h-400000z">
    533        self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=True) for _ in range(n_layers)])
    +
    535        self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=True) for _ in range(n_layers)])
    @@ -1748,7 +1749,7 @@ M834 80h400000v40h-400000z">
    535        self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])
    +
    537        self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])
    @@ -1760,7 +1761,7 @@ M834 80h400000v40h-400000z">
    537        self.read = nn.Linear(d_model, n_vocab)
    +
    539        self.read = nn.Linear(d_model, n_vocab)
    @@ -1772,7 +1773,7 @@ M834 80h400000v40h-400000z">
    541        self.norm_e = nn.LayerNorm(d_model)
    +
    543        self.norm_e = nn.LayerNorm(d_model)
    @@ -1789,7 +1790,7 @@ M834 80h400000v40h-400000z">
    543    def forward(self, x: torch.Tensor, ret: torch.Tensor):
    +
    545    def forward(self, x: torch.Tensor, ret: torch.Tensor):
    @@ -1801,7 +1802,7 @@ M834 80h400000v40h-400000z">
    552        h = self.emb(x)
    +
    554        h = self.emb(x)
    @@ -1814,7 +1815,7 @@ M834 80h400000v40h-400000z">
    558        ret_emb = self.emb(ret)
    +
    560        ret_emb = self.emb(ret)
    @@ -1826,7 +1827,7 @@ M834 80h400000v40h-400000z">
    561        p_ca = 0
    +
    563        p_ca = 0
    @@ -1838,7 +1839,7 @@ M834 80h400000v40h-400000z">
    563        for p in range(len(self.attn)):
    +
    565        for p in range(len(self.attn)):
    @@ -1850,7 +1851,7 @@ M834 80h400000v40h-400000z">
    565            h = self.attn[p](h)
    +
    567            h = self.attn[p](h)
    @@ -1862,7 +1863,7 @@ M834 80h400000v40h-400000z">
    569            if self.ca_layers and p == min(self.ca_layers):
    +
    571            if self.ca_layers and p == min(self.ca_layers):
    @@ -1875,7 +1876,7 @@ M834 80h400000v40h-400000z">
    573                e = self.encoder(ret_emb, h)
    +
    575                e = self.encoder(ret_emb, h)
    @@ -1887,7 +1888,7 @@ M834 80h400000v40h-400000z">
    575                e = self.norm_e(e)
    +
    577                e = self.norm_e(e)
    @@ -1899,7 +1900,7 @@ M834 80h400000v40h-400000z">
    578            if p in self.ca_layers:
    +
    580            if p in self.ca_layers:
    @@ -1911,7 +1912,7 @@ M834 80h400000v40h-400000z">
    580                h = self.cca[p_ca](h, e)
    +
    582                h = self.cca[p_ca](h, e)
    @@ -1923,7 +1924,7 @@ M834 80h400000v40h-400000z">
    582                p_ca += 1
    +
    584                p_ca += 1
    @@ -1935,7 +1936,7 @@ M834 80h400000v40h-400000z">
    585            h = self.ffw[p](h)
    +
    587            h = self.ffw[p](h)
    @@ -1947,7 +1948,7 @@ M834 80h400000v40h-400000z">
    588        return self.read(h)
    +
    590        return self.read(h)
    @@ -1959,7 +1960,7 @@ M834 80h400000v40h-400000z">
    591def _test():
    +
    593def _test():
    @@ -1970,26 +1971,26 @@ M834 80h400000v40h-400000z">
    595    chunk_len = 4
    -596    d_model = 8
    -597    d_ff = 32
    -598    n_heads = 2
    -599    d_k = 4
    -600
    -601    device = torch.device('cuda:0')
    +            
    597    chunk_len = 4
    +598    d_model = 8
    +599    d_ff = 32
    +600    n_heads = 2
    +601    d_k = 4
     602
    -603    m = RetroModel(5, d_model, 6, {2, 5}, chunk_len, n_heads, d_k, d_ff,
    -604                   encoder=NearestNeighborEncoder(chunk_len, 2, {1}, d_model, n_heads, d_k, d_ff))
    -605
    -606    m.to(device)
    -607    x = [1, 2, 4, 4, 0, 1, 2, 3, 4, 3]
    -608    ret = [
    -609        [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
    -610        [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
    -611    ]
    -612    res = m(torch.tensor([x] * 10).to(device), torch.tensor([ret] * 10).to(device))
    -613
    -614    inspect(res)
    +603 device = torch.device('cuda:0') +604 +605 m = RetroModel(5, d_model, 6, {2, 5}, chunk_len, n_heads, d_k, d_ff, +606 encoder=NearestNeighborEncoder(chunk_len, 2, {1}, d_model, n_heads, d_k, d_ff)) +607 +608 m.to(device) +609 x = [1, 2, 4, 4, 0, 1, 2, 3, 4, 3] +610 ret = [ +611 [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]], +612 [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]], +613 ] +614 res = m(torch.tensor([x] * 10).to(device), torch.tensor([ret] * 10).to(device)) +615 +616 inspect(res)
    @@ -2001,8 +2002,8 @@ M834 80h400000v40h-400000z">
    618if __name__ == '__main__':
    -619    _test()
    +
    620if __name__ == '__main__':
    +621    _test()

    RETRO training

    This is the training code for RETRO.

    +

    View Run

    -
    14import torch
    -15from torch import nn
    -16from torch.utils.data import DataLoader, RandomSampler
    -17
    -18from labml import monit, lab, tracker, experiment, logger
    -19from labml.logger import Text
    -20from labml_helpers.datasets.text import TextFileDataset
    -21from labml_nn.optimizers.noam import Noam
    -22from labml_nn.transformers.retro import model as retro
    -23from labml_nn.transformers.retro.dataset import Dataset, RetroIndex
    -24from labml_nn.transformers.retro.model import RetroModel, NearestNeighborEncoder
    +
    16import torch
    +17from torch import nn
    +18from torch.utils.data import DataLoader, RandomSampler
    +19
    +20from labml import monit, lab, tracker, experiment, logger
    +21from labml.logger import Text
    +22from labml_helpers.datasets.text import TextFileDataset
    +23from labml_nn.optimizers.noam import Noam
    +24from labml_nn.transformers.retro import model as retro
    +25from labml_nn.transformers.retro.dataset import Dataset, RetroIndex
    +26from labml_nn.transformers.retro.model import RetroModel, NearestNeighborEncoder
    @@ -97,7 +98,7 @@
    -
    27class Sampler:
    +
    29class Sampler:
    @@ -116,7 +117,7 @@
    -
    34    def __init__(self, device: torch.device, model: retro.RetroModel, tds: TextFileDataset, chunk_len: int):
    +
    36    def __init__(self, device: torch.device, model: retro.RetroModel, tds: TextFileDataset, chunk_len: int):
    @@ -127,10 +128,10 @@
    -
    41        self.chunk_len = chunk_len
    -42        self.tds = tds
    -43        self.model = model
    -44        self.device = device
    +
    43        self.chunk_len = chunk_len
    +44        self.tds = tds
    +45        self.model = model
    +46        self.device = device
    @@ -142,7 +143,7 @@
    -
    47        self.index = RetroIndex()
    +
    49        self.index = RetroIndex()
    @@ -154,7 +155,7 @@
    -
    49    def retrieve_nearest_neighbours(self, chunk: str):
    +
    51    def retrieve_nearest_neighbours(self, chunk: str):
    @@ -166,7 +167,7 @@
    -
    55        neighbor_offsets = self.index([chunk], None)
    +
    57        neighbor_offsets = self.index([chunk], None)
    @@ -179,8 +180,8 @@
    -
    58        text = self.tds.train
    -59        neighbors = [text[j: j + self.chunk_len * 2] for j in neighbor_offsets[0]]
    +
    60        text = self.tds.train
    +61        neighbors = [text[j: j + self.chunk_len * 2] for j in neighbor_offsets[0]]
    @@ -192,7 +193,7 @@
    -
    62        return neighbors
    +
    64        return neighbors
    @@ -204,7 +205,7 @@
    -
    64    def sample(self, prompt: str, sample_len: int):
    +
    66    def sample(self, prompt: str, sample_len: int):
    @@ -216,7 +217,7 @@
    -
    70        neighbors_str = []
    +
    72        neighbors_str = []
    @@ -228,7 +229,7 @@
    -
    73        sampled = ''
    +
    75        sampled = ''
    @@ -241,7 +242,7 @@
    -
    76        for i in range(sample_len):
    +
    78        for i in range(sample_len):
    @@ -253,7 +254,7 @@
    -
    79            while len(neighbors_str) < len(prompt) // self.chunk_len:
    +
    81            while len(neighbors_str) < len(prompt) // self.chunk_len:
    @@ -265,8 +266,8 @@
    -
    81                off = len(neighbors_str) * self.chunk_len
    -82                chunk = prompt[off: off + self.chunk_len]
    +
    83                off = len(neighbors_str) * self.chunk_len
    +84                chunk = prompt[off: off + self.chunk_len]
    @@ -278,7 +279,7 @@
    -
    84                neighbors_str.append(self.retrieve_nearest_neighbours(chunk))
    +
    86                neighbors_str.append(self.retrieve_nearest_neighbours(chunk))
    @@ -290,7 +291,7 @@
    -
    87            src = self.tds.text_to_i(prompt)
    +
    89            src = self.tds.text_to_i(prompt)
    @@ -302,7 +303,7 @@
    -
    89            neighbors = torch.stack([torch.stack([self.tds.text_to_i(n) for n in chunk]) for chunk in neighbors_str])
    +
    91            neighbors = torch.stack([torch.stack([self.tds.text_to_i(n) for n in chunk]) for chunk in neighbors_str])
    @@ -314,8 +315,8 @@
    -
    92            src = src.to(self.device)
    -93            neighbors = neighbors.to(self.device)
    +
    94            src = src.to(self.device)
    +95            neighbors = neighbors.to(self.device)
    @@ -327,7 +328,7 @@
    -
    96            res = self.model(src[None, :], neighbors[None, :, :, :])
    +
    98            res = self.model(src[None, :], neighbors[None, :, :, :])
    @@ -339,7 +340,7 @@
    -
    99            token = res[0, -1, :].argmax(dim=-1)
    +
    101            token = res[0, -1, :].argmax(dim=-1)
    @@ -351,8 +352,8 @@
    -
    102            prompt += self.tds.itos[token.item()]
    -103            sampled += self.tds.itos[token.item()]
    +
    104            prompt += self.tds.itos[token.item()]
    +105            sampled += self.tds.itos[token.item()]
    @@ -364,7 +365,7 @@
    -
    106        return sampled
    +
    108        return sampled
    @@ -376,7 +377,7 @@
    -
    109class Trainer:
    +
    111class Trainer:
    @@ -395,8 +396,8 @@
    -
    114    def __init__(self, device: torch.device, model: retro.RetroModel,
    -115                 dataloader: DataLoader, optimizer: torch.optim.Optimizer):
    +
    116    def __init__(self, device: torch.device, model: retro.RetroModel,
    +117                 dataloader: DataLoader, optimizer: torch.optim.Optimizer):
    @@ -407,11 +408,11 @@
    -
    122        self.optimizer = optimizer
    -123        self.device = device
    -124        self.dataloader = dataloader
    -125        self.model = model
    -126        self.loss_func = nn.CrossEntropyLoss()
    +
    124        self.optimizer = optimizer
    +125        self.device = device
    +126        self.dataloader = dataloader
    +127        self.model = model
    +128        self.loss_func = nn.CrossEntropyLoss()
    @@ -423,7 +424,7 @@
    -
    128    def __call__(self):
    +
    130    def __call__(self):
    @@ -435,7 +436,7 @@
    -
    134        for i, (src, tgt, neighbors) in monit.enum('Train', self.dataloader):
    +
    136        for i, (src, tgt, neighbors) in monit.enum('Train', self.dataloader):
    @@ -447,7 +448,7 @@
    -
    136            src, tgt, neighbors = src.to(self.device), tgt.to(self.device), neighbors.to(self.device)
    +
    138            src, tgt, neighbors = src.to(self.device), tgt.to(self.device), neighbors.to(self.device)
    @@ -459,7 +460,7 @@
    -
    139            res = self.model(src, neighbors)
    +
    141            res = self.model(src, neighbors)
    @@ -471,7 +472,7 @@
    -
    141            loss = self.loss_func(res.view(-1, res.shape[-1]), tgt.view(-1))
    +
    143            loss = self.loss_func(res.view(-1, res.shape[-1]), tgt.view(-1))
    @@ -483,7 +484,7 @@
    -
    144            self.optimizer.zero_grad()
    +
    146            self.optimizer.zero_grad()
    @@ -495,7 +496,7 @@
    -
    146            loss.backward()
    +
    148            loss.backward()
    @@ -507,7 +508,7 @@
    -
    148            self.optimizer.step()
    +
    150            self.optimizer.step()
    @@ -519,8 +520,8 @@
    -
    151            tracker.save({'loss.train': loss})
    -152            tracker.add_global_step(len(src))
    +
    153            tracker.save({'loss.train': loss})
    +154            tracker.add_global_step(len(src))
    @@ -532,7 +533,7 @@
    -
    155def train():
    +
    157def train():
    @@ -544,7 +545,7 @@
    -
    161    experiment.create(name='retro_small')
    +
    163    experiment.create(name='retro_small')
    @@ -556,7 +557,7 @@
    -
    164    device = torch.device('cuda:0')
    +
    166    device = torch.device('cuda:0')
    @@ -568,10 +569,10 @@
    -
    167    tds = TextFileDataset(
    -168        lab.get_data_path() / 'tiny_shakespeare.txt',
    -169        list,
    -170        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
    +
    169    tds = TextFileDataset(
    +170        lab.get_data_path() / 'tiny_shakespeare.txt',
    +171        list,
    +172        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
    @@ -583,7 +584,7 @@
    -
    173    train_dataset = Dataset(lab.get_data_path() / 'retro_train_dataset.json', tds)
    +
    175    train_dataset = Dataset(lab.get_data_path() / 'retro_train_dataset.json', tds)
    @@ -595,9 +596,9 @@
    -
    176    train_dl = DataLoader(train_dataset,
    -177                          batch_size=4,
    -178                          sampler=RandomSampler(train_dataset, replacement=True))
    +
    178    train_dl = DataLoader(train_dataset,
    +179                          batch_size=4,
    +180                          sampler=RandomSampler(train_dataset, replacement=True))
    @@ -609,11 +610,11 @@
    -
    181    chunk_len = 16
    -182    d_model = 128
    -183    d_ff = 512
    -184    n_heads = 16
    -185    d_k = 16
    +
    183    chunk_len = 16
    +184    d_model = 128
    +185    d_ff = 512
    +186    n_heads = 16
    +187    d_k = 16
    @@ -625,7 +626,7 @@
    -
    188    nearest_neighbor_encoder = NearestNeighborEncoder(chunk_len, 6, {3}, d_model, n_heads, d_k, d_ff)
    +
    190    nearest_neighbor_encoder = NearestNeighborEncoder(chunk_len, 6, {3}, d_model, n_heads, d_k, d_ff)
    @@ -637,10 +638,10 @@
    -
    190    model = RetroModel(tds.n_tokens, d_model, 6,
    -191                       {3, 5},
    -192                       chunk_len, n_heads, d_k, d_ff,
    -193                       encoder=nearest_neighbor_encoder)
    +
    192    model = RetroModel(tds.n_tokens, d_model, 6,
    +193                       {3, 5},
    +194                       chunk_len, n_heads, d_k, d_ff,
    +195                       encoder=nearest_neighbor_encoder)
    @@ -652,7 +653,7 @@
    -
    195    model = model.to(device)
    +
    197    model = model.to(device)
    @@ -664,7 +665,7 @@
    -
    197    optimizer = Noam(model.parameters(), lr=1., d_model=d_model, warmup=2_000)
    +
    199    optimizer = Noam(model.parameters(), lr=1., d_model=d_model, warmup=2_000)
    @@ -677,7 +678,7 @@
    -
    199    trainer = Trainer(device, model, train_dl, optimizer)
    +
    201    trainer = Trainer(device, model, train_dl, optimizer)
    @@ -690,7 +691,7 @@
    -
    201    sampler = Sampler(device, model, tds, chunk_len)
    +
    203    sampler = Sampler(device, model, tds, chunk_len)
    @@ -702,7 +703,7 @@
    -
    203    prompt = '''Second Citizen:\nOne word, good citizens.\n\nFirst Citizen:'''
    +
    205    prompt = '''Second Citizen:\nOne word, good citizens.\n\nFirst Citizen:'''
    @@ -714,7 +715,7 @@
    -
    206    experiment.add_pytorch_models(model=model)
    +
    208    experiment.add_pytorch_models(model=model)
    @@ -726,7 +727,7 @@
    -
    209    with experiment.start():
    +
    211    with experiment.start():
    @@ -739,7 +740,7 @@
    -
    211        for epoch in monit.loop(32):
    +
    213        for epoch in monit.loop(32):
    @@ -751,7 +752,7 @@
    -
    213            trainer()
    +
    215            trainer()
    @@ -763,7 +764,7 @@
    -
    215            tracker.new_line()
    +
    217            tracker.new_line()
    @@ -776,8 +777,8 @@
    -
    217            logger.log([(prompt.replace('\n', '\\n\n'), Text.subtle),
    -218                        (sampler.sample(prompt, 128).replace('\n', '\\n\n'), Text.none)])
    +
    219            logger.log([(prompt.replace('\n', '\\n\n'), Text.subtle),
    +220                        (sampler.sample(prompt, 128).replace('\n', '\\n\n'), Text.none)])
    @@ -789,7 +790,7 @@
    -
    220            experiment.save_checkpoint()
    +
    222            experiment.save_checkpoint()
    @@ -801,8 +802,8 @@
    -
    224if __name__ == '__main__':
    -225    train()
    +
    226if __name__ == '__main__':
    +227    train()