comments fixes

This commit is contained in:
Varuna Jayasiri
2021-09-06 13:27:47 +05:30
parent 103cf81a13
commit c72d3b4f83
5 changed files with 92 additions and 75 deletions

View File

@ -23,7 +23,7 @@ def knn(queries: torch.Tensor, index: faiss.IndexFlatL2, keys_store: np.ndarray,
"""
## $k$-NN to get $p(w_t, c_t)$
Here we refer to $f($\color{yellowgreen}{c_t})$ as queries,
Here we refer to $f(\color{yellowgreen}{c_t})$ as queries,
$f(c_i)$ as keys and $w_i$ as values.
"""
@ -33,7 +33,7 @@ def knn(queries: torch.Tensor, index: faiss.IndexFlatL2, keys_store: np.ndarray,
# Flatten the `batch` and `sequence` dimensions of queries
queries = queries.view(-1, queries_shape[-1])
# Find 10 nearest neighbors of $f($\color{yellowgreen}{c_t})$ among $f(c_i)$.
# Find 10 nearest neighbors of $f(\color{yellowgreen}{c_t})$ among $f(c_i)$.
# `distance` is the distance given by FAISS and `idx`, $i$ is the index of it in `keys_store`.
distance, idx = index.search(queries.numpy(), 10)

View File

@ -119,6 +119,8 @@ class Configs(NLPAutoRegressionConfigs):
route_prob = route_prob / total
# Load balancing loss
# $$\mathscr{L} = N \sum_{i=1}^N f_i \cdot P_i$$
# $\mathscr{L}$ is the loss for a single layer and here we are
# taking the sum of losses across all layers.
load_balancing_loss = self.n_experts * (route_frac * route_prob).sum()
# Track stats