StyleGan 2 ආදර්ශ පුහුණුව

සඳහාපුහුණු කේතය මෙයයි StyleGan 2 ආකෘතිය.

Generated Images

මේවා80K පියවර සඳහා පුහුණුවෙන් පසු ජනනය කරන ලද රූප වේ.

අපගේක්රියාත්මක කිරීම අවම වේ StyleGAN 2 ආදර්ශ පුහුණු කේතය. ක්රියාත්මක කිරීම සරල ලෙස තබා ගැනීම සඳහා සහාය වන්නේ තනි GPU පුහුණුවක් පමණි. පුහුණු ලූපය ඇතුළුව කේත පේළි 500 කට වඩා අඩු මට්ටමක තබා ගැනීම සඳහා එය හැකිලීමට අපට හැකි විය.

ඩීඩීපී(බෙදා හරින ලද දත්ත සමාන්තරව) සහ බහු-gpu පුහුණුව නොමැතිව විශාල විභේදන (128+) සඳහා ආකෘතිය පුහුණු කිරීමට නොහැකි වනු ඇත. ඔබට fp16 සහ DDP සමඟ පුහුණු කේතය අවශ්ය නම් ලුසිඩ්රයිව්/ස්ටයිලෙගන්2-පයිටෝච්දෙස බලන්න.

අපිමෙය සෙලෙබා-එච්කියු දත්ත කට්ටලයමත පුහුණු කළෙමු. fast.ai හි මෙම සාකච්ඡාවේදීබාගත කිරීමේ උපදෙස් ඔබට සොයාගත හැකිය. data/stylegan ෆෝල්ඩරයතුළ පින්තූර සුරකින්න.

31import math
32from pathlib import Path
33from typing import Iterator, Tuple
34
35import torch
36import torch.utils.data
37import torchvision
38from PIL import Image
39
40from labml import tracker, lab, monit, experiment
41from labml.configs import BaseConfigs
42from labml_helpers.device import DeviceConfigs
43from labml_helpers.train_valid import ModeState, hook_model_outputs
44from labml_nn.gan.stylegan import Discriminator, Generator, MappingNetwork, GradientPenalty, PathLengthPenalty
45from labml_nn.gan.wasserstein import DiscriminatorLoss, GeneratorLoss
46from labml_nn.utils import cycle_dataloader

දත්තකට්ටලය

මෙයපුහුණු දත්ත කට්ටලය පටවන අතර එය ලබා දෙන රූපයේ ප්රමාණයට වෙනස් කරයි.

49class Dataset(torch.utils.data.Dataset):
  • path පින්තූර අඩංගු ෆෝල්ඩරයට යන මාර්ගය
  • image_size රූපයේ ප්රමාණය
56    def __init__(self, path: str, image_size: int):
61        super().__init__()

සියලුම jpg ලිපිගොනු වල මාර්ග ලබා ගන්න

64        self.paths = [p for p in Path(path).glob(f'**/*.jpg')]

පරිවර්තනය

67        self.transform = torchvision.transforms.Compose([

රූපයවෙනස් කරන්න

69            torchvision.transforms.Resize(image_size),

PyTorchටෙන්සරය බවට පරිවර්තනය කරන්න

71            torchvision.transforms.ToTensor(),
72        ])

රූපගණන

74    def __len__(self):
76        return len(self.paths)

index -th රූපය ලබා ගන්න

78    def __getitem__(self, index):
80        path = self.paths[index]
81        img = Image.open(path)
82        return self.transform(img)

වින්යාසකිරීම්

85class Configs(BaseConfigs):

ආකෘතියපුහුණු කිරීමේ උපකරණය. DeviceConfigs ලබා ගත හැකි CUDA උපාංගයක් අහුලනවා හෝ CPU කිරීමට පෙරනිමි.

93    device: torch.device = DeviceConfigs()
96    discriminator: Discriminator
98    generator: Generator
100    mapping_network: MappingNetwork

වෙනස්කම්කරන්නා සහ උත්පාදක නැතිවීමේ කාර්යයන්. අපි වොසර්ස්ටයින් නැතිවීම භාවිතා කරමු

104    discriminator_loss: DiscriminatorLoss
105    generator_loss: GeneratorLoss

ප්රශස්තකරණය

