make RNN definition independent of global variables

This commit is contained in:
Ke Ding
2017-03-16 15:44:09 +08:00
parent 96acc3108d
commit 1862d0fb28
4 changed files with 19 additions and 11 deletions

View File

@ -38,13 +38,15 @@ test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x):
# Set initial states
h0 = Variable(torch.zeros(num_layers, x.size(0), hidden_size).cuda())
c0 = Variable(torch.zeros(num_layers, x.size(0), hidden_size).cuda())
h0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size).cuda())
c0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size).cuda())
# Forward propagate RNN
out, _ = self.lstm(x, (h0, c0))
@ -90,4 +92,4 @@ for images, labels in test_loader:
print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total))
# Save the Model
torch.save(rnn, 'rnn.pkl')
torch.save(rnn, 'rnn.pkl')

View File

@ -38,13 +38,15 @@ test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x):
# Set initial states
h0 = Variable(torch.zeros(num_layers, x.size(0), hidden_size))
c0 = Variable(torch.zeros(num_layers, x.size(0), hidden_size))
h0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size))
c0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size))
# Forward propagate RNN
out, _ = self.lstm(x, (h0, c0))

View File

@ -38,14 +38,16 @@ test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
class BiRNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(BiRNN, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
batch_first=True, bidirectional=True)
self.fc = nn.Linear(hidden_size*2, num_classes) # 2 for bidirection
def forward(self, x):
# Set initial states
h0 = Variable(torch.zeros(num_layers*2, x.size(0), hidden_size)).cuda() # 2 for bidirection
c0 = Variable(torch.zeros(num_layers*2, x.size(0), hidden_size)).cuda()
h0 = Variable(torch.zeros(self.num_layers*2, x.size(0), self.hidden_size)).cuda() # 2 for bidirection
c0 = Variable(torch.zeros(self.num_layers*2, x.size(0), self.hidden_size)).cuda()
# Forward propagate RNN
out, _ = self.lstm(x, (h0, c0))
@ -91,4 +93,4 @@ for images, labels in test_loader:
print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total))
# Save the Model
torch.save(rnn, 'rnn.pkl')
torch.save(rnn, 'rnn.pkl')

View File

@ -38,14 +38,16 @@ test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
class BiRNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(BiRNN, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
batch_first=True, bidirectional=True)
self.fc = nn.Linear(hidden_size*2, num_classes) # 2 for bidirection
def forward(self, x):
# Set initial states
h0 = Variable(torch.zeros(num_layers*2, x.size(0), hidden_size)) # 2 for bidirection
c0 = Variable(torch.zeros(num_layers*2, x.size(0), hidden_size))
h0 = Variable(torch.zeros(self.num_layers*2, x.size(0), self.hidden_size)) # 2 for bidirection
c0 = Variable(torch.zeros(self.num_layers*2, x.size(0), self.hidden_size))
# Forward propagate RNN
out, _ = self.lstm(x, (h0, c0))
@ -91,4 +93,4 @@ for images, labels in test_loader:
print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total))
# Save the Model
torch.save(rnn, 'rnn.pkl')
torch.save(rnn, 'rnn.pkl')