mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-05 00:24:02 +08:00
Merge pull request #7 from DingKe/master
make RNN definition independent of global variables
This commit is contained in:
@ -38,13 +38,15 @@ test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
|
||||
class RNN(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_layers, num_classes):
|
||||
super(RNN, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
|
||||
self.fc = nn.Linear(hidden_size, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
# Set initial states
|
||||
h0 = Variable(torch.zeros(num_layers, x.size(0), hidden_size).cuda())
|
||||
c0 = Variable(torch.zeros(num_layers, x.size(0), hidden_size).cuda())
|
||||
h0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size).cuda())
|
||||
c0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size).cuda())
|
||||
|
||||
# Forward propagate RNN
|
||||
out, _ = self.lstm(x, (h0, c0))
|
||||
@ -90,4 +92,4 @@ for images, labels in test_loader:
|
||||
print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total))
|
||||
|
||||
# Save the Model
|
||||
torch.save(rnn, 'rnn.pkl')
|
||||
torch.save(rnn, 'rnn.pkl')
|
||||
|
@ -38,13 +38,15 @@ test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
|
||||
class RNN(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_layers, num_classes):
|
||||
super(RNN, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
|
||||
self.fc = nn.Linear(hidden_size, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
# Set initial states
|
||||
h0 = Variable(torch.zeros(num_layers, x.size(0), hidden_size))
|
||||
c0 = Variable(torch.zeros(num_layers, x.size(0), hidden_size))
|
||||
h0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size))
|
||||
c0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size))
|
||||
|
||||
# Forward propagate RNN
|
||||
out, _ = self.lstm(x, (h0, c0))
|
||||
|
@ -38,14 +38,16 @@ test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
|
||||
class BiRNN(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_layers, num_classes):
|
||||
super(BiRNN, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
|
||||
batch_first=True, bidirectional=True)
|
||||
self.fc = nn.Linear(hidden_size*2, num_classes) # 2 for bidirection
|
||||
|
||||
def forward(self, x):
|
||||
# Set initial states
|
||||
h0 = Variable(torch.zeros(num_layers*2, x.size(0), hidden_size)).cuda() # 2 for bidirection
|
||||
c0 = Variable(torch.zeros(num_layers*2, x.size(0), hidden_size)).cuda()
|
||||
h0 = Variable(torch.zeros(self.num_layers*2, x.size(0), self.hidden_size)).cuda() # 2 for bidirection
|
||||
c0 = Variable(torch.zeros(self.num_layers*2, x.size(0), self.hidden_size)).cuda()
|
||||
|
||||
# Forward propagate RNN
|
||||
out, _ = self.lstm(x, (h0, c0))
|
||||
@ -91,4 +93,4 @@ for images, labels in test_loader:
|
||||
print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total))
|
||||
|
||||
# Save the Model
|
||||
torch.save(rnn, 'rnn.pkl')
|
||||
torch.save(rnn, 'rnn.pkl')
|
||||
|
@ -38,14 +38,16 @@ test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
|
||||
class BiRNN(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_layers, num_classes):
|
||||
super(BiRNN, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
|
||||
batch_first=True, bidirectional=True)
|
||||
self.fc = nn.Linear(hidden_size*2, num_classes) # 2 for bidirection
|
||||
|
||||
def forward(self, x):
|
||||
# Set initial states
|
||||
h0 = Variable(torch.zeros(num_layers*2, x.size(0), hidden_size)) # 2 for bidirection
|
||||
c0 = Variable(torch.zeros(num_layers*2, x.size(0), hidden_size))
|
||||
h0 = Variable(torch.zeros(self.num_layers*2, x.size(0), self.hidden_size)) # 2 for bidirection
|
||||
c0 = Variable(torch.zeros(self.num_layers*2, x.size(0), self.hidden_size))
|
||||
|
||||
# Forward propagate RNN
|
||||
out, _ = self.lstm(x, (h0, c0))
|
||||
@ -91,4 +93,4 @@ for images, labels in test_loader:
|
||||
print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total))
|
||||
|
||||
# Save the Model
|
||||
torch.save(rnn, 'rnn.pkl')
|
||||
torch.save(rnn, 'rnn.pkl')
|
||||
|
Reference in New Issue
Block a user