Update tutorials for pytorch 0.4.0

This commit is contained in:
yunjey
2018-05-10 17:52:01 +09:00
parent 9087fe6427
commit 78c6afe681
40 changed files with 44263 additions and 0 deletions

View File

@ -0,0 +1,94 @@
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Hyper-parameters
input_size = 784
hidden_size = 500
num_classes = 10
num_epochs = 5
batch_size = 100
learning_rate = 0.001
# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = torchvision.datasets.MNIST(root='../../data',
train=False,
transform=transforms.ToTensor())
# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
# Fully connected neural network with one hidden layer
class NeuralNet(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(NeuralNet, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out
model = NeuralNet(input_size, hidden_size, num_classes).to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
# Move tensors to the configured device
images = images.reshape(-1, 28*28).to(device)
labels = labels.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch+1, num_epochs, i+1, total_step, loss.item()))
# Test the model
# In test phase, we don't need to compute gradients (for memory efficiency)
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.reshape(-1, 28*28).to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total))
# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

View File

@ -0,0 +1,55 @@
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
# Hyper-parameters
input_size = 1
output_size = 1
num_epochs = 60
learning_rate = 0.001
# Toy dataset
x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],
[9.779], [6.182], [7.59], [2.167], [7.042],
[10.791], [5.313], [7.997], [3.1]], dtype=np.float32)
y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],
[3.366], [2.596], [2.53], [1.221], [2.827],
[3.465], [1.65], [2.904], [1.3]], dtype=np.float32)
# Linear regression model
model = nn.Linear(input_size, output_size)
# Loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# Train the model
for epoch in range(num_epochs):
# Convert numpy arrays to torch tensors
inputs = torch.from_numpy(x_train)
targets = torch.from_numpy(y_train)
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, targets)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (epoch+1) % 5 == 0:
print ('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))
# Plot the graph
predicted = model(torch.from_numpy(x_train)).detach().numpy()
plt.plot(x_train, y_train, 'ro', label='Original data')
plt.plot(x_train, predicted, label='Fitted line')
plt.legend()
plt.show()
# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

View File

@ -0,0 +1,76 @@
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
# Hyper-parameters
input_size = 784
num_classes = 10
num_epochs = 5
batch_size = 100
learning_rate = 0.001
# MNIST dataset (images and labels)
train_dataset = torchvision.datasets.MNIST(root='../../data',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = torchvision.datasets.MNIST(root='../../data',
train=False,
transform=transforms.ToTensor())
# Data loader (input pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
# Logistic regression model
model = nn.Linear(input_size, num_classes)
# Loss and optimizer
# nn.CrossEntropyLoss() computes softmax internally
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
# Reshape images to (batch_size, input_size)
images = images.reshape(-1, 28*28)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch+1, num_epochs, i+1, total_step, loss.item()))
# Test the model
# In test phase, we don't need to compute gradients (for memory efficieny)
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.reshape(-1, 28*28)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum()
print('Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))
# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

View File

@ -0,0 +1,189 @@
import torch
import torchvision
import torch.nn as nn
import numpy as np
import torchvision.transforms as transforms
# ================================================================== #
# Table of Contents #
# ================================================================== #
# 1. Basic autograd example 1 (Line 25 to 39)
# 2. Basic autograd example 2 (Line 46 to 83)
# 3. Loading data from numpy (Line 90 to 97)
# 4. Input pipline (Line 104 to 129)
# 5. Input pipline for custom dataset (Line 136 to 156)
# 6. Pretrained model (Line 163 to 176)
# 7. Save and load model (Line 183 to 189)
# ================================================================== #
# 1. Basic autograd example 1 #
# ================================================================== #
# Create tensors.
x = torch.tensor(1, requires_grad=True)
w = torch.tensor(2, requires_grad=True)
b = torch.tensor(3, requires_grad=True)
# Build a computational graph.
y = w * x + b # y = 2 * x + 3
# Compute gradients.
y.backward()
# Print out the gradients.
print(x.grad) # x.grad = 2
print(w.grad) # w.grad = 1
print(b.grad) # b.grad = 1
# ================================================================== #
# 2. Basic autograd example 2 #
# ================================================================== #
# Create tensors of shape (10, 3) and (10, 2).
x = torch.randn(10, 3)
y = torch.randn(10, 2)
# Build a fully connected layer.
linear = nn.Linear(3, 2)
print ('w: ', linear.weight)
print ('b: ', linear.bias)
# Build loss function and optimizer.
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(linear.parameters(), lr=0.01)
# Forward pass.
pred = linear(x)
# Compute loss.
loss = criterion(pred, y)
print('loss: ', loss.item())
# Backward pass.
loss.backward()
# Print out the gradients.
print ('dL/dw: ', linear.weight.grad)
print ('dL/db: ', linear.bias.grad)
# 1-step gradient descent.
optimizer.step()
# You can also perform gradient descent at the low level.
# linear.weight.data.sub_(0.01 * linear.weight.grad.data)
# linear.bias.data.sub_(0.01 * linear.bias.grad.data)
# Print out the loss after 1-step gradient descent.
pred = linear(x)
loss = criterion(pred, y)
print('loss after 1 step optimization: ', loss.item())
# ================================================================== #
# 3. Loading data from numpy #
# ================================================================== #
# Create a numpy array.
x = np.array([[1, 2], [3, 4]])
# Convert the numpy array to a torch tensor.
y = torch.from_numpy(x)
# Convert the torch tensor to a numpy array.
z = y.numpy()
# ================================================================== #
# 4. Input pipline #
# ================================================================== #
# Download and construct CIFAR-10 dataset.
train_dataset = torchvision.datasets.CIFAR10(root='../../data/',
train=True,
transform=transforms.ToTensor(),
download=True)
# Fetch one data pair (read data from disk).
image, label = train_dataset[0]
print (image.size())
print (label)
# Data loader (this provides queues and threads in a very simple way).
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=64,
shuffle=True)
# When iteration starts, queue and thread start to load data from files.
data_iter = iter(train_loader)
# Mini-batch images and labels.
images, labels = data_iter.next()
# Actual usage of the data loader is as below.
for images, labels in train_loader:
# Training code should be written here.
pass
# ================================================================== #
# 5. Input pipline for custom dataset #
# ================================================================== #
# You should your build your custom dataset as below.
class CustomDataset(torch.utils.data.Dataset):
def __init__(self):
# TODO
# 1. Initialize file paths or a list of file names.
pass
def __getitem__(self, index):
# TODO
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
# 2. Preprocess the data (e.g. torchvision.Transform).
# 3. Return a data pair (e.g. image and label).
pass
def __len__(self):
# You should change 0 to the total size of your dataset.
return 0
# You can then use the prebuilt data loader.
custom_dataset = CustomDataset()
train_loader = torch.utils.data.DataLoader(dataset=custom_dataset,
batch_size=64,
shuffle=True)
# ================================================================== #
# 6. Pretrained model #
# ================================================================== #
# Download and load the pretrained ResNet-18.
resnet = torchvision.models.resnet18(pretrained=True)
# If you want to finetune only the top layer of the model, set as below.
for param in resnet.parameters():
param.requires_grad = False
# Replace the top layer for finetuning.
resnet.fc = nn.Linear(resnet.fc.in_features, 100) # 100 is an example.
# Forward pass.
images = torch.randn(64, 3, 224, 224)
outputs = resnet(images)
print (outputs.size()) # (64, 100)
# ================================================================== #
# 7. Save and load the model #
# ================================================================== #
# Save and load the entire model.
torch.save(resnet, 'model.ckpt')
model = torch.load('model.ckpt')
# Save and load only the model parameters (recommended).
torch.save(resnet.state_dict(), 'params.ckpt')
resnet.load_state_dict(torch.load('params.ckpt'))

