formatting

This commit is contained in:
Varuna Jayasiri
2021-02-17 11:43:59 +05:30
parent 7f2e4dff07
commit 9a9d6c671d

View File

@ -91,8 +91,8 @@ class MultiHeadAttention(Module):
self.heads = heads self.heads = heads
# These transform the `query`, `key` and `value` vectors for multi-headed attention. # These transform the `query`, `key` and `value` vectors for multi-headed attention.
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias) self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias) self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True) self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
# Softmax for attention along the time dimension of `key` # Softmax for attention along the time dimension of `key`
@ -119,10 +119,10 @@ class MultiHeadAttention(Module):
return torch.einsum('ibhd,jbhd->ijbh', query, key) return torch.einsum('ibhd,jbhd->ijbh', query, key)
def forward(self, *, def forward(self, *,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
mask: Optional[torch.Tensor] = None): mask: Optional[torch.Tensor] = None):
""" """
`query`, `key` and `value` are the tensors that store `query`, `key` and `value` are the tensors that store
collection of *query*, *key* and *value* vectors. collection of *query*, *key* and *value* vectors.