මෙයසෙලෙබා එච්කියු දත්ත කට්ටලය මත ඩීඩීපීඑම් පදනම් කරගත් ආකෘතියක් පුහුණු කරයි. fast.ai හි මෙම සාකච්ඡාවේදීබාගත කිරීමේ උපදෙස් ඔබට සොයාගත හැකිය. data/celebA
ෆෝල්ඩරයතුළ පින්තූර සුරකින්න.
කඩදාසික ක්ෂය සමග ආදර්ශ ඝාතීය වෙනස්වන සාමාන්යය භාවිතා කර ඇත. සරල බව සඳහා අපි මෙය මඟ හැර ඇත්තෙමු.
21from typing import List
22
23import torch
24import torch.utils.data
25import torchvision
26from PIL import Image
27
28from labml import lab, tracker, experiment, monit
29from labml.configs import BaseConfigs, option
30from labml_helpers.device import DeviceConfigs
31from labml_nn.diffusion.ddpm import DenoiseDiffusion
32from labml_nn.diffusion.ddpm.unet import UNet35class Configs(BaseConfigs):ආකෘතියපුහුණු කිරීමේ උපකරණය. DeviceConfigs
ලබා ගත හැකි CUDA උපාංගයක් අහුලනවා හෝ CPU කිරීමට පෙරනිමි.
42 device: torch.device = DeviceConfigs()සඳහාU-Net ආකෘතිය
45 eps_model: UNet47 diffusion: DenoiseDiffusionරූපයේනාලිකා ගණන. RGB සඳහා.
50 image_channels: int = 3රූපප්රමාණය
52 image_size: int = 32ආරම්භකවිශේෂාංග සිතියමේ නාලිකා ගණන
54 n_channels: int = 64එක්එක් විභේදනයේ නාලිකා අංක ලැයිස්තුව. නාලිකා ගණන වේ channel_multipliers[i] * n_channels
57 channel_multipliers: List[int] = [1, 2, 2, 4]එක්එක් යෝජනාවේදී අවධානය භාවිතා කළ යුතුද යන්න පෙන්වන බූලියන් ලැයිස්තුව
59 is_attention: List[int] = [False, False, False, True]කාලපියවර ගණන
62 n_steps: int = 1_000කණ්ඩායම්ප්රමාණය
64 batch_size: int = 64උත්පාදනයකිරීමට සාම්පල ගණන
66 n_samples: int = 16ඉගෙනුම්අනුපාතය
68 learning_rate: float = 2e-5පුහුණුඑපොච් ගණන
71 epochs: int = 1_000දත්තකට්ටලය
74 dataset: torch.utils.data.Datasetදත්තකාරකය
76 data_loader: torch.utils.data.DataLoaderආදම්ප්රශස්තකරණය
79 optimizer: torch.optim.Adam81 def init(self):ආකෘතිය සාදන්න
83 self.eps_model = UNet(
84 image_channels=self.image_channels,
85 n_channels=self.n_channels,
86 ch_mults=self.channel_multipliers,
87 is_attn=self.is_attention,
88 ).to(self.device)DDPM පන්ති නිර්මාණය
91 self.diffusion = DenoiseDiffusion(
92 eps_model=self.eps_model,
93 n_steps=self.n_steps,
94 device=self.device,
95 )දත්තකාරකය සාදන්න
98 self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)ප්රශස්තකරණයසාදන්න
100 self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)රූපලොග් වීම
103 tracker.set_image("sample", True)105 def sample(self):109 with torch.no_grad():111 x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
112 device=self.device)පියවර සඳහා ශබ්දය ඉවත් කරන්න
115 for t_ in monit.iterate('Sample', self.n_steps):117 t = self.n_steps - t_ - 1වෙතින්නියැදිය
119 x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))ලොග්සාම්පල
122 tracker.save('sample', x)124 def train(self):දත්තසමුදාය හරහා නැවත කරන්න
130 for data in monit.iterate('Train', self.data_loader):ගෝලීයපියවර වැඩි කිරීම
132 tracker.add_global_step()උපාංගයවෙත දත්ත ගෙනයන්න
134 data = data.to(self.device)අනුක්රමිකශුන්ය කරන්න
137 self.optimizer.zero_grad()අලාභයගණනය කරන්න
139 loss = self.diffusion.loss(data)අනුක්රමිකගණනය
141 loss.backward()ප්රශස්තිකරණපියවරක් ගන්න
143 self.optimizer.step()අලාභයලුහුබඳින්න
145 tracker.save('loss', loss)147 def run(self):151 for _ in monit.loop(self.epochs):ආකෘතියපුහුණු කරන්න
153 self.train()පින්තූරකිහිපයක් සාම්පල කරන්න
155 self.sample()කොන්සෝලයේනව රේඛාවක්
157 tracker.new_line()ආකෘතියසුරකින්න
159 experiment.save_checkpoint()162class CelebADataset(torch.utils.data.Dataset):167 def __init__(self, image_size: int):
168 super().__init__()සෙලෙබාපින්තූර ෆෝල්ඩරය
171 folder = lab.get_data_path() / 'celebA'ගොනුලැයිස්තුව
173 self._files = [p for p in folder.glob(f'**/*.jpg')]රූපයවෙනස් කර ටෙන්සර් බවට පරිවර්තනය කිරීම සඳහා පරිවර්තනයන්
176 self._transform = torchvision.transforms.Compose([
177 torchvision.transforms.Resize(image_size),
178 torchvision.transforms.ToTensor(),
179 ])දත්තසමුදාය ප්රමාණය
181 def __len__(self):185 return len(self._files)රූපයක්ලබා ගන්න
187 def __getitem__(self, index: int):191 img = Image.open(self._files[index])
192 return self._transform(img)සෙලෙබාදත්ත කට්ටලය සාදන්න
195@option(Configs.dataset, 'CelebA')
196def celeb_dataset(c: Configs):200 return CelebADataset(c.image_size)203class MNISTDataset(torchvision.datasets.MNIST):208 def __init__(self, image_size):
209 transform = torchvision.transforms.Compose([
210 torchvision.transforms.Resize(image_size),
211 torchvision.transforms.ToTensor(),
212 ])
213
214 super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)216 def __getitem__(self, item):
217 return super().__getitem__(item)[0]MNISTදත්ත සමුදාය සාදන්න
220@option(Configs.dataset, 'MNIST')
221def mnist_dataset(c: Configs):225 return MNISTDataset(c.image_size)228def main():අත්හදාබැලීම සාදන්න
230 experiment.create(name='diffuse', writers={'screen', 'comet'})වින්යාසයන්සාදන්න
233 configs = Configs()වින්යාසයන්සකසන්න. ශබ්දකෝෂයේ අගයන් සම්මත කිරීමෙන් ඔබට පෙරනිමි අභිබවා යා හැකිය.
236 experiment.configs(configs, {
237 'dataset': 'CelebA', # 'MNIST'
238 'image_channels': 3, # 1,
239 'epochs': 100, # 5,
240 })ආරම්භකරන්න
243 configs.init()ඉතිරිකිරීම සහ පැටවීම සඳහා ආකෘති සකසන්න
246 experiment.add_pytorch_models({'eps_model': configs.eps_model})පුහුණුලූපය ආරම්භ කර ක්රියාත්මක කරන්න
249 with experiment.start():
250 configs.run()254if __name__ == '__main__':
255 main()