View File

@ -0,0 +1,102 @@
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Hyper-parameters
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.003
# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data/',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = torchvision.datasets.MNIST(root='../../data/',
train=False,
transform=transforms.ToTensor())
# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
# Bidirectional recurrent neural network (many-to-one)
class BiRNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(BiRNN, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
self.fc = nn.Linear(hidden_size*2, num_classes) # 2 for bidirection
def forward(self, x):
# Set initial states
h0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device) # 2 for bidirection
c0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device)
# Forward propagate LSTM
out, _ = self.lstm(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size*2)
# Decode the hidden state of the last time step
out = self.fc(out[:, -1, :])
return out
model = BiRNN(input_size, hidden_size, num_layers, num_classes).to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.reshape(-1, sequence_length, input_size).to(device)
labels = labels.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch+1, num_epochs, i+1, total_step, loss.item()))
# Test the model
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.reshape(-1, sequence_length, input_size).to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))
# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

View File

@ -0,0 +1,100 @@
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Hyper parameters
num_epochs = 5
num_classes = 10
batch_size = 100
learning_rate = 0.001
# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data/',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = torchvision.datasets.MNIST(root='../../data/',
train=False,
transform=transforms.ToTensor())
# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
# Convolutional neural network (two convolutional layers)
class ConvNet(nn.Module):
def __init__(self, num_classes=10):
super(ConvNet, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.layer2 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.fc = nn.Linear(7*7*32, num_classes)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = out.reshape(out.size(0), -1)
out = self.fc(out)
return out
model = ConvNet(num_classes).to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch+1, num_epochs, i+1, total_step, loss.item()))
# Test the model
model.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))
# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

View File

@ -0,0 +1,169 @@
# ---------------------------------------------------------------------------- #
# An implementation of https://arxiv.org/pdf/1512.03385.pdf #
# See section 4.2 for the model architecture on CIFAR-10 #
# Some part of the code was referenced from below #
# https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py #
# ---------------------------------------------------------------------------- #
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Hyper-parameters
num_epochs = 80
learning_rate = 0.001
# Image preprocessing modules
transform = transforms.Compose([
transforms.Pad(4),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32),
transforms.ToTensor()])
# CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='../../data/',
train=True,
transform=transform,
download=True)
test_dataset = torchvision.datasets.CIFAR10(root='../../data/',
train=False,
transform=transforms.ToTensor())
# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=100,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=100,
shuffle=False)
# 3x3 convolution
def conv3x3(in_channels, out_channels, stride=1):
return nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False)
# Residual block
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(ResidualBlock, self).__init__()
self.conv1 = conv3x3(in_channels, out_channels, stride)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(out_channels, out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
self.downsample = downsample
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
# ResNet
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=10):
super(ResNet, self).__init__()
self.in_channels = 16
self.conv = conv3x3(3, 16)
self.bn = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self.make_layer(block, 16, layers[0])
self.layer2 = self.make_layer(block, 32, layers[0], 2)
self.layer3 = self.make_layer(block, 64, layers[1], 2)
self.avg_pool = nn.AvgPool2d(8)
self.fc = nn.Linear(64, num_classes)
def make_layer(self, block, out_channels, blocks, stride=1):
downsample = None
if (stride != 1) or (self.in_channels != out_channels):
downsample = nn.Sequential(
conv3x3(self.in_channels, out_channels, stride=stride),
nn.BatchNorm2d(out_channels))
layers = []
layers.append(block(self.in_channels, out_channels, stride, downsample))
self.in_channels = out_channels
for i in range(1, blocks):
layers.append(block(out_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv(x)
out = self.bn(out)
out = self.relu(out)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.avg_pool(out)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
model = ResNet(ResidualBlock, [2, 2, 2, 2]).to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# For updating learning rate
def update_lr(optimizer, lr):
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# Train the model
total_step = len(train_loader)
curr_lr = learning_rate
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print ("Epoch [{}/{}], Step [{}/{}] Loss: {:.4f}"
.format(epoch+1, num_epochs, i+1, total_step, loss.item()))
# Decay learning rate
if (epoch+1) % 20 == 0:
curr_lr /= 3
update_lr(optimizer, curr_lr)
# Test the model
model.eval()
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))
# Save the model checkpoint
torch.save(model.state_dict(), 'resnet.ckpt')

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,46 @@
import torch
import os
class Dictionary(object):
def __init__(self):
self.word2idx = {}
self.idx2word = {}
self.idx = 0
def add_word(self, word):
if not word in self.word2idx:
self.word2idx[word] = self.idx
self.idx2word[self.idx] = word
self.idx += 1
def __len__(self):
return len(self.word2idx)
class Corpus(object):
def __init__(self):
self.dictionary = Dictionary()
def get_data(self, path, batch_size=20):
# Add words to the dictionary
with open(path, 'r') as f:
tokens = 0
for line in f:
words = line.split() + ['<eos>']
tokens += len(words)
for word in words:
self.dictionary.add_word(word)
# Tokenize the file content
ids = torch.LongTensor(tokens)
token = 0
with open(path, 'r') as f:
for line in f:
words = line.split() + ['<eos>']
for word in words:
ids[token] = self.dictionary.word2idx[word]
token += 1
num_batches = ids.size(0) // batch_size
ids = ids[:num_batches*batch_size]
return ids.view(batch_size, -1)

