This is the build the database and retrieves nearest neighbors for RETRO model.
We use FAISS library for the database whilst the paper had used the SCaNN library.
16from typing import List, Optional
17
18import faiss
19import numpy as np
20import torch
21
22from labml import lab, monit
23from labml_helpers.datasets.text import TextFileDataset
24from labml_nn.transformers.retro.bert_embeddings import BERTChunkEmbeddingschunk_len
 is the length of a chunk (number of characters) batch_size
 is the batch size to use when calculating  d_emb
 is the number of features in  embeddings  lists to select in FAISS index n_centeroids
 is the number of lists in the index code_size
 encoded vector size in the index n_probe
 is the number of lists to probe 27def build_database(chunk_len: int = 16, batch_size: int = 64, d_emb: int = 768, n_centeroids: int = 256,
28                   code_size: int = 64, n_probe: int = 8, n_train: int = 50_000):Load the dataset text file
43    dataset = TextFileDataset(
44        lab.get_data_path() / 'tiny_shakespeare.txt',
45        list,
46        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')Get training data (a string)
49    text = dataset.trainSplit the text into chunks of chunk_length
 
52    chunks = [text[i:i + chunk_len] for i in range(0, len(text), chunk_len) if i + chunk_len * 2 < len(text)]Get the offsets of each of the chunks
54    chunk_offsets = np.array([i for i in range(0, len(text), chunk_len) if i + chunk_len * 2 < len(text)])Number of chunks
56    n_chunks = len(chunks)Initialize BERT to get
59    bert = BERTChunkEmbeddings(torch.device('cuda:0'))Get chunk embeddings by processing batch_size
 number of chunks on each iteration 
62    chunk_emb = []
63    for i in monit.iterate('Get embeddings', range(0, n_chunks, batch_size)):
64        chunk_emb.append(bert(chunks[i: i + batch_size]).cpu())Merge them into a single tensor
66    chunk_emb = torch.cat(chunk_emb, dim=0).numpy()Create the FAISS index
69    quantizer = faiss.IndexFlatL2(d_emb)
70    index = faiss.IndexIVFPQ(quantizer, d_emb, n_centeroids, code_size, 8)
71    index.nprobe = n_probeGet a random sample of the the chunk indexes
74    random_sample = np.random.choice(np.arange(n_chunks), size=[min(n_train, n_chunks)], replace=False)Train the index to store the keys
77    with monit.section('Train index'):
78        index.train(chunk_emb[random_sample])Add the chunks to the index in batches of size 1024
 
81    for s in monit.iterate('Index', range(0, n_chunks, 1024)):
82        e = min(s + 1024, n_chunks)Add to index
84        index.add_with_ids(chunk_emb[s:e], chunk_offsets[s: e])Save the index
87    with monit.section('Save'):
88        faiss.write_index(index, str(lab.get_data_path() / 'retro.index'))91class RetroIndex:chunk_len
 is the chunk length n_probe
 is the number of lists to probe n_neighbors
 is the number of neighbors to retrieve n_extra
 is the number of extra neighbors to retrieve since we will be  removing neighbors overlapping with the query chunk exclude_neighbor_span
 is the extra text length to avoid when checking for overlaps96    def __init__(self, chunk_len: int = 16, n_probe: int = 8,
97                 n_neighbors: int = 2, n_extra: int = 2,
98                 exclude_neighbor_span: int = 8):108        self.n_neighbors = n_neighbors
109        self.chunk_len = chunk_len
110        self.exclude_neighbor_span = exclude_neighbor_span
111        self.n_extra = n_extraInitialize BERT to get
114        self.bert = BERTChunkEmbeddings(torch.device('cuda:0'))Load the database
116        with monit.section('Load index'):
117            self.index = faiss.read_index(str(lab.get_data_path() / 'retro.index'))
118            self.index.nprobe = n_probeThe positions of the neighbors are given by neighbor_offsets
 and the position of the query chunk is offset
.
120    def filter_neighbors(self, offset: int, neighbor_offsets: List[int]):127        return [n for n in neighbor_offsets
128                if n < offset - (self.chunk_len + self.exclude_neighbor_span)
129                or n > offset + (self.chunk_len + self.exclude_neighbor_span)]131    def __call__(self, query_chunks: List[str], offsets: Optional[List[int]]):Get of query chunks
137        emb = self.bert(query_chunks).cpu()Get n_neighbors + n_extra
 nearest neighbors from the database 
140        distance, neighbor_offsets = self.index.search(emb.numpy(), self.n_neighbors + self.n_extra)If the query chunk offsets are given filter out overlapping chunks
143        if offsets is not None:
144            neighbor_offsets = [self.filter_neighbors(off, n_off)
145                                for off, n_off in zip(offsets, neighbor_offsets)]Get the closest n_neighbors
 after filtering 
148        neighbor_offsets = [n_off[:self.n_neighbors] for n_off in neighbor_offsets]151        return neighbor_offsets155if __name__ == '__main__':
156    build_database()