From 86a0872430f36146e8c1b4dc4c0efecb81e6107f Mon Sep 17 00:00:00 2001 From: yunjey Date: Sat, 11 Mar 2017 14:54:46 +0900 Subject: [PATCH] code for saving the model is added --- tutorials/01 - Linear Regression/main.py | 5 +- tutorials/02 - Logistic Regression/main.py | 9 +- .../main-gpu.py | 4 +- .../03 - Feedforward Neural Network/main.py | 4 +- .../main-gpu.py | 13 +- .../04 - Convolutional Neural Network/main.py | 11 +- .../05 - Deep Residual Network/main-gpu.py | 14 +- tutorials/05 - Deep Residual Network/main.py | 10 +- .../06 - Recurrent Neural Network/main-gpu.py | 9 +- .../06 - Recurrent Neural Network/main.py | 9 +- .../main-gpu.py | 9 +- .../main.py | 9 +- tutorials/08 - Language Model/data_utils.py | 46 ++++++ tutorials/08 - Language Model/main-gpu.py | 124 ++++++++++++++++ tutorials/08 - Language Model/main.py | 123 ++++++++++++++++ .../main-gpu.py | 134 ++++++++++++++++++ .../main.py | 134 ++++++++++++++++++ 17 files changed, 630 insertions(+), 37 deletions(-) create mode 100644 tutorials/08 - Language Model/data_utils.py create mode 100644 tutorials/08 - Language Model/main-gpu.py create mode 100644 tutorials/08 - Language Model/main.py create mode 100644 tutorials/10 - Generative Adversarial Network/main-gpu.py create mode 100644 tutorials/10 - Generative Adversarial Network/main.py diff --git a/tutorials/01 - Linear Regression/main.py b/tutorials/01 - Linear Regression/main.py index 13ea21d..274305f 100644 --- a/tutorials/01 - Linear Regression/main.py +++ b/tutorials/01 - Linear Regression/main.py @@ -58,4 +58,7 @@ predicted = model(Variable(torch.from_numpy(x_train))).data.numpy() plt.plot(x_train, y_train, 'ro', label='Original data') plt.plot(x_train, predicted, label='Fitted line') plt.legend() -plt.show() \ No newline at end of file +plt.show() + +# Save the Model +torch.save(model, 'model.pkl') \ No newline at end of file diff --git a/tutorials/02 - Logistic Regression/main.py b/tutorials/02 - Logistic Regression/main.py index 0a82db0..c648433 100644 --- a/tutorials/02 - Logistic Regression/main.py +++ b/tutorials/02 - Logistic Regression/main.py @@ -13,12 +13,12 @@ batch_size = 100 learning_rate = 0.001 # MNIST Dataset (Images and Labels) -train_dataset = dsets.MNIST(root='./data', +train_dataset = dsets.MNIST(root='../data', train=True, transform=transforms.ToTensor(), download=True) -test_dataset = dsets.MNIST(root='./data', +test_dataset = dsets.MNIST(root='../data', train=False, transform=transforms.ToTensor()) @@ -76,4 +76,7 @@ for images, labels in test_loader: total += labels.size(0) correct += (predicted == labels).sum() -print('Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total)) \ No newline at end of file +print('Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total)) + +# Save the Model +torch.save(model, 'model.pkl') \ No newline at end of file diff --git a/tutorials/03 - Feedforward Neural Network/main-gpu.py b/tutorials/03 - Feedforward Neural Network/main-gpu.py index 58afd7f..b0480a8 100644 --- a/tutorials/03 - Feedforward Neural Network/main-gpu.py +++ b/tutorials/03 - Feedforward Neural Network/main-gpu.py @@ -14,12 +14,12 @@ batch_size = 100 learning_rate = 0.001 # MNIST Dataset -train_dataset = dsets.MNIST(root='./data', +train_dataset = dsets.MNIST(root='../data', train=True, transform=transforms.ToTensor(), download=True) -test_dataset = dsets.MNIST(root='./data', +test_dataset = dsets.MNIST(root='../data', train=False, transform=transforms.ToTensor()) diff --git a/tutorials/03 - Feedforward Neural Network/main.py b/tutorials/03 - Feedforward Neural Network/main.py index 80d2ece..c0f28cd 100644 --- a/tutorials/03 - Feedforward Neural Network/main.py +++ b/tutorials/03 - Feedforward Neural Network/main.py @@ -14,12 +14,12 @@ batch_size = 100 learning_rate = 0.001 # MNIST Dataset -train_dataset = dsets.MNIST(root='./data', +train_dataset = dsets.MNIST(root='../data', train=True, transform=transforms.ToTensor(), download=True) -test_dataset = dsets.MNIST(root='./data', +test_dataset = dsets.MNIST(root='../data', train=False, transform=transforms.ToTensor()) diff --git a/tutorials/04 - Convolutional Neural Network/main-gpu.py b/tutorials/04 - Convolutional Neural Network/main-gpu.py index 33c7096..5040bd6 100644 --- a/tutorials/04 - Convolutional Neural Network/main-gpu.py +++ b/tutorials/04 - Convolutional Neural Network/main-gpu.py @@ -11,12 +11,12 @@ batch_size = 100 learning_rate = 0.001 # MNIST Dataset -train_dataset = dsets.MNIST(root='./data/', +train_dataset = dsets.MNIST(root='../data/', train=True, transform=transforms.ToTensor(), download=True) -test_dataset = dsets.MNIST(root='./data/', +test_dataset = dsets.MNIST(root='../data/', train=False, transform=transforms.ToTensor()) @@ -77,7 +77,7 @@ for epoch in range(num_epochs): %(epoch+1, num_epochs, i+1, len(train_dataset)//batch_size, loss.data[0])) # Test the Model -cnn.eval() +cnn.eval() # Change model to 'eval' mode (BN uses moving mean/var). correct = 0 total = 0 for images, labels in test_loader: @@ -85,6 +85,9 @@ for images, labels in test_loader: outputs = cnn(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) - correct += (predicted == labels).sum() + correct += (predicted.cpu() == labels).sum() -print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total)) \ No newline at end of file +print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total)) + +# Save the Trained Model +torch.save(cnn, 'cnn.pkl') \ No newline at end of file diff --git a/tutorials/04 - Convolutional Neural Network/main.py b/tutorials/04 - Convolutional Neural Network/main.py index d159580..5013b27 100644 --- a/tutorials/04 - Convolutional Neural Network/main.py +++ b/tutorials/04 - Convolutional Neural Network/main.py @@ -11,12 +11,12 @@ batch_size = 100 learning_rate = 0.001 # MNIST Dataset -train_dataset = dsets.MNIST(root='./data/', +train_dataset = dsets.MNIST(root='../data/', train=True, transform=transforms.ToTensor(), download=True) -test_dataset = dsets.MNIST(root='./data/', +test_dataset = dsets.MNIST(root='../data/', train=False, transform=transforms.ToTensor()) @@ -77,7 +77,7 @@ for epoch in range(num_epochs): %(epoch+1, num_epochs, i+1, len(train_dataset)//batch_size, loss.data[0])) # Test the Model -cnn.eval() +cnn.eval() # Change model to 'eval' mode (BN uses moving mean/var). correct = 0 total = 0 for images, labels in test_loader: @@ -87,4 +87,7 @@ for images, labels in test_loader: total += labels.size(0) correct += (predicted == labels).sum() -print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total)) \ No newline at end of file +print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total)) + +# Save the Trained Model +torch.save(cnn, 'cnn.pkl') \ No newline at end of file diff --git a/tutorials/05 - Deep Residual Network/main-gpu.py b/tutorials/05 - Deep Residual Network/main-gpu.py index 3066093..2de3165 100644 --- a/tutorials/05 - Deep Residual Network/main-gpu.py +++ b/tutorials/05 - Deep Residual Network/main-gpu.py @@ -14,12 +14,12 @@ transform = transforms.Compose([ transforms.ToTensor()]) # CIFAR-10 Dataset -train_dataset = dsets.CIFAR10(root='./data/', +train_dataset = dsets.CIFAR10(root='../data/', train=True, transform=transform, download=True) -test_dataset = dsets.CIFAR10(root='./data/', +test_dataset = dsets.CIFAR10(root='../data/', train=False, transform=transforms.ToTensor()) @@ -109,7 +109,7 @@ lr = 0.001 optimizer = torch.optim.Adam(resnet.parameters(), lr=lr) # Training -for epoch in range(40): +for epoch in range(80): for i, (images, labels) in enumerate(train_loader): images = Variable(images.cuda()) labels = Variable(labels.cuda()) @@ -122,7 +122,7 @@ for epoch in range(40): optimizer.step() if (i+1) % 100 == 0: - print ("Epoch [%d/%d], Iter [%d/%d] Loss: %.4f" %(epoch+1, 40, i+1, 500, loss.data[0])) + print ("Epoch [%d/%d], Iter [%d/%d] Loss: %.4f" %(epoch+1, 80, i+1, 500, loss.data[0])) # Decaying Learning Rate if (epoch+1) % 20 == 0: @@ -130,6 +130,7 @@ for epoch in range(40): optimizer = torch.optim.Adam(resnet.parameters(), lr=lr) # Test +resnet.eval() correct = 0 total = 0 for images, labels in test_loader: @@ -139,4 +140,7 @@ for images, labels in test_loader: total += labels.size(0) correct += (predicted.cpu() == labels).sum() -print('Accuracy of the model on the test images: %d %%' % (100 * correct / total)) \ No newline at end of file +print('Accuracy of the model on the test images: %d %%' % (100 * correct / total)) + +# Save the Model +torch.save(resnet, 'resnet.pkl') \ No newline at end of file diff --git a/tutorials/05 - Deep Residual Network/main.py b/tutorials/05 - Deep Residual Network/main.py index 13e9272..d849f88 100644 --- a/tutorials/05 - Deep Residual Network/main.py +++ b/tutorials/05 - Deep Residual Network/main.py @@ -14,12 +14,12 @@ transform = transforms.Compose([ transforms.ToTensor()]) # CIFAR-10 Dataset -train_dataset = dsets.CIFAR10(root='./data/', +train_dataset = dsets.CIFAR10(root='../data/', train=True, transform=transform, download=True) -test_dataset = dsets.CIFAR10(root='./data/', +test_dataset = dsets.CIFAR10(root='../data/', train=False, transform=transforms.ToTensor()) @@ -130,6 +130,7 @@ for epoch in range(80): optimizer = torch.optim.Adam(resnet.parameters(), lr=lr) # Test +resnet.eval() correct = 0 total = 0 for images, labels in test_loader: @@ -139,4 +140,7 @@ for images, labels in test_loader: total += labels.size(0) correct += (predicted == labels).sum() -print('Accuracy of the model on the test images: %d %%' % (100 * correct / total)) \ No newline at end of file +print('Accuracy of the model on the test images: %d %%' % (100 * correct / total)) + +# Save the Model +torch.save(resnet, 'resnet.pkl') \ No newline at end of file diff --git a/tutorials/06 - Recurrent Neural Network/main-gpu.py b/tutorials/06 - Recurrent Neural Network/main-gpu.py index 5aa4160..492c347 100644 --- a/tutorials/06 - Recurrent Neural Network/main-gpu.py +++ b/tutorials/06 - Recurrent Neural Network/main-gpu.py @@ -16,12 +16,12 @@ num_epochs = 2 learning_rate = 0.01 # MNIST Dataset -train_dataset = dsets.MNIST(root='./data/', +train_dataset = dsets.MNIST(root='../data/', train=True, transform=transforms.ToTensor(), download=True) -test_dataset = dsets.MNIST(root='./data/', +test_dataset = dsets.MNIST(root='../data/', train=False, transform=transforms.ToTensor()) @@ -87,4 +87,7 @@ for images, labels in test_loader: total += labels.size(0) correct += (predicted.cpu() == labels).sum() -print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total)) \ No newline at end of file +print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total)) + +# Save the Model +torch.save(rnn, 'rnn.pkl') \ No newline at end of file diff --git a/tutorials/06 - Recurrent Neural Network/main.py b/tutorials/06 - Recurrent Neural Network/main.py index 2510003..b512b15 100644 --- a/tutorials/06 - Recurrent Neural Network/main.py +++ b/tutorials/06 - Recurrent Neural Network/main.py @@ -16,12 +16,12 @@ num_epochs = 2 learning_rate = 0.01 # MNIST Dataset -train_dataset = dsets.MNIST(root='./data/', +train_dataset = dsets.MNIST(root='../data/', train=True, transform=transforms.ToTensor(), download=True) -test_dataset = dsets.MNIST(root='./data/', +test_dataset = dsets.MNIST(root='../data/', train=False, transform=transforms.ToTensor()) @@ -87,4 +87,7 @@ for images, labels in test_loader: total += labels.size(0) correct += (predicted == labels).sum() -print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total)) \ No newline at end of file +print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total)) + +# Save the Model +torch.save(rnn, 'rnn.pkl') \ No newline at end of file diff --git a/tutorials/07 - Bidirectional Recurrent Neural Network/main-gpu.py b/tutorials/07 - Bidirectional Recurrent Neural Network/main-gpu.py index d44341a..a1b49e1 100644 --- a/tutorials/07 - Bidirectional Recurrent Neural Network/main-gpu.py +++ b/tutorials/07 - Bidirectional Recurrent Neural Network/main-gpu.py @@ -16,12 +16,12 @@ num_epochs = 2 learning_rate = 0.003 # MNIST Dataset -train_dataset = dsets.MNIST(root='./data/', +train_dataset = dsets.MNIST(root='../data/', train=True, transform=transforms.ToTensor(), download=True) -test_dataset = dsets.MNIST(root='./data/', +test_dataset = dsets.MNIST(root='../data/', train=False, transform=transforms.ToTensor()) @@ -88,4 +88,7 @@ for images, labels in test_loader: total += labels.size(0) correct += (predicted.cpu() == labels).sum() -print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total)) \ No newline at end of file +print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total)) + +# Save the Model +torch.save(rnn, 'rnn.pkl') \ No newline at end of file diff --git a/tutorials/07 - Bidirectional Recurrent Neural Network/main.py b/tutorials/07 - Bidirectional Recurrent Neural Network/main.py index 2f3e0de..7adbf6f 100644 --- a/tutorials/07 - Bidirectional Recurrent Neural Network/main.py +++ b/tutorials/07 - Bidirectional Recurrent Neural Network/main.py @@ -16,12 +16,12 @@ num_epochs = 2 learning_rate = 0.003 # MNIST Dataset -train_dataset = dsets.MNIST(root='./data/', +train_dataset = dsets.MNIST(root='../data/', train=True, transform=transforms.ToTensor(), download=True) -test_dataset = dsets.MNIST(root='./data/', +test_dataset = dsets.MNIST(root='../data/', train=False, transform=transforms.ToTensor()) @@ -88,4 +88,7 @@ for images, labels in test_loader: total += labels.size(0) correct += (predicted == labels).sum() -print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total)) \ No newline at end of file +print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total)) + +# Save the Model +torch.save(rnn, 'rnn.pkl') \ No newline at end of file diff --git a/tutorials/08 - Language Model/data_utils.py b/tutorials/08 - Language Model/data_utils.py new file mode 100644 index 0000000..e0238b8 --- /dev/null +++ b/tutorials/08 - Language Model/data_utils.py @@ -0,0 +1,46 @@ +import torch +import os + +class Dictionary(object): + def __init__(self): + self.word2idx = {} + self.idx2word = {} + self.idx = 0 + + def add_word(self, word): + if not word in self.word2idx: + self.word2idx[word] = self.idx + self.idx2word[self.idx] = word + self.idx += 1 + + def __len__(self): + return len(self.word2idx) + +class Corpus(object): + def __init__(self, path='./data'): + self.dictionary = Dictionary() + self.train = os.path.join(path, 'train.txt') + self.test = os.path.join(path, 'test.txt') + + def get_data(self, path, batch_size=20): + # Add words to the dictionary + with open(path, 'r') as f: + tokens = 0 + for line in f: + words = line.split() + [''] + tokens += len(words) + for word in words: + self.dictionary.add_word(word) + + # Tokenize the file content + ids = torch.LongTensor(tokens) + token = 0 + with open(path, 'r') as f: + for line in f: + words = line.split() + [''] + for word in words: + ids[token] = self.dictionary.word2idx[word] + token += 1 + num_batches = ids.size(0) // batch_size + ids = ids[:num_batches*batch_size] + return ids.view(batch_size, -1) \ No newline at end of file diff --git a/tutorials/08 - Language Model/main-gpu.py b/tutorials/08 - Language Model/main-gpu.py new file mode 100644 index 0000000..9d6a244 --- /dev/null +++ b/tutorials/08 - Language Model/main-gpu.py @@ -0,0 +1,124 @@ +# RNN Based Language Model on Penn Treebank dataset. +# Some part of the code was referenced from below. +# https://github.com/pytorch/examples/tree/master/word_language_model +import torch +import torch.nn as nn +import numpy as np +from torch.autograd import Variable +from data_utils import Dictionary, Corpus + +# Hyper Parameters +embed_size = 128 +hidden_size = 1024 +num_layers = 1 +num_epochs = 5 +num_samples = 1000 # number of words to be sampled +batch_size = 20 +seq_length = 30 +learning_rate = 0.002 + +# Load Penn Treebank Dataset +train_path = './data/train.txt' +sample_path = './sample.txt' +corpus = Corpus() +ids = corpus.get_data(train_path, batch_size) +vocab_size = len(corpus.dictionary) +num_batches = ids.size(1) // seq_length + +# RNN Based Language Model +class RNNLM(nn.Module): + def __init__(self, vocab_size, embed_size, hidden_size, num_layers): + super(RNNLM, self).__init__() + self.embed = nn.Embedding(vocab_size, embed_size) + self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True) + self.linear = nn.Linear(hidden_size, vocab_size) + + self.init_weights() + + def init_weights(self): + self.embed.weight.data.uniform_(-0.1, 0.1) + self.linear.bias.data.fill_(0) + self.linear.weight.data.uniform_(-0.1, 0.1) + + def forward(self, x, h): + # Embed word ids to vectors + x = self.embed(x) + + # Forward propagate RNN + out, h = self.lstm(x, h) + + # Reshape output to (batch_size*sequence_length, hidden_size) + out = out.contiguous().view(out.size(0)*out.size(1), out.size(2)) + + # Decode hidden states of all time step + out = self.linear(out) + return out, h + +model = RNNLM(vocab_size, embed_size, hidden_size, num_layers) +model.cuda() + +# Loss and Optimizer +criterion = nn.CrossEntropyLoss() +optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) + +# Truncated Backpropagation +def detach(states): + return [Variable(state.data) for state in states] + +# Training +for epoch in range(num_epochs): + # Initial hidden and memory states + states = (Variable(torch.zeros(num_layers, batch_size, hidden_size)).cuda(), + Variable(torch.zeros(num_layers, batch_size, hidden_size)).cuda()) + + for i in range(0, ids.size(1) - seq_length, seq_length): + # Get batch inputs and targets + inputs = Variable(ids[:, i:i+seq_length]).cuda() + targets = Variable(ids[:, (i+1):(i+1)+seq_length].contiguous()).cuda() + + # Forward + Backward + Optimize + model.zero_grad() + states = detach(states) + outputs, states = model(inputs, states) + loss = criterion(outputs, targets.view(-1)) + loss.backward() + torch.nn.utils.clip_grad_norm(model.parameters(), 0.5) + optimizer.step() + + step = (i+1) // seq_length + if step % 100 == 0: + print ('Epoch [%d/%d], Step[%d/%d], Loss: %.3f, Perplexity: %5.2f' % + (epoch+1, num_epochs, step, num_batches, loss.data[0], np.exp(loss.data[0]))) + +# Sampling +with open(sample_path, 'w') as f: + # Set intial hidden ane memory states + state = (Variable(torch.zeros(num_layers, 1, hidden_size)).cuda(), + Variable(torch.zeros(num_layers, 1, hidden_size)).cuda()) + + # Select one word id randomly + prob = torch.ones(vocab_size) + input = Variable(torch.multinomial(prob, num_samples=1).unsqueeze(1), + volatile=True).cuda() + + for i in range(num_samples): + # Forward propagate rnn + output, state = model(input, state) + + # Sample a word id + prob = output.squeeze().data.exp().cpu() + word_id = torch.multinomial(prob, 1)[0] + + # Feed sampled word id to next time step + input.data.fill_(word_id) + + # File write + word = corpus.dictionary.idx2word[word_id] + word = '\n' if word == '' else word + ' ' + f.write(word) + + if (i+1) % 100 == 0: + print('Sampled [%d/%d] words and save to %s'%(i+1, num_samples, sample_path)) + +# Save the Trained Model +torch.save(model, 'model.pkl') \ No newline at end of file diff --git a/tutorials/08 - Language Model/main.py b/tutorials/08 - Language Model/main.py new file mode 100644 index 0000000..df98e16 --- /dev/null +++ b/tutorials/08 - Language Model/main.py @@ -0,0 +1,123 @@ +# Some part of the code was referenced from below. +# https://github.com/pytorch/examples/tree/master/word_language_model +import torch +import torch.nn as nn +import numpy as np +from torch.autograd import Variable +from data_utils import Dictionary, Corpus + +# Hyper Parameters +embed_size = 128 +hidden_size = 1024 +num_layers = 1 +num_epochs = 5 +num_samples = 1000 # number of words to be sampled +batch_size = 20 +seq_length = 30 +learning_rate = 0.002 + +# Load Penn Treebank Dataset +train_path = './data/train.txt' +sample_path = './sample.txt' +corpus = Corpus() +ids = corpus.get_data(train_path, batch_size) +vocab_size = len(corpus.dictionary) +num_batches = ids.size(1) // seq_length + +# RNN Based Language Model +class RNNLM(nn.Module): + def __init__(self, vocab_size, embed_size, hidden_size, num_layers): + super(RNNLM, self).__init__() + self.embed = nn.Embedding(vocab_size, embed_size) + self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True) + self.linear = nn.Linear(hidden_size, vocab_size) + + self.init_weights() + + def init_weights(self): + self.embed.weight.data.uniform_(-0.1, 0.1) + self.linear.bias.data.fill_(0) + self.linear.weight.data.uniform_(-0.1, 0.1) + + def forward(self, x, h): + # Embed word ids to vectors + x = self.embed(x) + + # Forward propagate RNN + out, h = self.lstm(x, h) + + # Reshape output to (batch_size*sequence_length, hidden_size) + out = out.contiguous().view(out.size(0)*out.size(1), out.size(2)) + + # Decode hidden states of all time step + out = self.linear(out) + return out, h + +model = RNNLM(vocab_size, embed_size, hidden_size, num_layers) + + +# Loss and Optimizer +criterion = nn.CrossEntropyLoss() +optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) + +# Truncated Backpropagation +def detach(states): + return [Variable(state.data) for state in states] + +# Training +for epoch in range(num_epochs): + # Initial hidden and memory states + states = (Variable(torch.zeros(num_layers, batch_size, hidden_size)), + Variable(torch.zeros(num_layers, batch_size, hidden_size))) + + for i in range(0, ids.size(1) - seq_length, seq_length): + # Get batch inputs and targets + inputs = Variable(ids[:, i:i+seq_length]) + targets = Variable(ids[:, (i+1):(i+1)+seq_length].contiguous()) + + # Forward + Backward + Optimize + model.zero_grad() + states = detach(states) + outputs, states = model(inputs, states) + loss = criterion(outputs, targets.view(-1)) + loss.backward() + torch.nn.utils.clip_grad_norm(model.parameters(), 0.5) + optimizer.step() + + step = (i+1) // seq_length + if step % 100 == 0: + print ('Epoch [%d/%d], Step[%d/%d], Loss: %.3f, Perplexity: %5.2f' % + (epoch+1, num_epochs, step, num_batches, loss.data[0], np.exp(loss.data[0]))) + +# Sampling +with open(sample_path, 'w') as f: + # Set intial hidden ane memory states + state = (Variable(torch.zeros(num_layers, 1, hidden_size)), + Variable(torch.zeros(num_layers, 1, hidden_size))) + + # Select one word id randomly + prob = torch.ones(vocab_size) + input = Variable(torch.multinomial(prob, num_samples=1).unsqueeze(1), + volatile=True) + + for i in range(num_samples): + # Forward propagate rnn + output, state = model(input, state) + + # Sample a word id + prob = output.squeeze().data.exp() + word_id = torch.multinomial(prob, 1)[0] + + # Feed sampled word id to next time step + input.data.fill_(word_id) + + # File write + word = corpus.dictionary.idx2word[word_id] + word = '\n' if word == '' else word + ' ' + f.write(word) + + if (i+1) % 100 == 0: + print('Sampled [%d/%d] words and save to %s'%(i+1, num_samples, sample_path)) + +# Save the Trained Model +torch.save(model, 'model.pkl') \ No newline at end of file diff --git a/tutorials/10 - Generative Adversarial Network/main-gpu.py b/tutorials/10 - Generative Adversarial Network/main-gpu.py new file mode 100644 index 0000000..e14f768 --- /dev/null +++ b/tutorials/10 - Generative Adversarial Network/main-gpu.py @@ -0,0 +1,134 @@ +import torch +import torchvision +import torch.nn as nn +import torchvision.datasets as dsets +import torchvision.transforms as transforms +from torch.autograd import Variable + +# Image Preprocessing +transform = transforms.Compose([ + transforms.Scale(36), + transforms.RandomCrop(32), + transforms.ToTensor(), + transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) + +# CIFAR-10 Dataset +train_dataset = dsets.CIFAR10(root='../data/', + train=True, + transform=transform, + download=True) + +# Data Loader (Input Pipeline) +train_loader = torch.utils.data.DataLoader(dataset=train_dataset, + batch_size=100, + shuffle=True) + +# 5x5 Convolution +def conv5x5(in_channels, out_channels, stride): + return nn.Conv2d(in_channels, out_channels, kernel_size=4, + stride=stride, padding=1, bias=False) + +# Discriminator Model +class Discriminator(nn.Module): + def __init__(self): + super(Discriminator, self).__init__() + self.model = nn.Sequential( + conv5x5(3, 16, 2), + nn.LeakyReLU(0.2, inplace=True), + conv5x5(16, 32, 2), + nn.BatchNorm2d(32), + nn.LeakyReLU(0.2, inplace=True), + conv5x5(32, 64, 2), + nn.BatchNorm2d(64), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 1, kernel_size=4), + nn.Sigmoid()) + + def forward(self, x): + out = self.model(x) + out = out.view(out.size(0), -1) + return out + +# 4x4 Transpose convolution +def conv_transpose4x4(in_channels, out_channels, stride=1, padding=1, bias=False): + return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, + stride=stride, padding=padding, bias=bias) + +# Generator Model +class Generator(nn.Module): + def __init__(self): + super(Generator, self).__init__() + self.model = nn.Sequential( + conv_transpose4x4(128, 64, padding=0), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + conv_transpose4x4(64, 32, 2), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + conv_transpose4x4(32, 16, 2), + nn.BatchNorm2d(16), + nn.ReLU(inplace=True), + conv_transpose4x4(16, 3, 2, bias=True), + nn.Tanh()) + + def forward(self, x): + x = x.view(x.size(0), 128, 1, 1) + out = self.model(x) + return out + +discriminator = Discriminator() +generator = Generator() +discriminator.cuda() +generator.cuda() + +# Loss and Optimizer +criterion = nn.BCELoss() +lr = 0.002 +d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr) +g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr) + +# Training +for epoch in range(50): + for i, (images, _) in enumerate(train_loader): + images = Variable(images.cuda()) + real_labels = Variable(torch.ones(images.size(0)).cuda()) + fake_labels = Variable(torch.zeros(images.size(0)).cuda()) + + # Train the discriminator + discriminator.zero_grad() + outputs = discriminator(images) + real_loss = criterion(outputs, real_labels) + real_score = outputs + + noise = Variable(torch.randn(images.size(0), 128).cuda()) + fake_images = generator(noise) + outputs = discriminator(fake_images) + fake_loss = criterion(outputs, fake_labels) + fake_score = outputs + + d_loss = real_loss + fake_loss + d_loss.backward() + d_optimizer.step() + + # Train the generator + generator.zero_grad() + noise = Variable(torch.randn(images.size(0), 128).cuda()) + fake_images = generator(noise) + outputs = discriminator(fake_images) + g_loss = criterion(outputs, real_labels) + g_loss.backward() + g_optimizer.step() + + if (i+1) % 100 == 0: + print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, ' + 'D(x): %.2f, D(G(z)): %.2f' + %(epoch, 50, i+1, 500, d_loss.data[0], g_loss.data[0], + real_score.cpu().data.mean(), fake_score.cpu().data.mean())) + + # Save the sampled images + torchvision.utils.save_image(fake_images.data, + './data/fake_samples_%d_%d.png' %(epoch+1, i+1)) + +# Save the Models +torch.save(generator, './generator.pkl') +torch.save(discriminator, './discriminator.pkl') \ No newline at end of file diff --git a/tutorials/10 - Generative Adversarial Network/main.py b/tutorials/10 - Generative Adversarial Network/main.py new file mode 100644 index 0000000..9f2ae1a --- /dev/null +++ b/tutorials/10 - Generative Adversarial Network/main.py @@ -0,0 +1,134 @@ +import torch +import torchvision +import torch.nn as nn +import torchvision.datasets as dsets +import torchvision.transforms as transforms +from torch.autograd import Variable + +# Image Preprocessing +transform = transforms.Compose([ + transforms.Scale(36), + transforms.RandomCrop(32), + transforms.ToTensor(), + transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) + +# CIFAR-10 Dataset +train_dataset = dsets.CIFAR10(root='../data/', + train=True, + transform=transform, + download=True) + +# Data Loader (Input Pipeline) +train_loader = torch.utils.data.DataLoader(dataset=train_dataset, + batch_size=100, + shuffle=True) + +# 5x5 Convolution +def conv5x5(in_channels, out_channels, stride): + return nn.Conv2d(in_channels, out_channels, kernel_size=4, + stride=stride, padding=1, bias=False) + +# Discriminator Model +class Discriminator(nn.Module): + def __init__(self): + super(Discriminator, self).__init__() + self.model = nn.Sequential( + conv5x5(3, 16, 2), + nn.LeakyReLU(0.2, inplace=True), + conv5x5(16, 32, 2), + nn.BatchNorm2d(32), + nn.LeakyReLU(0.2, inplace=True), + conv5x5(32, 64, 2), + nn.BatchNorm2d(64), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 1, kernel_size=4), + nn.Sigmoid()) + + def forward(self, x): + out = self.model(x) + out = out.view(out.size(0), -1) + return out + +# 4x4 Transpose convolution +def conv_transpose4x4(in_channels, out_channels, stride=1, padding=1, bias=False): + return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, + stride=stride, padding=padding, bias=bias) + +# Generator Model +class Generator(nn.Module): + def __init__(self): + super(Generator, self).__init__() + self.model = nn.Sequential( + conv_transpose4x4(128, 64, padding=0), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + conv_transpose4x4(64, 32, 2), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + conv_transpose4x4(32, 16, 2), + nn.BatchNorm2d(16), + nn.ReLU(inplace=True), + conv_transpose4x4(16, 3, 2, bias=True), + nn.Tanh()) + + def forward(self, x): + x = x.view(x.size(0), 128, 1, 1) + out = self.model(x) + return out + +discriminator = Discriminator() +generator = Generator() + + + +# Loss and Optimizer +criterion = nn.BCELoss() +lr = 0.0002 +d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr) +g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr) + +# Training +for epoch in range(50): + for i, (images, _) in enumerate(train_loader): + images = Variable(images.cuda()) + real_labels = Variable(torch.ones(images.size(0))) + fake_labels = Variable(torch.zeros(images.size(0))) + + # Train the discriminator + discriminator.zero_grad() + outputs = discriminator(images) + real_loss = criterion(outputs, real_labels) + real_score = outputs + + noise = Variable(torch.randn(images.size(0), 128)) + fake_images = generator(noise) + outputs = discriminator(fake_images) + fake_loss = criterion(outputs, fake_labels) + fake_score = outputs + + d_loss = real_loss + fake_loss + d_loss.backward() + d_optimizer.step() + + # Train the generator + generator.zero_grad() + noise = Variable(torch.randn(images.size(0), 128)) + fake_images = generator(noise) + outputs = discriminator(fake_images) + g_loss = criterion(outputs, real_labels) + g_loss.backward() + g_optimizer.step() + + if (i+1) % 100 == 0: + print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, ' + 'D(x): %.2f, D(G(z)): %.2f' + %(epoch, 50, i+1, 500, d_loss.data[0], g_loss.data[0], + real_score.data.mean(), fake_score.data.mean())) + + # Save the sampled images + torchvision.utils.save_image(fake_images.data, + './data/fake_samples_%d_%d.png' %(epoch+1, i+1)) + +# Save the Models +torch.save(generator, './generator.pkl') +torch.save(discriminator, './discriminator.pkl') \ No newline at end of file