මෙය සෙලෙබා එච්කියු දත්ත කට්ටලය මත ඩීඩීපීඑම් පදනම් කරගත් ආකෘතියක් පුහුණු කරයි. fast.ai හි මෙම සාකච්ඡාවේදී බාගත කිරීමේ උපදෙස් ඔබට සොයාගත හැකිය. data/celebA
ෆෝල්ඩරය තුළ පින්තූර සුරකින්න.
කඩදාසි ක ක්ෂය සමග ආදර්ශ ඝාතීය වෙනස්වන සාමාන්යය භාවිතා කර ඇත. සරල බව සඳහා අපි මෙය මඟ හැර ඇත්තෙමු.
20from typing import List
21
22import torch
23import torch.utils.data
24import torchvision
25from PIL import Image
26
27from labml import lab, tracker, experiment, monit
28from labml.configs import BaseConfigs, option
29from labml_helpers.device import DeviceConfigs
30from labml_nn.diffusion.ddpm import DenoiseDiffusion
31from labml_nn.diffusion.ddpm.unet import UNet34class Configs(BaseConfigs):ආකෘතියපුහුණු කිරීමේ උපකරණය. DeviceConfigs
 ලබා ගත හැකි CUDA උපාංගයක් අහුලනවා හෝ CPU කිරීමට පෙරනිමි. 
41    device: torch.device = DeviceConfigs()සඳහාU-Net ආකෘතිය
44    eps_model: UNet46    diffusion: DenoiseDiffusionරූපයේනාලිකා ගණන. RGB සඳහා.
49    image_channels: int = 3රූපප්රමාණය
51    image_size: int = 32ආරම්භකවිශේෂාංග සිතියමේ නාලිකා ගණන
53    n_channels: int = 64එක්එක් විභේදනයේ නාලිකා අංක ලැයිස්තුව. නාලිකා ගණන වේ channel_multipliers[i] * n_channels
 
56    channel_multipliers: List[int] = [1, 2, 2, 4]එක්එක් යෝජනාවේදී අවධානය භාවිතා කළ යුතුද යන්න පෙන්වන බූලියන් ලැයිස්තුව
58    is_attention: List[int] = [False, False, False, True]කාලපියවර ගණන
61    n_steps: int = 1_000කණ්ඩායම්ප්රමාණය
63    batch_size: int = 64උත්පාදනයකිරීමට සාම්පල ගණන
65    n_samples: int = 16ඉගෙනුම්අනුපාතය
67    learning_rate: float = 2e-5පුහුණුඑපොච් ගණන
70    epochs: int = 1_000දත්තකට්ටලය
73    dataset: torch.utils.data.Datasetදත්තකාරකය
75    data_loader: torch.utils.data.DataLoaderආදම්ප්රශස්තකරණය
78    optimizer: torch.optim.Adam80    def init(self):ආකෘතිය සාදන්න
82        self.eps_model = UNet(
83            image_channels=self.image_channels,
84            n_channels=self.n_channels,
85            ch_mults=self.channel_multipliers,
86            is_attn=self.is_attention,
87        ).to(self.device)DDPM පන්ති නිර්මාණය
90        self.diffusion = DenoiseDiffusion(
91            eps_model=self.eps_model,
92            n_steps=self.n_steps,
93            device=self.device,
94        )දත්තකාරකය සාදන්න
97        self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)ප්රශස්තකරණයසාදන්න
99        self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)රූපලොග් වීම
102        tracker.set_image("sample", True)104    def sample(self):108        with torch.no_grad():110            x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
111                            device=self.device)පියවර සඳහා ශබ්දය ඉවත් කරන්න
114            for t_ in monit.iterate('Sample', self.n_steps):116                t = self.n_steps - t_ - 1වෙතින්නියැදිය
118                x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))ලොග්සාම්පල
121            tracker.save('sample', x)123    def train(self):දත්තසමුදාය හරහා නැවත කරන්න
129        for data in monit.iterate('Train', self.data_loader):ගෝලීයපියවර වැඩි කිරීම
131            tracker.add_global_step()උපාංගයවෙත දත්ත ගෙනයන්න
133            data = data.to(self.device)අනුක්රමිකශුන්ය කරන්න
136            self.optimizer.zero_grad()අලාභයගණනය කරන්න
138            loss = self.diffusion.loss(data)අනුක්රමිකගණනය
140            loss.backward()ප්රශස්තිකරණපියවරක් ගන්න
142            self.optimizer.step()අලාභයලුහුබඳින්න
144            tracker.save('loss', loss)146    def run(self):150        for _ in monit.loop(self.epochs):ආකෘතියපුහුණු කරන්න
152            self.train()පින්තූරකිහිපයක් සාම්පල කරන්න
154            self.sample()කොන්සෝලයේනව රේඛාවක්
156            tracker.new_line()ආකෘතියසුරකින්න
158            experiment.save_checkpoint()161class CelebADataset(torch.utils.data.Dataset):166    def __init__(self, image_size: int):
167        super().__init__()සෙලෙබාපින්තූර ෆෝල්ඩරය
170        folder = lab.get_data_path() / 'celebA'ගොනුලැයිස්තුව
172        self._files = [p for p in folder.glob(f'**/*.jpg')]රූපයවෙනස් කර ටෙන්සර් බවට පරිවර්තනය කිරීම සඳහා පරිවර්තනයන්
175        self._transform = torchvision.transforms.Compose([
176            torchvision.transforms.Resize(image_size),
177            torchvision.transforms.ToTensor(),
178        ])දත්තසමුදාය ප්රමාණය
180    def __len__(self):184        return len(self._files)රූපයක්ලබා ගන්න
186    def __getitem__(self, index: int):190        img = Image.open(self._files[index])
191        return self._transform(img)සෙලෙබාදත්ත කට්ටලය සාදන්න
194@option(Configs.dataset, 'CelebA')
195def celeb_dataset(c: Configs):199    return CelebADataset(c.image_size)202class MNISTDataset(torchvision.datasets.MNIST):207    def __init__(self, image_size):
208        transform = torchvision.transforms.Compose([
209            torchvision.transforms.Resize(image_size),
210            torchvision.transforms.ToTensor(),
211        ])
212
213        super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)215    def __getitem__(self, item):
216        return super().__getitem__(item)[0]MNISTදත්ත සමුදාය සාදන්න
219@option(Configs.dataset, 'MNIST')
220def mnist_dataset(c: Configs):224    return MNISTDataset(c.image_size)227def main():අත්හදාබැලීම සාදන්න
229    experiment.create(name='diffuse', writers={'screen', 'labml'})වින්යාසයන්සාදන්න
232    configs = Configs()වින්යාසයන්සකසන්න. ශබ්දකෝෂයේ අගයන් සම්මත කිරීමෙන් ඔබට පෙරනිමි අභිබවා යා හැකිය.
235    experiment.configs(configs, {
236        'dataset': 'CelebA',  # 'MNIST'
237        'image_channels': 3,  # 1,
238        'epochs': 100,  # 5,
239    })ආරම්භකරන්න
242    configs.init()ඉතිරිකිරීම සහ පැටවීම සඳහා ආකෘති සකසන්න
245    experiment.add_pytorch_models({'eps_model': configs.eps_model})පුහුණුලූපය ආරම්භ කර ක්රියාත්මක කරන්න
248    with experiment.start():
249        configs.run()253if __name__ == '__main__':
254    main()