From 9a9d6c671df73becb6a1b23f623ced98fac2db79 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Wed, 17 Feb 2021 11:43:59 +0530 Subject: [PATCH] formatting --- labml_nn/transformers/mha.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/labml_nn/transformers/mha.py b/labml_nn/transformers/mha.py index 7552d342..8af94d31 100644 --- a/labml_nn/transformers/mha.py +++ b/labml_nn/transformers/mha.py @@ -91,8 +91,8 @@ class MultiHeadAttention(Module): self.heads = heads # These transform the `query`, `key` and `value` vectors for multi-headed attention. - self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias) - self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias) + self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias) + self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias) self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True) # Softmax for attention along the time dimension of `key` @@ -119,10 +119,10 @@ class MultiHeadAttention(Module): return torch.einsum('ibhd,jbhd->ijbh', query, key) def forward(self, *, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: Optional[torch.Tensor] = None): + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: Optional[torch.Tensor] = None): """ `query`, `key` and `value` are the tensors that store collection of *query*, *key* and *value* vectors.