View File

@ -0,0 +1,120 @@
# Some part of the code was referenced from below.
# https://github.com/pytorch/examples/tree/master/word_language_model
import torch
import torch.nn as nn
import numpy as np
from torch.nn.utils import clip_grad_norm
from data_utils import Dictionary, Corpus
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Hyper-parameters
embed_size = 128
hidden_size = 1024
num_layers = 1
num_epochs = 5
num_samples = 1000 # number of words to be sampled
batch_size = 20
seq_length = 30
learning_rate = 0.002
# Load "Penn Treebank" dataset
corpus = Corpus()
ids = corpus.get_data('data/train.txt', batch_size)
vocab_size = len(corpus.dictionary)
num_batches = ids.size(1) // seq_length
# RNN based language model
class RNNLM(nn.Module):
def __init__(self, vocab_size, embed_size, hidden_size, num_layers):
super(RNNLM, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
self.linear = nn.Linear(hidden_size, vocab_size)
def forward(self, x, h):
# Embed word ids to vectors
x = self.embed(x)
# Forward propagate LSTM
out, (h, c) = self.lstm(x, h)
# Reshape output to (batch_size*sequence_length, hidden_size)
out = out.reshape(out.size(0)*out.size(1), out.size(2))
# Decode hidden states of all time steps
out = self.linear(out)
return out, (h, c)
model = RNNLM(vocab_size, embed_size, hidden_size, num_layers).to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Truncated backpropagation
def detach(states):
return [state.detach() for state in states]
# Train the model
for epoch in range(num_epochs):
# Set initial hidden and cell states
states = (torch.zeros(num_layers, batch_size, hidden_size).to(device),
torch.zeros(num_layers, batch_size, hidden_size).to(device))
for i in range(0, ids.size(1) - seq_length, seq_length):
# Get mini-batch inputs and targets
inputs = ids[:, i:i+seq_length].to(device)
targets = ids[:, (i+1):(i+1)+seq_length].to(device)
# Forward pass
states = detach(states)
outputs, states = model(inputs, states)
loss = criterion(outputs, targets.reshape(-1))
# Backward and optimize
model.zero_grad()
loss.backward()
clip_grad_norm(model.parameters(), 0.5)
optimizer.step()
step = (i+1) // seq_length
if step % 100 == 0:
print ('Epoch [{}/{}], Step[{}/{}], Loss: {:.4f}, Perplexity: {:5.2f}'
.format(epoch+1, num_epochs, step, num_batches, loss.item(), np.exp(loss.item())))
# Test the model
with torch.no_grad():
with open('sample.txt', 'w') as f:
# Set intial hidden ane cell states
state = (torch.zeros(num_layers, 1, hidden_size).to(device),
torch.zeros(num_layers, 1, hidden_size).to(device))
# Select one word id randomly
prob = torch.ones(vocab_size)
input = torch.multinomial(prob, num_samples=1).unsqueeze(1).to(device)
for i in range(num_samples):
# Forward propagate RNN
output, state = model(input, state)
# Sample a word id
prob = output.exp()
word_id = torch.multinomial(prob, num_samples=1).item()
# Fill input with sampled word id for the next time step
input.fill_(word_id)
# File write
word = corpus.dictionary.idx2word[word_id]
word = '\n' if word == '<eos>' else word + ' '
f.write(word)
if (i+1) % 100 == 0:
print('Sampled [{}/{}] words and save to {}'.format(i+1, num_samples, 'sample.txt'))
# Save the model checkpoints
torch.save(model.state_dict(), 'model.ckpt')

View File

@ -0,0 +1,102 @@
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Hyper-parameters
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.01
# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data/',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = torchvision.datasets.MNIST(root='../../data/',
train=False,
transform=transforms.ToTensor())
# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
# Recurrent neural network (many-to-one)
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x):
# Set initial hidden and cell states
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
# Forward propagate LSTM
out, _ = self.lstm(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size)
# Decode the hidden state of the last time step
out = self.fc(out[:, -1, :])
return out
model = RNN(input_size, hidden_size, num_layers, num_classes).to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.reshape(-1, sequence_length, input_size).to(device)
labels = labels.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch+1, num_epochs, i+1, total_step, loss.item()))
# Test the model
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.reshape(-1, sequence_length, input_size).to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))
# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

View File

