mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-03 05:46:16 +08:00
✍️ mha english
This commit is contained in:
@ -164,8 +164,8 @@ This is used to transform <strong>key</strong>, <strong>query</strong>, and <str
|
|||||||
<a href='#section-7'>#</a>
|
<a href='#section-7'>#</a>
|
||||||
</div>
|
</div>
|
||||||
<p>Input has shape <code>[seq_len, batch_size, d_model]</code> or <code>[batch_size, d_model]</code>.
|
<p>Input has shape <code>[seq_len, batch_size, d_model]</code> or <code>[batch_size, d_model]</code>.
|
||||||
We apply the linear transformation of the last dimension and splits that into
|
We apply the linear transformation to the last dimension and split that into
|
||||||
the heads</p>
|
the heads.</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">49</span> <span class="n">head_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">49</span> <span class="n">head_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span></pre></div>
|
||||||
@ -214,13 +214,13 @@ the heads</p>
|
|||||||
<p>
|
<p>
|
||||||
<script type="math/tex; mode=display">\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V</script>
|
<script type="math/tex; mode=display">\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V</script>
|
||||||
</p>
|
</p>
|
||||||
<p>In simple terms, it finds keys that matches the query, and get the values of
|
<p>In simple terms, it finds keys that matches the query, and gets the values of
|
||||||
those keys.</p>
|
those keys.</p>
|
||||||
<p>It uses dot-product of query and key as the indicator of how matching they are.
|
<p>It uses dot-product of query and key as the indicator of how matching they are.
|
||||||
Before taking the $softmax$ the dot-products are scaled by $\frac{1}{\sqrt{d_k}}$.
|
Before taking the $softmax$ the dot-products are scaled by $\frac{1}{\sqrt{d_k}}$.
|
||||||
This is done to avoid large dot-product values causing softmax to
|
This is done to avoid large dot-product values causing softmax to
|
||||||
give very small gradients when $d_k$ is large.</p>
|
give very small gradients when $d_k$ is large.</p>
|
||||||
<p>Softmax is calculate along the axis of of the sequence (or time).</p>
|
<p>Softmax is calculated along the axis of of the sequence (or time).</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">61</span><span class="k">class</span> <span class="nc">MultiHeadAttention</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">61</span><span class="k">class</span> <span class="nc">MultiHeadAttention</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
|
||||||
@ -335,7 +335,7 @@ give very small gradients when $d_k$ is large.</p>
|
|||||||
<div class='section-link'>
|
<div class='section-link'>
|
||||||
<a href='#section-21'>#</a>
|
<a href='#section-21'>#</a>
|
||||||
</div>
|
</div>
|
||||||
<p>We store attentions so that it can used for logging, or other computations if needed</p>
|
<p>We store attentions so that it can be used for logging, or other computations if needed</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">109</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">109</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
|
||||||
@ -370,9 +370,9 @@ give very small gradients when $d_k$ is large.</p>
|
|||||||
<a href='#section-24'>#</a>
|
<a href='#section-24'>#</a>
|
||||||
</div>
|
</div>
|
||||||
<p><code>query</code>, <code>key</code> and <code>value</code> are the tensors that store
|
<p><code>query</code>, <code>key</code> and <code>value</code> are the tensors that store
|
||||||
collection of<em>query</em>, <em>key</em> and <em>value</em> vectors.
|
collection of <em>query</em>, <em>key</em> and <em>value</em> vectors.
|
||||||
They have shape <code>[seq_len, batch_size, d_model]</code>.</p>
|
They have shape <code>[seq_len, batch_size, d_model]</code>.</p>
|
||||||
<p><code>mask</code> has shape <code>[seq_len, seq_len, batch_size]</code> and indicates
|
<p><code>mask</code> has shape <code>[seq_len, seq_len, batch_size]</code> and
|
||||||
<code>mask[i, j, b]</code> indicates whether for batch <code>b</code>,
|
<code>mask[i, j, b]</code> indicates whether for batch <code>b</code>,
|
||||||
query at position <code>i</code> has access to key-value at position <code>j</code>.</p>
|
query at position <code>i</code> has access to key-value at position <code>j</code>.</p>
|
||||||
</div>
|
</div>
|
||||||
@ -404,7 +404,7 @@ query at position <code>i</code> has access to key-value at position <code>j</co
|
|||||||
</div>
|
</div>
|
||||||
<p><code>mask</code> has shape <code>[seq_len, seq_len, batch_size]</code>,
|
<p><code>mask</code> has shape <code>[seq_len, seq_len, batch_size]</code>,
|
||||||
where first dimension is the query dimension.
|
where first dimension is the query dimension.
|
||||||
If the query dimension is equal to $1$ it will be broadcasted</p>
|
If the query dimension is equal to $1$ it will be broadcasted.</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">143</span> <span class="k">assert</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">143</span> <span class="k">assert</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div>
|
||||||
@ -426,8 +426,8 @@ If the query dimension is equal to $1$ it will be broadcasted</p>
|
|||||||
<div class='section-link'>
|
<div class='section-link'>
|
||||||
<a href='#section-28'>#</a>
|
<a href='#section-28'>#</a>
|
||||||
</div>
|
</div>
|
||||||
<p>Prepare <code>query</code>, <code>key</code> and <code>value</code> for attention computation
|
<p>Prepare <code>query</code>, <code>key</code> and <code>value</code> for attention computation.
|
||||||
These will then have shape <code>[seq_len, batch_size, heads, d_k]</code></p>
|
These will then have shape <code>[seq_len, batch_size, heads, d_k]</code>.</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">150</span> <span class="n">query</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">query</span><span class="p">(</span><span class="n">query</span><span class="p">)</span>
|
<div class="highlight"><pre><span class="lineno">150</span> <span class="n">query</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">query</span><span class="p">(</span><span class="n">query</span><span class="p">)</span>
|
||||||
@ -440,8 +440,8 @@ These will then have shape <code>[seq_len, batch_size, heads, d_k]</code></p>
|
|||||||
<div class='section-link'>
|
<div class='section-link'>
|
||||||
<a href='#section-29'>#</a>
|
<a href='#section-29'>#</a>
|
||||||
</div>
|
</div>
|
||||||
<p>Compute attention scores $Q K^\top$
|
<p>Compute attention scores $Q K^\top$.
|
||||||
Results in a tensor of shape <code>[seq_len, seq_len, batch_size, heads]</code></p>
|
This gives a tensor of shape <code>[seq_len, seq_len, batch_size, heads]</code>.</p>
|
||||||
</div>
|
</div>
|
||||||
<div class='code'>
|
<div class='code'>
|
||||||
<div class="highlight"><pre><span class="lineno">156</span> <span class="n">scores</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_scores</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">)</span></pre></div>
|
<div class="highlight"><pre><span class="lineno">156</span> <span class="n">scores</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_scores</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">)</span></pre></div>
|
||||||
|
|||||||
@ -44,8 +44,8 @@ class PrepareForMultiHeadAttention(Module):
|
|||||||
|
|
||||||
def __call__(self, x: torch.Tensor):
|
def __call__(self, x: torch.Tensor):
|
||||||
# Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`.
|
# Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`.
|
||||||
# We apply the linear transformation of the last dimension and splits that into
|
# We apply the linear transformation to the last dimension and split that into
|
||||||
# the heads
|
# the heads.
|
||||||
head_shape = x.shape[:-1]
|
head_shape = x.shape[:-1]
|
||||||
|
|
||||||
# Linear transform
|
# Linear transform
|
||||||
@ -66,7 +66,7 @@ class MultiHeadAttention(Module):
|
|||||||
|
|
||||||
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
|
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
|
||||||
|
|
||||||
In simple terms, it finds keys that matches the query, and get the values of
|
In simple terms, it finds keys that matches the query, and gets the values of
|
||||||
those keys.
|
those keys.
|
||||||
|
|
||||||
It uses dot-product of query and key as the indicator of how matching they are.
|
It uses dot-product of query and key as the indicator of how matching they are.
|
||||||
@ -74,7 +74,7 @@ class MultiHeadAttention(Module):
|
|||||||
This is done to avoid large dot-product values causing softmax to
|
This is done to avoid large dot-product values causing softmax to
|
||||||
give very small gradients when $d_k$ is large.
|
give very small gradients when $d_k$ is large.
|
||||||
|
|
||||||
Softmax is calculate along the axis of of the sequence (or time).
|
Softmax is calculated along the axis of of the sequence (or time).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
|
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
|
||||||
@ -105,7 +105,7 @@ class MultiHeadAttention(Module):
|
|||||||
# Scaling factor before the softmax
|
# Scaling factor before the softmax
|
||||||
self.scale = 1 / math.sqrt(self.d_k)
|
self.scale = 1 / math.sqrt(self.d_k)
|
||||||
|
|
||||||
# We store attentions so that it can used for logging, or other computations if needed
|
# We store attentions so that it can be used for logging, or other computations if needed
|
||||||
self.attn = None
|
self.attn = None
|
||||||
|
|
||||||
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
|
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
|
||||||
@ -125,10 +125,10 @@ class MultiHeadAttention(Module):
|
|||||||
mask: Optional[torch.Tensor] = None):
|
mask: Optional[torch.Tensor] = None):
|
||||||
"""
|
"""
|
||||||
`query`, `key` and `value` are the tensors that store
|
`query`, `key` and `value` are the tensors that store
|
||||||
collection of*query*, *key* and *value* vectors.
|
collection of *query*, *key* and *value* vectors.
|
||||||
They have shape `[seq_len, batch_size, d_model]`.
|
They have shape `[seq_len, batch_size, d_model]`.
|
||||||
|
|
||||||
`mask` has shape `[seq_len, seq_len, batch_size]` and indicates
|
`mask` has shape `[seq_len, seq_len, batch_size]` and
|
||||||
`mask[i, j, b]` indicates whether for batch `b`,
|
`mask[i, j, b]` indicates whether for batch `b`,
|
||||||
query at position `i` has access to key-value at position `j`.
|
query at position `i` has access to key-value at position `j`.
|
||||||
"""
|
"""
|
||||||
@ -139,20 +139,20 @@ class MultiHeadAttention(Module):
|
|||||||
if mask is not None:
|
if mask is not None:
|
||||||
# `mask` has shape `[seq_len, seq_len, batch_size]`,
|
# `mask` has shape `[seq_len, seq_len, batch_size]`,
|
||||||
# where first dimension is the query dimension.
|
# where first dimension is the query dimension.
|
||||||
# If the query dimension is equal to $1$ it will be broadcasted
|
# If the query dimension is equal to $1$ it will be broadcasted.
|
||||||
assert mask.shape[0] == 1 or mask.shape[0] == mask.shape[1]
|
assert mask.shape[0] == 1 or mask.shape[0] == mask.shape[1]
|
||||||
|
|
||||||
# Same mask applied to all heads.
|
# Same mask applied to all heads.
|
||||||
mask = mask.unsqueeze(-1)
|
mask = mask.unsqueeze(-1)
|
||||||
|
|
||||||
# Prepare `query`, `key` and `value` for attention computation
|
# Prepare `query`, `key` and `value` for attention computation.
|
||||||
# These will then have shape `[seq_len, batch_size, heads, d_k]`
|
# These will then have shape `[seq_len, batch_size, heads, d_k]`.
|
||||||
query = self.query(query)
|
query = self.query(query)
|
||||||
key = self.key(key)
|
key = self.key(key)
|
||||||
value = self.value(value)
|
value = self.value(value)
|
||||||
|
|
||||||
# Compute attention scores $Q K^\top$
|
# Compute attention scores $Q K^\top$.
|
||||||
# Results in a tensor of shape `[seq_len, seq_len, batch_size, heads]`
|
# This gives a tensor of shape `[seq_len, seq_len, batch_size, heads]`.
|
||||||
scores = self.get_scores(query, key)
|
scores = self.get_scores(query, key)
|
||||||
|
|
||||||
# Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$
|
# Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$
|
||||||
|
|||||||
Reference in New Issue
Block a user