✍️ mha english

This commit is contained in:
Varuna Jayasiri
2021-02-01 07:28:33 +05:30
parent 53128f5679
commit 5cd2b8701b
2 changed files with 24 additions and 24 deletions

View File

@ -164,8 +164,8 @@ This is used to transform <strong>key</strong>, <strong>query</strong>, and <str
<a href='#section-7'>#</a> <a href='#section-7'>#</a>
</div> </div>
<p>Input has shape <code>[seq_len, batch_size, d_model]</code> or <code>[batch_size, d_model]</code>. <p>Input has shape <code>[seq_len, batch_size, d_model]</code> or <code>[batch_size, d_model]</code>.
We apply the linear transformation of the last dimension and splits that into We apply the linear transformation to the last dimension and split that into
the heads</p> the heads.</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="n">head_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span></pre></div> <div class="highlight"><pre><span class="lineno">49</span> <span class="n">head_shape</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span></pre></div>
@ -214,13 +214,13 @@ the heads</p>
<p> <p>
<script type="math/tex; mode=display">\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V</script> <script type="math/tex; mode=display">\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V</script>
</p> </p>
<p>In simple terms, it finds keys that matches the query, and get the values of <p>In simple terms, it finds keys that matches the query, and gets the values of
those keys.</p> those keys.</p>
<p>It uses dot-product of query and key as the indicator of how matching they are. <p>It uses dot-product of query and key as the indicator of how matching they are.
Before taking the $softmax$ the dot-products are scaled by $\frac{1}{\sqrt{d_k}}$. Before taking the $softmax$ the dot-products are scaled by $\frac{1}{\sqrt{d_k}}$.
This is done to avoid large dot-product values causing softmax to This is done to avoid large dot-product values causing softmax to
give very small gradients when $d_k$ is large.</p> give very small gradients when $d_k$ is large.</p>
<p>Softmax is calculate along the axis of of the sequence (or time).</p> <p>Softmax is calculated along the axis of of the sequence (or time).</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">61</span><span class="k">class</span> <span class="nc">MultiHeadAttention</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div> <div class="highlight"><pre><span class="lineno">61</span><span class="k">class</span> <span class="nc">MultiHeadAttention</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
@ -335,7 +335,7 @@ give very small gradients when $d_k$ is large.</p>
<div class='section-link'> <div class='section-link'>
<a href='#section-21'>#</a> <a href='#section-21'>#</a>
</div> </div>
<p>We store attentions so that it can used for logging, or other computations if needed</p> <p>We store attentions so that it can be used for logging, or other computations if needed</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">109</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="kc">None</span></pre></div> <div class="highlight"><pre><span class="lineno">109</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="kc">None</span></pre></div>
@ -370,9 +370,9 @@ give very small gradients when $d_k$ is large.</p>
<a href='#section-24'>#</a> <a href='#section-24'>#</a>
</div> </div>
<p><code>query</code>, <code>key</code> and <code>value</code> are the tensors that store <p><code>query</code>, <code>key</code> and <code>value</code> are the tensors that store
collection of<em>query</em>, <em>key</em> and <em>value</em> vectors. collection of <em>query</em>, <em>key</em> and <em>value</em> vectors.
They have shape <code>[seq_len, batch_size, d_model]</code>.</p> They have shape <code>[seq_len, batch_size, d_model]</code>.</p>
<p><code>mask</code> has shape <code>[seq_len, seq_len, batch_size]</code> and indicates <p><code>mask</code> has shape <code>[seq_len, seq_len, batch_size]</code> and
<code>mask[i, j, b]</code> indicates whether for batch <code>b</code>, <code>mask[i, j, b]</code> indicates whether for batch <code>b</code>,
query at position <code>i</code> has access to key-value at position <code>j</code>.</p> query at position <code>i</code> has access to key-value at position <code>j</code>.</p>
</div> </div>
@ -404,7 +404,7 @@ query at position <code>i</code> has access to key-value at position <code>j</co
</div> </div>
<p><code>mask</code> has shape <code>[seq_len, seq_len, batch_size]</code>, <p><code>mask</code> has shape <code>[seq_len, seq_len, batch_size]</code>,
where first dimension is the query dimension. where first dimension is the query dimension.
If the query dimension is equal to $1$ it will be broadcasted</p> If the query dimension is equal to $1$ it will be broadcasted.</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">143</span> <span class="k">assert</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div> <div class="highlight"><pre><span class="lineno">143</span> <span class="k">assert</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></pre></div>
@ -426,8 +426,8 @@ If the query dimension is equal to $1$ it will be broadcasted</p>
<div class='section-link'> <div class='section-link'>
<a href='#section-28'>#</a> <a href='#section-28'>#</a>
</div> </div>
<p>Prepare <code>query</code>, <code>key</code> and <code>value</code> for attention computation <p>Prepare <code>query</code>, <code>key</code> and <code>value</code> for attention computation.
These will then have shape <code>[seq_len, batch_size, heads, d_k]</code></p> These will then have shape <code>[seq_len, batch_size, heads, d_k]</code>.</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">150</span> <span class="n">query</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">query</span><span class="p">(</span><span class="n">query</span><span class="p">)</span> <div class="highlight"><pre><span class="lineno">150</span> <span class="n">query</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">query</span><span class="p">(</span><span class="n">query</span><span class="p">)</span>
@ -440,8 +440,8 @@ These will then have shape <code>[seq_len, batch_size, heads, d_k]</code></p>
<div class='section-link'> <div class='section-link'>
<a href='#section-29'>#</a> <a href='#section-29'>#</a>
</div> </div>
<p>Compute attention scores $Q K^\top$ <p>Compute attention scores $Q K^\top$.
Results in a tensor of shape <code>[seq_len, seq_len, batch_size, heads]</code></p> This gives a tensor of shape <code>[seq_len, seq_len, batch_size, heads]</code>.</p>
</div> </div>
<div class='code'> <div class='code'>
<div class="highlight"><pre><span class="lineno">156</span> <span class="n">scores</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_scores</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">)</span></pre></div> <div class="highlight"><pre><span class="lineno">156</span> <span class="n">scores</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_scores</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">)</span></pre></div>

