සඳහාපුහුණු කේතය මෙයයි StyleGan 2 ආකෘතිය.

මේවා80K පියවර සඳහා පුහුණුවෙන් පසු ජනනය කරන ලද රූප වේ.
අපගේක්රියාත්මක කිරීම අවම වේ StyleGAN 2 ආදර්ශ පුහුණු කේතය. ක්රියාත්මක කිරීම සරල ලෙස තබා ගැනීම සඳහා සහාය වන්නේ තනි GPU පුහුණුවක් පමණි. පුහුණු ලූපය ඇතුළුව කේත පේළි 500 කට වඩා අඩු මට්ටමක තබා ගැනීම සඳහා එය හැකිලීමට අපට හැකි විය.
ඩීඩීපී(බෙදා හරින ලද දත්ත සමාන්තරව) සහ බහු-gpu පුහුණුව නොමැතිව විශාල විභේදන (128+) සඳහා ආකෘතිය පුහුණු කිරීමට නොහැකි වනු ඇත. ඔබට fp16 සහ DDP සමඟ පුහුණු කේතය අවශ්ය නම් ලුසිඩ්රයිව්/ස්ටයිලෙගන්2-පයිටෝච්දෙස බලන්න.
අපිමෙය සෙලෙබා-එච්කියු දත්ත කට්ටලයමත පුහුණු කළෙමු. fast.ai හි මෙම සාකච්ඡාවේදීබාගත කිරීමේ උපදෙස් ඔබට සොයාගත හැකිය. data/stylegan
ෆෝල්ඩරයතුළ පින්තූර සුරකින්න.
31import math
32from pathlib import Path
33from typing import Iterator, Tuple
34
35import torch
36import torch.utils.data
37import torchvision
38from PIL import Image
39
40from labml import tracker, lab, monit, experiment
41from labml.configs import BaseConfigs
42from labml_helpers.device import DeviceConfigs
43from labml_helpers.train_valid import ModeState, hook_model_outputs
44from labml_nn.gan.stylegan import Discriminator, Generator, MappingNetwork, GradientPenalty, PathLengthPenalty
45from labml_nn.gan.wasserstein import DiscriminatorLoss, GeneratorLoss
46from labml_nn.utils import cycle_dataloader49class Dataset(torch.utils.data.Dataset):path
පින්තූර අඩංගු ෆෝල්ඩරයට යන මාර්ගය image_size
රූපයේ ප්රමාණය56 def __init__(self, path: str, image_size: int):61 super().__init__()සියලුම jpg
ලිපිගොනු වල මාර්ග ලබා ගන්න
64 self.paths = [p for p in Path(path).glob(f'**/*.jpg')]පරිවර්තනය
67 self.transform = torchvision.transforms.Compose([රූපයවෙනස් කරන්න
69 torchvision.transforms.Resize(image_size),PyTorchටෙන්සරය බවට පරිවර්තනය කරන්න
71 torchvision.transforms.ToTensor(),
72 ])රූපගණන
74 def __len__(self):76 return len(self.paths)index
-th රූපය ලබා ගන්න
78 def __getitem__(self, index):80 path = self.paths[index]
81 img = Image.open(path)
82 return self.transform(img)85class Configs(BaseConfigs):ආකෘතියපුහුණු කිරීමේ උපකරණය. DeviceConfigs
ලබා ගත හැකි CUDA උපාංගයක් අහුලනවා හෝ CPU කිරීමට පෙරනිමි.
93 device: torch.device = DeviceConfigs()96 discriminator: Discriminator98 generator: Generator100 mapping_network: MappingNetworkවෙනස්කම්කරන්නා සහ උත්පාදක නැතිවීමේ කාර්යයන්. අපි වොසර්ස්ටයින් නැතිවීම භාවිතා කරමු
104 discriminator_loss: DiscriminatorLoss
105 generator_loss: GeneratorLossප්රශස්තකරණය
108 generator_optimizer: torch.optim.Adam
109 discriminator_optimizer: torch.optim.Adam
110 mapping_network_optimizer: torch.optim.Adam113 gradient_penalty = GradientPenalty()ශ්රේණියේදණ්ඩන සංගුණකය
115 gradient_penalty_coefficient: float = 10.118 path_length_penalty: PathLengthPenaltyදත්තකාරකය
121 loader: Iteratorකණ්ඩායම්ප්රමාණය
124 batch_size: int = 32මානයන්හි සහ
126 d_latent: int = 512රූපයේඋස/පළල
128 image_size: int = 32සිතියම්කරණජාලයේ ස්ථර ගණන
130 mapping_network_layers: int = 8උත්පාදකසහ වෙනස්කම් කිරීමේ ඉගෙනුම් අනුපාතය
132 learning_rate: float = 1e-3ජාලඉගෙනුම් අනුපාතය සිතියම්ගත කිරීම (අනෙක් ඒවාට වඩා අඩු)
134 mapping_network_learning_rate: float = 1e-5මතඵලය අනුක්රමික සමුච්චය කිරීමට පියවර ගණන. Effective ලදායී කණ්ඩායම් ප්රමාණය වැඩි කිරීමට මෙය භාවිතා කරන්න.
136 gradient_accumulate_steps: int = 1සහ ආදම් ප්රශස්තකරණය සඳහා
138 adam_betas: Tuple[float, float] = (0.0, 0.99)මෝස්තරමිශ්ර කිරීමේ සම්භාවිතාව
140 style_mixing_prob: float = 0.9මුළුපුහුණු පියවර ගණන
143 training_steps: int = 150_000උත්පාදකයන්ත්රයේ බ්ලොක් ගණන (රූප විභේදනය මත පදනම්ව ගණනය කරනු ලැබේ)
146 n_gen_blocks: intනියාමනයකිරීමේ පාඩු ගණනය කිරීම වෙනුවට, කඩදාසි කම්මැලි නියාමනය යෝජනා කරන අතර එහිදී විධිමත් කිරීමේ නියමයන් වරකට වරක් ගණනය කරනු ලැබේ. මෙය පුහුණු කාර්යක්ෂමතාව වැඩි දියුණු කරයි.
අනුක්රමිකද penalty ුවම ගණනය කළ යුතු කාල සීමාව
154 lazy_gradient_penalty_interval: int = 4මාර්ගදිග ද penalty ුවම් ගණනය කිරීමේ පරතරය
156 lazy_path_penalty_interval: int = 32පුහුණුවේආරම්භක අදියරේදී මාර්ග දිග ද penalty ුවම ගණනය කිරීම මඟ හරින්න
158 lazy_path_penalty_after: int = 5_000ජනනයකරන ලද පින්තූර ලොග් කරන්නේ කෙසේද?
161 log_generated_interval: int = 500ආදර්ශමුරපොලවල් සුරැකීමට කොපමණ වාරයක්
163 save_checkpoint_interval: int = 2_000ලොග්වීම සක්රිය කිරීම සඳහා පුහුණු මාදිලියේ තත්වය
166 mode: ModeStateආදර්ශස්ථර ප්රතිදානයන් ලොග් කළ යුතුද යන්න
168 log_layer_outputs: bool = False අපි මෙය සෙලෙබා-එච්කියු දත්ත කට්ටලය මත පුහුණු කළෙමු. fast.ai හි මෙම සාකච්ඡාවේදීබාගත කිරීමේ උපදෙස් ඔබට සොයාගත හැකිය. data/stylegan
ෆෝල්ඩරය තුළ පින්තූර සුරකින්න.
175 dataset_path: str = str(lab.get_data_path() / 'stylegan2')177 def init(self):දත්තසමුදාය සාදන්න
182 dataset = Dataset(self.dataset_path, self.image_size)දත්තපැටවුම සාදන්න
184 dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=8,
185 shuffle=True, drop_last=True, pin_memory=True)අඛණ්ඩ චක්රීය කාරකය
187 self.loader = cycle_dataloader(dataloader)රූප විභේදනයේ
190 log_resolution = int(math.log2(self.image_size))වෙනස්කම්කරන්නා සහ උත්පාදක යන්ත්රය සාදන්න
193 self.discriminator = Discriminator(log_resolution).to(self.device)
194 self.generator = Generator(log_resolution, self.d_latent).to(self.device)ශෛලියසහ ශබ්ද යෙදවුම් නිර්මාණය කිරීම සඳහා උත්පාදක කොටස් ගණන ලබා ගන්න
196 self.n_gen_blocks = self.generator.n_blocksසිතියම්කරණජාලයක් සාදන්න
198 self.mapping_network = MappingNetwork(self.d_latent, self.mapping_network_layers).to(self.device)මාර්ගදිග ද penalty ුවම් නැතිවීම
200 self.path_length_penalty = PathLengthPenalty(0.99).to(self.device)ස්ථරප්රතිදානයන් අධීක්ෂණය කිරීම සඳහා ආදර්ශ කොකු එක් කරන්න
203 if self.log_layer_outputs:
204 hook_model_outputs(self.mode, self.discriminator, 'discriminator')
205 hook_model_outputs(self.mode, self.generator, 'generator')
206 hook_model_outputs(self.mode, self.mapping_network, 'mapping_network')වෙනස්කම්කරන්නා සහ උත්පාදක පාඩු
209 self.discriminator_loss = DiscriminatorLoss().to(self.device)
210 self.generator_loss = GeneratorLoss().to(self.device)ප්රශස්තකරණයසාදන්න
213 self.discriminator_optimizer = torch.optim.Adam(
214 self.discriminator.parameters(),
215 lr=self.learning_rate, betas=self.adam_betas
216 )
217 self.generator_optimizer = torch.optim.Adam(
218 self.generator.parameters(),
219 lr=self.learning_rate, betas=self.adam_betas
220 )
221 self.mapping_network_optimizer = torch.optim.Adam(
222 self.mapping_network.parameters(),
223 lr=self.mapping_network_learning_rate, betas=self.adam_betas
224 )ට්රැකර්වින්යාසයන් සකසන්න
227 tracker.set_image("generated", True)මෙමසාම්පල අහඹු ලෙස සහ සිතියම්කරණ ජාලයෙන් ලබා ගන්න.
ඒවගේම අපි ගුප්ත විචල්යයන් දෙකක් ජනනය හා අනුරූප ලබා එහිදී සමහර විට ශෛලිය මිශ්ර අදාළ වේ. ඉන්පසු අපි අහඹු ලෙස හරස් ඕවර් ලක්ෂ්යයක් සාම්පල කර හරස් ඕවර් ලක්ෂ්යයට පෙර උත්පාදක කුට්ටි වලට සහ පසුව කුට්ටි වලට අදාළ වෙමු.
229 def get_w(self, batch_size: int):මෝස්තරමිශ්ර කරන්න
243 if torch.rand(()).item() < self.style_mixing_prob:අහඹුහරස් ඕවර් ලක්ෂ්යය
245 cross_over_point = int(torch.rand(()).item() * self.n_gen_blocks)නියැදිය සහ
247 z2 = torch.randn(batch_size, self.d_latent).to(self.device)
248 z1 = torch.randn(batch_size, self.d_latent).to(self.device)ලබා ගන්න
250 w1 = self.mapping_network(z1)
251 w2 = self.mapping_network(z2)උත්පාදකකුට්ටි සහ සංයුක්ත කිරීම සඳහා පුළුල් කරන්න
253 w1 = w1[None, :, :].expand(cross_over_point, -1, -1)
254 w2 = w2[None, :, :].expand(self.n_gen_blocks - cross_over_point, -1, -1)
255 return torch.cat((w1, w2), dim=0)මිශ්රනොකර
257 else:නියැදිය සහ
259 z = torch.randn(batch_size, self.d_latent).to(self.device)ලබා ගන්න
261 w = self.mapping_network(z)උත්පාදකකොටස් සඳහා පුළුල් කරන්න
263 return w[None, :, :].expand(self.n_gen_blocks, -1, -1)265 def get_noise(self, batch_size: int):ශබ්දයගබඩා කිරීමට ලැයිස්තුව
272 noise = []ශබ්දවිභේදනය ආරම්භ වන්නේ
274 resolution = 4එක්එක් උත්පාදක කොටස සඳහා ශබ්දය ජනනය කරන්න
277 for i in range(self.n_gen_blocks):පළමුකොටස ඇත්තේ එක් සංවහනයක් පමණි
279 if i == 0:
280 n1 = Noneපළමුකැටි ගැසුණු ස්ථරයෙන් පසු එකතු කිරීම සඳහා ශබ්දය ජනනය කරන්න
282 else:
283 n1 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)දෙවනකැටි ගැසුණු ස්ථරයට පසු එකතු කිරීම සඳහා ශබ්දය ජනනය කරන්න
285 n2 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)ලැයිස්තුවටශබ්ද ආතතීන් එක් කරන්න
288 noise.append((n1, n2))ඊළඟවාරණයට විභේදනයක් ඇත
291 resolution *= 2ආපසුශබ්ද ආතතීන්
294 return noise296 def generate_images(self, batch_size: int):ලබාගන්න
304 w = self.get_w(batch_size)ශබ්දයලබා ගන්න
306 noise = self.get_noise(batch_size)රූපජනනය කරන්න
309 images = self.generator(w, noise)ආපසුපින්තූර සහ
312 return images, w314 def step(self, idx: int):වෙනස්කම්කරන්නා පුහුණු කරන්න
320 with monit.section('Discriminator'):අනුක්රමිකයළි පිහිටුවන්න
322 self.discriminator_optimizer.zero_grad()සඳහාඅනුක්රමික සමුච්චය gradient_accumulate_steps
325 for i in range(self.gradient_accumulate_steps):යාවත්කාලීනකිරීම mode
. සක්රිය කිරීම ලොග් කළ යුතුද යන්න සකසන්න
327 with self.mode.update(is_log_activations=(idx + 1) % self.log_generated_interval == 0):උත්පාදකයන්ත්රයෙන් නියැදි පින්තූර
329 generated_images, _ = self.generate_images(self.batch_size)ජනනයකරන ලද රූප සඳහා වෙනස්කම් කිරීමේ වර්ගීකරණය
331 fake_output = self.discriminator(generated_images.detach())දත්තපැටවුමෙන් සැබෑ රූප ලබා ගන්න
334 real_images = next(self.loader).to(self.device)අපිඵලය අනුක්රමික දඬුවම සඳහා wr. සැබෑ රූප ගණනය කිරීමට අවශ්ය
336 if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
337 real_images.requires_grad_()සැබෑරූප සඳහා වෙනස්කම් කිරීමේ වර්ගීකරණය
339 real_output = self.discriminator(real_images)වෙනස්කම්කරන්නාගේ පාඩුව ලබා ගන්න
342 real_loss, fake_loss = self.discriminator_loss(real_output, fake_output)
343 disc_loss = real_loss + fake_lossශ්රේණියේද penalty ුවම එකතු කරන්න
346 if (idx + 1) % self.lazy_gradient_penalty_interval == 0:ශ්රේණියේද penalty ුවම ගණනය කර ලොග් කරන්න
348 gp = self.gradient_penalty(real_images, real_output)
349 tracker.add('loss.gp', gp)සංගුණකයමගින් ගුණ කර ශ්රේණියේ ද penalty ුවම එකතු කරන්න
351 disc_loss = disc_loss + 0.5 * self.gradient_penalty_coefficient * gp * self.lazy_gradient_penalty_intervalඅනුක්රමිකගණනය
354 disc_loss.backward()ලොග්වෙනස්කම් කරන්නාගේ අලාභය
357 tracker.add('loss.discriminator', disc_loss)
358
359 if (idx + 1) % self.log_generated_interval == 0:විටින්විට ලොග් වෙනස්කම් කිරීමේ ආකෘති පරාමිතීන්
361 tracker.add('discriminator', self.discriminator)ස්ථායීකරණයසඳහා ක්ලිප් අනුක්රමික
364 torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=1.0)ප්රශස්තිකරණපියවර ගන්න
366 self.discriminator_optimizer.step()උත්පාදකයන්ත්රය පුහුණු කරන්න
369 with monit.section('Generator'):අනුක්රමිකයළි පිහිටුවන්න
371 self.generator_optimizer.zero_grad()
372 self.mapping_network_optimizer.zero_grad()සඳහාඅනුක්රමික සමුච්චය gradient_accumulate_steps
375 for i in range(self.gradient_accumulate_steps):උත්පාදකයන්ත්රයෙන් නියැදි පින්තූර
377 generated_images, w = self.generate_images(self.batch_size)ජනනයකරන ලද රූප සඳහා වෙනස්කම් කිරීමේ වර්ගීකරණය
379 fake_output = self.discriminator(generated_images)උත්පාදකනැතිවීම ලබා ගන්න
382 gen_loss = self.generator_loss(fake_output)මාර්ගදිග ද penalty ුවම එකතු
385 if idx > self.lazy_path_penalty_after and (idx + 1) % self.lazy_path_penalty_interval == 0:මාර්ගයදිග ද penalty ුවම ගණනය
387 plp = self.path_length_penalty(w, generated_images)නම්නොසලකා හරින්න nan
389 if not torch.isnan(plp):
390 tracker.add('loss.plp', plp)
391 gen_loss = gen_loss + plpඅනුක්රමිකගණනය කරන්න
394 gen_loss.backward()ලොග්උත්පාදක නැතිවීම
397 tracker.add('loss.generator', gen_loss)
398
399 if (idx + 1) % self.log_generated_interval == 0:විටින්විට ලොග් වෙනස්කම් කිරීමේ ආකෘති පරාමිතීන්
401 tracker.add('generator', self.generator)
402 tracker.add('mapping_network', self.mapping_network)ස්ථායීකරණයසඳහා ක්ලිප් අනුක්රමික
405 torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=1.0)
406 torch.nn.utils.clip_grad_norm_(self.mapping_network.parameters(), max_norm=1.0)ප්රශස්තිකරණපියවර ගන්න
409 self.generator_optimizer.step()
410 self.mapping_network_optimizer.step()ලොග්ජනනය කරන ලද රූප
413 if (idx + 1) % self.log_generated_interval == 0:
414 tracker.add('generated', torch.cat([generated_images[:6], real_images[:3]], dim=0))ආදර්ශමුරපොලවල් සුරකින්න
416 if (idx + 1) % self.save_checkpoint_interval == 0:
417 experiment.save_checkpoint()ෆ්ලෂ්ට්රැකර්
420 tracker.save()422 def train(self):ලූප්සඳහා training_steps
428 for i in monit.loop(self.training_steps):පුහුණුපියවරක් ගන්න
430 self.step(i)432 if (i + 1) % self.log_generated_interval == 0:
433 tracker.new_line()436def main():අත්හදාබැලීමක් සාදන්න
442 experiment.create(name='stylegan2')වින්යාසවස්තුව සාදන්න
444 configs = Configs()වින්යාසයන්සකසන්න සහ සමහර ඒවා අභිබවා යන්න
447 experiment.configs(configs, {
448 'device.cuda_device': 0,
449 'image_size': 64,
450 'log_generated_interval': 200
451 })ආරම්භකරන්න
454 configs.init()ඉතිරිකිරීම සහ පැටවීම සඳහා ආකෘති සකසන්න
456 experiment.add_pytorch_models(mapping_network=configs.mapping_network,
457 generator=configs.generator,
458 discriminator=configs.discriminator)අත්හදාබැලීම ආරම්භ කරන්න
461 with experiment.start():පුහුණුලූපය ධාවනය කරන්න
463 configs.train()467if __name__ == '__main__':
468 main()