HyperNetworks - HyperLSTM

We have implemented HyperLSTM introduced in paper HyperNetworks, with annotations. This blog post by David Ha gives a good explanation of HyperNetworks.

We have an experiment that trains a HyperLSTM to predict text on Shakespear dataset. Here’s the link to code: experiment.py

Open In Colab View Run

HyperNetworks uses a smaller network to generate weights of a larger network. There are two variants: static hyper-networks and dynamic hyper-networks. Static HyperNetworks have smaller network that generates weights (kernels) of a convolutional network. Dynamic HyperNetworks generate parameters of a recurrent neural network for each step. This is an implementation of the latter.

Dynamic HyperNetworks

In an RNN the parameters stay constant for each step. Dynamic HyperNetworks generate different parameters for each step. HyperLSTM has the structure of a LSTM but the parameters of each step are changed by a smaller LSTM network.

In the basic form, a Dynamic HyperNetwork has a smaller recurrent network that generates a feature vector corresponding to each parameter tensor of the larger recurrent network. Let’s say the larger network has some parameter $\color{cyan}{W_h}$ the smaller network generates a feature vector $z_h$ and we dynamically compute $\color{cyan}{W_h}$ as a linear transformation of $z_h$. For instance $\color{cyan}{W_h} = \langle W_{hz}, z_h \rangle$ where $W_{hz}$ is a 3-d tensor parameter and $\langle . \rangle$ is a tensor-vector multiplication. $z_h$ is usually a linear transformation of the output of the smaller recurrent network.

Weight scaling instead of computing

Large recurrent networks have large dynamically computed parameters. These are calculated using a linear transformation of feature vector $z$. And this transformation requires an even larger weight tensor. That is, when $\color{cyan}{W_h}$ has shape $N_h \times N_h$, $W_{hz}$ will be $N_h \times N_h \times N_z$.

To overcome this, we compute the weight parameters of the recurrent network by dynamically scaling each row of a matrix of same size. where $W_{hd}$ is a $N_h \times N_h$ parameter matrix.

We can further optimize this when we compute $\color{cyan}{W_h} h$, as where $\odot$ stands for element-wise multiplication.

70from typing import Optional, Tuple
71
72import torch
73from torch import nn
74
75from labml_helpers.module import Module
76from labml_nn.lstm import LSTMCell

HyperLSTM Cell

For HyperLSTM the smaller network and the larger networks both have the LSTM structure. This is defined in Appendix A.2.2 in the paper.

79class HyperLSTMCell(Module):

input_size is the size of the input $x_t$, hidden_size is the size of the LSTM, and hyper_size is the size of the smaller LSTM that alters the weights of the larger outer LSTM. n_z is the size of the feature vectors used to alter the LSTM weights.

We use the output of the smaller LSTM to computer $z_h^{i,f,g,o}$, $z_x^{i,f,g,o}$ and $z_b^{i,f,g,o}$ using linear transformations. We calculate $d_h^{i,f,g,o}(z_h^{i,f,g,o})$, $d_x^{i,f,g,o}(z_x^{i,f,g,o})$, and $d_b^{i,f,g,o}(z_b^{i,f,g,o})$ from these again using linear transformations. These are then used to scale the rows of weight and bias tensors of the main LSTM.

📝 Since the computation of $z$ and $d$ are two sequential linear transformations these can be combined into a single linear transformation. However we’ve implemented this separately so that it matches with the description in the paper.

87    def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int):
105        super().__init__()

The input to the hyper lstm is where $x_t$ is the input and $h_{t-1}$ is the output of the outer LSTM at previous step. So the input size is hidden_size + input_size.

The output of hyper lstm is $\hat{h}_t$ and $\hat{c}_t$.

118        self.hyper = LSTMCell(hidden_size + input_size, hyper_size, layer_norm=True)

🤔 In the paper it was specified as I feel that’s a typo.

124        self.z_h = nn.Linear(hyper_size, 4 * n_z)

126        self.z_x = nn.Linear(hyper_size, 4 * n_z)

128        self.z_b = nn.Linear(hyper_size, 4 * n_z, bias=False)

131        d_h = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)]
132        self.d_h = nn.ModuleList(d_h)

134        d_x = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)]
135        self.d_x = nn.ModuleList(d_x)

