From 69145a6685c08ecf1adcceec54bbb0fbb0d062ca Mon Sep 17 00:00:00 2001 From: yunjey Date: Sun, 26 Mar 2017 18:25:42 +0900 Subject: [PATCH] vanilla gan added --- .../10 - Generative Adversarial Network/main-gpu.py | 12 +++++++----- .../10 - Generative Adversarial Network/main.py | 12 +++++++----- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/tutorials/10 - Generative Adversarial Network/main-gpu.py b/tutorials/10 - Generative Adversarial Network/main-gpu.py index 9ca986d..4a741bc 100644 --- a/tutorials/10 - Generative Adversarial Network/main-gpu.py +++ b/tutorials/10 - Generative Adversarial Network/main-gpu.py @@ -27,11 +27,13 @@ class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.fc1 = nn.Linear(784, 256) - self.fc2 = nn.Linear(256, 1) + self.fc2 = nn.Linear(256, 256) + self.fc3 = nn.Linear(256, 1) def forward(self, x): h = F.relu(self.fc1(x)) - out = F.sigmoid(self.fc2(h)) + h = F.relu(self.fc2(h)) + out = F.sigmoid(self.fc3(h)) return out # Generator Model @@ -39,8 +41,8 @@ class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.fc1 = nn.Linear(128, 256) - self.fc2 = nn.Linear(256, 512) - self.fc3 = nn.Linear(512, 784) + self.fc2 = nn.Linear(256, 256) + self.fc3 = nn.Linear(256, 784) def forward(self, x): h = F.leaky_relu(self.fc1(x)) @@ -101,7 +103,7 @@ for epoch in range(200): # Save the sampled images fake_images = fake_images.view(fake_images.size(0), 1, 28, 28) torchvision.utils.save_image(fake_images.data, - './data2/fake_samples_%d.png' %epoch+1) + './data/fake_samples_%d.png' %(epoch+1)) # Save the Models torch.save(generator.state_dict(), './generator.pkl') diff --git a/tutorials/10 - Generative Adversarial Network/main.py b/tutorials/10 - Generative Adversarial Network/main.py index 0d7df09..2c7cbc1 100644 --- a/tutorials/10 - Generative Adversarial Network/main.py +++ b/tutorials/10 - Generative Adversarial Network/main.py @@ -27,11 +27,13 @@ class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.fc1 = nn.Linear(784, 256) - self.fc2 = nn.Linear(256, 1) + self.fc2 = nn.Linear(256, 256) + self.fc3 = nn.Linear(256, 1) def forward(self, x): h = F.relu(self.fc1(x)) - out = F.sigmoid(self.fc2(h)) + h = F.relu(self.fc2(h)) + out = F.sigmoid(self.fc3(h)) return out # Generator Model @@ -39,8 +41,8 @@ class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.fc1 = nn.Linear(128, 256) - self.fc2 = nn.Linear(256, 512) - self.fc3 = nn.Linear(512, 784) + self.fc2 = nn.Linear(256, 256) + self.fc3 = nn.Linear(256, 784) def forward(self, x): h = F.leaky_relu(self.fc1(x)) @@ -101,7 +103,7 @@ for epoch in range(200): # Save the sampled images fake_images = fake_images.view(fake_images.size(0), 1, 28, 28) torchvision.utils.save_image(fake_images.data, - './data2/fake_samples_%d.png' %epoch+1) + './data/fake_samples_%d.png' %(epoch+1)) # Save the Models torch.save(generator.state_dict(), './generator.pkl')