@ -0,0 +1,144 @@
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Hyper-parameters
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples'
# Create a directory if not exists
if not os.path.exists(sample_dir):
os.makedirs(sample_dir)
# Image processing
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), # 3 for RGB channels
std=(0.5, 0.5, 0.5))])
# MNIST dataset
mnist = torchvision.datasets.MNIST(root='../../data/',
train=True,
transform=transform,
download=True)
# Data loader
data_loader = torch.utils.data.DataLoader(dataset=mnist,
batch_size=batch_size,
shuffle=True)
# Discriminator
D = nn.Sequential(
nn.Linear(image_size, hidden_size),
nn.LeakyReLU(0.2),
nn.Linear(hidden_size, hidden_size),
nn.LeakyReLU(0.2),
nn.Linear(hidden_size, 1),
nn.Sigmoid())
# Generator
G = nn.Sequential(
nn.Linear(latent_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, image_size),
nn.Tanh())
# Device setting
D = D.to(device)
G = G.to(device)
# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)
def denorm(x):
out = (x + 1) / 2
return out.clamp(0, 1)
def reset_grad():
d_optimizer.zero_grad()
g_optimizer.zero_grad()
# Start training
total_step = len(data_loader)
for epoch in range(num_epochs):
for i, (images, _) in enumerate(data_loader):
images = images.reshape(batch_size, -1).to(device)
# Create the labels which are later used as input for the BCE loss
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# ================================================================== #
# Train the discriminator #
# ================================================================== #
# Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
# Second term of the loss is always zero since real_labels == 1
outputs = D(images)
d_loss_real = criterion(outputs, real_labels)
real_score = outputs
# Compute BCELoss using fake images
# First term of the loss is always zero since fake_labels == 0
z = torch.randn(batch_size, latent_size).to(device)
fake_images = G(z)
outputs = D(fake_images)
d_loss_fake = criterion(outputs, fake_labels)
fake_score = outputs
# Backprop and optimize
d_loss = d_loss_real + d_loss_fake
reset_grad()
d_loss.backward()
d_optimizer.step()
# ================================================================== #
# Train the generator #
# ================================================================== #
# Compute loss with fake images
z = torch.randn(batch_size, latent_size).to(device)
fake_images = G(z)
outputs = D(fake_images)
# We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
# For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
g_loss = criterion(outputs, real_labels)
# Backprop and optimize
reset_grad()
g_loss.backward()
g_optimizer.step()
if (i+1) % 200 == 0:
print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
.format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(),
real_score.mean().item(), fake_score.mean().item()))
# Save real images
if (epoch+1) == 1:
images = images.reshape(images.size(0), 1, 28, 28)
save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
# Save sampled images
fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))
# Save the model checkpoints
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')

View File