108    generator_optimizer: torch.optim.Adam
109    discriminator_optimizer: torch.optim.Adam
110    mapping_network_optimizer: torch.optim.Adam
113    gradient_penalty = GradientPenalty()

ශ්රේණියේදණ්ඩන සංගුණකය

115    gradient_penalty_coefficient: float = 10.
118    path_length_penalty: PathLengthPenalty

දත්තකාරකය

121    loader: Iterator

කණ්ඩායම්ප්රමාණය

124    batch_size: int = 32

මානයන්හි සහ

126    d_latent: int = 512

රූපයේඋස/පළල

128    image_size: int = 32

සිතියම්කරණජාලයේ ස්ථර ගණන

130    mapping_network_layers: int = 8

උත්පාදකසහ වෙනස්කම් කිරීමේ ඉගෙනුම් අනුපාතය

132    learning_rate: float = 1e-3

ජාලඉගෙනුම් අනුපාතය සිතියම්ගත කිරීම (අනෙක් ඒවාට වඩා අඩු)

134    mapping_network_learning_rate: float = 1e-5

මතඵලය අනුක්රමික සමුච්චය කිරීමට පියවර ගණන. Effective ලදායී කණ්ඩායම් ප්රමාණය වැඩි කිරීමට මෙය භාවිතා කරන්න.

136    gradient_accumulate_steps: int = 1

සහ ආදම් ප්රශස්තකරණය සඳහා

138    adam_betas: Tuple[float, float] = (0.0, 0.99)

මෝස්තරමිශ්ර කිරීමේ සම්භාවිතාව

140    style_mixing_prob: float = 0.9

මුළුපුහුණු පියවර ගණන

143    training_steps: int = 150_000

උත්පාදකයන්ත්රයේ බ්ලොක් ගණන (රූප විභේදනය මත පදනම්ව ගණනය කරනු ලැබේ)

146    n_gen_blocks: int

කම්මැලිවිධිමත් කිරීම

නියාමනයකිරීමේ පාඩු ගණනය කිරීම වෙනුවට, කඩදාසි කම්මැලි නියාමනය යෝජනා කරන අතර එහිදී විධිමත් කිරීමේ නියමයන් වරකට වරක් ගණනය කරනු ලැබේ. මෙය පුහුණු කාර්යක්ෂමතාව වැඩි දියුණු කරයි.

අනුක්රමිකද penalty ුවම ගණනය කළ යුතු කාල සීමාව

154    lazy_gradient_penalty_interval: int = 4

මාර්ගදිග ද penalty ුවම් ගණනය කිරීමේ පරතරය

156    lazy_path_penalty_interval: int = 32

පුහුණුවේආරම්භක අදියරේදී මාර්ග දිග ද penalty ුවම ගණනය කිරීම මඟ හරින්න

158    lazy_path_penalty_after: int = 5_000

ජනනයකරන ලද පින්තූර ලොග් කරන්නේ කෙසේද?

161    log_generated_interval: int = 500

ආදර්ශමුරපොලවල් සුරැකීමට කොපමණ වාරයක්

163    save_checkpoint_interval: int = 2_000

ලොග්වීම සක්රිය කිරීම සඳහා පුහුණු මාදිලියේ තත්වය

166    mode: ModeState

ආදර්ශස්ථර ප්රතිදානයන් ලොග් කළ යුතුද යන්න

168    log_layer_outputs: bool = False

අපි මෙය සෙලෙබා-එච්කියු දත්ත කට්ටලය මත පුහුණු කළෙමු. fast.ai හි මෙම සාකච්ඡාවේදීබාගත කිරීමේ උපදෙස් ඔබට සොයාගත හැකිය. data/stylegan ෆෝල්ඩරය තුළ පින්තූර සුරකින්න.

175    dataset_path: str = str(lab.get_data_path() / 'stylegan2')

ආරම්භකරන්න

177    def init(self):

දත්තසමුදාය සාදන්න

182        dataset = Dataset(self.dataset_path, self.image_size)

දත්තපැටවුම සාදන්න

184        dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=8,
185                                                 shuffle=True, drop_last=True, pin_memory=True)
187        self.loader = cycle_dataloader(dataloader)

රූප විභේදනයේ

190        log_resolution = int(math.log2(self.image_size))

