formatting

This commit is contained in:
Varuna Jayasiri
2021-02-17 11:43:59 +05:30
parent 7f2e4dff07
commit 9a9d6c671d

View File

@ -91,8 +91,8 @@ class MultiHeadAttention(Module):
self.heads = heads
# These transform the `query`, `key` and `value` vectors for multi-headed attention.
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
# Softmax for attention along the time dimension of `key`
@ -119,10 +119,10 @@ class MultiHeadAttention(Module):
return torch.einsum('ibhd,jbhd->ijbh', query, key)
def forward(self, *,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None):
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None):
"""
`query`, `key` and `value` are the tensors that store
collection of *query*, *key* and *value* vectors.