mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-31 02:39:16 +08:00
comments fixes
This commit is contained in:
@ -23,7 +23,7 @@ def knn(queries: torch.Tensor, index: faiss.IndexFlatL2, keys_store: np.ndarray,
|
||||
"""
|
||||
## $k$-NN to get $p(w_t, c_t)$
|
||||
|
||||
Here we refer to $f($\color{yellowgreen}{c_t})$ as queries,
|
||||
Here we refer to $f(\color{yellowgreen}{c_t})$ as queries,
|
||||
$f(c_i)$ as keys and $w_i$ as values.
|
||||
"""
|
||||
|
||||
@ -33,7 +33,7 @@ def knn(queries: torch.Tensor, index: faiss.IndexFlatL2, keys_store: np.ndarray,
|
||||
# Flatten the `batch` and `sequence` dimensions of queries
|
||||
queries = queries.view(-1, queries_shape[-1])
|
||||
|
||||
# Find 10 nearest neighbors of $f($\color{yellowgreen}{c_t})$ among $f(c_i)$.
|
||||
# Find 10 nearest neighbors of $f(\color{yellowgreen}{c_t})$ among $f(c_i)$.
|
||||
# `distance` is the distance given by FAISS and `idx`, $i$ is the index of it in `keys_store`.
|
||||
distance, idx = index.search(queries.numpy(), 10)
|
||||
|
||||
|
||||
@ -119,6 +119,8 @@ class Configs(NLPAutoRegressionConfigs):
|
||||
route_prob = route_prob / total
|
||||
# Load balancing loss
|
||||
# $$\mathscr{L} = N \sum_{i=1}^N f_i \cdot P_i$$
|
||||
# $\mathscr{L}$ is the loss for a single layer and here we are
|
||||
# taking the sum of losses across all layers.
|
||||
load_balancing_loss = self.n_experts * (route_frac * route_prob).sum()
|
||||
|
||||
# Track stats
|
||||
|
||||
Reference in New Issue
Block a user