වෙනස්කම්කරන්නා සහ උත්පාදක යන්ත්රය සාදන්න

193        self.discriminator = Discriminator(log_resolution).to(self.device)
194        self.generator = Generator(log_resolution, self.d_latent).to(self.device)

ශෛලියසහ ශබ්ද යෙදවුම් නිර්මාණය කිරීම සඳහා උත්පාදක කොටස් ගණන ලබා ගන්න

196        self.n_gen_blocks = self.generator.n_blocks

සිතියම්කරණජාලයක් සාදන්න

198        self.mapping_network = MappingNetwork(self.d_latent, self.mapping_network_layers).to(self.device)

මාර්ගදිග ද penalty ුවම් නැතිවීම

200        self.path_length_penalty = PathLengthPenalty(0.99).to(self.device)

ස්ථරප්රතිදානයන් අධීක්ෂණය කිරීම සඳහා ආදර්ශ කොකු එක් කරන්න

203        if self.log_layer_outputs:
204            hook_model_outputs(self.mode, self.discriminator, 'discriminator')
205            hook_model_outputs(self.mode, self.generator, 'generator')
206            hook_model_outputs(self.mode, self.mapping_network, 'mapping_network')

වෙනස්කම්කරන්නා සහ උත්පාදක පාඩු

209        self.discriminator_loss = DiscriminatorLoss().to(self.device)
210        self.generator_loss = GeneratorLoss().to(self.device)

ප්රශස්තකරණයසාදන්න

213        self.discriminator_optimizer = torch.optim.Adam(
214            self.discriminator.parameters(),
215            lr=self.learning_rate, betas=self.adam_betas
216        )
217        self.generator_optimizer = torch.optim.Adam(
218            self.generator.parameters(),
219            lr=self.learning_rate, betas=self.adam_betas
220        )
221        self.mapping_network_optimizer = torch.optim.Adam(
222            self.mapping_network.parameters(),
223            lr=self.mapping_network_learning_rate, betas=self.adam_betas
224        )

ට්රැකර්වින්යාසයන් සකසන්න

227        tracker.set_image("generated", True)

නියැදිය

මෙමසාම්පල අහඹු ලෙස සහ සිතියම්කරණ ජාලයෙන් ලබා ගන්න.

ඒවගේම අපි ගුප්ත විචල්යයන් දෙකක් ජනනය හා අනුරූප ලබා එහිදී සමහර විට ශෛලිය මිශ්ර අදාළ වේ. ඉන්පසු අපි අහඹු ලෙස හරස් ඕවර් ලක්ෂ්යයක් සාම්පල කර හරස් ඕවර් ලක්ෂ්යයට පෙර උත්පාදක කුට්ටි වලට සහ පසුව කුට්ටි වලට අදාළ වෙමු.

229    def get_w(self, batch_size: int):

මෝස්තරමිශ්ර කරන්න

243        if torch.rand(()).item() < self.style_mixing_prob:

අහඹුහරස් ඕවර් ලක්ෂ්යය

245            cross_over_point = int(torch.rand(()).item() * self.n_gen_blocks)

නියැදිය සහ

247            z2 = torch.randn(batch_size, self.d_latent).to(self.device)
248            z1 = torch.randn(batch_size, self.d_latent).to(self.device)

ලබා ගන්න

250            w1 = self.mapping_network(z1)
251            w2 = self.mapping_network(z2)

උත්පාදකකුට්ටි සහ සංයුක්ත කිරීම සඳහා පුළුල් කරන්න

253            w1 = w1[None, :, :].expand(cross_over_point, -1, -1)
254            w2 = w2[None, :, :].expand(self.n_gen_blocks - cross_over_point, -1, -1)
255            return torch.cat((w1, w2), dim=0)

මිශ්රනොකර

257        else:

නියැදිය සහ

259            z = torch.randn(batch_size, self.d_latent).to(self.device)

ලබා ගන්න

261            w = self.mapping_network(z)

උත්පාදකකොටස් සඳහා පුළුල් කරන්න

263            return w[None, :, :].expand(self.n_gen_blocks, -1, -1)

ශබ්දයජනනය කරන්න

මෙයඑක් එක් උත්පාදක කොටසසඳහා ශබ්දය ජනනය කරයි

