diff --git a/labml_nn/transformers/mha.py b/labml_nn/transformers/mha.py index 7552d342..8af94d31 100644 --- a/labml_nn/transformers/mha.py +++ b/labml_nn/transformers/mha.py @@ -91,8 +91,8 @@ class MultiHeadAttention(Module): self.heads = heads # These transform the `query`, `key` and `value` vectors for multi-headed attention. - self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias) - self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias) + self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias) + self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias) self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True) # Softmax for attention along the time dimension of `key` @@ -119,10 +119,10 @@ class MultiHeadAttention(Module): return torch.einsum('ibhd,jbhd->ijbh', query, key) def forward(self, *, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: Optional[torch.Tensor] = None): + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: Optional[torch.Tensor] = None): """ `query`, `key` and `value` are the tensors that store collection of *query*, *key* and *value* vectors.