View File

@ -44,8 +44,8 @@ class PrepareForMultiHeadAttention(Module):
def __call__(self, x: torch.Tensor): def __call__(self, x: torch.Tensor):
# Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`. # Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`.
# We apply the linear transformation of the last dimension and splits that into # We apply the linear transformation to the last dimension and split that into
# the heads # the heads.
head_shape = x.shape[:-1] head_shape = x.shape[:-1]
# Linear transform # Linear transform
@ -66,7 +66,7 @@ class MultiHeadAttention(Module):
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$ $$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
In simple terms, it finds keys that matches the query, and get the values of In simple terms, it finds keys that matches the query, and gets the values of
those keys. those keys.
It uses dot-product of query and key as the indicator of how matching they are. It uses dot-product of query and key as the indicator of how matching they are.
@ -74,7 +74,7 @@ class MultiHeadAttention(Module):
This is done to avoid large dot-product values causing softmax to This is done to avoid large dot-product values causing softmax to
give very small gradients when $d_k$ is large. give very small gradients when $d_k$ is large.
Softmax is calculate along the axis of of the sequence (or time). Softmax is calculated along the axis of of the sequence (or time).
""" """
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True): def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
@ -105,7 +105,7 @@ class MultiHeadAttention(Module):
# Scaling factor before the softmax # Scaling factor before the softmax
self.scale = 1 / math.sqrt(self.d_k) self.scale = 1 / math.sqrt(self.d_k)
# We store attentions so that it can used for logging, or other computations if needed # We store attentions so that it can be used for logging, or other computations if needed
self.attn = None self.attn = None
def get_scores(self, query: torch.Tensor, key: torch.Tensor): def get_scores(self, query: torch.Tensor, key: torch.Tensor):
@ -125,10 +125,10 @@ class MultiHeadAttention(Module):
mask: Optional[torch.Tensor] = None): mask: Optional[torch.Tensor] = None):
""" """
`query`, `key` and `value` are the tensors that store `query`, `key` and `value` are the tensors that store
collection of*query*, *key* and *value* vectors. collection of *query*, *key* and *value* vectors.
They have shape `[seq_len, batch_size, d_model]`. They have shape `[seq_len, batch_size, d_model]`.
`mask` has shape `[seq_len, seq_len, batch_size]` and indicates `mask` has shape `[seq_len, seq_len, batch_size]` and
`mask[i, j, b]` indicates whether for batch `b`, `mask[i, j, b]` indicates whether for batch `b`,
query at position `i` has access to key-value at position `j`. query at position `i` has access to key-value at position `j`.
""" """
@ -139,20 +139,20 @@ class MultiHeadAttention(Module):
if mask is not None: if mask is not None:
# `mask` has shape `[seq_len, seq_len, batch_size]`, # `mask` has shape `[seq_len, seq_len, batch_size]`,
# where first dimension is the query dimension. # where first dimension is the query dimension.
# If the query dimension is equal to $1$ it will be broadcasted # If the query dimension is equal to $1$ it will be broadcasted.
assert mask.shape[0] == 1 or mask.shape[0] == mask.shape[1] assert mask.shape[0] == 1 or mask.shape[0] == mask.shape[1]
# Same mask applied to all heads. # Same mask applied to all heads.
mask = mask.unsqueeze(-1) mask = mask.unsqueeze(-1)
# Prepare `query`, `key` and `value` for attention computation # Prepare `query`, `key` and `value` for attention computation.
# These will then have shape `[seq_len, batch_size, heads, d_k]` # These will then have shape `[seq_len, batch_size, heads, d_k]`.
query = self.query(query) query = self.query(query)
key = self.key(key) key = self.key(key)
value = self.value(value) value = self.value(value)
# Compute attention scores $Q K^\top$ # Compute attention scores $Q K^\top$.
# Results in a tensor of shape `[seq_len, seq_len, batch_size, heads]` # This gives a tensor of shape `[seq_len, seq_len, batch_size, heads]`.
scores = self.get_scores(query, key) scores = self.get_scores(query, key)
# Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$ # Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$