මෙය PyTorch ගැඹුරු සංකෝචන උත්පාදක අහිතකර ජාලයන් සමඟ කඩදාසි අධීක්ෂණය නොකළ නියෝජන ඉගෙනුම් ක්රියාත්මක කිරීමයි.
මෙමක්රියාත්මක කිරීම PyTorch DCGAN නිබන්ධනයමත පදනම් වේ.
15import torch.nn as nn
16
17from labml import experiment
18from labml.configs import calculate
19from labml_helpers.module import Module
20from labml_nn.gan.original.experiment import Configsමෙයසෙලෙබා මුහුණු සඳහා භාවිතා කරන ද-සංවහන ජාලයට සමාන වන නමුත් MNIST රූප සඳහා වෙනස් කර ඇත.

23class Generator(Module):33 def __init__(self):
34 super().__init__()ආදානයනාලිකා 100 ක් සමඟ ඇත
36 self.layers = nn.Sequential(මෙය ප්රතිදානය ලබා දෙයි
38 nn.ConvTranspose2d(100, 1024, 3, 1, 0, bias=False),
39 nn.BatchNorm2d(1024),
40 nn.ReLU(True),මෙයලබා දෙයි
42 nn.ConvTranspose2d(1024, 512, 3, 2, 0, bias=False),
43 nn.BatchNorm2d(512),
44 nn.ReLU(True),මෙයලබා දෙයි
46 nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
47 nn.BatchNorm2d(256),
48 nn.ReLU(True),මෙයලබා දෙයි
50 nn.ConvTranspose2d(256, 1, 4, 2, 1, bias=False),
51 nn.Tanh()
52 )
53
54 self.apply(_weights_init)56 def forward(self, x):හැඩයෙන්වෙනස් [batch_size, 100]
කරන්න [batch_size, 100, 1, 1]
58 x = x.unsqueeze(-1).unsqueeze(-1)
59 x = self.layers(x)
60 return x63class Discriminator(Module):68 def __init__(self):
69 super().__init__()ආදානයඑක් නාලිකාවක් සමඟ ඇත
71 self.layers = nn.Sequential(මෙයලබා දෙයි
73 nn.Conv2d(1, 256, 4, 2, 1, bias=False),
74 nn.LeakyReLU(0.2, inplace=True),මෙයලබා දෙයි
76 nn.Conv2d(256, 512, 4, 2, 1, bias=False),
77 nn.BatchNorm2d(512),
78 nn.LeakyReLU(0.2, inplace=True),මෙයලබා දෙයි
80 nn.Conv2d(512, 1024, 3, 2, 0, bias=False),
81 nn.BatchNorm2d(1024),
82 nn.LeakyReLU(0.2, inplace=True),මෙයලබා දෙයි
84 nn.Conv2d(1024, 1, 3, 1, 0, bias=False),
85 )
86 self.apply(_weights_init)88 def forward(self, x):
89 x = self.layers(x)
90 return x.view(x.shape[0], -1)93def _weights_init(m):
94 classname = m.__class__.__name__
95 if classname.find('Conv') != -1:
96 nn.init.normal_(m.weight.data, 0.0, 0.02)
97 elif classname.find('BatchNorm') != -1:
98 nn.init.normal_(m.weight.data, 1.0, 0.02)
99 nn.init.constant_(m.bias.data, 0)අපි සරල ගැන් අත්හදා බැලීම් ආනයනය කර උත්පාදක යන්ත්රය සහ වෙනස්කම් කරන ජාල වෙනස් කරමු
104calculate(Configs.generator, 'cnn', lambda c: Generator().to(c.device))
105calculate(Configs.discriminator, 'cnn', lambda c: Discriminator().to(c.device))108def main():
109 conf = Configs()
110 experiment.create(name='mnist_dcgan')
111 experiment.configs(conf,
112 {'discriminator': 'cnn',
113 'generator': 'cnn',
114 'label_smoothing': 0.01})
115 with experiment.start():
116 conf.run()
117
118
119if __name__ == '__main__':
120 main()