mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 10:48:49 +08:00 
			
		
		
		
	formatting
This commit is contained in:
		| @ -91,8 +91,8 @@ class MultiHeadAttention(Module): | |||||||
|         self.heads = heads |         self.heads = heads | ||||||
|  |  | ||||||
|         # These transform the `query`, `key` and `value` vectors for multi-headed attention. |         # These transform the `query`, `key` and `value` vectors for multi-headed attention. | ||||||
|         self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k,  bias=bias) |         self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias) | ||||||
|         self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k,  bias=bias) |         self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias) | ||||||
|         self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True) |         self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True) | ||||||
|  |  | ||||||
|         # Softmax for attention along the time dimension of `key` |         # Softmax for attention along the time dimension of `key` | ||||||
| @ -119,10 +119,10 @@ class MultiHeadAttention(Module): | |||||||
|         return torch.einsum('ibhd,jbhd->ijbh', query, key) |         return torch.einsum('ibhd,jbhd->ijbh', query, key) | ||||||
|  |  | ||||||
|     def forward(self, *, |     def forward(self, *, | ||||||
|                  query: torch.Tensor, |                 query: torch.Tensor, | ||||||
|                  key: torch.Tensor, |                 key: torch.Tensor, | ||||||
|                  value: torch.Tensor, |                 value: torch.Tensor, | ||||||
|                  mask: Optional[torch.Tensor] = None): |                 mask: Optional[torch.Tensor] = None): | ||||||
|         """ |         """ | ||||||
|         `query`, `key` and `value` are the tensors that store |         `query`, `key` and `value` are the tensors that store | ||||||
|         collection of *query*, *key* and *value* vectors. |         collection of *query*, *key* and *value* vectors. | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri