mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-16 02:41:38 +08:00
rhn
This commit is contained in:
@ -10,15 +10,15 @@ class LSTMCell(Module):
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_lin = nn.Linear(hidden_size, 4 * hidden_size)
|
||||
self.input_lin = nn.Linear(input_size, 4 * hidden_size)
|
||||
self.input_lin = nn.Linear(input_size, 4 * hidden_size, bias=False)
|
||||
|
||||
def __call__(self, x, h, c):
|
||||
ifgo = self.hidden_lin(h) + self.input_lin(x)
|
||||
|
||||
i = torch.sigmoid(ifgo[:, :self.hidden_size])
|
||||
f = torch.sigmoid(ifgo[:, self.hidden_size:self.hidden_size * 2])
|
||||
g = torch.tanh(ifgo[:, self.hidden_size * 2:self.d_model * 4])
|
||||
o = torch.sigmoid(ifgo[:, self.hidden_size * 3:self.d_model * 3])
|
||||
g = torch.tanh(ifgo[:, self.hidden_size * 2:self.hidden_size * 3])
|
||||
o = torch.sigmoid(ifgo[:, self.hidden_size * 3:self.hidden_size * 4])
|
||||
c_next = f * c + i * g
|
||||
h_next = o * torch.tanh(c_next)
|
||||
|
||||
@ -30,8 +30,8 @@ class LSTM(Module):
|
||||
super().__init__()
|
||||
self.n_layers = n_layers
|
||||
self.hidden_size = hidden_size
|
||||
cells = [LSTMCell(input_size, hidden_size)] + [LSTMCell(hidden_size, hidden_size) for _ in range(n_layers - 1)]
|
||||
self.cells = nn.ModuleList(cells)
|
||||
self.cells = nn.ModuleList([LSTMCell(input_size, hidden_size)] +
|
||||
[LSTMCell(hidden_size, hidden_size) for _ in range(n_layers - 1)])
|
||||
|
||||
def __call__(self, x: torch.Tensor, state=None):
|
||||
time_steps, batch_size = x.shape[:2]
|
||||
@ -43,13 +43,16 @@ class LSTM(Module):
|
||||
(h, c) = state
|
||||
h, c = torch.unbind(h), torch.unbind(c)
|
||||
|
||||
out = []
|
||||
for t in range(time_steps):
|
||||
inp = x[t]
|
||||
for i in range(self.n_layers):
|
||||
h[i], c[i] = self.cells[i](inp, h[i], c[i])
|
||||
inp = h[i]
|
||||
out.append(h[-1])
|
||||
|
||||
out = torch.stack(out)
|
||||
h = torch.stack(h)
|
||||
c = torch.stack(c)
|
||||
|
||||
return h, c
|
||||
return out, (h, c)
|
||||
|
58
labml_nn/recurrent_highway_networks/__init__.py
Normal file
58
labml_nn/recurrent_highway_networks/__init__.py
Normal file
@ -0,0 +1,58 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from labml_helpers.module import Module
|
||||
|
||||
|
||||
class RHNCell(Module):
|
||||
def __init__(self, input_size: int, hidden_size: int, depth: int):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.depth = depth
|
||||
self.hidden_lin = nn.ModuleList([nn.Linear(hidden_size, 2 * hidden_size) for _ in range(depth)])
|
||||
self.input_lin = nn.Linear(input_size, 2 * hidden_size, bias=False)
|
||||
|
||||
def __call__(self, x, s):
|
||||
for i in range(self.depth):
|
||||
if i == 0:
|
||||
ht = self.input_lin(x) + self.hidden_lin[i](s)
|
||||
else:
|
||||
ht = self.hidden_lin[i](s)
|
||||
|
||||
h = torch.tanh(ht[:, :self.hidden_size])
|
||||
t = torch.sigmoid(ht[:, self.hidden_size:])
|
||||
|
||||
s = s + (h - s) * t
|
||||
|
||||
return s
|
||||
|
||||
|
||||
class RHN(Module):
|
||||
def __init__(self, input_size: int, hidden_size: int, depth: int, n_layers: int):
|
||||
super().__init__()
|
||||
self.n_layers = n_layers
|
||||
self.hidden_size = hidden_size
|
||||
self.cells = nn.ModuleList([RHNCell(input_size, hidden_size, depth)] +
|
||||
[RHNCell(hidden_size, hidden_size, depth) for _ in range(n_layers - 1)])
|
||||
|
||||
def __call__(self, x: torch.Tensor, state=None):
|
||||
# x [seq_len, batch, d_model]
|
||||
time_steps, batch_size = x.shape[:2]
|
||||
|
||||
if state is None:
|
||||
s = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
|
||||
else:
|
||||
s = torch.unbind(state)
|
||||
|
||||
out = []
|
||||
for t in range(time_steps):
|
||||
inp = x[t]
|
||||
for i in range(self.n_layers):
|
||||
s[i] = self.cells[i](inp, s[i])
|
||||
inp = s[i]
|
||||
out.append(s[-1])
|
||||
|
||||
out = torch.stack(out)
|
||||
s = torch.stack(s)
|
||||
return out, s
|
2
setup.py
2
setup.py
@ -5,7 +5,7 @@ with open("readme.rst", "r") as f:
|
||||
|
||||
setuptools.setup(
|
||||
name='labml_nn',
|
||||
version='0.4.0',
|
||||
version='0.4.1',
|
||||
author="Varuna Jayasiri, Nipun Wijerathne",
|
||||
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
|
||||
description="A collection of PyTorch implementations of neural network architectures and layers.",
|
||||
|
Reference in New Issue
Block a user