@ -0,0 +1,59 @@
# Image Captioning
The goal of image captioning is to convert a given input image into a natural language description. The encoder-decoder framework is widely used for this task. The image encoder is a convolutional neural network (CNN). In this tutorial, we used [resnet-152](https://arxiv.org/abs/1512.03385) model pretrained on the [ILSVRC-2012-CLS](http://www.image-net.org/challenges/LSVRC/2012/) image classification dataset. The decoder is a long short-term memory (LSTM) network.
![alt text](png/model.png)
#### Training phase
For the encoder part, the pretrained CNN extracts the feature vector from a given input image. The feature vector is linearly transformed to have the same dimension as the input dimension of the LSTM network. For the decoder part, source and target texts are predefined. For example, if the image description is **"Giraffes standing next to each other"**, the source sequence is a list containing **['\<start\>', 'Giraffes', 'standing', 'next', 'to', 'each', 'other']** and the target sequence is a list containing **['Giraffes', 'standing', 'next', 'to', 'each', 'other', '\<end\>']**. Using these source and target sequences and the feature vector, the LSTM decoder is trained as a language model conditioned on the feature vector.
#### Test phase
In the test phase, the encoder part is almost same as the training phase. The only difference is that batchnorm layer uses moving average and variance instead of mini-batch statistics. This can be easily implemented using [encoder.eval()](https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/03-advanced/image_captioning/sample.py#L43). For the decoder part, there is a significant difference between the training phase and the test phase. In the test phase, the LSTM decoder can't see the image description. To deal with this problem, the LSTM decoder feeds back the previosly generated word to the next input. This can be implemented using a [for-loop](https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/03-advanced/image_captioning/model.py#L48).
## Usage
#### 1. Clone the repositories
```bash
$ git clone https://github.com/pdollar/coco.git
$ cd coco/PythonAPI/
$ make
$ python setup.py build
$ python setup.py install
$ cd ../../
$ git clone https://github.com/yunjey/pytorch-tutorial.git
$ cd pytorch-tutorial/tutorials/03-advanced/image_captioning/
```
#### 2. Download the dataset
```bash
$ pip install -r requirements.txt
$ chmod +x download.sh
$ ./download.sh
```
#### 3. Preprocessing
```bash
$ python build_vocab.py
$ python resize.py
```
#### 4. Train the model
```bash
$ python train.py
```
#### 5. Test the model
```bash
$ python sample.py --image='png/example.png'
```
<br>
## Pretrained model
If you do not want to train the model from scratch, you can use a pretrained model. You can download the pretrained model [here](https://www.dropbox.com/s/ne0ixz5d58ccbbz/pretrained_model.zip?dl=0) and the vocabulary file [here](https://www.dropbox.com/s/26adb7y9m98uisa/vocap.zip?dl=0). You should extract pretrained_model.zip to `./models/` and vocab.pkl to `./data/` using `unzip` command.

View File

@ -0,0 +1,76 @@
import nltk
import pickle
import argparse
from collections import Counter
from pycocotools.coco import COCO
class Vocabulary(object):
"""Simple vocabulary wrapper."""
def __init__(self):
self.word2idx = {}
self.idx2word = {}
self.idx = 0
def add_word(self, word):
if not word in self.word2idx:
self.word2idx[word] = self.idx
self.idx2word[self.idx] = word
self.idx += 1
def __call__(self, word):
if not word in self.word2idx:
return self.word2idx['<unk>']
return self.word2idx[word]
def __len__(self):
return len(self.word2idx)
def build_vocab(json, threshold):
"""Build a simple vocabulary wrapper."""
coco = COCO(json)
counter = Counter()
ids = coco.anns.keys()
for i, id in enumerate(ids):
caption = str(coco.anns[id]['caption'])
tokens = nltk.tokenize.word_tokenize(caption.lower())
counter.update(tokens)
if (i+1) % 1000 == 0:
print("[{}/{}] Tokenized the captions.".format(i+1, len(ids)))
# If the word frequency is less than 'threshold', then the word is discarded.
words = [word for word, cnt in counter.items() if cnt >= threshold]
# Create a vocab wrapper and add some special tokens.
vocab = Vocabulary()
vocab.add_word('<pad>')
vocab.add_word('<start>')
vocab.add_word('<end>')
vocab.add_word('<unk>')
# Add the words to the vocabulary.
for i, word in enumerate(words):
vocab.add_word(word)
return vocab
def main(args):
vocab = build_vocab(json=args.caption_path, threshold=args.threshold)
vocab_path = args.vocab_path
with open(vocab_path, 'wb') as f:
pickle.dump(vocab, f)
print("Total vocabulary size: {}".format(len(vocab)))
print("Saved the vocabulary wrapper to '{}'".format(vocab_path))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--caption_path', type=str,
default='data/annotations/captions_train2014.json',
help='path for train annotation file')
parser.add_argument('--vocab_path', type=str, default='./data/vocab.pkl',
help='path for saving vocabulary wrapper')
parser.add_argument('--threshold', type=int, default=4,
help='minimum word count threshold')
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,105 @@
import torch
import torchvision.transforms as transforms
import torch.utils.data as data
import os
import pickle
import numpy as np
import nltk
from PIL import Image
from build_vocab import Vocabulary
from pycocotools.coco import COCO
class CocoDataset(data.Dataset):
"""COCO Custom Dataset compatible with torch.utils.data.DataLoader."""
def __init__(self, root, json, vocab, transform=None):
"""Set the path for images, captions and vocabulary wrapper.
Args:
root: image directory.
json: coco annotation file path.
vocab: vocabulary wrapper.
transform: image transformer.
"""
self.root = root
self.coco = COCO(json)
self.ids = list(self.coco.anns.keys())
self.vocab = vocab
self.transform = transform
def __getitem__(self, index):
"""Returns one data pair (image and caption)."""
coco = self.coco
vocab = self.vocab
ann_id = self.ids[index]
caption = coco.anns[ann_id]['caption']
img_id = coco.anns[ann_id]['image_id']
path = coco.loadImgs(img_id)[0]['file_name']
image = Image.open(os.path.join(self.root, path)).convert('RGB')
if self.transform is not None:
image = self.transform(image)
# Convert caption (string) to word ids.
tokens = nltk.tokenize.word_tokenize(str(caption).lower())
caption = []
caption.append(vocab('<start>'))
caption.extend([vocab(token) for token in tokens])
caption.append(vocab('<end>'))
target = torch.Tensor(caption)
return image, target
def __len__(self):
return len(self.ids)
def collate_fn(data):
"""Creates mini-batch tensors from the list of tuples (image, caption).
We should build custom collate_fn rather than using default collate_fn,
because merging caption (including padding) is not supported in default.
Args:
data: list of tuple (image, caption).
- image: torch tensor of shape (3, 256, 256).
- caption: torch tensor of shape (?); variable length.
Returns:
images: torch tensor of shape (batch_size, 3, 256, 256).
targets: torch tensor of shape (batch_size, padded_length).
lengths: list; valid length for each padded caption.
"""
# Sort a data list by caption length (descending order).
data.sort(key=lambda x: len(x[1]), reverse=True)
images, captions = zip(*data)
# Merge images (from tuple of 3D tensor to 4D tensor).
images = torch.stack(images, 0)
# Merge captions (from tuple of 1D tensor to 2D tensor).
lengths = [len(cap) for cap in captions]
targets = torch.zeros(len(captions), max(lengths)).long()
for i, cap in enumerate(captions):
end = lengths[i]
targets[i, :end] = cap[:end]
return images, targets, lengths
def get_loader(root, json, vocab, transform, batch_size, shuffle, num_workers):
"""Returns torch.utils.data.DataLoader for custom coco dataset."""
# COCO caption dataset
coco = CocoDataset(root=root,
json=json,
vocab=vocab,
transform=transform)
# Data loader for COCO dataset
# This will return (images, captions, lengths) for each iteration.
# images: a tensor of shape (batch_size, 3, 224, 224).
# captions: a tensor of shape (batch_size, padded_length).
# lengths: a list indicating valid length for each caption. length is (batch_size).
data_loader = torch.utils.data.DataLoader(dataset=coco,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
collate_fn=collate_fn)
return data_loader

View File

@ -0,0 +1,11 @@
mkdir data
wget http://msvocds.blob.core.windows.net/annotations-1-0-3/captions_train-val2014.zip -P ./data/
wget http://msvocds.blob.core.windows.net/coco2014/train2014.zip -P ./data/
wget http://msvocds.blob.core.windows.net/coco2014/val2014.zip -P ./data/
unzip ./data/captions_train-val2014.zip -d ./data/
rm ./data/captions_train-val2014.zip
unzip ./data/train2014.zip -d ./data/
rm ./data/train2014.zip
unzip ./data/val2014.zip -d ./data/
rm ./data/val2014.zip

View File

@ -0,0 +1,56 @@
import torch
import torch.nn as nn
import torchvision.models as models
from torch.nn.utils.rnn import pack_padded_sequence
class EncoderCNN(nn.Module):
def __init__(self, embed_size):
"""Load the pretrained ResNet-152 and replace top fc layer."""
super(EncoderCNN, self).__init__()
resnet = models.resnet152(pretrained=True)
modules = list(resnet.children())[:-1] # delete the last fc layer.
self.resnet = nn.Sequential(*modules)
self.linear = nn.Linear(resnet.fc.in_features, embed_size)
self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
def forward(self, images):
"""Extract feature vectors from input images."""
with torch.no_grad():
features = self.resnet(images)
features = features.reshape(features.size(0), -1)
features = self.bn(self.linear(features))
return features
class DecoderRNN(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=20):
"""Set the hyper-parameters and build the layers."""
super(DecoderRNN, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
self.linear = nn.Linear(hidden_size, vocab_size)
self.max_seg_length = max_seq_length
def forward(self, features, captions, lengths):
"""Decode image feature vectors and generates captions."""
embeddings = self.embed(captions)
embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
packed = pack_padded_sequence(embeddings, lengths, batch_first=True)
hiddens, _ = self.lstm(packed)
outputs = self.linear(hiddens[0])
return outputs
def sample(self, features, states=None):
"""Generate captions for given image features using greedy search."""
sampled_ids = []
inputs = features.unsqueeze(1)
for i in range(self.max_seg_length):
hiddens, states = self.lstm(inputs, states) # hiddens: (batch_size, 1, hidden_size)
outputs = self.linear(hiddens.squeeze(1)) # outputs: (batch_size, vocab_size)
_, predicted = outputs.max(1) # predicted: (batch_size)
sampled_ids.append(predicted)
inputs = self.embed(predicted) # inputs: (batch_size, embed_size)
inputs = inputs.unsqueeze(1) # inputs: (batch_size, 1, embed_size)
sampled_ids = torch.stack(sampled_ids, 1) # sampled_ids: (batch_size, max_seq_length)
return sampled_ids

Binary file not shown.

After

Width:  |  Height:  |  Size: 220 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 246 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 246 KiB

View File

@ -0,0 +1,5 @@
matplotlib
nltk
numpy
Pillow
argparse

View File

@ -0,0 +1,42 @@
import argparse
import os
from PIL import Image
def resize_image(image, size):
"""Resize an image to the given size."""
return image.resize(size, Image.ANTIALIAS)
def resize_images(image_dir, output_dir, size):
"""Resize the images in 'image_dir' and save into 'output_dir'."""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
images = os.listdir(image_dir)
num_images = len(images)
for i, image in enumerate(images):
with open(os.path.join(image_dir, image), 'r+b') as f:
with Image.open(f) as img:
img = resize_image(img, size)
img.save(os.path.join(output_dir, image), img.format)
if (i+1) % 100 == 0:
print ("[{}/{}] Resized the images and saved into '{}'."
.format(i+1, num_images, output_dir))
def main(args):
image_dir = args.image_dir
output_dir = args.output_dir
image_size = [args.image_size, args.image_size]
resize_images(image_dir, output_dir, image_size)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image_dir', type=str, default='./data/train2014/',
help='directory for train images')
parser.add_argument('--output_dir', type=str, default='./data/resized2014/',
help='directory for saving resized images')
parser.add_argument('--image_size', type=int, default=256,
help='size for image after processing')
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,81 @@
import torch
import matplotlib.pyplot as plt
import numpy as np
import argparse
import pickle
import os
from torchvision import transforms
from build_vocab import Vocabulary
from model import EncoderCNN, DecoderRNN
from PIL import Image
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_image(image_path, transform=None):
image = Image.open(image_path)
image = image.resize([224, 224], Image.LANCZOS)
if transform is not None:
image = transform(image).unsqueeze(0)
return image
def main(args):
# Image preprocessing
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
# Load vocabulary wrapper
with open(args.vocab_path, 'rb') as f:
vocab = pickle.load(f)
# Build models
encoder = EncoderCNN(args.embed_size).eval() # eval mode (batchnorm uses moving mean/variance)
decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab), args.num_layers)
encoder = encoder.to(device)
decoder = decoder.to(device)
# Load the trained model parameters
encoder.load_state_dict(torch.load(args.encoder_path))
decoder.load_state_dict(torch.load(args.decoder_path))
# Prepare an image
image = load_image(args.image, transform)
image_tensor = image.to(device)
# Generate an caption from the image
feature = encoder(image_tensor)
sampled_ids = decoder.sample(feature)
sampled_ids = sampled_ids[0].cpu().numpy() # (1, max_seq_length) -> (max_seq_length)
# Convert word_ids to words
sampled_caption = []
for word_id in sampled_ids:
word = vocab.idx2word[word_id]
sampled_caption.append(word)
if word == '<end>':
break
sentence = ' '.join(sampled_caption)
# Print out the image and the generated caption
print (sentence)
image = Image.open(args.image)
plt.imshow(np.asarray(image))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image', type=str, required=True, help='input image for generating caption')
parser.add_argument('--encoder_path', type=str, default='models/encoder-2-1000.ckpt', help='path for trained encoder')
parser.add_argument('--decoder_path', type=str, default='models/decoder-2-1000.ckpt', help='path for trained decoder')
parser.add_argument('--vocab_path', type=str, default='data/vocab.pkl', help='path for vocabulary wrapper')
# Model parameters (should be same as paramters in train.py)
parser.add_argument('--embed_size', type=int , default=256, help='dimension of word embedding vectors')
parser.add_argument('--hidden_size', type=int , default=512, help='dimension of lstm hidden states')
parser.add_argument('--num_layers', type=int , default=1, help='number of layers in lstm')
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,101 @@
import argparse
import torch
import torch.nn as nn
import numpy as np
import os
import pickle
from data_loader import get_loader
from build_vocab import Vocabulary
from model import EncoderCNN, DecoderRNN
from torch.nn.utils.rnn import pack_padded_sequence
from torchvision import transforms
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def main(args):
# Create model directory
if not os.path.exists(args.model_path):
os.makedirs(args.model_path)
# Image preprocessing, normalization for the pretrained resnet
transform = transforms.Compose([
transforms.RandomCrop(args.crop_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
# Load vocabulary wrapper
with open(args.vocab_path, 'rb') as f:
vocab = pickle.load(f)
# Build data loader
data_loader = get_loader(args.image_dir, args.caption_path, vocab,
transform, args.batch_size,
shuffle=True, num_workers=args.num_workers)
# Build the models
encoder = EncoderCNN(args.embed_size).to(device)
decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab), args.num_layers).to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters())
optimizer = torch.optim.Adam(params, lr=args.learning_rate)
# Train the models
total_step = len(data_loader)
for epoch in range(args.num_epochs):
for i, (images, captions, lengths) in enumerate(data_loader):
# Set mini-batch dataset
images = images.to(device)
captions = captions.to(device)
targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]
# Forward, backward and optimize
features = encoder(images)
outputs = decoder(features, captions, lengths)
loss = criterion(outputs, targets)
decoder.zero_grad()
encoder.zero_grad()
loss.backward()
optimizer.step()
# Print log info
if i % args.log_step == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
.format(epoch, args.num_epochs, i, total_step, loss.item(), np.exp(loss.item())))
# Save the model checkpoints
if (i+1) % args.save_step == 0:
torch.save(decoder.state_dict(), os.path.join(
args.model_path, 'decoder-{}-{}.ckpt'.format(epoch+1, i+1)))
torch.save(encoder.state_dict(), os.path.join(
args.model_path, 'encoder-{}-{}.ckpt'.format(epoch+1, i+1)))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default='models/' , help='path for saving trained models')
parser.add_argument('--crop_size', type=int, default=224 , help='size for randomly cropping images')
parser.add_argument('--vocab_path', type=str, default='data/vocab.pkl', help='path for vocabulary wrapper')
parser.add_argument('--image_dir', type=str, default='data/resized2014', help='directory for resized images')
parser.add_argument('--caption_path', type=str, default='data/annotations/captions_train2014.json', help='path for train annotation json file')
parser.add_argument('--log_step', type=int , default=10, help='step size for prining log info')
parser.add_argument('--save_step', type=int , default=1000, help='step size for saving trained models')
# Model parameters
parser.add_argument('--embed_size', type=int , default=256, help='dimension of word embedding vectors')
parser.add_argument('--hidden_size', type=int , default=512, help='dimension of lstm hidden states')
parser.add_argument('--num_layers', type=int , default=1, help='number of layers in lstm')
parser.add_argument('--num_epochs', type=int, default=5)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--num_workers', type=int, default=2)
parser.add_argument('--learning_rate', type=float, default=0.001)
args = parser.parse_args()
print(args)
main(args)

