රූපඋත්පාදනය කිරීමට සහ ලබා දී ඇති රූප අතර අන්තර්ක්රියාකාරිත්වයන් නිර්මාණය කිරීමේ කේතය මෙයයි.
14import numpy as np
15import torch
16from matplotlib import pyplot as plt
17from torchvision.transforms.functional import to_pil_image, resize
18
19from labml import experiment, monit
20from labml_nn.diffusion.ddpm import DenoiseDiffusion, gather
21from labml_nn.diffusion.ddpm.experiment import Configs
24class Sampler:
diffusion
DenoiseDiffusion
උදාහරණයක් වේ image_channels
යනු රූපයේ ඇති නාලිකා ගණන image_size
යනු රූපයේ ප්රමාණයයි device
ආකෘතියේ උපාංගය වේ29 def __init__(self, diffusion: DenoiseDiffusion, image_channels: int, image_size: int, device: torch.device):
36 self.device = device
37 self.image_size = image_size
38 self.image_channels = image_channels
39 self.diffusion = diffusion
42 self.n_steps = diffusion.n_steps
44 self.eps_model = diffusion.eps_model
46 self.beta = diffusion.beta
48 self.alpha = diffusion.alpha
50 self.alpha_bar = diffusion.alpha_bar
52 alpha_bar_tm1 = torch.cat([self.alpha_bar.new_ones((1,)), self.alpha_bar[:-1]])
64 self.beta_tilde = self.beta * (1 - alpha_bar_tm1) / (1 - self.alpha_bar)
66 self.mu_tilde_coef1 = self.beta * (alpha_bar_tm1 ** 0.5) / (1 - self.alpha_bar)
68 self.mu_tilde_coef2 = (self.alpha ** 0.5) * (1 - alpha_bar_tm1) / (1 - self.alpha_bar)
70 self.sigma2 = self.beta
රූපයක්ප්රදර්ශනය කිරීම සඳහා උපකාරක කාර්යය
72 def show_image(self, img, title=""):
74 img = img.clip(0, 1)
75 img = img.cpu().numpy()
76 plt.imshow(img.transpose(1, 2, 0))
77 plt.title(title)
78 plt.show()
වීඩියෝවක්නිර්මාණය කිරීම සඳහා උපකාරක කාර්යය
80 def make_video(self, frames, path="video.mp4"):
82 import imageio
20දෙවන වීඩියෝව
84 writer = imageio.get_writer(path, fps=len(frames) // 20)
එක්එක් රූපය එකතු කරන්න
86 for f in frames:
87 f = f.clip(0, 1)
88 f = to_pil_image(resize(f, [368, 368]))
89 writer.append_data(np.array(f))
91 writer.close()
අපිපියවරෙන් පියවර රූපයක් සාම්පල කර ඇති අතර සෑම පියවරකදීම ඇස්තමේන්තුව පෙන්වයි
93 def sample_animation(self, n_frames: int = 1000, create_video: bool = True):
104 xt = torch.randn([1, self.image_channels, self.image_size, self.image_size], device=self.device)
ලොග්වීමට පරතරය
107 interval = self.n_steps // n_frames
වීඩියෝසඳහා රාමු
109 frames = []
නියැදි පියවර
111 for t_inv in monit.iterate('Denoise', self.n_steps):
113 t_ = self.n_steps - t_inv - 1
එය tensor දී
115 t = xt.new_full((1,), t_, dtype=torch.long)
117 eps_theta = self.eps_model(xt, t)
118 if t_ % interval == 0:
රාමුලබා ගන්න සහ එකතු කරන්න
120 x0 = self.p_x0(xt, t, eps_theta)
121 frames.append(x0[0])
122 if not create_video:
123 self.show_image(x0[0], f"{t_}")
වෙතින්නියැදිය
125 xt = self.p_sample(xt, t, eps_theta)
වීඩියෝසාදන්න
128 if create_video:
129 self.make_video(frames)
අපිලබා ගන්න සහ .
ඉන්පසුඅන්තර්පොලේට් කරන්න
ඉන්පසුලබා ගන්න
x1
වේ x2
වේ lambda_
වේ t_
වේ 131 def interpolate(self, x1: torch.Tensor, x2: torch.Tensor, lambda_: float, t_: int = 100):
සාම්පලගණන
150 n_samples = x1.shape[0]
ටෙන්සර්
152 t = torch.full((n_samples,), t_, device=self.device)
154 xt = (1 - lambda_) * self.diffusion.q_sample(x1, t) + lambda_ * self.diffusion.q_sample(x2, t)
157 return self._sample_x0(xt, t_)
x1
වේ x2
වේ n_frames
යනු රූපය සඳහා රාමු ගණන t_
වේ create_video
වීඩියෝවක් සෑදිය යුතුද නැතහොත් එක් එක් රාමුව පෙන්විය යුතුද යන්න නියම කරයි159 def interpolate_animate(self, x1: torch.Tensor, x2: torch.Tensor, n_frames: int = 100, t_: int = 100,
160 create_video=True):
මුල්රූප පෙන්වන්න
172 self.show_image(x1, "x1")
173 self.show_image(x2, "x2")
කණ්ඩායම්මානය එක් කරන්න
175 x1 = x1[None, :, :, :]
176 x2 = x2[None, :, :, :]
ටෙන්සර්
178 t = torch.full((1,), t_, device=self.device)
180 x1t = self.diffusion.q_sample(x1, t)
182 x2t = self.diffusion.q_sample(x2, t)
183
184 frames = []
විවිධරාමු ලබා ගන්න
186 for i in monit.iterate('Interpolate', n_frames + 1, is_children_silent=True):
188 lambda_ = i / n_frames
190 xt = (1 - lambda_) * x1t + lambda_ * x2t
192 x0 = self._sample_x0(xt, t_)
රාමුවලට එකතු කරන්න
194 frames.append(x0[0])
රාමුවපෙන්වන්න
196 if not create_video:
197 self.show_image(x0[0], f"{lambda_ :.2f}")
වීඩියෝසාදන්න
200 if create_video:
201 self.make_video(frames)
203 def _sample_x0(self, xt: torch.Tensor, n_steps: int):
නියැදිගණන
212 n_samples = xt.shape[0]
පියවර ගන්නා තෙක් නැවත ක්රියාත්මක වන්න
214 for t_ in monit.iterate('Denoise', n_steps):
215 t = n_steps - t_ - 1
වෙතින්නියැදිය
217 xt = self.diffusion.p_sample(xt, xt.new_full((n_samples,), t, dtype=torch.long))
ආපසු
220 return xt
222 def sample(self, n_samples: int = 16):
227 xt = torch.randn([n_samples, self.image_channels, self.image_size, self.image_size], device=self.device)
230 x0 = self._sample_x0(xt, self.n_steps)
පින්තූරපෙන්වන්න
233 for i in range(n_samples):
234 self.show_image(x0[i])
236 def p_sample(self, xt: torch.Tensor, t: torch.Tensor, eps_theta: torch.Tensor):
249 alpha_bar = gather(self.alpha_bar, t)
251 alpha = gather(self.alpha, t)
253 eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
256 mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
258 var = gather(self.sigma2, t)
261 eps = torch.randn(xt.shape, device=xt.device)
නියැදිය
263 return mean + (var ** .5) * eps
265 def p_x0(self, xt: torch.Tensor, t: torch.Tensor, eps: torch.Tensor):
273 alpha_bar = gather(self.alpha_bar, t)
277 return (xt - (1 - alpha_bar) ** 0.5 * eps) / (alpha_bar ** 0.5)
සාම්පලජනනය කරන්න
280def main():
පුහුණුඅත්හදා බැලීම UUID
284 run_uuid = "a44333ea251411ec8007d1a1762ed686"
ඇගයීමක්ආරම්භ කරන්න
287 experiment.evaluate()
වින්යාසසාදන්න
290 configs = Configs()
පුහුණුධාවනයේ අභිරුචි වින්යාසය පූරණය කරන්න
292 configs_dict = experiment.load_configs(run_uuid)
වින්යාසයන්සකසන්න
294 experiment.configs(configs, configs_dict)
ආරම්භකරන්න
297 configs.init()
ඉතිරිකිරීම සහ පැටවීම සඳහා පයිටෝච් මොඩියුල සකසන්න
300 experiment.add_pytorch_models({'eps_model': configs.eps_model})
පුහුණුඅත්හදා බැලීම
303 experiment.load(run_uuid)
නියැදියසාදන්න
306 sampler = Sampler(diffusion=configs.diffusion,
307 image_channels=configs.image_channels,
308 image_size=configs.image_size,
309 device=configs.device)
ඇගයීමආරම්භ කරන්න
312 with experiment.start():
කිසිදුඅනුක්රමික
314 with torch.no_grad():
නිරූපණයකරන සජීවිකරණයක් සහිත රූපයක් සාම්පල කරන්න
316 sampler.sample_animation()
317
318 if False:
දත්තවලින් පින්තූර කිහිපයක් ලබා ගන්න
320 data = next(iter(configs.data_loader)).to(configs.device)
අන්තර්නිවේෂණයසජීවිකරණයක් සාදන්න
323 sampler.interpolate_animate(data[0], data[1])
327if __name__ == '__main__':
328 main()