We have implemented HyperLSTM introduced in paper HyperNetworks, with annotations using PyTorch. This blog post by David Ha gives a good explanation of HyperNetworks.
We have an experiment that trains a HyperLSTM to predict text on Shakespeare dataset.
Here’s the link to code: experiment.py
HyperNetworks use a smaller network to generate weights of a larger network. There are two variants: static hyper-networks and dynamic hyper-networks. Static HyperNetworks have smaller networks that generate weights (kernels) of a convolutional network. Dynamic HyperNetworks generate parameters of a recurrent neural network for each step. This is an implementation of the latter.
In a RNN the parameters stay constant for each step. Dynamic HyperNetworks generate different parameters for each step. HyperLSTM has the structure of a LSTM but the parameters of each step are changed by a smaller LSTM network.
In the basic form, a Dynamic HyperNetwork has a smaller recurrent network that generates a feature vector corresponding to each parameter tensor of the larger recurrent network. Let’s say the larger network has some parameter $\color{cyan}{W_h}$ the smaller network generates a feature vector $z_h$ and we dynamically compute $\color{cyan}{W_h}$ as a linear transformation of $z_h$. For instance $\color{cyan}{W_h} = \langle W_{hz}, z_h \rangle$ where $W_{hz}$ is a 3-d tensor parameter and $\langle . \rangle$ is a tensor-vector multiplication. $z_h$ is usually a linear transformation of the output of the smaller recurrent network.
Large recurrent networks have large dynamically computed parameters. These are calculated using linear transformation of feature vector $z$. And this transformation requires an even larger weight tensor. That is, when $\color{cyan}{W_h}$ has shape $N_h \times N_h$, $W_{hz}$ will be $N_h \times N_h \times N_z$.
To overcome this, we compute the weight parameters of the recurrent network by dynamically scaling each row of a matrix of same size. where $W_{hd}$ is a $N_h \times N_h$ parameter matrix.
We can further optimize this when we compute $\color{cyan}{W_h} h$, as where $\odot$ stands for element-wise multiplication.
71from typing import Optional, Tuple
72
73import torch
74from torch import nn
75
76from labml_helpers.module import Module
77from labml_nn.lstm import LSTMCellFor HyperLSTM the smaller network and the larger network both have the LSTM structure. This is defined in Appendix A.2.2 in the paper.
80class HyperLSTMCell(Module):input_size is the size of the input $x_t$,
hidden_size is the size of the LSTM, and
hyper_size is the size of the smaller LSTM that alters the weights of the larger outer LSTM.
n_z is the size of the feature vectors used to alter the LSTM weights.
We use the output of the smaller LSTM to compute $z_h^{i,f,g,o}$, $z_x^{i,f,g,o}$ and $z_b^{i,f,g,o}$ using linear transformations. We calculate $d_h^{i,f,g,o}(z_h^{i,f,g,o})$, $d_x^{i,f,g,o}(z_x^{i,f,g,o})$, and $d_b^{i,f,g,o}(z_b^{i,f,g,o})$ from these, using linear transformations again. These are then used to scale the rows of weight and bias tensors of the main LSTM.
📝 Since the computation of $z$ and $d$ are two sequential linear transformations these can be combined into a single linear transformation. However we’ve implemented this separately so that it matches with the description in the paper.
88    def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int):106        super().__init__()The input to the hyperLSTM is
where $x_t$ is the input and $h_{t-1}$ is the output of the outer LSTM at previous step.
So the input size is hidden_size + input_size.
The output of hyperLSTM is $\hat{h}_t$ and $\hat{c}_t$.
119        self.hyper = LSTMCell(hidden_size + input_size, hyper_size, layer_norm=True)🤔 In the paper it was specified as I feel that it’s a typo.
125        self.z_h = nn.Linear(hyper_size, 4 * n_z)127        self.z_x = nn.Linear(hyper_size, 4 * n_z)129        self.z_b = nn.Linear(hyper_size, 4 * n_z, bias=False)132        d_h = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)]
133        self.d_h = nn.ModuleList(d_h)135        d_x = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)]
136        self.d_x = nn.ModuleList(d_x)138        d_b = [nn.Linear(n_z, hidden_size) for _ in range(4)]
139        self.d_b = nn.ModuleList(d_b)The weight matrices $W_h^{i,f,g,o}$
142        self.w_h = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, hidden_size)) for _ in range(4)])The weight matrices $W_x^{i,f,g,o}$
144        self.w_x = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, input_size)) for _ in range(4)])Layer normalization
147        self.layer_norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)])
148        self.layer_norm_c = nn.LayerNorm(hidden_size)150    def __call__(self, x: torch.Tensor,
151                 h: torch.Tensor, c: torch.Tensor,
152                 h_hat: torch.Tensor, c_hat: torch.Tensor):159        x_hat = torch.cat((h, x), dim=-1)161        h_hat, c_hat = self.hyper(x_hat, h_hat, c_hat)164        z_h = self.z_h(h_hat).chunk(4, dim=-1)166        z_x = self.z_x(h_hat).chunk(4, dim=-1)168        z_b = self.z_b(h_hat).chunk(4, dim=-1)We calculate $i$, $f$, $g$ and $o$ in a loop
171        ifgo = []
172        for i in range(4):174            d_h = self.d_h[i](z_h[i])176            d_x = self.d_x[i](z_x[i])183            y = d_h * torch.einsum('ij,bj->bi', self.w_h[i], h) + \
184                d_x * torch.einsum('ij,bj->bi', self.w_x[i], x) + \
185                self.d_b[i](z_b[i])
186
187            ifgo.append(self.layer_norm[i](y))190        i, f, g, o = ifgo193        c_next = torch.sigmoid(f) * c + torch.sigmoid(i) * torch.tanh(g)196        h_next = torch.sigmoid(o) * torch.tanh(self.layer_norm_c(c_next))
197
198        return h_next, c_next, h_hat, c_hat201class HyperLSTM(Module):Create a network of n_layers of HyperLSTM.
205    def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int, n_layers: int):210        super().__init__()Store sizes to initialize state
213        self.n_layers = n_layers
214        self.hidden_size = hidden_size
215        self.hyper_size = hyper_sizeCreate cells for each layer. Note that only the first layer gets the input directly. Rest of the layers get the input from the layer below
219        self.cells = nn.ModuleList([HyperLSTMCell(input_size, hidden_size, hyper_size, n_z)] +
220                                   [HyperLSTMCell(hidden_size, hidden_size, hyper_size, n_z) for _ in
221                                    range(n_layers - 1)])x has shape [n_steps, batch_size, input_size] andstate is a tuple of $h, c, \hat{h}, \hat{c}$.
 $h, c$ have shape [batch_size, hidden_size] and
 $\hat{h}, \hat{c}$ have shape [batch_size, hyper_size].223    def __call__(self, x: torch.Tensor,
224                 state: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None):231        n_steps, batch_size = x.shape[:2]Initialize the state with zeros if None
234        if state is None:
235            h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
236            c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
237            h_hat = [x.new_zeros(batch_size, self.hyper_size) for _ in range(self.n_layers)]
238            c_hat = [x.new_zeros(batch_size, self.hyper_size) for _ in range(self.n_layers)]240        else:
241            (h, c, h_hat, c_hat) = stateReverse stack the tensors to get the states of each layer
📝 You can just work with the tensor itself but this is easier to debug
245            h, c = list(torch.unbind(h)), list(torch.unbind(c))
246            h_hat, c_hat = list(torch.unbind(h_hat)), list(torch.unbind(c_hat))Collect the outputs of the final layer at each step
249        out = []
250        for t in range(n_steps):Input to the first layer is the input itself
252            inp = x[t]Loop through the layers
254            for layer in range(self.n_layers):Get the state of the layer
256                h[layer], c[layer], h_hat[layer], c_hat[layer] = \
257                    self.cells[layer](inp, h[layer], c[layer], h_hat[layer], c_hat[layer])Input to the next layer is the state of this layer
259                inp = h[layer]Collect the output $h$ of the final layer
261            out.append(h[-1])Stack the outputs and states
264        out = torch.stack(out)
265        h = torch.stack(h)
266        c = torch.stack(c)
267        h_hat = torch.stack(h_hat)
268        c_hat = torch.stack(c_hat)271        return out, (h, c, h_hat, c_hat)