View File

@ -0,0 +1,33 @@
# Neural Style Transfer
[Neural style transfer](https://arxiv.org/abs/1508.06576) is an algorithm that combines the content of one image with the style of another image using CNN. Given a content image and a style image, the goal is to generate a target image that minimizes the content difference with the content image and the style difference with the style image.
<p align="center"><img width="100%" src="png/neural_style2.png" /></p>
#### Content loss
To minimize the content difference, we forward propagate the content image and the target image to pretrained [VGGNet](https://arxiv.org/abs/1409.1556) respectively, and extract feature maps from multiple convolutional layers. Then, the target image is updated to minimize the [mean-squared error](https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/03-advanced/neural_style_transfer/main.py#L92-L93) between the feature maps of the content image and its feature maps.
#### Style loss
As in computing the content loss, we forward propagate the style image and the target image to the VGGNet and extract convolutional feature maps. To generate a texture that matches the style of the style image, we update the target image by minimizing the mean-squared error between the Gram matrix of the style image and the Gram matrix of the target image (feature correlation minimization). See [here](https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/03-advanced/neural_style_transfer/main.py#L95-L105) for how to compute the style loss.
<br>
## Usage
```bash
$ pip install -r requirements.txt
$ python main.py --content='png/content.png' --style='png/style.png'
```
<br>
## Results
The following is the result of applying variaous styles of artwork to Anne Hathaway's photograph.
![alt text](png/neural_style.png)

View File

@ -0,0 +1,126 @@
from __future__ import division
from torchvision import models
from torchvision import transforms
from PIL import Image
import argparse
import torch
import torchvision
import torch.nn as nn
import numpy as np
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_image(image_path, transform=None, max_size=None, shape=None):
"""Load an image and convert it to a torch tensor."""
image = Image.open(image_path)
if max_size:
scale = max_size / max(image.size)
size = np.array(image.size) * scale
image = image.resize(size.astype(int), Image.ANTIALIAS)
if shape:
image = image.resize(shape, Image.LANCZOS)
if transform:
image = transform(image).unsqueeze(0)
return image.to(device)
class VGGNet(nn.Module):
def __init__(self):
"""Select conv1_1 ~ conv5_1 activation maps."""
super(VGGNet, self).__init__()
self.select = ['0', '5', '10', '19', '28']
self.vgg = models.vgg19(pretrained=True).features
def forward(self, x):
"""Extract multiple convolutional feature maps."""
features = []
for name, layer in self.vgg._modules.items():
x = layer(x)
if name in self.select:
features.append(x)
return features
def main(config):
# Image preprocessing
# VGGNet was trained on ImageNet where images are normalized by mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
# We use the same normalization statistics here.
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225))])
# Load content and style images
# Make the style image same size as the content image
content = load_image(config.content, transform, max_size=config.max_size)
style = load_image(config.style, transform, shape=[content.size(2), content.size(3)])
# Initialize a target image with the content image
target = content.clone().requires_grad_(True)
optimizer = torch.optim.Adam([target], lr=config.lr, betas=[0.5, 0.999])
vgg = VGGNet().to(device).eval()
for step in range(config.total_step):
# Extract multiple(5) conv feature vectors
target_features = vgg(target)
content_features = vgg(content)
style_features = vgg(style)
style_loss = 0
content_loss = 0
for f1, f2, f3 in zip(target_features, content_features, style_features):
# Compute content loss with target and content images
content_loss += torch.mean((f1 - f2)**2)
# Reshape convolutional feature maps
_, c, h, w = f1.size()
f1 = f1.view(c, h * w)
f3 = f3.view(c, h * w)
# Compute gram matrix
f1 = torch.mm(f1, f1.t())
f3 = torch.mm(f3, f3.t())
# Compute style loss with target and style images
style_loss += torch.mean((f1 - f3)**2) / (c * h * w)
# Compute total loss, backprop and optimize
loss = content_loss + config.style_weight * style_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (step+1) % config.log_step == 0:
print ('Step [{}/{}], Content Loss: {:.4f}, Style Loss: {:.4f}'
.format(step+1, config.total_step, content_loss.item(), style_loss.item()))
if (step+1) % config.sample_step == 0:
# Save the generated image
denorm = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4.44))
img = target.clone().squeeze()
img = denorm(img).clamp_(0, 1)
torchvision.utils.save_image(img, 'output-{}.png'.format(step+1))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--content', type=str, default='png/content.png')
parser.add_argument('--style', type=str, default='png/style.png')
parser.add_argument('--max_size', type=int, default=400)
parser.add_argument('--total_step', type=int, default=2000)
parser.add_argument('--log_step', type=int, default=10)
parser.add_argument('--sample_step', type=int, default=500)
parser.add_argument('--style_weight', type=float, default=100)
parser.add_argument('--lr', type=float, default=0.003)
config = parser.parse_args()
print(config)
main(config)

