mirror of
https://github.com/yunjey/pytorch-tutorial.git
synced 2025-07-27 12:03:31 +08:00
vanilla gan added
This commit is contained in:
@ -27,11 +27,13 @@ class Discriminator(nn.Module):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Discriminator, self).__init__()
|
super(Discriminator, self).__init__()
|
||||||
self.fc1 = nn.Linear(784, 256)
|
self.fc1 = nn.Linear(784, 256)
|
||||||
self.fc2 = nn.Linear(256, 1)
|
self.fc2 = nn.Linear(256, 256)
|
||||||
|
self.fc3 = nn.Linear(256, 1)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h = F.relu(self.fc1(x))
|
h = F.relu(self.fc1(x))
|
||||||
out = F.sigmoid(self.fc2(h))
|
h = F.relu(self.fc2(h))
|
||||||
|
out = F.sigmoid(self.fc3(h))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
# Generator Model
|
# Generator Model
|
||||||
@ -39,8 +41,8 @@ class Generator(nn.Module):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Generator, self).__init__()
|
super(Generator, self).__init__()
|
||||||
self.fc1 = nn.Linear(128, 256)
|
self.fc1 = nn.Linear(128, 256)
|
||||||
self.fc2 = nn.Linear(256, 512)
|
self.fc2 = nn.Linear(256, 256)
|
||||||
self.fc3 = nn.Linear(512, 784)
|
self.fc3 = nn.Linear(256, 784)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h = F.leaky_relu(self.fc1(x))
|
h = F.leaky_relu(self.fc1(x))
|
||||||
@ -101,7 +103,7 @@ for epoch in range(200):
|
|||||||
# Save the sampled images
|
# Save the sampled images
|
||||||
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
|
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
|
||||||
torchvision.utils.save_image(fake_images.data,
|
torchvision.utils.save_image(fake_images.data,
|
||||||
'./data2/fake_samples_%d.png' %epoch+1)
|
'./data/fake_samples_%d.png' %(epoch+1))
|
||||||
|
|
||||||
# Save the Models
|
# Save the Models
|
||||||
torch.save(generator.state_dict(), './generator.pkl')
|
torch.save(generator.state_dict(), './generator.pkl')
|
||||||
|
@ -27,11 +27,13 @@ class Discriminator(nn.Module):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Discriminator, self).__init__()
|
super(Discriminator, self).__init__()
|
||||||
self.fc1 = nn.Linear(784, 256)
|
self.fc1 = nn.Linear(784, 256)
|
||||||
self.fc2 = nn.Linear(256, 1)
|
self.fc2 = nn.Linear(256, 256)
|
||||||
|
self.fc3 = nn.Linear(256, 1)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h = F.relu(self.fc1(x))
|
h = F.relu(self.fc1(x))
|
||||||
out = F.sigmoid(self.fc2(h))
|
h = F.relu(self.fc2(h))
|
||||||
|
out = F.sigmoid(self.fc3(h))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
# Generator Model
|
# Generator Model
|
||||||
@ -39,8 +41,8 @@ class Generator(nn.Module):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Generator, self).__init__()
|
super(Generator, self).__init__()
|
||||||
self.fc1 = nn.Linear(128, 256)
|
self.fc1 = nn.Linear(128, 256)
|
||||||
self.fc2 = nn.Linear(256, 512)
|
self.fc2 = nn.Linear(256, 256)
|
||||||
self.fc3 = nn.Linear(512, 784)
|
self.fc3 = nn.Linear(256, 784)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h = F.leaky_relu(self.fc1(x))
|
h = F.leaky_relu(self.fc1(x))
|
||||||
@ -101,7 +103,7 @@ for epoch in range(200):
|
|||||||
# Save the sampled images
|
# Save the sampled images
|
||||||
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
|
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
|
||||||
torchvision.utils.save_image(fake_images.data,
|
torchvision.utils.save_image(fake_images.data,
|
||||||
'./data2/fake_samples_%d.png' %epoch+1)
|
'./data/fake_samples_%d.png' %(epoch+1))
|
||||||
|
|
||||||
# Save the Models
|
# Save the Models
|
||||||
torch.save(generator.state_dict(), './generator.pkl')
|
torch.save(generator.state_dict(), './generator.pkl')
|
||||||
|
Reference in New Issue
Block a user