මෙය සෙලෙබා එච්කියු දත්ත කට්ටලය මත ඩීඩීපීඑම් පදනම් කරගත් ආකෘතියක් පුහුණු කරයි. fast.ai හි මෙම සාකච්ඡාවේදී බාගත කිරීමේ උපදෙස් ඔබට සොයාගත හැකිය. data/celebA
ෆෝල්ඩරය තුළ පින්තූර සුරකින්න.
කඩදාසි ක ක්ෂය සමග ආදර්ශ ඝාතීය වෙනස්වන සාමාන්යය භාවිතා කර ඇත. සරල බව සඳහා අපි මෙය මඟ හැර ඇත්තෙමු.
20from typing import List
21
22import torch
23import torch.utils.data
24import torchvision
25from PIL import Image
26
27from labml import lab, tracker, experiment, monit
28from labml.configs import BaseConfigs, option
29from labml_helpers.device import DeviceConfigs
30from labml_nn.diffusion.ddpm import DenoiseDiffusion
31from labml_nn.diffusion.ddpm.unet import UNet
34class Configs(BaseConfigs):
ආකෘතියපුහුණු කිරීමේ උපකරණය. DeviceConfigs
ලබා ගත හැකි CUDA උපාංගයක් අහුලනවා හෝ CPU කිරීමට පෙරනිමි.
41 device: torch.device = DeviceConfigs()
සඳහාU-Net ආකෘතිය
44 eps_model: UNet
46 diffusion: DenoiseDiffusion
රූපයේනාලිකා ගණන. RGB සඳහා.
49 image_channels: int = 3
රූපප්රමාණය
51 image_size: int = 32
ආරම්භකවිශේෂාංග සිතියමේ නාලිකා ගණන
53 n_channels: int = 64
එක්එක් විභේදනයේ නාලිකා අංක ලැයිස්තුව. නාලිකා ගණන වේ channel_multipliers[i] * n_channels
56 channel_multipliers: List[int] = [1, 2, 2, 4]
එක්එක් යෝජනාවේදී අවධානය භාවිතා කළ යුතුද යන්න පෙන්වන බූලියන් ලැයිස්තුව
58 is_attention: List[int] = [False, False, False, True]
කාලපියවර ගණන
61 n_steps: int = 1_000
කණ්ඩායම්ප්රමාණය
63 batch_size: int = 64
උත්පාදනයකිරීමට සාම්පල ගණන
65 n_samples: int = 16
ඉගෙනුම්අනුපාතය
67 learning_rate: float = 2e-5
පුහුණුඑපොච් ගණන
70 epochs: int = 1_000
දත්තකට්ටලය
73 dataset: torch.utils.data.Dataset
දත්තකාරකය
75 data_loader: torch.utils.data.DataLoader
ආදම්ප්රශස්තකරණය
78 optimizer: torch.optim.Adam
80 def init(self):
ආකෘතිය සාදන්න
82 self.eps_model = UNet(
83 image_channels=self.image_channels,
84 n_channels=self.n_channels,
85 ch_mults=self.channel_multipliers,
86 is_attn=self.is_attention,
87 ).to(self.device)
DDPM පන්ති නිර්මාණය
90 self.diffusion = DenoiseDiffusion(
91 eps_model=self.eps_model,
92 n_steps=self.n_steps,
93 device=self.device,
94 )
දත්තකාරකය සාදන්න
97 self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)
ප්රශස්තකරණයසාදන්න
99 self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)
රූපලොග් වීම
102 tracker.set_image("sample", True)
104 def sample(self):
108 with torch.no_grad():
110 x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
111 device=self.device)
පියවර සඳහා ශබ්දය ඉවත් කරන්න
114 for t_ in monit.iterate('Sample', self.n_steps):
116 t = self.n_steps - t_ - 1
වෙතින්නියැදිය
118 x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))
ලොග්සාම්පල
121 tracker.save('sample', x)
123 def train(self):
දත්තසමුදාය හරහා නැවත කරන්න
129 for data in monit.iterate('Train', self.data_loader):
ගෝලීයපියවර වැඩි කිරීම
131 tracker.add_global_step()
උපාංගයවෙත දත්ත ගෙනයන්න
133 data = data.to(self.device)
අනුක්රමිකශුන්ය කරන්න
136 self.optimizer.zero_grad()
අලාභයගණනය කරන්න
138 loss = self.diffusion.loss(data)
අනුක්රමිකගණනය
140 loss.backward()
ප්රශස්තිකරණපියවරක් ගන්න
142 self.optimizer.step()
අලාභයලුහුබඳින්න
144 tracker.save('loss', loss)
146 def run(self):
150 for _ in monit.loop(self.epochs):
ආකෘතියපුහුණු කරන්න
152 self.train()
පින්තූරකිහිපයක් සාම්පල කරන්න
154 self.sample()
කොන්සෝලයේනව රේඛාවක්
156 tracker.new_line()
ආකෘතියසුරකින්න
158 experiment.save_checkpoint()
161class CelebADataset(torch.utils.data.Dataset):
166 def __init__(self, image_size: int):
167 super().__init__()
සෙලෙබාපින්තූර ෆෝල්ඩරය
170 folder = lab.get_data_path() / 'celebA'
ගොනුලැයිස්තුව
172 self._files = [p for p in folder.glob(f'**/*.jpg')]
රූපයවෙනස් කර ටෙන්සර් බවට පරිවර්තනය කිරීම සඳහා පරිවර්තනයන්
175 self._transform = torchvision.transforms.Compose([
176 torchvision.transforms.Resize(image_size),
177 torchvision.transforms.ToTensor(),
178 ])
දත්තසමුදාය ප්රමාණය
180 def __len__(self):
184 return len(self._files)
රූපයක්ලබා ගන්න
186 def __getitem__(self, index: int):
190 img = Image.open(self._files[index])
191 return self._transform(img)
සෙලෙබාදත්ත කට්ටලය සාදන්න
194@option(Configs.dataset, 'CelebA')
195def celeb_dataset(c: Configs):
199 return CelebADataset(c.image_size)
202class MNISTDataset(torchvision.datasets.MNIST):
207 def __init__(self, image_size):
208 transform = torchvision.transforms.Compose([
209 torchvision.transforms.Resize(image_size),
210 torchvision.transforms.ToTensor(),
211 ])
212
213 super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)
215 def __getitem__(self, item):
216 return super().__getitem__(item)[0]
MNISTදත්ත සමුදාය සාදන්න
219@option(Configs.dataset, 'MNIST')
220def mnist_dataset(c: Configs):
224 return MNISTDataset(c.image_size)
227def main():
අත්හදාබැලීම සාදන්න
229 experiment.create(name='diffuse', writers={'screen', 'labml'})
වින්යාසයන්සාදන්න
232 configs = Configs()
වින්යාසයන්සකසන්න. ශබ්දකෝෂයේ අගයන් සම්මත කිරීමෙන් ඔබට පෙරනිමි අභිබවා යා හැකිය.
235 experiment.configs(configs, {
236 'dataset': 'CelebA', # 'MNIST'
237 'image_channels': 3, # 1,
238 'epochs': 100, # 5,
239 })
ආරම්භකරන්න
242 configs.init()
ඉතිරිකිරීම සහ පැටවීම සඳහා ආකෘති සකසන්න
245 experiment.add_pytorch_models({'eps_model': configs.eps_model})
පුහුණුලූපය ආරම්භ කර ක්රියාත්මක කරන්න
248 with experiment.start():
249 configs.run()
253if __name__ == '__main__':
254 main()