vanilla gan added

This commit is contained in:
yunjey
2017-03-26 18:25:42 +09:00
parent a438f8e6fc
commit 69145a6685
2 changed files with 14 additions and 10 deletions

View File

@ -27,11 +27,13 @@ class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 1)
self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(256, 1)
def forward(self, x):
h = F.relu(self.fc1(x))
out = F.sigmoid(self.fc2(h))
h = F.relu(self.fc2(h))
out = F.sigmoid(self.fc3(h))
return out
# Generator Model
@ -39,8 +41,8 @@ class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fc1 = nn.Linear(128, 256)
self.fc2 = nn.Linear(256, 512)
self.fc3 = nn.Linear(512, 784)
self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(256, 784)
def forward(self, x):
h = F.leaky_relu(self.fc1(x))
@ -101,7 +103,7 @@ for epoch in range(200):
# Save the sampled images
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
torchvision.utils.save_image(fake_images.data,
'./data2/fake_samples_%d.png' %epoch+1)
'./data/fake_samples_%d.png' %(epoch+1))
# Save the Models
torch.save(generator.state_dict(), './generator.pkl')

View File

@ -27,11 +27,13 @@ class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 1)
self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(256, 1)
def forward(self, x):
h = F.relu(self.fc1(x))
out = F.sigmoid(self.fc2(h))
h = F.relu(self.fc2(h))
out = F.sigmoid(self.fc3(h))
return out
# Generator Model
@ -39,8 +41,8 @@ class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fc1 = nn.Linear(128, 256)
self.fc2 = nn.Linear(256, 512)
self.fc3 = nn.Linear(512, 784)
self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(256, 784)
def forward(self, x):
h = F.leaky_relu(self.fc1(x))
@ -101,7 +103,7 @@ for epoch in range(200):
# Save the sampled images
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
torchvision.utils.save_image(fake_images.data,
'./data2/fake_samples_%d.png' %epoch+1)
'./data/fake_samples_%d.png' %(epoch+1))
# Save the Models
torch.save(generator.state_dict(), './generator.pkl')