265    def get_noise(self, batch_size: int):

ශබ්දයගබඩා කිරීමට ලැයිස්තුව

272        noise = []

ශබ්දවිභේදනය ආරම්භ වන්නේ

274        resolution = 4

එක්එක් උත්පාදක කොටස සඳහා ශබ්දය ජනනය කරන්න

277        for i in range(self.n_gen_blocks):

පළමුකොටස ඇත්තේ එක් සංවහනයක් පමණි

279            if i == 0:
280                n1 = None

පළමුකැටි ගැසුණු ස්ථරයෙන් පසු එකතු කිරීම සඳහා ශබ්දය ජනනය කරන්න

282            else:
283                n1 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)

දෙවනකැටි ගැසුණු ස්ථරයට පසු එකතු කිරීම සඳහා ශබ්දය ජනනය කරන්න

285            n2 = torch.randn(batch_size, 1, resolution, resolution, device=self.device)

ලැයිස්තුවටශබ්ද ආතතීන් එක් කරන්න

288            noise.append((n1, n2))

ඊළඟවාරණයට විභේදනයක් ඇත

291            resolution *= 2

ආපසුශබ්ද ආතතීන්

294        return noise

රූපජනනය කරන්න

මෙයඋත්පාදක යන්ත්රය භාවිතයෙන් රූප ජනනය කරයි

296    def generate_images(self, batch_size: int):

ලබාගන්න

304        w = self.get_w(batch_size)

ශබ්දයලබා ගන්න

306        noise = self.get_noise(batch_size)

රූපජනනය කරන්න

309        images = self.generator(w, noise)

ආපසුපින්තූර සහ

312        return images, w

පුහුණුපියවර

314    def step(self, idx: int):

වෙනස්කම්කරන්නා පුහුණු කරන්න

320        with monit.section('Discriminator'):

අනුක්රමිකයළි පිහිටුවන්න

322            self.discriminator_optimizer.zero_grad()

සඳහාඅනුක්රමික සමුච්චය gradient_accumulate_steps

325            for i in range(self.gradient_accumulate_steps):

යාවත්කාලීනකිරීම mode . සක්රිය කිරීම ලොග් කළ යුතුද යන්න සකසන්න

327                with self.mode.update(is_log_activations=(idx + 1) % self.log_generated_interval == 0):

උත්පාදකයන්ත්රයෙන් නියැදි පින්තූර

329                    generated_images, _ = self.generate_images(self.batch_size)

ජනනයකරන ලද රූප සඳහා වෙනස්කම් කිරීමේ වර්ගීකරණය

331                    fake_output = self.discriminator(generated_images.detach())

දත්තපැටවුමෙන් සැබෑ රූප ලබා ගන්න

334                    real_images = next(self.loader).to(self.device)

අපිඵලය අනුක්රමික දඬුවම සඳහා wr. සැබෑ රූප ගණනය කිරීමට අවශ්ය

336                    if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
337                        real_images.requires_grad_()

සැබෑරූප සඳහා වෙනස්කම් කිරීමේ වර්ගීකරණය

339                    real_output = self.discriminator(real_images)

වෙනස්කම්කරන්නාගේ පාඩුව ලබා ගන්න

342                    real_loss, fake_loss = self.discriminator_loss(real_output, fake_output)
343                    disc_loss = real_loss + fake_loss

ශ්රේණියේද penalty ුවම එකතු කරන්න

346                    if (idx + 1) % self.lazy_gradient_penalty_interval == 0:

ශ්රේණියේද penalty ුවම ගණනය කර ලොග් කරන්න

348                        gp = self.gradient_penalty(real_images, real_output)
349                        tracker.add('loss.gp', gp)

සංගුණකයමගින් ගුණ කර ශ්රේණියේ ද penalty ුවම එකතු කරන්න

351                        disc_loss = disc_loss + 0.5 * self.gradient_penalty_coefficient * gp * self.lazy_gradient_penalty_interval

අනුක්රමිකගණනය

354                    disc_loss.backward()

ලොග්වෙනස්කම් කරන්නාගේ අලාභය

357                    tracker.add('loss.discriminator', disc_loss)
358
359            if (idx + 1) % self.log_generated_interval == 0:

