We have implemented HyperLSTM introduced in paper HyperNetworks, with annotations using PyTorch. This blog post by David Ha gives a good explanation of HyperNetworks.
We have an experiment that trains a HyperLSTM to predict text on Shakespear dataset.
Here’s the link to code: experiment.py
HyperNetworks use a smaller network to generate weights of a larger network. There are two variants: static hyper-networks and dynamic hyper-networks. Static HyperNetworks have smaller network that generates weights (kernels) of a convolutional network. Dynamic HyperNetworks generate parameters of a recurrent neural network for each step. This is an implementation of the latter.
In an RNN the parameters stay constant for each step. Dynamic HyperNetworks generate different parameters for each step. HyperLSTM has the structure of a LSTM but the parameters of each step are changed by a smaller LSTM network.
In the basic form, a Dynamic HyperNetwork has a smaller recurrent network that generates a feature vector corresponding to each parameter tensor of the larger recurrent network. Let’s say the larger network has some parameter $\color{cyan}{W_h}$ the smaller network generates a feature vector $z_h$ and we dynamically compute $\color{cyan}{W_h}$ as a linear transformation of $z_h$. For instance $\color{cyan}{W_h} = \langle W_{hz}, z_h \rangle$ where $W_{hz}$ is a 3-d tensor parameter and $\langle . \rangle$ is a tensor-vector multiplication. $z_h$ is usually a linear transformation of the output of the smaller recurrent network.
Large recurrent networks have large dynamically computed parameters. These are calculated using a linear transformation of feature vector $z$. And this transformation requires an even larger weight tensor. That is, when $\color{cyan}{W_h}$ has shape $N_h \times N_h$, $W_{hz}$ will be $N_h \times N_h \times N_z$.
To overcome this, we compute the weight parameters of the recurrent network by dynamically scaling each row of a matrix of same size. where $W_{hd}$ is a $N_h \times N_h$ parameter matrix.
We can further optimize this when we compute $\color{cyan}{W_h} h$, as where $\odot$ stands for element-wise multiplication.
71from typing import Optional, Tuple
72
73import torch
74from torch import nn
75
76from labml_helpers.module import Module
77from labml_nn.lstm import LSTMCellFor HyperLSTM the smaller network and the larger networks both have the LSTM structure. This is defined in Appendix A.2.2 in the paper.
80class HyperLSTMCell(Module):input_size is the size of the input $x_t$,
hidden_size is the size of the LSTM, and
hyper_size is the size of the smaller LSTM that alters the weights of the larger outer LSTM.
n_z is the size of the feature vectors used to alter the LSTM weights.
We use the output of the smaller LSTM to computer $z_h^{i,f,g,o}$, $z_x^{i,f,g,o}$ and $z_b^{i,f,g,o}$ using linear transformations. We calculate $d_h^{i,f,g,o}(z_h^{i,f,g,o})$, $d_x^{i,f,g,o}(z_x^{i,f,g,o})$, and $d_b^{i,f,g,o}(z_b^{i,f,g,o})$ from these again using linear transformations. These are then used to scale the rows of weight and bias tensors of the main LSTM.
📝 Since the computation of $z$ and $d$ are two sequential linear transformations these can be combined into a single linear transformation. However we’ve implemented this separately so that it matches with the description in the paper.
88 def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int):106 super().__init__()The input to the hyper lstm is
where $x_t$ is the input and $h_{t-1}$ is the output of the outer LSTM at previous step.
So the input size is hidden_size + input_size.
The output of hyper lstm is $\hat{h}_t$ and $\hat{c}_t$.
119 self.hyper = LSTMCell(hidden_size + input_size, hyper_size, layer_norm=True)🤔 In the paper it was specified as I feel that’s a typo.
125 self.z_h = nn.Linear(hyper_size, 4 * n_z)127 self.z_x = nn.Linear(hyper_size, 4 * n_z)129 self.z_b = nn.Linear(hyper_size, 4 * n_z, bias=False)132 d_h = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)]
133 self.d_h = nn.ModuleList(d_h)135 d_x = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)]
136 self.d_x = nn.ModuleList(d_x)138 d_b = [nn.Linear(n_z, hidden_size) for _ in range(4)]
139 self.d_b = nn.ModuleList(d_b)The weight matrices $W_h^{i,f,g,o}$
142 self.w_h = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, hidden_size)) for _ in range(4)])The weight matrices $W_x^{i,f,g,o}$
144 self.w_x = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, input_size)) for _ in range(4)])Layer normalization
147 self.layer_norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)])
148 self.layer_norm_c = nn.LayerNorm(hidden_size)150 def __call__(self, x: torch.Tensor,
151 h: torch.Tensor, c: torch.Tensor,
152 h_hat: torch.Tensor, c_hat: torch.Tensor):159 x_hat = torch.cat((h, x), dim=-1)161 h_hat, c_hat = self.hyper(x_hat, h_hat, c_hat)164 z_h = self.z_h(h_hat).chunk(4, dim=-1)166 z_x = self.z_x(h_hat).chunk(4, dim=-1)168 z_b = self.z_b(h_hat).chunk(4, dim=-1)We calculate $i$, $f$, $g$ and $o$ in a loop
171 ifgo = []
172 for i in range(4):174 d_h = self.d_h[i](z_h[i])176 d_x = self.d_x[i](z_x[i])183 y = d_h * torch.einsum('ij,bj->bi', self.w_h[i], h) + \
184 d_x * torch.einsum('ij,bj->bi', self.w_x[i], x) + \
185 self.d_b[i](z_b[i])
186
187 ifgo.append(self.layer_norm[i](y))190 i, f, g, o = ifgo193 c_next = torch.sigmoid(f) * c + torch.sigmoid(i) * torch.tanh(g)196 h_next = torch.sigmoid(o) * torch.tanh(self.layer_norm_c(c_next))
197
198 return h_next, c_next, h_hat, c_hat201class HyperLSTM(Module):Create a network of n_layers of HyperLSTM.
205 def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int, n_layers: int):210 super().__init__()Store sizes to initialize state
213 self.n_layers = n_layers
214 self.hidden_size = hidden_size
215 self.hyper_size = hyper_sizeCreate cells for each layer. Note that only the first layer gets the input directly. Rest of the layers get the input from the layer below
219 self.cells = nn.ModuleList([HyperLSTMCell(input_size, hidden_size, hyper_size, n_z)] +
220 [HyperLSTMCell(hidden_size, hidden_size, hyper_size, n_z) for _ in
221 range(n_layers - 1)])x has shape [n_steps, batch_size, input_size] andstate is a tuple of $h, c, \hat{h}, \hat{c}$.
$h, c$ have shape [batch_size, hidden_size] and
$\hat{h}, \hat{c}$ have shape [batch_size, hyper_size].223 def __call__(self, x: torch.Tensor,
224 state: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None):231 n_steps, batch_size = x.shape[:2]Initialize the state with zeros if None
234 if state is None:
235 h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
236 c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
237 h_hat = [x.new_zeros(batch_size, self.hyper_size) for _ in range(self.n_layers)]
238 c_hat = [x.new_zeros(batch_size, self.hyper_size) for _ in range(self.n_layers)]240 else:
241 (h, c, h_hat, c_hat) = stateReverse stack the tensors to get the states of each layer
📝 You can just work with the tensor itself but this is easier to debug
245 h, c = list(torch.unbind(h)), list(torch.unbind(c))
246 h_hat, c_hat = list(torch.unbind(h_hat)), list(torch.unbind(c_hat))Collect the outputs of the final layer at each step
249 out = []
250 for t in range(n_steps):Input to the first layer is the input itself
252 inp = x[t]Loop through the layers
254 for layer in range(self.n_layers):Get the state of the layer
256 h[layer], c[layer], h_hat[layer], c_hat[layer] = \
257 self.cells[layer](inp, h[layer], c[layer], h_hat[layer], c_hat[layer])Input to the next layer is the state of this layer
259 inp = h[layer]Collect the output $h$ of the final layer
261 out.append(h[-1])Stack the outputs and states
264 out = torch.stack(out)
265 h = torch.stack(h)
266 c = torch.stack(c)
267 h_hat = torch.stack(h_hat)
268 c_hat = torch.stack(c_hat)271 return out, (h, c, h_hat, c_hat)