diff --git a/docs/transformers/retro/model.html b/docs/transformers/retro/model.html index f4c3486d..a22b4a53 100644 --- a/docs/transformers/retro/model.html +++ b/docs/transformers/retro/model.html @@ -1809,11 +1809,12 @@ M834 80h400000v40h-400000z">. We use same embeddings for both input and neighbors

+

Embeddings of the retrieved neighbors .

+

We use same embeddings for both input and neighbors

-
557        ret_emb = self.emb(ret)
+
558        ret_emb = self.emb(ret)
@@ -1825,7 +1826,7 @@ M834 80h400000v40h-400000z">
560        p_ca = 0
+
561        p_ca = 0
@@ -1837,7 +1838,7 @@ M834 80h400000v40h-400000z">
562        for p in range(len(self.attn)):
+
563        for p in range(len(self.attn)):
@@ -1849,7 +1850,7 @@ M834 80h400000v40h-400000z">
564            h = self.attn[p](h)
+
565            h = self.attn[p](h)
@@ -1861,7 +1862,7 @@ M834 80h400000v40h-400000z">
568            if self.ca_layers and p == min(self.ca_layers):
+
569            if self.ca_layers and p == min(self.ca_layers):
@@ -1869,11 +1870,12 @@ M834 80h400000v40h-400000z"> We passed the embeddings of to encoder.

+

+

We passed the embeddings of to encoder.

-
571                e = self.encoder(ret_emb, h)
+
573                e = self.encoder(ret_emb, h)
@@ -1885,7 +1887,7 @@ M834 80h400000v40h-400000z">
573                e = self.norm_e(e)
+
575                e = self.norm_e(e)
@@ -1897,7 +1899,7 @@ M834 80h400000v40h-400000z">
576            if p in self.ca_layers:
+
578            if p in self.ca_layers:
@@ -1909,7 +1911,7 @@ M834 80h400000v40h-400000z">
578                h = self.cca[p_ca](h, e)
+
580                h = self.cca[p_ca](h, e)
@@ -1921,7 +1923,7 @@ M834 80h400000v40h-400000z">
580                p_ca += 1
+
582                p_ca += 1
@@ -1933,7 +1935,7 @@ M834 80h400000v40h-400000z">
583            h = self.ffw[p](h)
+
585            h = self.ffw[p](h)
@@ -1945,7 +1947,7 @@ M834 80h400000v40h-400000z">
586        return self.read(h)
+
588        return self.read(h)
@@ -1957,7 +1959,7 @@ M834 80h400000v40h-400000z">
589def _test():
+
591def _test():
@@ -1968,26 +1970,26 @@ M834 80h400000v40h-400000z">
593    chunk_len = 4
-594    d_model = 8
-595    d_ff = 32
-596    n_heads = 2
-597    d_k = 4
-598
-599    device = torch.device('cuda:0')
+            
595    chunk_len = 4
+596    d_model = 8
+597    d_ff = 32
+598    n_heads = 2
+599    d_k = 4
 600
-601    m = RetroModel(5, d_model, 6, {2, 5}, chunk_len, n_heads, d_k, d_ff,
-602                   encoder=NearestNeighborEncoder(chunk_len, 2, {1}, d_model, n_heads, d_k, d_ff))
-603
-604    m.to(device)
-605    x = [1, 2, 4, 4, 0, 1, 2, 3, 4, 3]
-606    ret = [
-607        [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
-608        [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
-609    ]
-610    res = m(torch.tensor([x] * 10).to(device), torch.tensor([ret] * 10).to(device))
-611
-612    inspect(res)
+601 device = torch.device('cuda:0') +602 +603 m = RetroModel(5, d_model, 6, {2, 5}, chunk_len, n_heads, d_k, d_ff, +604 encoder=NearestNeighborEncoder(chunk_len, 2, {1}, d_model, n_heads, d_k, d_ff)) +605 +606 m.to(device) +607 x = [1, 2, 4, 4, 0, 1, 2, 3, 4, 3] +608 ret = [ +609 [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]], +610 [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]], +611 ] +612 res = m(torch.tensor([x] * 10).to(device), torch.tensor([ret] * 10).to(device)) +613 +614 inspect(res)
@@ -1999,8 +2001,8 @@ M834 80h400000v40h-400000z">
616if __name__ == '__main__':
-617    _test()
+
618if __name__ == '__main__':
+619    _test()