137        d_b = [nn.Linear(n_z, hidden_size) for _ in range(4)]
138        self.d_b = nn.ModuleList(d_b)

The weight matrices $W_h^{i,f,g,o}$

141        self.w_h = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, hidden_size)) for _ in range(4)])

The weight matrices $W_x^{i,f,g,o}$

143        self.w_x = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, input_size)) for _ in range(4)])

Layer normalization

146        self.layer_norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)])
147        self.layer_norm_c = nn.LayerNorm(hidden_size)
149    def __call__(self, x: torch.Tensor,
150                 h: torch.Tensor, c: torch.Tensor,
151                 h_hat: torch.Tensor, c_hat: torch.Tensor):

158        x_hat = torch.cat((h, x), dim=-1)

160        h_hat, c_hat = self.hyper(x_hat, h_hat, c_hat)

163        z_h = self.z_h(h_hat).chunk(4, dim=-1)

165        z_x = self.z_x(h_hat).chunk(4, dim=-1)

167        z_b = self.z_b(h_hat).chunk(4, dim=-1)

We calculate $i$, $f$, $g$ and $o$ in a loop

170        ifgo = []
171        for i in range(4):

173            d_h = self.d_h[i](z_h[i])

175            d_x = self.d_x[i](z_x[i])

182            y = d_h * torch.einsum('ij,bj->bi', self.w_h[i], h) + \
183                d_x * torch.einsum('ij,bj->bi', self.w_x[i], x) + \
184                self.d_b[i](z_b[i])
185
186            ifgo.append(self.layer_norm[i](y))

189        i, f, g, o = ifgo

192        c_next = torch.sigmoid(f) * c + torch.sigmoid(i) * torch.tanh(g)

195        h_next = torch.sigmoid(o) * torch.tanh(self.layer_norm_c(c_next))
196
197        return h_next, c_next, h_hat, c_hat

HyperLSTM module

200class HyperLSTM(Module):

Create a network of n_layers of HyperLSTM.

204    def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int, n_layers: int):
209        super().__init__()

Store sizes to initialize state

212        self.n_layers = n_layers
213        self.hidden_size = hidden_size
214        self.hyper_size = hyper_size

Create cells for each layer. Note that only the first layer gets the input directly. Rest of the layers get the input from the layer below

218        self.cells = nn.ModuleList([HyperLSTMCell(input_size, hidden_size, hyper_size, n_z)] +
219                                   [HyperLSTMCell(hidden_size, hidden_size, hyper_size, n_z) for _ in
220                                    range(n_layers - 1)])
  • x has shape [n_steps, batch_size, input_size] and
  • state is a tuple of $h, c, \hat{h}, \hat{c}$. $h, c$ have shape [batch_size, hidden_size] and $\hat{h}, \hat{c}$ have shape [batch_size, hyper_size].
222    def __call__(self, x: torch.Tensor,
223                 state: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None):
230        n_steps, batch_size = x.shape[:2]

Initialize the state with zeros if None

233        if state is None:
234            h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
235            c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
236            h_hat = [x.new_zeros(batch_size, self.hyper_size) for _ in range(self.n_layers)]
237            c_hat = [x.new_zeros(batch_size, self.hyper_size) for _ in range(self.n_layers)]
239        else:
240            (h, c, h_hat, c_hat) = state

Reverse stack the tensors to get the states of each layer

📝 You can just work with the tensor itself but this is easier to debug

244            h, c = list(torch.unbind(h)), list(torch.unbind(c))
245            h_hat, c_hat = list(torch.unbind(h_hat)), list(torch.unbind(c_hat))

Collect the outputs of the final layer at each step

248        out = []
249        for t in range(n_steps):

Input to the first layer is the input itself

251            inp = x[t]

Loop through the layers

253            for layer in range(self.n_layers):

Get the state of the layer

255                h[layer], c[layer], h_hat[layer], c_hat[layer] = \
256                    self.cells[layer](inp, h[layer], c[layer], h_hat[layer], c_hat[layer])

Input to the next layer is the state of this layer

258                inp = h[layer]

Collect the output $h$ of the final layer

260            out.append(h[-1])

Stack the outputs and states

263        out = torch.stack(out)
264        h = torch.stack(h)
265        c = torch.stack(c)
266        h_hat = torch.stack(h_hat)
267        c_hat = torch.stack(c_hat)
270        return out, (h, c, h_hat, c_hat)