This commit is contained in:
Varuna Jayasiri
2020-09-07 16:22:40 +05:30
parent 7630233b3a
commit 0eb20f8559
3 changed files with 68 additions and 7 deletions

View File

@ -10,15 +10,15 @@ class LSTMCell(Module):
self.hidden_size = hidden_size
self.hidden_lin = nn.Linear(hidden_size, 4 * hidden_size)
self.input_lin = nn.Linear(input_size, 4 * hidden_size)
self.input_lin = nn.Linear(input_size, 4 * hidden_size, bias=False)
def __call__(self, x, h, c):
ifgo = self.hidden_lin(h) + self.input_lin(x)
i = torch.sigmoid(ifgo[:, :self.hidden_size])
f = torch.sigmoid(ifgo[:, self.hidden_size:self.hidden_size * 2])
g = torch.tanh(ifgo[:, self.hidden_size * 2:self.d_model * 4])
o = torch.sigmoid(ifgo[:, self.hidden_size * 3:self.d_model * 3])
g = torch.tanh(ifgo[:, self.hidden_size * 2:self.hidden_size * 3])
o = torch.sigmoid(ifgo[:, self.hidden_size * 3:self.hidden_size * 4])
c_next = f * c + i * g
h_next = o * torch.tanh(c_next)
@ -30,8 +30,8 @@ class LSTM(Module):
super().__init__()
self.n_layers = n_layers
self.hidden_size = hidden_size
cells = [LSTMCell(input_size, hidden_size)] + [LSTMCell(hidden_size, hidden_size) for _ in range(n_layers - 1)]
self.cells = nn.ModuleList(cells)
self.cells = nn.ModuleList([LSTMCell(input_size, hidden_size)] +
[LSTMCell(hidden_size, hidden_size) for _ in range(n_layers - 1)])
def __call__(self, x: torch.Tensor, state=None):
time_steps, batch_size = x.shape[:2]
@ -43,13 +43,16 @@ class LSTM(Module):
(h, c) = state
h, c = torch.unbind(h), torch.unbind(c)
out = []
for t in range(time_steps):
inp = x[t]
for i in range(self.n_layers):
h[i], c[i] = self.cells[i](inp, h[i], c[i])
inp = h[i]
out.append(h[-1])
out = torch.stack(out)
h = torch.stack(h)
c = torch.stack(c)
return h, c
return out, (h, c)

View File

@ -0,0 +1,58 @@
import torch
from torch import nn
from labml_helpers.module import Module
class RHNCell(Module):
def __init__(self, input_size: int, hidden_size: int, depth: int):
super().__init__()
self.hidden_size = hidden_size
self.depth = depth
self.hidden_lin = nn.ModuleList([nn.Linear(hidden_size, 2 * hidden_size) for _ in range(depth)])
self.input_lin = nn.Linear(input_size, 2 * hidden_size, bias=False)
def __call__(self, x, s):
for i in range(self.depth):
if i == 0:
ht = self.input_lin(x) + self.hidden_lin[i](s)
else:
ht = self.hidden_lin[i](s)
h = torch.tanh(ht[:, :self.hidden_size])
t = torch.sigmoid(ht[:, self.hidden_size:])
s = s + (h - s) * t
return s
class RHN(Module):
def __init__(self, input_size: int, hidden_size: int, depth: int, n_layers: int):
super().__init__()
self.n_layers = n_layers
self.hidden_size = hidden_size
self.cells = nn.ModuleList([RHNCell(input_size, hidden_size, depth)] +
[RHNCell(hidden_size, hidden_size, depth) for _ in range(n_layers - 1)])
def __call__(self, x: torch.Tensor, state=None):
# x [seq_len, batch, d_model]
time_steps, batch_size = x.shape[:2]
if state is None:
s = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
else:
s = torch.unbind(state)
out = []
for t in range(time_steps):
inp = x[t]
for i in range(self.n_layers):
s[i] = self.cells[i](inp, s[i])
inp = s[i]
out.append(s[-1])
out = torch.stack(out)
s = torch.stack(s)
return out, s

View File

@ -5,7 +5,7 @@ with open("readme.rst", "r") as f:
setuptools.setup(
name='labml_nn',
version='0.4.0',
version='0.4.1',
author="Varuna Jayasiri, Nipun Wijerathne",
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
description="A collection of PyTorch implementations of neural network architectures and layers.",