mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-30 18:27:03 +08:00
formatting
This commit is contained in:
@ -91,8 +91,8 @@ class MultiHeadAttention(Module):
|
||||
self.heads = heads
|
||||
|
||||
# These transform the `query`, `key` and `value` vectors for multi-headed attention.
|
||||
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
|
||||
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
|
||||
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
|
||||
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
|
||||
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
|
||||
|
||||
# Softmax for attention along the time dimension of `key`
|
||||
@ -119,10 +119,10 @@ class MultiHeadAttention(Module):
|
||||
return torch.einsum('ibhd,jbhd->ijbh', query, key)
|
||||
|
||||
def forward(self, *,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None):
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
`query`, `key` and `value` are the tensors that store
|
||||
collection of *query*, *key* and *value* vectors.
|
||||
|
||||
Reference in New Issue
Block a user