formatting

This commit is contained in:
Varuna Jayasiri
2022-03-12 15:57:09 +05:30
parent 1536c6ec5e
commit a39d91dacd
2 changed files with 40 additions and 36 deletions

File diff suppressed because one or more lines are too long

View File

@ -553,6 +553,7 @@ class RetroModel(nn.Module):
# Embeddings of the retrieved neighbors
# $E^j_u = \text{E\small{MB}}_{\text{enc}}\big(\text{R\small{ET}}(C_u)^j\big)$.
#
# We use same embeddings for both input and neighbors
ret_emb = self.emb(ret)
@ -567,6 +568,7 @@ class RetroModel(nn.Module):
# when $p = \min(P)$
if self.ca_layers and p == min(self.ca_layers):
# $E = \text{E\small{NCODER}}(\text{R\small{ET}}(C_u)_{1 \le u \le l}, H)$
#
# We passed the embeddings of $\text{R\small{ET}}(C_u)_{1 \le u \le l}$ to encoder.
e = self.encoder(ret_emb, h)
# Normalize encoder embeddings