diff --git a/docs/transformers/retro/model.html b/docs/transformers/retro/model.html index f4c3486d..a22b4a53 100644 --- a/docs/transformers/retro/model.html +++ b/docs/transformers/retro/model.html @@ -1809,11 +1809,12 @@ M834 80h400000v40h-400000z">. We use same embeddings for both input and neighbors
+Embeddings of the retrieved neighbors .
+We use same embeddings for both input and neighbors
557 ret_emb = self.emb(ret)
558 ret_emb = self.emb(ret)
560 p_ca = 0
561 p_ca = 0
562 for p in range(len(self.attn)):
563 for p in range(len(self.attn)):
564 h = self.attn[p](h)
565 h = self.attn[p](h)
568 if self.ca_layers and p == min(self.ca_layers):
569 if self.ca_layers and p == min(self.ca_layers):
+
We passed the embeddings of to encoder.
571 e = self.encoder(ret_emb, h)
573 e = self.encoder(ret_emb, h)
573 e = self.norm_e(e)
575 e = self.norm_e(e)
576 if p in self.ca_layers:
578 if p in self.ca_layers:
578 h = self.cca[p_ca](h, e)
580 h = self.cca[p_ca](h, e)
580 p_ca += 1
582 p_ca += 1
583 h = self.ffw[p](h)
585 h = self.ffw[p](h)
586 return self.read(h)
588 return self.read(h)
589def _test():
591def _test():
593 chunk_len = 4 -594 d_model = 8 -595 d_ff = 32 -596 n_heads = 2 -597 d_k = 4 -598 -599 device = torch.device('cuda:0') ++601 device = torch.device('cuda:0') +602 +603 m = RetroModel(5, d_model, 6, {2, 5}, chunk_len, n_heads, d_k, d_ff, +604 encoder=NearestNeighborEncoder(chunk_len, 2, {1}, d_model, n_heads, d_k, d_ff)) +605 +606 m.to(device) +607 x = [1, 2, 4, 4, 0, 1, 2, 3, 4, 3] +608 ret = [ +609 [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]], +610 [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]], +611 ] +612 res = m(torch.tensor([x] * 10).to(device), torch.tensor([ret] * 10).to(device)) +613 +614 inspect(res)595 chunk_len = 4 +596 d_model = 8 +597 d_ff = 32 +598 n_heads = 2 +599 d_k = 4 600 -601 m = RetroModel(5, d_model, 6, {2, 5}, chunk_len, n_heads, d_k, d_ff, -602 encoder=NearestNeighborEncoder(chunk_len, 2, {1}, d_model, n_heads, d_k, d_ff)) -603 -604 m.to(device) -605 x = [1, 2, 4, 4, 0, 1, 2, 3, 4, 3] -606 ret = [ -607 [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]], -608 [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]], -609 ] -610 res = m(torch.tensor([x] * 10).to(device), torch.tensor([ret] * 10).to(device)) -611 -612 inspect(res)
616if __name__ == '__main__': -617 _test()
618if __name__ == '__main__':
+619 _test()