mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-15 18:27:20 +08:00
rhn docs
This commit is contained in:
@ -2,10 +2,9 @@
|
||||
# LabML Models
|
||||
|
||||
* [Transformers](transformers/index.html)
|
||||
* [Recurrent Highway Networks](recurrent_highway_networks/index.html)
|
||||
* [LSTM](lstm/index.html)
|
||||
|
||||
TODO:
|
||||
|
||||
* LSTM
|
||||
* Highway Networks
|
||||
* 🤔
|
||||
If you have any suggestions for other new implementations,
|
||||
please create a [Github Issue](https://github.com/lab-ml/labml_nn/issues).
|
||||
"""
|
||||
|
@ -1,3 +1,8 @@
|
||||
"""
|
||||
This is an implementation of [Recurrent Highway Networks](https://arxiv.org/abs/1607.03474).
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@ -5,54 +10,142 @@ from labml_helpers.module import Module
|
||||
|
||||
|
||||
class RHNCell(Module):
|
||||
"""
|
||||
## Recurrent Highway Network Cell
|
||||
|
||||
This implements equations $(6) - (9)$.
|
||||
|
||||
$s_d^t = h_d^t . g_d^t + s_{d - 1}^t . c_d^t$
|
||||
|
||||
where
|
||||
|
||||
\begin{align}
|
||||
h_0^t &= tanh(lin_{hx}(x) + lin_{hs}(s_D^{t-1})) \\
|
||||
g_0^t &= \sigma(lin_{gx}(x) + lin_{gs}^1(s_D^{t-1})) \\
|
||||
c_0^t &= \sigma(lin_{cx}(x) + lin_{cs}^1(s_D^{t-1}))
|
||||
\end{align}
|
||||
|
||||
and for $0 < d < D$
|
||||
|
||||
\begin{align}
|
||||
h_d^t &= tanh(lin_{hs}^d(s_d^t)) \\
|
||||
g_d^t &= \sigma(lin_{gs}^d(s_d^t)) \\
|
||||
c_d^t &= \sigma(lin_{cs}^d(s_d^t))
|
||||
\end{align}
|
||||
|
||||
Here we have made a couple of changes to notations from the paper.
|
||||
To avoid confusion with time, the gate is represented with $g$,
|
||||
which was $t$ in the paper.
|
||||
To avoid confusion with multiple layers we use $d$ for depth and $D$ for
|
||||
total depth instead of $l$ and $L$ from paper.
|
||||
|
||||
We have also replaced the weight matrices and bias vectors from the equations with
|
||||
linear transforms, because that's how the implementation is going to look like.
|
||||
|
||||
We implement weight tying, as described in paper, $c_d^t = (1 - g_d^t$.
|
||||
"""
|
||||
|
||||
def __init__(self, input_size: int, hidden_size: int, depth: int):
|
||||
"""
|
||||
`input_size` is the feature length of the input and `hidden_size` is
|
||||
feature length of the cell.
|
||||
`depth` is $D$.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.depth = depth
|
||||
# We combine $lin_{hs}$ and $lin_{gs}$, with a single linear layer.
|
||||
# We can then split the results to get the $lin_{hs}$ and $lin_{gs}$ components.
|
||||
# This is the $lin_{hs}^d$ and $lin_{gs}^d$ for $0 \leq d < D$.
|
||||
self.hidden_lin = nn.ModuleList([nn.Linear(hidden_size, 2 * hidden_size) for _ in range(depth)])
|
||||
|
||||
# Similarly we combine $lin_{hx}$ and $lin_{gx}$.
|
||||
self.input_lin = nn.Linear(input_size, 2 * hidden_size, bias=False)
|
||||
|
||||
def __call__(self, x, s):
|
||||
for i in range(self.depth):
|
||||
if i == 0:
|
||||
ht = self.input_lin(x) + self.hidden_lin[i](s)
|
||||
def __call__(self, x: torch.Tensor, s: torch.Tensor):
|
||||
"""
|
||||
`x` has shape `[batch_size, input_size]` and
|
||||
`s` has shape `[batch_size, hidden_size]`.
|
||||
"""
|
||||
|
||||
# Iterate $0 \leq d < D$
|
||||
for d in range(self.depth):
|
||||
# We calculate the concatenation of linear transforms for $h$ and $g$
|
||||
if d == 0:
|
||||
# The input is used only when $d$ is $0$.
|
||||
hg = self.input_lin(x) + self.hidden_lin[d](s)
|
||||
else:
|
||||
ht = self.hidden_lin[i](s)
|
||||
hg = self.hidden_lin[d](s)
|
||||
|
||||
h = torch.tanh(ht[:, :self.hidden_size])
|
||||
t = torch.sigmoid(ht[:, self.hidden_size:])
|
||||
# Use the first half of `hg` to get $h_d^t$
|
||||
# \begin{align}
|
||||
# h_0^t &= tanh(lin_{hx}(x) + lin_{hs}(s_D^{t-1})) \\
|
||||
# h_d^t &= tanh(lin_{hs}^d(s_d^t))
|
||||
# \end{align}
|
||||
h = torch.tanh(hg[:, :self.hidden_size])
|
||||
# Use the second half of `hg` to get $g_d^t$
|
||||
# \begin{align}
|
||||
# g_0^t &= \sigma(lin_{gx}(x) + lin_{gs}^1(s_D^{t-1})) \\
|
||||
# g_d^t &= \sigma(lin_{gs}^d(s_d^t))
|
||||
# \end{align}
|
||||
g = torch.sigmoid(hg[:, self.hidden_size:])
|
||||
|
||||
s = s + (h - s) * t
|
||||
s = h * g + s * (1 - g)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
class RHN(Module):
|
||||
"""
|
||||
### Multilayer Recurrent Highway Network
|
||||
"""
|
||||
|
||||
def __init__(self, input_size: int, hidden_size: int, depth: int, n_layers: int):
|
||||
"""
|
||||
Create a network of `n_layers` of recurrent highway network layers, each with depth `depth`, $D$.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.n_layers = n_layers
|
||||
self.hidden_size = hidden_size
|
||||
# Create cells for each layer. Note that only the first layer gets the input directly.
|
||||
# Rest of the layers get the input from the layer below
|
||||
self.cells = nn.ModuleList([RHNCell(input_size, hidden_size, depth)] +
|
||||
[RHNCell(hidden_size, hidden_size, depth) for _ in range(n_layers - 1)])
|
||||
|
||||
def __call__(self, x: torch.Tensor, state=None):
|
||||
# x [seq_len, batch, d_model]
|
||||
def __call__(self, x: torch.Tensor, state: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
`x` has shape `[seq_len, batch_size, input_size]` and
|
||||
`s` has shape `[batch_size, hidden_size]`.
|
||||
"""
|
||||
time_steps, batch_size = x.shape[:2]
|
||||
|
||||
# Initialize the state if `None`
|
||||
if state is None:
|
||||
s = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
|
||||
else:
|
||||
# Reverse stack the state to get the state of each layer <br />
|
||||
# 📝 You can just work with the tensor itself but this is easier to debug
|
||||
s = torch.unbind(state)
|
||||
|
||||
# Array to collect the outputs of the final layer at each time step.
|
||||
out = []
|
||||
|
||||
# Run through the network for each time step
|
||||
for t in range(time_steps):
|
||||
# Input to the first layer is the input itself
|
||||
inp = x[t]
|
||||
for i in range(self.n_layers):
|
||||
s[i] = self.cells[i](inp, s[i])
|
||||
inp = s[i]
|
||||
# Loop through the layers
|
||||
for layer in range(self.n_layers):
|
||||
# Get the state of the first layer
|
||||
s[layer] = self.cells[layer](inp, s[layer])
|
||||
# Input to the next layer is the state of this layer
|
||||
inp = s[layer]
|
||||
# Collect the output of the final layer
|
||||
out.append(s[-1])
|
||||
|
||||
# Stack the outputs and states
|
||||
out = torch.stack(out)
|
||||
s = torch.stack(s)
|
||||
return out, s
|
||||
|
15
readme.rst
15
readme.rst
@ -16,13 +16,18 @@ Transformers
|
||||
and
|
||||
`relative multi-headed attention <http://lab-ml.com/labml_nn/transformers/relative_mha.html>`_.
|
||||
|
||||
✅ TODO
|
||||
-------
|
||||
Recurrent Highway Networks
|
||||
--------------------------
|
||||
|
||||
* Recurrent Highway Networks
|
||||
* LSTMs
|
||||
This is the implementation for `Recurrent Highway Networks <http://lab-ml.com/labml_nn/recurrent_highway_networks>`_.
|
||||
|
||||
Please create a Github issue if there's something you'ld like to see implemented here.
|
||||
|
||||
LSTM
|
||||
----
|
||||
|
||||
This is the implementation for `LSTMs <http://lab-ml.com/labml_nn/lstm>`_.
|
||||
|
||||
✅ Please create a Github issue if there's something you'ld like to see implemented here.
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
Reference in New Issue
Block a user