විටින්විට ලොග් වෙනස්කම් කිරීමේ ආකෘති පරාමිතීන්

361                tracker.add('discriminator', self.discriminator)

ස්ථායීකරණයසඳහා ක්ලිප් අනුක්රමික

364            torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), max_norm=1.0)

ප්රශස්තිකරණපියවර ගන්න

366            self.discriminator_optimizer.step()

උත්පාදකයන්ත්රය පුහුණු කරන්න

369        with monit.section('Generator'):

අනුක්රමිකයළි පිහිටුවන්න

371            self.generator_optimizer.zero_grad()
372            self.mapping_network_optimizer.zero_grad()

සඳහාඅනුක්රමික සමුච්චය gradient_accumulate_steps

375            for i in range(self.gradient_accumulate_steps):

උත්පාදකයන්ත්රයෙන් නියැදි පින්තූර

377                generated_images, w = self.generate_images(self.batch_size)

ජනනයකරන ලද රූප සඳහා වෙනස්කම් කිරීමේ වර්ගීකරණය

379                fake_output = self.discriminator(generated_images)

උත්පාදකනැතිවීම ලබා ගන්න

382                gen_loss = self.generator_loss(fake_output)

මාර්ගදිග ද penalty ුවම එකතු

385                if idx > self.lazy_path_penalty_after and (idx + 1) % self.lazy_path_penalty_interval == 0:

මාර්ගයදිග ද penalty ුවම ගණනය

387                    plp = self.path_length_penalty(w, generated_images)

නම්නොසලකා හරින්න nan

389                    if not torch.isnan(plp):
390                        tracker.add('loss.plp', plp)
391                        gen_loss = gen_loss + plp

අනුක්රමිකගණනය කරන්න

394                gen_loss.backward()

ලොග්උත්පාදක නැතිවීම

397                tracker.add('loss.generator', gen_loss)
398
399            if (idx + 1) % self.log_generated_interval == 0:

විටින්විට ලොග් වෙනස්කම් කිරීමේ ආකෘති පරාමිතීන්

401                tracker.add('generator', self.generator)
402                tracker.add('mapping_network', self.mapping_network)

ස්ථායීකරණයසඳහා ක්ලිප් අනුක්රමික

405            torch.nn.utils.clip_grad_norm_(self.generator.parameters(), max_norm=1.0)
406            torch.nn.utils.clip_grad_norm_(self.mapping_network.parameters(), max_norm=1.0)

ප්රශස්තිකරණපියවර ගන්න

409            self.generator_optimizer.step()
410            self.mapping_network_optimizer.step()

ලොග්ජනනය කරන ලද රූප

413        if (idx + 1) % self.log_generated_interval == 0:
414            tracker.add('generated', torch.cat([generated_images[:6], real_images[:3]], dim=0))

ආදර්ශමුරපොලවල් සුරකින්න

416        if (idx + 1) % self.save_checkpoint_interval == 0:
417            experiment.save_checkpoint()

ෆ්ලෂ්ට්රැකර්

420        tracker.save()

දුම්රියආකෘතිය

422    def train(self):

ලූප්සඳහා training_steps

428        for i in monit.loop(self.training_steps):

පුහුණුපියවරක් ගන්න

430            self.step(i)

432            if (i + 1) % self.log_generated_interval == 0:
433                tracker.new_line()

දුම්රියවිලායGan2

436def main():

අත්හදාබැලීමක් සාදන්න

442    experiment.create(name='stylegan2')

වින්යාසවස්තුව සාදන්න

444    configs = Configs()

වින්යාසයන්සකසන්න සහ සමහර ඒවා අභිබවා යන්න

447    experiment.configs(configs, {
448        'device.cuda_device': 0,
449        'image_size': 64,
450        'log_generated_interval': 200
451    })

ආරම්භකරන්න

454    configs.init()

ඉතිරිකිරීම සහ පැටවීම සඳහා ආකෘති සකසන්න

456    experiment.add_pytorch_models(mapping_network=configs.mapping_network,
457                                  generator=configs.generator,
458                                  discriminator=configs.discriminator)

අත්හදාබැලීම ආරම්භ කරන්න

461    with experiment.start():

පුහුණුලූපය ධාවනය කරන්න

463        configs.train()

467if __name__ == '__main__':
468    main()