mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-01 20:28:41 +08:00
formatting
This commit is contained in:
@ -91,8 +91,8 @@ class MultiHeadAttention(Module):
|
|||||||
self.heads = heads
|
self.heads = heads
|
||||||
|
|
||||||
# These transform the `query`, `key` and `value` vectors for multi-headed attention.
|
# These transform the `query`, `key` and `value` vectors for multi-headed attention.
|
||||||
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
|
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
|
||||||
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
|
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
|
||||||
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
|
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
|
||||||
|
|
||||||
# Softmax for attention along the time dimension of `key`
|
# Softmax for attention along the time dimension of `key`
|
||||||
@ -119,10 +119,10 @@ class MultiHeadAttention(Module):
|
|||||||
return torch.einsum('ibhd,jbhd->ijbh', query, key)
|
return torch.einsum('ibhd,jbhd->ijbh', query, key)
|
||||||
|
|
||||||
def forward(self, *,
|
def forward(self, *,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
mask: Optional[torch.Tensor] = None):
|
mask: Optional[torch.Tensor] = None):
|
||||||
"""
|
"""
|
||||||
`query`, `key` and `value` are the tensors that store
|
`query`, `key` and `value` are the tensors that store
|
||||||
collection of *query*, *key* and *value* vectors.
|
collection of *query*, *key* and *value* vectors.
|
||||||
|
|||||||
Reference in New Issue
Block a user