Binary file not shown.

After

Width:  |  Height:  |  Size: 599 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 493 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 682 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 947 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 MiB

View File

@ -0,0 +1,4 @@
argparse
torch
torchvision
Pillow

View File

@ -0,0 +1,101 @@
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Create a directory if not exists
sample_dir = 'samples'
if not os.path.exists(sample_dir):
os.makedirs(sample_dir)
# Hyper-parameters
image_size = 784
h_dim = 400
z_dim = 20
num_epochs = 15
batch_size = 128
learning_rate = 1e-3
# MNIST dataset
dataset = torchvision.datasets.MNIST(root='../../data',
train=True,
transform=transforms.ToTensor(),
download=True)
# Data loader
data_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True)
# VAE model
class VAE(nn.Module):
def __init__(self, image_size=784, h_dim=400, z_dim=20):
super(VAE, self).__init__()
self.fc1 = nn.Linear(image_size, h_dim)
self.fc2 = nn.Linear(h_dim, z_dim)
self.fc3 = nn.Linear(h_dim, z_dim)
self.fc4 = nn.Linear(z_dim, h_dim)
self.fc5 = nn.Linear(h_dim, image_size)
def encode(self, x):
h = F.relu(self.fc1(x))
return self.fc2(h), self.fc3(h)
def reparameterize(self, mu, log_var):
std = torch.exp(log_var/2)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
h = F.relu(self.fc4(z))
return F.sigmoid(self.fc5(h))
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
x_reconst = self.decode(z)
return x_reconst, mu, log_var
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Start training
for epoch in range(num_epochs):
for i, (x, _) in enumerate(data_loader):
# Forward pass
x = x.to(device).view(-1, image_size)
x_reconst, mu, log_var = model(x)
# Compute reconstruction loss and kl divergence
# For KL divergence, see Appendix B in VAE paper or http://yunjey47.tistory.com/43
reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)
kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
# Backprop and optimize
loss = reconst_loss + kl_div
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 10 == 0:
print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}"
.format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item(), kl_div.item()))
with torch.no_grad():
# Save the sampled images
z = torch.randn(batch_size, z_dim).to(device)
out = model.decode(z).view(-1, 1, 28, 28)
save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))
# Save the reconstructed images
out, _, _ = model(x)
x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)
save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))

