diff --git a/docs/transformers/retro/index.html b/docs/transformers/retro/index.html index 6739b1e0..0fab694d 100644 --- a/docs/transformers/retro/index.html +++ b/docs/transformers/retro/index.html @@ -80,6 +80,7 @@
This is the model definition for RETRO.
+14import math
-15from typing import Set
-16
-17import torch
-18from torch import nn
-19
-20from labml.logger import inspect16import math
+17from typing import Set
+18
+19import torch
+20from torch import nn
+21
+22from labml.logger import inspect23class RotaryPositionalEmbeddings(nn.Module):25class RotaryPositionalEmbeddings(nn.Module):34 def __init__(self, d: int, base: int = 10_000):36 def __init__(self, d: int, base: int = 10_000):39 super().__init__()41 super().__init__()41 self.theta = nn.Parameter(1. / (base ** (torch.arange(0, d, 2).float() / d)), requires_grad=False)43 self.theta = nn.Parameter(1. / (base ** (torch.arange(0, d, 2).float() / d)), requires_grad=False)43 def forward(self, x: torch.Tensor):45 def forward(self, x: torch.Tensor):48 batch_size, seq_len, n_heads, d = x.shape50 batch_size, seq_len, n_heads, d = x.shape51 d_2 = d // 253 d_2 = d // 254 seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)56 seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)57 idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta)59 idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta)61 idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)63 idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)65 neg_half_x = torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)67 neg_half_x = torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)77 rx = (x * idx_theta2.cos()[None, :, None, :]) + (neg_half_x * idx_theta2.sin()[None, :, None, :])79 rx = (x * idx_theta2.cos()[None, :, None, :]) + (neg_half_x * idx_theta2.sin()[None, :, None, :])80 return rx82 return rx83class SelfAttention(nn.Module):85class SelfAttention(nn.Module):90 def __init__(self, d_model: int, n_heads: int, d_k: int, is_causal: bool):92 def __init__(self, d_model: int, n_heads: int, d_k: int, is_causal: bool):97 super().__init__()
-98
-99 self.is_causal = is_causal
-100 self.n_heads = n_heads
-101 self.d_k = d_k99 super().__init__()
+100
+101 self.is_causal = is_causal
+102 self.n_heads = n_heads
+103 self.d_k = d_k104 self.scale = 1 / math.sqrt(self.d_k)
106 self.scale = 1 / math.sqrt(self.d_k)107 self.query = nn.Linear(d_model, n_heads * d_k) -108 self.key = nn.Linear(d_model, n_heads * d_k) -109 self.value = nn.Linear(d_model, n_heads * d_k)
109 self.query = nn.Linear(d_model, n_heads * d_k)
+110 self.key = nn.Linear(d_model, n_heads * d_k)
+111 self.value = nn.Linear(d_model, n_heads * d_k)112 self.norm = nn.LayerNorm(d_model)
114 self.norm = nn.LayerNorm(d_model)115 self.softmax = nn.Softmax(dim=-1)
117 self.softmax = nn.Softmax(dim=-1)118 self.rotary_pe = RotaryPositionalEmbeddings(self.d_k)
120 self.rotary_pe = RotaryPositionalEmbeddings(self.d_k)121 self.output = nn.Linear(n_heads * d_k, d_model)
123 self.output = nn.Linear(n_heads * d_k, d_model)123 def mask_attention(self, attn: torch.Tensor):
125 def mask_attention(self, attn: torch.Tensor):131 if not self.is_causal: -132 return attn
133 if not self.is_causal:
+134 return attn135 mask = torch.tril(attn.new_ones(attn.shape[-2:]))
137 mask = torch.tril(attn.new_ones(attn.shape[-2:]))137 return attn.masked_fill(mask == 0, float('-inf'))
139 return attn.masked_fill(mask == 0, float('-inf'))139 def forward(self, h: torch.Tensor):
141 def forward(self, h: torch.Tensor):145 h_res = h
147 h_res = h148 h = self.norm(h)
150 h = self.norm(h)152 mh_shape = (*h.shape[:-1], self.n_heads, self.d_k) -153 q = self.query(h).view(mh_shape) -154 k = self.key(h).view(mh_shape) -155 v = self.value(h).view(mh_shape)
154 mh_shape = (*h.shape[:-1], self.n_heads, self.d_k)
+155 q = self.query(h).view(mh_shape)
+156 k = self.key(h).view(mh_shape)
+157 v = self.value(h).view(mh_shape)158 q = self.rotary_pe(q) -159 k = self.rotary_pe(k)
160 q = self.rotary_pe(q)
+161 k = self.rotary_pe(k)162 attn = torch.einsum('bihd,bjhd->bhij', q, k)
164 attn = torch.einsum('bihd,bjhd->bhij', q, k)164 attn = attn * self.scale
166 attn = attn * self.scale167 attn = self.mask_attention(attn)
169 attn = self.mask_attention(attn)170 attn = self.softmax(attn)
172 attn = self.softmax(attn)173 h = torch.einsum("bhij,bjhd->bihd", attn, v)
175 h = torch.einsum("bhij,bjhd->bihd", attn, v)177 h = h.reshape(*h.shape[:-2], -1)
179 h = h.reshape(*h.shape[:-2], -1)181 h = self.output(h)
183 h = self.output(h)184 return h + h_res
186 return h + h_res187class CrossAttention(nn.Module):
189class CrossAttention(nn.Module):201 def __init__(self, d_model: int, n_heads: int, d_k: int):
203 def __init__(self, d_model: int, n_heads: int, d_k: int):207 super().__init__() -208 -209 self.n_heads = n_heads -210 self.d_k = d_k
209 super().__init__()
+210
+211 self.n_heads = n_heads
+212 self.d_k = d_k213 self.scale = 1 / math.sqrt(self.d_k)
215 self.scale = 1 / math.sqrt(self.d_k)216 self.query = nn.Linear(d_model, n_heads * d_k) -217 self.key = nn.Linear(d_model, n_heads * d_k) -218 self.value = nn.Linear(d_model, n_heads * d_k)
218 self.query = nn.Linear(d_model, n_heads * d_k)
+219 self.key = nn.Linear(d_model, n_heads * d_k)
+220 self.value = nn.Linear(d_model, n_heads * d_k)221 self.norm = nn.LayerNorm(d_model)
223 self.norm = nn.LayerNorm(d_model)224 self.softmax = nn.Softmax(dim=-1)
226 self.softmax = nn.Softmax(dim=-1)227 self.output = nn.Linear(n_heads * d_k, d_model)
229 self.output = nn.Linear(n_heads * d_k, d_model)229 def forward(self, e: torch.Tensor, h: torch.Tensor):
231 def forward(self, e: torch.Tensor, h: torch.Tensor):238 e_res = e
240 e_res = e241 e = self.norm(e)
243 e = self.norm(e)244 q = self.query(e).view(*e.shape[:-1], self.n_heads, self.d_k)
246 q = self.query(e).view(*e.shape[:-1], self.n_heads, self.d_k)246 k = self.key(h).view(*h.shape[:-1], self.n_heads, self.d_k) -247 v = self.value(h).view(*h.shape[:-1], self.n_heads, self.d_k)
248 k = self.key(h).view(*h.shape[:-1], self.n_heads, self.d_k)
+249 v = self.value(h).view(*h.shape[:-1], self.n_heads, self.d_k)252 attn = torch.einsum('bcnihd,bcjhd->bcnhij', q, k)
254 attn = torch.einsum('bcnihd,bcjhd->bcnhij', q, k)254 attn = attn * self.scale
256 attn = attn * self.scale257 attn = self.softmax(attn)
259 attn = self.softmax(attn)260 e = torch.einsum("bcnhij,bcjhd->bcnihd", attn, v)
262 e = torch.einsum("bcnhij,bcjhd->bcnihd", attn, v)264 e = e.reshape(*e.shape[:-2], -1)
266 e = e.reshape(*e.shape[:-2], -1)268 e = self.output(e)
270 e = self.output(e)271 return e + e_res
273 return e + e_res274class ChunkedCrossAttention(nn.Module):
276class ChunkedCrossAttention(nn.Module):286 def __init__(self, d_model: int, n_heads: int, d_k: int, chunk_len: int):
288 def __init__(self, d_model: int, n_heads: int, d_k: int, chunk_len: int):294 super().__init__() -295 -296 self.chunk_len = chunk_len -297 self.n_heads = n_heads -298 self.d_k = d_k
296 super().__init__()
+297
+298 self.chunk_len = chunk_len
+299 self.n_heads = n_heads
+300 self.d_k = d_k301 self.scale = 1 / math.sqrt(self.d_k)
303 self.scale = 1 / math.sqrt(self.d_k)304 self.query = nn.Linear(d_model, n_heads * d_k) -305 self.key = nn.Linear(d_model, n_heads * d_k) -306 self.value = nn.Linear(d_model, n_heads * d_k)
306 self.query = nn.Linear(d_model, n_heads * d_k)
+307 self.key = nn.Linear(d_model, n_heads * d_k)
+308 self.value = nn.Linear(d_model, n_heads * d_k)309 self.norm = nn.LayerNorm(d_model)
311 self.norm = nn.LayerNorm(d_model)312 self.softmax = nn.Softmax(dim=-1)
314 self.softmax = nn.Softmax(dim=-1)315 self.output = nn.Linear(n_heads * d_k, d_model)
317 self.output = nn.Linear(n_heads * d_k, d_model)317 def forward(self, h: torch.Tensor, e: torch.Tensor):
319 def forward(self, h: torch.Tensor, e: torch.Tensor):324 batch_size, chunks, neighbors, neighbor_len, d_model = e.shape
326 batch_size, chunks, neighbors, neighbor_len, d_model = e.shape327 if chunks == 0: -328 return h
329 if chunks == 0:
+330 return h331 h_res = h
333 h_res = h339 h = h[:, self.chunk_len - 1:]
341 h = h[:, self.chunk_len - 1:]341 h = self.norm(h)
343 h = self.norm(h)343 if h.shape[1] < chunks * self.chunk_len: -344 h = torch.cat((h, h.new_zeros(batch_size, chunks * self.chunk_len - h.shape[1], d_model)), dim=1)
345 if h.shape[1] < chunks * self.chunk_len:
+346 h = torch.cat((h, h.new_zeros(batch_size, chunks * self.chunk_len - h.shape[1], d_model)), dim=1)346 h = h.reshape(batch_size, chunks, self.chunk_len, d_model)
348 h = h.reshape(batch_size, chunks, self.chunk_len, d_model)349 q = self.query(h).view(*h.shape[:-1], self.n_heads, self.d_k)
351 q = self.query(h).view(*h.shape[:-1], self.n_heads, self.d_k)351 k = self.key(e).view(*e.shape[:-1], self.n_heads, self.d_k) -352 v = self.value(e).view(*e.shape[:-1], self.n_heads, self.d_k)
353 k = self.key(e).view(*e.shape[:-1], self.n_heads, self.d_k)
+354 v = self.value(e).view(*e.shape[:-1], self.n_heads, self.d_k)357 attn = torch.einsum('bcihd,bcnjhd->bchinj', q, k)
359 attn = torch.einsum('bcihd,bcnjhd->bchinj', q, k)359 attn = attn * self.scale
361 attn = attn * self.scale362 attn = self.softmax(attn.view(*attn.shape[:-2], -1)).view(attn.shape)
364 attn = self.softmax(attn.view(*attn.shape[:-2], -1)).view(attn.shape)365 h = torch.einsum("bchinj,bcnjhd->bcihd", attn, v)
367 h = torch.einsum("bchinj,bcnjhd->bcihd", attn, v)369 h = h.reshape(batch_size, chunks * self.chunk_len, -1)
371 h = h.reshape(batch_size, chunks * self.chunk_len, -1)373 h = self.output(h)
375 h = self.output(h)376 h = torch.cat((h.new_zeros(batch_size, self.chunk_len - 1, d_model), h), dim=1)
378 h = torch.cat((h.new_zeros(batch_size, self.chunk_len - 1, d_model), h), dim=1)379 return h[:, :h_res.shape[1]] + h_res
381 return h[:, :h_res.shape[1]] + h_res382class FeedForward(nn.Module):
384class FeedForward(nn.Module):389 def __init__(self, d_model: int, d_ff: int):
391 def __init__(self, d_model: int, d_ff: int):395 super().__init__()
397 super().__init__()398 self.lin1 = nn.Linear(d_model, d_ff) -399 self.lin2 = nn.Linear(d_ff, d_model)
400 self.lin1 = nn.Linear(d_model, d_ff)
+401 self.lin2 = nn.Linear(d_ff, d_model)402 self.act = nn.ReLU()
404 self.act = nn.ReLU()405 self.norm = nn.LayerNorm(d_model)
407 self.norm = nn.LayerNorm(d_model)407 def forward(self, h: torch.Tensor):
409 def forward(self, h: torch.Tensor):413 h_res = h
415 h_res = h415 h = self.norm(h)
417 h = self.norm(h)417 h = self.lin1(h)
419 h = self.lin1(h)419 h = self.act(h)
421 h = self.act(h)421 h = self.lin2(h)
423 h = self.lin2(h)424 return h + h_res
426 return h + h_res427class NearestNeighborEncoder(nn.Module):
429class NearestNeighborEncoder(nn.Module):434 def __init__(self, chunk_len: int, n_layers: int, ca_layers: Set[int], -435 d_model: int, n_heads: int, d_k: int, d_ff: int):
436 def __init__(self, chunk_len: int, n_layers: int, ca_layers: Set[int],
+437 d_model: int, n_heads: int, d_k: int, d_ff: int):446 super().__init__() -447 self.ca_layers = ca_layers -448 self.chunk_len = chunk_len
448 super().__init__()
+449 self.ca_layers = ca_layers
+450 self.chunk_len = chunk_len450 self.ca = nn.ModuleList([CrossAttention(d_model, n_heads, d_k) for _ in range(len(ca_layers))])
452 self.ca = nn.ModuleList([CrossAttention(d_model, n_heads, d_k) for _ in range(len(ca_layers))])452 self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=False) for _ in range(n_layers)])
454 self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=False) for _ in range(n_layers)])454 self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])
456 self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])457 self.norm_h = nn.LayerNorm(d_model)
459 self.norm_h = nn.LayerNorm(d_model)459 def forward(self, e: torch.Tensor, h: torch.Tensor):
461 def forward(self, e: torch.Tensor, h: torch.Tensor):472 batch_size, chunks, neighbors, neighbor_len, d_model = e.shape
474 batch_size, chunks, neighbors, neighbor_len, d_model = e.shape475 h_split = h[:, :self.chunk_len * chunks, :].reshape(batch_size, chunks, self.chunk_len, d_model)
477 h_split = h[:, :self.chunk_len * chunks, :].reshape(batch_size, chunks, self.chunk_len, d_model)478 h_split = self.norm_h(h_split)
480 h_split = self.norm_h(h_split)481 p_ca = 0
483 p_ca = 0483 for p in range(len(self.attn)):
485 for p in range(len(self.attn)):486 e = self.attn[p](e.view(-1, neighbor_len, d_model)).view(e.shape)
488 e = self.attn[p](e.view(-1, neighbor_len, d_model)).view(e.shape)489 if p in self.ca_layers:
491 if p in self.ca_layers:491 e = self.ca[p_ca](e, h_split)
493 e = self.ca[p_ca](e, h_split)493 p_ca += 1
495 p_ca += 1496 e = self.ffw[p](e)
498 e = self.ffw[p](e)499 return e
501 return e502class RetroModel(nn.Module):
504class RetroModel(nn.Module):509 def __init__(self, n_vocab: int, d_model: int, n_layers: int, ca_layers: Set[int], chunk_len: int, -510 n_heads: int, d_k: int, d_ff: int, encoder: NearestNeighborEncoder):
511 def __init__(self, n_vocab: int, d_model: int, n_layers: int, ca_layers: Set[int], chunk_len: int,
+512 n_heads: int, d_k: int, d_ff: int, encoder: NearestNeighborEncoder):522 super().__init__() -523 -524 self.ca_layers = ca_layers -525 self.encoder = encoder
524 super().__init__()
+525
+526 self.ca_layers = ca_layers
+527 self.encoder = encoder528 self.emb = nn.Embedding(n_vocab, d_model)
530 self.emb = nn.Embedding(n_vocab, d_model)530 self.cca = nn.ModuleList( -531 [ChunkedCrossAttention(d_model, n_heads, d_k, chunk_len) for _ in range(len(ca_layers))])
532 self.cca = nn.ModuleList(
+533 [ChunkedCrossAttention(d_model, n_heads, d_k, chunk_len) for _ in range(len(ca_layers))])533 self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=True) for _ in range(n_layers)])
535 self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=True) for _ in range(n_layers)])535 self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])
537 self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])537 self.read = nn.Linear(d_model, n_vocab)
539 self.read = nn.Linear(d_model, n_vocab)541 self.norm_e = nn.LayerNorm(d_model)
543 self.norm_e = nn.LayerNorm(d_model)543 def forward(self, x: torch.Tensor, ret: torch.Tensor):
545 def forward(self, x: torch.Tensor, ret: torch.Tensor):552 h = self.emb(x)
554 h = self.emb(x)558 ret_emb = self.emb(ret)
560 ret_emb = self.emb(ret)561 p_ca = 0
563 p_ca = 0563 for p in range(len(self.attn)):
565 for p in range(len(self.attn)):565 h = self.attn[p](h)
567 h = self.attn[p](h)569 if self.ca_layers and p == min(self.ca_layers):
571 if self.ca_layers and p == min(self.ca_layers):573 e = self.encoder(ret_emb, h)
575 e = self.encoder(ret_emb, h)575 e = self.norm_e(e)
577 e = self.norm_e(e)578 if p in self.ca_layers:
580 if p in self.ca_layers:580 h = self.cca[p_ca](h, e)
582 h = self.cca[p_ca](h, e)582 p_ca += 1
584 p_ca += 1585 h = self.ffw[p](h)
587 h = self.ffw[p](h)588 return self.read(h)
590 return self.read(h)591def _test():
593def _test():595 chunk_len = 4 -596 d_model = 8 -597 d_ff = 32 -598 n_heads = 2 -599 d_k = 4 -600 -601 device = torch.device('cuda:0') ++603 device = torch.device('cuda:0') +604 +605 m = RetroModel(5, d_model, 6, {2, 5}, chunk_len, n_heads, d_k, d_ff, +606 encoder=NearestNeighborEncoder(chunk_len, 2, {1}, d_model, n_heads, d_k, d_ff)) +607 +608 m.to(device) +609 x = [1, 2, 4, 4, 0, 1, 2, 3, 4, 3] +610 ret = [ +611 [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]], +612 [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]], +613 ] +614 res = m(torch.tensor([x] * 10).to(device), torch.tensor([ret] * 10).to(device)) +615 +616 inspect(res)597 chunk_len = 4 +598 d_model = 8 +599 d_ff = 32 +600 n_heads = 2 +601 d_k = 4 602 -603 m = RetroModel(5, d_model, 6, {2, 5}, chunk_len, n_heads, d_k, d_ff, -604 encoder=NearestNeighborEncoder(chunk_len, 2, {1}, d_model, n_heads, d_k, d_ff)) -605 -606 m.to(device) -607 x = [1, 2, 4, 4, 0, 1, 2, 3, 4, 3] -608 ret = [ -609 [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]], -610 [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]], -611 ] -612 res = m(torch.tensor([x] * 10).to(device), torch.tensor([ret] * 10).to(device)) -613 -614 inspect(res)
618if __name__ == '__main__': -619 _test()
620if __name__ == '__main__':
+621 _test()This is the training code for RETRO.
+14import torch
-15from torch import nn
-16from torch.utils.data import DataLoader, RandomSampler
-17
-18from labml import monit, lab, tracker, experiment, logger
-19from labml.logger import Text
-20from labml_helpers.datasets.text import TextFileDataset
-21from labml_nn.optimizers.noam import Noam
-22from labml_nn.transformers.retro import model as retro
-23from labml_nn.transformers.retro.dataset import Dataset, RetroIndex
-24from labml_nn.transformers.retro.model import RetroModel, NearestNeighborEncoder16import torch
+17from torch import nn
+18from torch.utils.data import DataLoader, RandomSampler
+19
+20from labml import monit, lab, tracker, experiment, logger
+21from labml.logger import Text
+22from labml_helpers.datasets.text import TextFileDataset
+23from labml_nn.optimizers.noam import Noam
+24from labml_nn.transformers.retro import model as retro
+25from labml_nn.transformers.retro.dataset import Dataset, RetroIndex
+26from labml_nn.transformers.retro.model import RetroModel, NearestNeighborEncoder27class Sampler:29class Sampler:34 def __init__(self, device: torch.device, model: retro.RetroModel, tds: TextFileDataset, chunk_len: int):36 def __init__(self, device: torch.device, model: retro.RetroModel, tds: TextFileDataset, chunk_len: int):41 self.chunk_len = chunk_len
-42 self.tds = tds
-43 self.model = model
-44 self.device = device43 self.chunk_len = chunk_len
+44 self.tds = tds
+45 self.model = model
+46 self.device = device47 self.index = RetroIndex()49 self.index = RetroIndex()49 def retrieve_nearest_neighbours(self, chunk: str):51 def retrieve_nearest_neighbours(self, chunk: str):55 neighbor_offsets = self.index([chunk], None)57 neighbor_offsets = self.index([chunk], None)58 text = self.tds.train
-59 neighbors = [text[j: j + self.chunk_len * 2] for j in neighbor_offsets[0]]60 text = self.tds.train
+61 neighbors = [text[j: j + self.chunk_len * 2] for j in neighbor_offsets[0]]62 return neighbors64 return neighbors64 def sample(self, prompt: str, sample_len: int):66 def sample(self, prompt: str, sample_len: int):70 neighbors_str = []72 neighbors_str = []73 sampled = ''75 sampled = ''76 for i in range(sample_len):78 for i in range(sample_len):79 while len(neighbors_str) < len(prompt) // self.chunk_len:81 while len(neighbors_str) < len(prompt) // self.chunk_len:81 off = len(neighbors_str) * self.chunk_len
-82 chunk = prompt[off: off + self.chunk_len]83 off = len(neighbors_str) * self.chunk_len
+84 chunk = prompt[off: off + self.chunk_len]84 neighbors_str.append(self.retrieve_nearest_neighbours(chunk))86 neighbors_str.append(self.retrieve_nearest_neighbours(chunk))87 src = self.tds.text_to_i(prompt)89 src = self.tds.text_to_i(prompt)89 neighbors = torch.stack([torch.stack([self.tds.text_to_i(n) for n in chunk]) for chunk in neighbors_str])91 neighbors = torch.stack([torch.stack([self.tds.text_to_i(n) for n in chunk]) for chunk in neighbors_str])92 src = src.to(self.device)
-93 neighbors = neighbors.to(self.device)94 src = src.to(self.device)
+95 neighbors = neighbors.to(self.device)96 res = self.model(src[None, :], neighbors[None, :, :, :])98 res = self.model(src[None, :], neighbors[None, :, :, :])99 token = res[0, -1, :].argmax(dim=-1)101 token = res[0, -1, :].argmax(dim=-1)102 prompt += self.tds.itos[token.item()]
-103 sampled += self.tds.itos[token.item()]104 prompt += self.tds.itos[token.item()]
+105 sampled += self.tds.itos[token.item()]106 return sampled108 return sampled109class Trainer:111class Trainer:114 def __init__(self, device: torch.device, model: retro.RetroModel,
-115 dataloader: DataLoader, optimizer: torch.optim.Optimizer):116 def __init__(self, device: torch.device, model: retro.RetroModel,
+117 dataloader: DataLoader, optimizer: torch.optim.Optimizer):122 self.optimizer = optimizer
-123 self.device = device
-124 self.dataloader = dataloader
-125 self.model = model
-126 self.loss_func = nn.CrossEntropyLoss()124 self.optimizer = optimizer
+125 self.device = device
+126 self.dataloader = dataloader
+127 self.model = model
+128 self.loss_func = nn.CrossEntropyLoss()128 def __call__(self):130 def __call__(self):134 for i, (src, tgt, neighbors) in monit.enum('Train', self.dataloader):136 for i, (src, tgt, neighbors) in monit.enum('Train', self.dataloader):136 src, tgt, neighbors = src.to(self.device), tgt.to(self.device), neighbors.to(self.device)138 src, tgt, neighbors = src.to(self.device), tgt.to(self.device), neighbors.to(self.device)139 res = self.model(src, neighbors)141 res = self.model(src, neighbors)141 loss = self.loss_func(res.view(-1, res.shape[-1]), tgt.view(-1))143 loss = self.loss_func(res.view(-1, res.shape[-1]), tgt.view(-1))144 self.optimizer.zero_grad()146 self.optimizer.zero_grad()146 loss.backward()148 loss.backward()148 self.optimizer.step()150 self.optimizer.step()151 tracker.save({'loss.train': loss})
-152 tracker.add_global_step(len(src))153 tracker.save({'loss.train': loss})
+154 tracker.add_global_step(len(src))155def train():157def train():161 experiment.create(name='retro_small')163 experiment.create(name='retro_small')164 device = torch.device('cuda:0')166 device = torch.device('cuda:0')167 tds = TextFileDataset(
-168 lab.get_data_path() / 'tiny_shakespeare.txt',
-169 list,
-170 url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')169 tds = TextFileDataset(
+170 lab.get_data_path() / 'tiny_shakespeare.txt',
+171 list,
+172 url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')173 train_dataset = Dataset(lab.get_data_path() / 'retro_train_dataset.json', tds)175 train_dataset = Dataset(lab.get_data_path() / 'retro_train_dataset.json', tds)176 train_dl = DataLoader(train_dataset,
-177 batch_size=4,
-178 sampler=RandomSampler(train_dataset, replacement=True))178 train_dl = DataLoader(train_dataset,
+179 batch_size=4,
+180 sampler=RandomSampler(train_dataset, replacement=True))181 chunk_len = 16
-182 d_model = 128
-183 d_ff = 512
-184 n_heads = 16
-185 d_k = 16183 chunk_len = 16
+184 d_model = 128
+185 d_ff = 512
+186 n_heads = 16
+187 d_k = 16188 nearest_neighbor_encoder = NearestNeighborEncoder(chunk_len, 6, {3}, d_model, n_heads, d_k, d_ff)190 nearest_neighbor_encoder = NearestNeighborEncoder(chunk_len, 6, {3}, d_model, n_heads, d_k, d_ff)190 model = RetroModel(tds.n_tokens, d_model, 6,
-191 {3, 5},
-192 chunk_len, n_heads, d_k, d_ff,
-193 encoder=nearest_neighbor_encoder)192 model = RetroModel(tds.n_tokens, d_model, 6,
+193 {3, 5},
+194 chunk_len, n_heads, d_k, d_ff,
+195 encoder=nearest_neighbor_encoder)195 model = model.to(device)197 model = model.to(device)197 optimizer = Noam(model.parameters(), lr=1., d_model=d_model, warmup=2_000)199 optimizer = Noam(model.parameters(), lr=1., d_model=d_model, warmup=2_000)199 trainer = Trainer(device, model, train_dl, optimizer)201 trainer = Trainer(device, model, train_dl, optimizer)201 sampler = Sampler(device, model, tds, chunk_len)203 sampler = Sampler(device, model, tds, chunk_len)203 prompt = '''Second Citizen:\nOne word, good citizens.\n\nFirst Citizen:'''205 prompt = '''Second Citizen:\nOne word, good citizens.\n\nFirst Citizen:'''206 experiment.add_pytorch_models(model=model)208 experiment.add_pytorch_models(model=model)209 with experiment.start():211 with experiment.start():211 for epoch in monit.loop(32):213 for epoch in monit.loop(32):213 trainer()215 trainer()215 tracker.new_line()217 tracker.new_line()217 logger.log([(prompt.replace('\n', '\\n\n'), Text.subtle),
-218 (sampler.sample(prompt, 128).replace('\n', '\\n\n'), Text.none)])219 logger.log([(prompt.replace('\n', '\\n\n'), Text.subtle),
+220 (sampler.sample(prompt, 128).replace('\n', '\\n\n'), Text.none)])220 experiment.save_checkpoint()222 experiment.save_checkpoint()224if __name__ == '__main__':
-225 train()226if __name__ == '__main__':
+227 train()