mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-27 12:03:31 +08:00
vanilla gan added
This commit is contained in:
@ -27,11 +27,13 @@ class Discriminator(nn.Module):
|
||||
def __init__(self):
|
||||
super(Discriminator, self).__init__()
|
||||
self.fc1 = nn.Linear(784, 256)
|
||||
self.fc2 = nn.Linear(256, 1)
|
||||
self.fc2 = nn.Linear(256, 256)
|
||||
self.fc3 = nn.Linear(256, 1)
|
||||
|
||||
def forward(self, x):
|
||||
h = F.relu(self.fc1(x))
|
||||
out = F.sigmoid(self.fc2(h))
|
||||
h = F.relu(self.fc2(h))
|
||||
out = F.sigmoid(self.fc3(h))
|
||||
return out
|
||||
|
||||
# Generator Model
|
||||
@ -39,8 +41,8 @@ class Generator(nn.Module):
|
||||
def __init__(self):
|
||||
super(Generator, self).__init__()
|
||||
self.fc1 = nn.Linear(128, 256)
|
||||
self.fc2 = nn.Linear(256, 512)
|
||||
self.fc3 = nn.Linear(512, 784)
|
||||
self.fc2 = nn.Linear(256, 256)
|
||||
self.fc3 = nn.Linear(256, 784)
|
||||
|
||||
def forward(self, x):
|
||||
h = F.leaky_relu(self.fc1(x))
|
||||
@ -101,7 +103,7 @@ for epoch in range(200):
|
||||
# Save the sampled images
|
||||
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
|
||||
torchvision.utils.save_image(fake_images.data,
|
||||
'./data2/fake_samples_%d.png' %epoch+1)
|
||||
'./data/fake_samples_%d.png' %(epoch+1))
|
||||
|
||||
# Save the Models
|
||||
torch.save(generator.state_dict(), './generator.pkl')
|
||||
|
@ -27,11 +27,13 @@ class Discriminator(nn.Module):
|
||||
def __init__(self):
|
||||
super(Discriminator, self).__init__()
|
||||
self.fc1 = nn.Linear(784, 256)
|
||||
self.fc2 = nn.Linear(256, 1)
|
||||
self.fc2 = nn.Linear(256, 256)
|
||||
self.fc3 = nn.Linear(256, 1)
|
||||
|
||||
def forward(self, x):
|
||||
h = F.relu(self.fc1(x))
|
||||
out = F.sigmoid(self.fc2(h))
|
||||
h = F.relu(self.fc2(h))
|
||||
out = F.sigmoid(self.fc3(h))
|
||||
return out
|
||||
|
||||
# Generator Model
|
||||
@ -39,8 +41,8 @@ class Generator(nn.Module):
|
||||
def __init__(self):
|
||||
super(Generator, self).__init__()
|
||||
self.fc1 = nn.Linear(128, 256)
|
||||
self.fc2 = nn.Linear(256, 512)
|
||||
self.fc3 = nn.Linear(512, 784)
|
||||
self.fc2 = nn.Linear(256, 256)
|
||||
self.fc3 = nn.Linear(256, 784)
|
||||
|
||||
def forward(self, x):
|
||||
h = F.leaky_relu(self.fc1(x))
|
||||
@ -101,7 +103,7 @@ for epoch in range(200):
|
||||
# Save the sampled images
|
||||
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
|
||||
torchvision.utils.save_image(fake_images.data,
|
||||
'./data2/fake_samples_%d.png' %epoch+1)
|
||||
'./data/fake_samples_%d.png' %(epoch+1))
|
||||
|
||||
# Save the Models
|
||||
torch.save(generator.state_dict(), './generator.pkl')
|
||||
|
Reference in New Issue
Block a user