View File

@ -0,0 +1,25 @@
# TensorBoard in PyTorch
In this tutorial, we implement a MNIST classifier using a simple neural network and visualize the training process using [TensorBoard](https://www.tensorflow.org/get_started/summaries_and_tensorboard). In training phase, we plot the loss and accuracy functions through `scalar_summary` and visualize the training images through `image_summary`. In addition, we visualize the weight and gradient values of the parameters of the neural network using `histogram_summary`. PyTorch code for handling these summary functions can be found [here](https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/04-utils/tensorboard/main.py#L81-L97).
![alt text](gif/tensorboard.gif)
<br>
## Usage
#### 1. Install the dependencies
```bash
$ pip install -r requirements.txt
```
#### 2. Train the model
```bash
$ python main.py
```
#### 3. Open the TensorBoard
To run the TensorBoard, open a new terminal and run the command below. Then, open http://localhost:6006/ on your web browser.
```bash
$ tensorboard --logdir='./logs' --port=6006
```

Binary file not shown.

After

Width:  |  Height:  |  Size: 561 KiB

View File

@ -0,0 +1,71 @@
# Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514
import tensorflow as tf
import numpy as np
import scipy.misc
try:
from StringIO import StringIO # Python 2.7
except ImportError:
from io import BytesIO # Python 3.x
class Logger(object):
def __init__(self, log_dir):
"""Create a summary writer logging to log_dir."""
self.writer = tf.summary.FileWriter(log_dir)
def scalar_summary(self, tag, value, step):
"""Log a scalar variable."""
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
self.writer.add_summary(summary, step)
def image_summary(self, tag, images, step):
"""Log a list of images."""
img_summaries = []
for i, img in enumerate(images):
# Write the image to a string
try:
s = StringIO()
except:
s = BytesIO()
scipy.misc.toimage(img).save(s, format="png")
# Create an Image object
img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
height=img.shape[0],
width=img.shape[1])
# Create a Summary value
img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))
# Create and write Summary
summary = tf.Summary(value=img_summaries)
self.writer.add_summary(summary, step)
def histo_summary(self, tag, values, step, bins=1000):
"""Log a histogram of the tensor of values."""
# Create a histogram using numpy
counts, bin_edges = np.histogram(values, bins=bins)
# Fill the fields of the histogram proto
hist = tf.HistogramProto()
hist.min = float(np.min(values))
hist.max = float(np.max(values))
hist.num = int(np.prod(values.shape))
hist.sum = float(np.sum(values))
hist.sum_squares = float(np.sum(values**2))
# Drop the start of the first bin
bin_edges = bin_edges[1:]
# Add bin edges and counts
for edge in bin_edges:
hist.bucket_limit.append(edge)
for c in counts:
hist.bucket.append(c)
# Create and write Summary
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
self.writer.add_summary(summary, step)
self.writer.flush()

View File

@ -0,0 +1,97 @@
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from logger import Logger
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# MNIST dataset
dataset = torchvision.datasets.MNIST(root='../../data',
train=True,
transform=transforms.ToTensor(),
download=True)
# Data loader
data_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=100,
shuffle=True)
# Fully connected neural network with one hidden layer
class NeuralNet(nn.Module):
def __init__(self, input_size=784, hidden_size=500, num_classes=10):
super(NeuralNet, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out
model = NeuralNet().to(device)
logger = Logger('./logs')
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)
data_iter = iter(data_loader)
iter_per_epoch = len(data_loader)
total_step = 50000
# Start training
for step in range(total_step):
# Reset the data_iter
if (step+1) % iter_per_epoch == 0:
data_iter = iter(data_loader)
# Fetch images and labels
images, labels = next(data_iter)
images, labels = images.view(images.size(0), -1).to(device), labels.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Compute accuracy
_, argmax = torch.max(outputs, 1)
accuracy = (labels == argmax.squeeze()).float().mean()
if (step+1) % 100 == 0:
print ('Step [{}/{}], Loss: {:.4f}, Acc: {:.2f}'
.format(step+1, total_step, loss.item(), accuracy.item()))
# ================================================================== #
# Tensorboard Logging #
# ================================================================== #
# 1. Log scalar values (scalar summary)
info = { 'loss': loss.item(), 'accuracy': accuracy.item() }
for tag, value in info.items():
logger.scalar_summary(tag, value, step+1)
# 2. Log values and gradients of the parameters (histogram summary)
for tag, value in model.named_parameters():
tag = tag.replace('.', '/')
logger.histo_summary(tag, value.data.cpu().numpy(), step+1)
logger.histo_summary(tag+'/grad', value.grad.data.cpu().numpy(), step+1)
# 3. Log training images (image summary)
info = { 'images': images.view(-1, 28, 28)[:10].cpu().numpy() }
for tag, images in info.items():
logger.image_summary(tag, images, step+1)

View File

@ -0,0 +1,5 @@
tensorflow
torch
torchvision
scipy
numpy