විසරණ සම්භාවිතා ආකෘති (ඩීඩීපීඑම්) පුහුණුව නිරූපණය කිරීම

Open In Colab

මෙය සෙලෙබා එච්කියු දත්ත කට්ටලය මත ඩීඩීපීඑම් පදනම් කරගත් ආකෘතියක් පුහුණු කරයි. fast.ai හි මෙම සාකච්ඡාවේදී බාගත කිරීමේ උපදෙස් ඔබට සොයාගත හැකිය. data/celebA ෆෝල්ඩරය තුළ පින්තූර සුරකින්න.

කඩදාසි ක ක්ෂය සමග ආදර්ශ ඝාතීය වෙනස්වන සාමාන්යය භාවිතා කර ඇත. සරල බව සඳහා අපි මෙය මඟ හැර ඇත්තෙමු.

20from typing import List
21
22import torch
23import torch.utils.data
24import torchvision
25from PIL import Image
26
27from labml import lab, tracker, experiment, monit
28from labml.configs import BaseConfigs, option
29from labml_helpers.device import DeviceConfigs
30from labml_nn.diffusion.ddpm import DenoiseDiffusion
31from labml_nn.diffusion.ddpm.unet import UNet

වින්යාසකිරීම්

34class Configs(BaseConfigs):

ආකෘතියපුහුණු කිරීමේ උපකරණය. DeviceConfigs ලබා ගත හැකි CUDA උපාංගයක් අහුලනවා හෝ CPU කිරීමට පෙරනිමි.

41    device: torch.device = DeviceConfigs()

සඳහාU-Net ආකෘතිය

44    eps_model: UNet
46    diffusion: DenoiseDiffusion

රූපයේනාලිකා ගණන. RGB සඳහා.

49    image_channels: int = 3

රූපප්රමාණය

51    image_size: int = 32

ආරම්භකවිශේෂාංග සිතියමේ නාලිකා ගණන

53    n_channels: int = 64

එක්එක් විභේදනයේ නාලිකා අංක ලැයිස්තුව. නාලිකා ගණන වේ channel_multipliers[i] * n_channels

56    channel_multipliers: List[int] = [1, 2, 2, 4]

එක්එක් යෝජනාවේදී අවධානය භාවිතා කළ යුතුද යන්න පෙන්වන බූලියන් ලැයිස්තුව

58    is_attention: List[int] = [False, False, False, True]

කාලපියවර ගණන

61    n_steps: int = 1_000

කණ්ඩායම්ප්රමාණය

63    batch_size: int = 64

උත්පාදනයකිරීමට සාම්පල ගණන

65    n_samples: int = 16

ඉගෙනුම්අනුපාතය

67    learning_rate: float = 2e-5

පුහුණුඑපොච් ගණන

70    epochs: int = 1_000

දත්තකට්ටලය

73    dataset: torch.utils.data.Dataset

දත්තකාරකය

75    data_loader: torch.utils.data.DataLoader

ආදම්ප්රශස්තකරණය

78    optimizer: torch.optim.Adam
80    def init(self):

ආකෘතිය සාදන්න

82        self.eps_model = UNet(
83            image_channels=self.image_channels,
84            n_channels=self.n_channels,
85            ch_mults=self.channel_multipliers,
86            is_attn=self.is_attention,
87        ).to(self.device)

DDPM පන්ති නිර්මාණය

90        self.diffusion = DenoiseDiffusion(
91            eps_model=self.eps_model,
92            n_steps=self.n_steps,
93            device=self.device,
94        )

දත්තකාරකය සාදන්න

97        self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)

ප්රශස්තකරණයසාදන්න

99        self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)

රූපලොග් වීම

102        tracker.set_image("sample", True)

නියැදිරූප

104    def sample(self):
108        with torch.no_grad():

110            x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
111                            device=self.device)

පියවර සඳහා ශබ්දය ඉවත් කරන්න

114            for t_ in monit.iterate('Sample', self.n_steps):

116                t = self.n_steps - t_ - 1

