tutorial updated

This commit is contained in:
yunjey
2017-05-28 20:06:40 +09:00
parent c53c48809e
commit c548e2ae9f
71 changed files with 1102 additions and 1123 deletions

View File

@ -0,0 +1,24 @@
## Variational Auto-Encoder
[Variational Auto-Encoder(VAE)](https://arxiv.org/abs/1312.6114) is one of the generative model. From a neural network perspective, the only difference between the VAE and the Auto-Encoder(AE) is that the latent vector z in VAE is stochastically sampled. This solves the problem that the AE learns identity mapping and can not have meaningful representations in latent space. In fact, the VAE uses [reparameterization trick](https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/03-advanced/variational_auto_encoder/main.py#L40-L44) to enable back propagation without sampling z directly from the mean and variance.
#### VAE loss
As in conventional auto-encoders, the VAE minimizes the reconstruction loss between the input image and the generated image. In addition, the VAE approximates z to the standard normal distribution so that the decoder in the VAE can be used for sampling in the test phase.
<p align="center"><img width="100%" src="png/vae.png" /></p>
## Usage
```bash
$ pip install -r requirements.txt
$ python main.py
```
<br>
## Results
Real image | Reconstruced image
:-------------------------:|:-------------------------:
![alt text](png/real.png) | ![alt text](png/reconst.png)

View File

@ -0,0 +1,98 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import datasets
from torchvision import transforms
import torchvision
# MNIST dataset
dataset = datasets.MNIST(root='./data',
train=True,
transform=transforms.ToTensor(),
download=True)
# Data loader
data_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=100,
shuffle=True)
def to_var(x):
if torch.cuda.is_available():
x = x.cuda()
return Variable(x)
# VAE model
class VAE(nn.Module):
def __init__(self, image_size=784, h_dim=400, z_dim=20):
super(VAE, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(image_size, h_dim),
nn.LeakyReLU(0.2),
nn.Linear(h_dim, z_dim*2)) # 2 for mean and variance.
self.decoder = nn.Sequential(
nn.Linear(z_dim, h_dim),
nn.ReLU(),
nn.Linear(h_dim, image_size),
nn.Sigmoid())
def reparametrize(self, mu, log_var):
""""z = mean + eps * sigma where eps is sampled from N(0, 1)."""
eps = to_var(torch.randn(mu.size(0), mu.size(1)))
z = mu + eps * torch.exp(log_var/2) # 2 for convert var to std
return z
def forward(self, x):
h = self.encoder(x)
mu, log_var = torch.chunk(h, 2, dim=1) # mean and log variance.
z = self.reparametrize(mu, log_var)
out = self.decoder(z)
return out, mu, log_var
def sample(self, z):
return self.decoder(z)
vae = VAE()
if torch.cuda.is_available():
vae.cuda()
optimizer = torch.optim.Adam(vae.parameters(), lr=0.001)
iter_per_epoch = len(data_loader)
data_iter = iter(data_loader)
# fixed inputs for debugging
fixed_z = to_var(torch.randn(100, 20))
fixed_x, _ = next(data_iter)
torchvision.utils.save_image(fixed_x.data.cpu(), './data/real_images.png')
fixed_x = to_var(fixed_x.view(fixed_x.size(0), -1))
for epoch in range(50):
for i, (images, _) in enumerate(data_loader):
images = to_var(images.view(images.size(0), -1))
out, mu, log_var = vae(images)
# Compute reconstruction loss and kl divergence
# For kl_divergence, see Appendix B in the paper or http://yunjey47.tistory.com/43
reconst_loss = F.binary_cross_entropy(out, images, size_average=False)
kl_divergence = torch.sum(0.5 * (mu**2 + torch.exp(log_var) - log_var -1))
# Backprop + Optimize
total_loss = reconst_loss + kl_divergence
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
if i % 100 == 0:
print ("Epoch[%d/%d], Step [%d/%d], Total Loss: %.4f, "
"Reconst Loss: %.4f, KL Div: %.7f"
%(epoch+1, 50, i+1, iter_per_epoch, total_loss.data[0],
reconst_loss.data[0], kl_divergence.data[0]))
# Save the reconstructed images
reconst_images, _, _ = vae(fixed_x)
reconst_images = reconst_images.view(reconst_images.size(0), 1, 28, 28)
torchvision.utils.save_image(reconst_images.data.cpu(),
'./data/reconst_images_%d.png' %(epoch+1))

Binary file not shown.

After

Width:  |  Height:  |  Size: 61 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 71 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 189 KiB

View File

@ -0,0 +1,2 @@
torch
torchvision