12from typing import Optional, Tuple
13
14import torch
15from torch import nn
16
17from labml_helpers.module import ModuleLSTM Cell computes , and . is like the long-term memory, and is like the short term memory. We use the input and to update the long term memory. In the update, some features of are cleared with a forget gate , and some features are added through a gate .
The new short term memory is the of the long-term memory multiplied by the output gate .
Note that the cell doesn't look at long term memory when doing the update. It only modifies it. Also never goes through a linear transformation. This is what solves vanishing and exploding gradients.
Here's the update rule.
stands for element-wise multiplication.
Intermediate values and gates are computed as linear transformations of the hidden state and input.
20class LSTMCell(Module):57    def __init__(self, input_size: int, hidden_size: int, layer_norm: bool = False):
58        super().__init__()These are the linear layer to transform the input
 and hidden
 vectors. One of them doesn't need a bias since we add the transformations. 
This combines , , , and transformations.
64        self.hidden_lin = nn.Linear(hidden_size, 4 * hidden_size)This combines , , , and transformations.
66        self.input_lin = nn.Linear(input_size, 4 * hidden_size, bias=False)Whether to apply layer normalizations.
Applying layer normalization gives better results. , , and embeddings are normalized and is normalized in
73        if layer_norm:
74            self.layer_norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)])
75            self.layer_norm_c = nn.LayerNorm(hidden_size)
76        else:
77            self.layer_norm = nn.ModuleList([nn.Identity() for _ in range(4)])
78            self.layer_norm_c = nn.Identity()80    def forward(self, x: torch.Tensor, h: torch.Tensor, c: torch.Tensor):We compute the linear transformations for , , and using the same linear layers.
83        ifgo = self.hidden_lin(h) + self.input_lin(x)Each layer produces an output of 4 times the hidden_size
 and we split them 
85        ifgo = ifgo.chunk(4, dim=-1)Apply layer normalization (not in original paper, but gives better results)
88        ifgo = [self.layer_norm[i](ifgo[i]) for i in range(4)]91        i, f, g, o = ifgo94        c_next = torch.sigmoid(f) * c + torch.sigmoid(i) * torch.tanh(g)Optionally, apply layer norm to
98        h_next = torch.sigmoid(o) * torch.tanh(self.layer_norm_c(c_next))
99
100        return h_next, c_next103class LSTM(Module): Create a network of n_layers
 of LSTM.
108    def __init__(self, input_size: int, hidden_size: int, n_layers: int):113        super().__init__()
114        self.n_layers = n_layers
115        self.hidden_size = hidden_sizeCreate cells for each layer. Note that only the first layer gets the input directly. Rest of the layers get the input from the layer below
118        self.cells = nn.ModuleList([LSTMCell(input_size, hidden_size)] +
119                                   [LSTMCell(hidden_size, hidden_size) for _ in range(n_layers - 1)]) x
 has shape [n_steps, batch_size, input_size]
 and state
 is a tuple of  and , each with a shape of [batch_size, hidden_size]
.
121    def forward(self, x: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):126        n_steps, batch_size = x.shape[:2]Initialize the state if None
 
129        if state is None:
130            h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
131            c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
132        else:
133            (h, c) = stateReverse stack the tensors to get the states of each layer
📝 You can just work with the tensor itself but this is easier to debug
137            h, c = list(torch.unbind(h)), list(torch.unbind(c))Array to collect the outputs of the final layer at each time step.
140        out = []
141        for t in range(n_steps):Input to the first layer is the input itself
143            inp = x[t]Loop through the layers
145            for layer in range(self.n_layers):Get the state of the layer
147                h[layer], c[layer] = self.cells[layer](inp, h[layer], c[layer])Input to the next layer is the state of this layer
149                inp = h[layer]Collect the output of the final layer
151            out.append(h[-1])Stack the outputs and states
154        out = torch.stack(out)
155        h = torch.stack(h)
156        c = torch.stack(c)
157
158        return out, (h, c)