වෙතින්නියැදිය

118                x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))

ලොග්සාම්පල

121            tracker.save('sample', x)

දුම්රිය

123    def train(self):

දත්තසමුදාය හරහා නැවත කරන්න

129        for data in monit.iterate('Train', self.data_loader):

ගෝලීයපියවර වැඩි කිරීම

131            tracker.add_global_step()

උපාංගයවෙත දත්ත ගෙනයන්න

133            data = data.to(self.device)

අනුක්රමිකශුන්ය කරන්න

136            self.optimizer.zero_grad()

අලාභයගණනය කරන්න

138            loss = self.diffusion.loss(data)

අනුක්රමිකගණනය

140            loss.backward()

ප්රශස්තිකරණපියවරක් ගන්න

142            self.optimizer.step()

අලාභයලුහුබඳින්න

144            tracker.save('loss', loss)

පුහුණුලූපය

146    def run(self):
150        for _ in monit.loop(self.epochs):

ආකෘතියපුහුණු කරන්න

152            self.train()

පින්තූරකිහිපයක් සාම්පල කරන්න

154            self.sample()

කොන්සෝලයේනව රේඛාවක්

156            tracker.new_line()

ආකෘතියසුරකින්න

158            experiment.save_checkpoint()

සෙලෙබාමූලස්ථානය දත්ත කට්ටලය

161class CelebADataset(torch.utils.data.Dataset):
166    def __init__(self, image_size: int):
167        super().__init__()

සෙලෙබාපින්තූර ෆෝල්ඩරය

170        folder = lab.get_data_path() / 'celebA'

ගොනුලැයිස්තුව

172        self._files = [p for p in folder.glob(f'**/*.jpg')]

රූපයවෙනස් කර ටෙන්සර් බවට පරිවර්තනය කිරීම සඳහා පරිවර්තනයන්

175        self._transform = torchvision.transforms.Compose([
176            torchvision.transforms.Resize(image_size),
177            torchvision.transforms.ToTensor(),
178        ])

දත්තසමුදාය ප්රමාණය

180    def __len__(self):
184        return len(self._files)

රූපයක්ලබා ගන්න

186    def __getitem__(self, index: int):
190        img = Image.open(self._files[index])
191        return self._transform(img)

සෙලෙබාදත්ත කට්ටලය සාදන්න

194@option(Configs.dataset, 'CelebA')
195def celeb_dataset(c: Configs):
199    return CelebADataset(c.image_size)

MNISTදත්ත කට්ටලය

202class MNISTDataset(torchvision.datasets.MNIST):
207    def __init__(self, image_size):
208        transform = torchvision.transforms.Compose([
209            torchvision.transforms.Resize(image_size),
210            torchvision.transforms.ToTensor(),
211        ])
212
213        super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)
215    def __getitem__(self, item):
216        return super().__getitem__(item)[0]

MNISTදත්ත සමුදාය සාදන්න

219@option(Configs.dataset, 'MNIST')
220def mnist_dataset(c: Configs):
224    return MNISTDataset(c.image_size)
227def main():

අත්හදාබැලීම සාදන්න

229    experiment.create(name='diffuse', writers={'screen', 'labml'})

වින්යාසයන්සාදන්න

232    configs = Configs()

වින්යාසයන්සකසන්න. ශබ්දකෝෂයේ අගයන් සම්මත කිරීමෙන් ඔබට පෙරනිමි අභිබවා යා හැකිය.

235    experiment.configs(configs, {
236        'dataset': 'CelebA',  # 'MNIST'
237        'image_channels': 3,  # 1,
238        'epochs': 100,  # 5,
239    })

ආරම්භකරන්න

242    configs.init()

ඉතිරිකිරීම සහ පැටවීම සඳහා ආකෘති සකසන්න

245    experiment.add_pytorch_models({'eps_model': configs.eps_model})

පුහුණුලූපය ආරම්භ කර ක්රියාත්මක කරන්න

248    with experiment.start():
249        configs.run()

253if __name__ == '__main__':
254    main()