මෙයකඩදාසි පිළිබඳ විවරණ කරන ලද PyTorch ක්රියාත්මක කිරීම ස්කෙච් ඇඳීම්වල ස්නායුක නිරූපණයකි .
ස්කෙච්ආර්එන්එන් යනු අනුක්රම-සිට-අනුක්රමික විචල්යතා ස්වයංක්රීය ආකේතයකි. එන්කෝඩරය සහ විකේතකය යන දෙකම පුනරාවර්තන ස්නායුක ජාල ආකෘති වේ. ආ roke ාත මාලාවක් පුරෝකථනය කිරීමෙන් ආ roke ාතය පදනම් කරගත් සරල ඇඳීම් ප්රතිනිර්මාණය කිරීමට එය ඉගෙන ගනී. එක් එක් ආ roke ාතය ගවුසියානු මිශ්රණයක් ලෙස විකේතකය පුරෝකථනය කරයි.
ඉක්මන් වෙතින් දත්ත බාගන්න, අඳින්න! දත්ත කට්ටලය. Sketch-RNNහි npz
ගොනු බාගත කිරීම සඳහා සබැඳියක් ඇත QuickDraw දත්ත සමුදාය කියවීමේ කොටස. බාගත කළ npz
ගොනුව (ය) data/sketch
ෆෝල්ඩරයේ තබන්න. මෙම කේතය bicycle
දත්ත කට්ටලය භාවිතා කිරීමට වින්යාස කර ඇත. ඔබට මෙය වින්යාසයන් තුළ වෙනස් කළ හැකිය.
ඇලෙක්සිස් ඩේවිඩ් ජැක් විසින් පයිටෝර්ච් ස්කෙච් ආර්එන්එන් ව්යාපෘතියෙන් උදව් ලබා ගත්තේය
32import math
33from typing import Optional, Tuple, Any
34
35import numpy as np
36import torch
37import torch.nn as nn
38from matplotlib import pyplot as plt
39from torch import optim
40from torch.utils.data import Dataset, DataLoader
41
42import einops
43from labml import lab, experiment, tracker, monit
44from labml_helpers.device import DeviceConfigs
45from labml_helpers.module import Module
46from labml_helpers.optimizer import OptimizerConfigs
47from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
50class StrokesDataset(Dataset):
dataset
Seq_len හැඩයේ හිරිවැටීම් අරා ලැයිස්තුවකි, 3. එය ආ ro ාත අනුපිළිවෙලක් වන අතර සෑම පහරක්ම නිඛිල 3 කින් නිරූපණය කෙරේ. පළමු දෙක x සහ y දිගේ විස්ථාපන වේ (, ) සහ අවසාන නිඛිලය පෑනෙහි තත්වය නිරූපණය කරයි, එය කඩදාසි ස්පර්ශ කරන්නේ නම් සහ එසේ නොමැතිනම්.
57 def __init__(self, dataset: np.array, max_seq_length: int, scale: Optional[float] = None):
67 data = []
අපිඑක් එක් අනුපිළිවෙල සහ පෙරීම හරහා නැවත නැවතත්
69 for seq in dataset:
ආro ාත අනුක්රමයේ දිග අපගේ පරාසය තුළ තිබේ නම් පෙරහන් කරන්න
71 if 10 < len(seq) <= max_seq_length:
කලම්ප , කිරීමට
73 seq = np.minimum(seq, 1000)
74 seq = np.maximum(seq, -1000)
පාවෙනලක්ෂ්ය අරාවකට පරිවර්තනය කර එකතු කරන්න data
76 seq = np.array(seq, dtype=np.float32)
77 data.append(seq)
එවිටඅපි (, ) ඒකාබද්ධ සම්මත අපගමනය වන පරිමාණ සාධකය ගණනය. මධ්යන්යය කෙසේ වෙතත් සමීප බැවින් සරල බව සඳහා මධ්යන්යය සකස් කර නොමැති බව කඩදාසි සටහන් කරයි.
83 if scale is None:
84 scale = np.std(np.concatenate([np.ravel(s[:, 0:2]) for s in data]))
85 self.scale = scale
සියලුමඅනුපිළිවෙලවල් අතර දිගම අනුක්රමික දිග ලබා ගන්න
88 longest_seq_len = max([len(seq) for seq in data])
ආරම්භකඅනුක්රමය (sos) සහ අවසාන අනුක්රමය (eos) සඳහා අමතර පියවර දෙකක් සමඟ අපි PyTorch දත්ත අරාව ආරම්භ කරමු. සෑම පියවරක්ම දෛශිකයකි . එකක් පමණක් වන අතර අනෙක් ඒවා වේ . ඒවා නියෝජනය කරන්නේ පෑන පහළට, පෑන ඉහළටසහ එම අනුපිළිවෙලට අනුක්රමය අවසන්කිරීමයි. ඊළඟ පියවරේදී පෑන කඩදාසි ස්පර්ශ කරන්නේ නම්. ඊළඟ පියවරේදී පෑන කඩදාසි ස්පර්ශ නොකරන්නේ නම්. එය චිත්රයේ අවසානය නම් වේ.
98 self.data = torch.zeros(len(data), longest_seq_len + 2, 5, dtype=torch.float)
වෙස්අරා අවශ්ය වන්නේ එක් අමතර පියවරක් පමණි, මන්ද එය විකේතනයේ ප්රතිදානයන් සඳහා වන data[:-1]
අතර එය ඊළඟ පියවර ගනී.
101 self.mask = torch.zeros(len(data), longest_seq_len + 1)
102
103 for i, seq in enumerate(data):
104 seq = torch.from_numpy(seq)
105 len_seq = len(seq)
පරිමාණයසහ කට්ටලය
107 self.data[i, 1:len_seq + 1, :2] = seq[:, :2] / scale
109 self.data[i, 1:len_seq + 1, 2] = 1 - seq[:, 2]
111 self.data[i, 1:len_seq + 1, 3] = seq[:, 2]
113 self.data[i, len_seq + 1:, 4] = 1
අනුක්රමයඅවසන් වන තෙක් මාස්ක් ක්රියාත්මක වේ
115 self.mask[i, :len_seq + 1] = 1
අනුපිළිවෙලආරම්භ කිරීම
118 self.data[:, 0, 2] = 1
දත්තසමුදාය ප්රමාණය
120 def __len__(self):
122 return len(self.data)
නියැදියක්ලබා ගන්න
124 def __getitem__(self, idx: int):
126 return self.data[idx], self.mask[idx]
මිශ්රණයනියෝජනය වන්නේ සහ . මෙම පන්තිය උෂ්ණත්වය සකස් කරන අතර පරාමිතීන්ගෙන් වර්ගීකරණ හා ගුසියානු බෙදාහැරීම් නිර්මාණය කරයි.
129class BivariateGaussianMixture:
139 def __init__(self, pi_logits: torch.Tensor, mu_x: torch.Tensor, mu_y: torch.Tensor,
140 sigma_x: torch.Tensor, sigma_y: torch.Tensor, rho_xy: torch.Tensor):
141 self.pi_logits = pi_logits
142 self.mu_x = mu_x
143 self.mu_y = mu_y
144 self.sigma_x = sigma_x
145 self.sigma_y = sigma_y
146 self.rho_xy = rho_xy
මිශ්රණයේබෙදාහැරීම් ගණන,
148 @property
149 def n_distributions(self):
151 return self.pi_logits.shape[-1]
උෂ්ණත්වයඅනුව සකසන්න
153 def set_temperature(self, temperature: float):
158 self.pi_logits /= temperature
160 self.sigma_x *= math.sqrt(temperature)
162 self.sigma_y *= math.sqrt(temperature)
164 def get_distribution(self):
කලම්ප , සහ NaN
s ලබා වළක්වා ගැනීමට
166 sigma_x = torch.clamp_min(self.sigma_x, 1e-5)
167 sigma_y = torch.clamp_min(self.sigma_y, 1e-5)
168 rho_xy = torch.clamp(self.rho_xy, -1 + 1e-5, 1 - 1e-5)
මාධ්යයන්ලබා ගන්න
171 mean = torch.stack([self.mu_x, self.mu_y], -1)
කෝවිචියන්ස්අනුකෘතිය ලබා ගන්න
173 cov = torch.stack([
174 sigma_x * sigma_x, rho_xy * sigma_x * sigma_y,
175 rho_xy * sigma_x * sigma_y, sigma_y * sigma_y
176 ], -1)
177 cov = cov.view(*sigma_y.shape, 2, 2)
ද්වි-විචල්යසාමාන්ය ව්යාප්තියක් සාදන්න.
📝එය [[a, 0], [b, c]]
කොහෙද ලෙස scale_tril
අනුකෘතිය කාර්යක්ෂම වනු ඇත. නමුත් සරල බව සඳහා අපි සම-විචල්යතා අනුකෘතිය භාවිතා කරමු. ද්වි-විචල්ය බෙදාහැරීම්, ඒවායේ සම-විචල්යතා අනුකෘතිය සහසම්භාවිතා dens නත්ව ක්රියාකාරිත්වය ගැන වැඩිදුර කියවීමට ඔබට අවශ්ය නම් මෙය හොඳ සම්පතකි .
188 multi_dist = torch.distributions.MultivariateNormal(mean, covariance_matrix=cov)
පිවිසුම් වලින් වර්ගීකරණ බෙදාහැරීමක් සාදන්න
191 cat_dist = torch.distributions.Categorical(logits=self.pi_logits)
194 return cat_dist, multi_dist
197class EncoderRNN(Module):
204 def __init__(self, d_z: int, enc_hidden_size: int):
205 super().__init__()
ආදානය ලෙස අනුක්රමයක් ගනිමින් ද්විපාර්ශ්වික LSTM සාදන්න.
208 self.lstm = nn.LSTM(5, enc_hidden_size, bidirectional=True)
ලබාගැනීමට ප්රධානියා
210 self.mu_head = nn.Linear(2 * enc_hidden_size, d_z)
ලබාගැනීමට ප්රධානියා
212 self.sigma_head = nn.Linear(2 * enc_hidden_size, d_z)
214 def forward(self, inputs: torch.Tensor, state=None):
ද්විපාර්ශ්විකඑල්එස්ටීඑම් හි සැඟවුණු තත්වය යනු අවසාන ටෝකනයේ ප්රතිදානය ඉදිරි දිශාවට සහ ප්රතිලෝම දිශාවට පළමු ටෝකනය සංයුක්ත කිරීමයි, එය අපට අවශ්ය දෙයයි.
221 _, (hidden, cell) = self.lstm(inputs.float(), state)
රාජ්යයටහැඩය ඇත [2, batch_size, hidden_size]
, එහිදී පළමු මානය දිශාව වේ. ලබා ගැනීම සඳහා අපි එය නැවත සකස්
225 hidden = einops.rearrange(hidden, 'fb b h -> b (fb h)')
228 mu = self.mu_head(hidden)
230 sigma_hat = self.sigma_head(hidden)
232 sigma = torch.exp(sigma_hat / 2.)
නියැදිය
235 z = mu + sigma * torch.normal(mu.new_zeros(mu.shape), mu.new_ones(mu.shape))
238 return z, mu, sigma_hat
241class DecoderRNN(Module):
248 def __init__(self, d_z: int, dec_hidden_size: int, n_distributions: int):
249 super().__init__()
LSTMආදානය ලෙස ගනී
251 self.lstm = nn.LSTM(d_z + 5, dec_hidden_size)
LSTMහි ආරම්භක තත්වය වේ . init_state
මේ සඳහා රේඛීය පරිවර්තනයයි
255 self.init_state = nn.Linear(d_z, 2 * dec_hidden_size)
මෙමස්තරය එක් එක් සඳහා ප්රතිදානයන් නිෂ්පාදනය කරයි n_distributions
. සෑම ව්යාප්තියකටම පරාමිතීන් හයක් අවශ්ය
260 self.mixtures = nn.Linear(dec_hidden_size, 6 * n_distributions)
මෙමහිස පිවිසුම් සඳහා වේ
263 self.q_head = nn.Linear(dec_hidden_size, 3)
මෙය කොතැනද යන්න ගණනය කිරීමයි
266 self.q_log_softmax = nn.LogSoftmax(-1)
මෙමපරාමිතීන් අනාගත යොමු කිරීම සඳහා ගබඩා කර ඇත
269 self.n_distributions = n_distributions
270 self.dec_hidden_size = dec_hidden_size
272 def forward(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]):
ආරම්භකතත්වය ගණනය කරන්න
274 if state is None:
276 h, c = torch.split(torch.tanh(self.init_state(z)), self.dec_hidden_size, 1)
h
සහ c
හැඩයන් [batch_size, lstm_size]
ඇත. LSTM හි භාවිතා කරන හැඩය [1, batch_size, lstm_size]
නිසා අපට ඒවා හැඩගස්වා ගැනීමට අවශ්යය.
279 state = (h.unsqueeze(0).contiguous(), c.unsqueeze(0).contiguous())
LSTMධාවනය කරන්න
282 outputs, state = self.lstm(x, state)
ලබාගන්න
285 q_logits = self.q_log_softmax(self.q_head(outputs))
ලබාගන්න . torch.split
ප්රතිදානය මානයක් self.n_distribution
හරහා ප්රමාණය tensors 6 බවට splits 2
.
291 pi_logits, mu_x, mu_y, sigma_x, sigma_y, rho_xy = \
292 torch.split(self.mixtures(outputs), self.n_distributions, 2)
ද්වි-විචල්යGaussian මිශ්රණයක් සාදන්න සහ කොහේද සහ
මිශ්රණයෙන් බෙදා හැරීම තෝරා ගැනීමේ වර්ගීකරණ සම්භාවිතාවන් වේ.
305 dist = BivariateGaussianMixture(pi_logits, mu_x, mu_y,
306 torch.exp(sigma_x), torch.exp(sigma_y), torch.tanh(rho_xy))
309 return dist, q_logits, state
312class ReconstructionLoss(Module):
317 def forward(self, mask: torch.Tensor, target: torch.Tensor,
318 dist: 'BivariateGaussianMixture', q_logits: torch.Tensor):
ලබා ගන්න
320 pi, mix = dist.get_distribution()
target
අවසාන මානය ලක්ෂණ [seq_len, batch_size, 5]
වන හැඩය ඇත . අපට අවශ්ය වන්නේ y ලබා ගැනීමට සහ මිශ්රණයේ ඇති එක් එක් බෙදාහැරීම් වලින් සම්භාවිතාවන් ලබා ගැනීමයි .
xy
හැඩය ඇත [seq_len, batch_size, n_distributions, 2]
327 xy = target[:, :, 0:2].unsqueeze(-2).expand(-1, -1, dist.n_distributions, -1)
සම්භාවිතාවගණනය කරන්න
333 probs = torch.sum(pi.probs * torch.exp(mix.log_prob(xy)), 2)
(longest_seq_len
) මූලද්රව්ය probs
තිබුණද, එකතුව ගනු ලබන්නේ ඉතිරිය වන බැවිනි වෙස් වලාගෙන.
අපඑකතුව ගෙන බෙදිය යුතු යැයි හැඟෙනු ඇත , නමුත් මෙය කෙටි අනුපිළිවෙලින් තනි අනාවැකි සඳහා වැඩි බරක් ලබා දෙනු ඇත. අපි බෙදෙන විට අපි එක් එක් අනාවැකිය සමාන බර
දෙන්න342 loss_stroke = -torch.mean(mask * torch.log(1e-5 + probs))
345 loss_pen = -torch.mean(target[:, :, 2:] * q_logits)
348 return loss_stroke + loss_pen
351class KLDivLoss(Module):
358 def forward(self, sigma_hat: torch.Tensor, mu: torch.Tensor):
360 return -0.5 * torch.mean(1 + sigma_hat - mu ** 2 - torch.exp(sigma_hat))
363class Sampler:
370 def __init__(self, encoder: EncoderRNN, decoder: DecoderRNN):
371 self.decoder = decoder
372 self.encoder = encoder
374 def sample(self, data: torch.Tensor, temperature: float):
376 longest_seq_len = len(data)
එන්කෝඩරයෙන් ලබා ගන්න
379 z, _, _ = self.encoder(data)
ආරම්භකඅනුක්රමය ආ roke ාතය වේ
382 s = data.new_tensor([0, 0, 1, 0, 0])
383 seq = [s]
ආරම්භකවිකේතකය වේ None
. විකේතකය එය ආරම්භ කරනු ඇත
386 state = None
අපටඅනුක්රමික අවශ්ය නොවේ
389 with torch.no_grad():
නියැදි පහරවල්
391 for i in range(longest_seq_len):
යනු විකේතකයට ආදානය කිරීමයි
393 data = torch.cat([s.view(1, 1, -1), z.unsqueeze(0)], 2)
විකේතකයෙන් ඊළඟ තත්වය ලබා ගන්න
396 dist, q_logits, state = self.decoder(data, z, state)
ආඝාතයසාම්පලයක්
398 s = self._sample_step(dist, q_logits, temperature)
ආro ාත අනුපිළිවෙලට නව ආ roke ාතය එක් කරන්න
400 seq.append(s)
නියැදීමනවත්වන්න නම් . මෙයින් ඇඟවෙන්නේ ස්කීච් කිරීම නතර වී ඇති
බවයි402 if s[4] == 1:
403 break
ආro ාත අනුපිළිවෙලෙහි පයිටෝච් ටෙන්සරයක් සාදන්න
406 seq = torch.stack(seq)
ආro ාත අනුපිළිවෙල සැලසුම් කරන්න
409 self.plot(seq)
411 @staticmethod
412 def _sample_step(dist: 'BivariateGaussianMixture', q_logits: torch.Tensor, temperature: float):
නියැදීම් සඳහා උෂ්ණත්වය සකසන්න. මෙය පන්තියේ ක්රියාත්මක වේ BivariateGaussianMixture
.
414 dist.set_temperature(temperature)
උෂ්ණත්වයසකස් කර ලබා ගන්න
416 pi, mix = dist.get_distribution()
මිශ්රණයෙන්භාවිතා කිරීම සඳහා බෙදා හැරීමේ දර්ශකයෙන් නියැදිය
418 idx = pi.sample()[0, 0]
ලොග්-සම්භාවිතාවන් q_logits
සමඟ වර්ගීකරණ බෙදාහැරීමක් සාදන්න
421 q = torch.distributions.Categorical(logits=q_logits / temperature)
වෙතින්නියැදිය
423 q_idx = q.sample()[0, 0]
මිශ්රණයේසාමාන්ය බෙදාහැරීම් වලින් නියැදිය සහ සුචිගත කරන ලද එක තෝරන්න idx
426 xy = mix.sample()[0, 0, idx]
හිස්ආඝාතය සාදන්න
429 stroke = q_logits.new_zeros(5)
සකසන්න
431 stroke[:2] = xy
සකසන්න
433 stroke[q_idx + 2] = 1
435 return stroke
437 @staticmethod
438 def plot(seq: torch.Tensor):
ලබාගැනීම සඳහා සමුච්චිත සාරාංශ ගන්න
440 seq[:, 0:2] = torch.cumsum(seq[:, 0:2], dim=0)
පෝරමයේනව අංකුර අරා සාදන්න
442 seq[:, 2] = seq[:, 3]
443 seq = seq[:, 0:3].detach().cpu().numpy()
කොහෙද ලකුණු දී අරාව බෙදන්න . i.e. පෑන කඩදාසි සිට ඔසවා එහිදී ලකුණු දී පිලි අරාව බෙදී. මෙය ආ ro ාත අනුපිළිවෙල ලැයිස්තුවක් ලබා දෙයි.
448 strokes = np.split(seq, np.where(seq[:, 2] > 0)[0] + 1)
පහරවල්එක් එක් අනුපිළිවෙල සැලසුම් කරන්න
450 for s in strokes:
451 plt.plot(s[:, 0], -s[:, 1])
අක්ෂපෙන්වන්න එපා
453 plt.axis('off')
කුමන්ත්රණයපෙන්වන්න
455 plt.show()
458class Configs(TrainValidConfigs):
අත්හදාබැලීම ක්රියාත්මක කිරීම සඳහා උපාංගය තෝරා ගැනීමට උපාංග වින්යාසයන්
466 device: torch.device = DeviceConfigs()
468 encoder: EncoderRNN
469 decoder: DecoderRNN
470 optimizer: optim.Adam
471 sampler: Sampler
472
473 dataset_name: str
474 train_loader: DataLoader
475 valid_loader: DataLoader
476 train_dataset: StrokesDataset
477 valid_dataset: StrokesDataset
එන්කෝඩර්සහ විකේතක ප්රමාණ
480 enc_hidden_size = 256
481 dec_hidden_size = 512
කණ්ඩායම්ප්රමාණය
484 batch_size = 100
විශේෂාංගගණන
487 d_z = 128
මිශ්රණයේබෙදාහැරීම් ගණන,
489 n_distributions = 20
KLඅපසරනය අඞු කිරීමට බර,
492 kl_div_loss_weight = 0.5
ශ්රේණියේක්ලිපින්
494 grad_clip = 1.
නියැදීම සඳහා උෂ්ණත්වය
496 temperature = 0.4
වඩාදිගු ආ roke ාත අනුපිළිවෙල පෙරහන් කරන්න
499 max_seq_length = 200
500
501 epochs = 100
502
503 kl_div_loss = KLDivLoss()
504 reconstruction_loss = ReconstructionLoss()
506 def init(self):
එන්කෝඩරයසහ විකේතකය ආරම්භ කරන්න
508 self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
509 self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)
ප්රශස්තකරණයසකසන්න. ප්රශස්තිකරණ වර්ගය සහ ඉගෙනුම් අනුපාතය වැනි දේවල් වින්යාසගත කළ හැකිය
512 optimizer = OptimizerConfigs()
513 optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
514 self.optimizer = optimizer
නියැදියසාදන්න
517 self.sampler = Sampler(self.encoder, self.decoder)
npz
ගොනු මාර්ගය වේ data/sketch/[DATASET NAME].npz
520 path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz'
අංකිතගොනුව පූරණය කරන්න
522 dataset = np.load(str(path), encoding='latin1', allow_pickle=True)
පුහුණුදත්ත සමුදාය සාදන්න
525 self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length)
වලංගුදත්ත සමුදාය සාදන්න
527 self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale)
පුහුණුදත්ත පැටවුම සාදන්න
530 self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True)
වලංගුදත්ත පැටවුම සාදන්න
532 self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)
ටෙන්සෝර්බෝඩ්හි ස්ථර ප්රතිදානයන් නිරීක්ෂණය කිරීම සඳහා කොකු එක් කරන්න
535 hook_model_outputs(self.mode, self.encoder, 'encoder')
536 hook_model_outputs(self.mode, self.decoder, 'decoder')
සම්පූර්ණදුම්රිය/වලංගු කිරීමේ අලාභය මුද්රණය කිරීම සඳහා ට්රැකර් සකසන්න
539 tracker.set_scalar("loss.total.*", True)
540
541 self.state_modules = []
543 def step(self, batch: Any, batch_idx: BatchIndex):
544 self.encoder.train(self.mode.is_train)
545 self.decoder.train(self.mode.is_train)
උපාංගය mask
වෙත ගෙන data
ගොස් අනුක්රමය සහ කණ්ඩායම් මානයන් මාරු කරන්න. data
හැඩය ඇති [seq_len, batch_size, 5]
අතර හැඩය mask
ඇත [seq_len, batch_size]
.
550 data = batch[0].to(self.device).transpose(0, 1)
551 mask = batch[1].to(self.device).transpose(0, 1)
පුහුණුමාදිලියේ වර්ධක පියවර
554 if self.mode.is_train:
555 tracker.add_global_step(len(data))
ආro ාත අනුපිළිවෙල කේතනය කරන්න
558 with monit.section("encoder"):
ලබාගන්න , සහ
560 z, mu, sigma_hat = self.encoder(data)
බෙදාහැරීම්මිශ්රණය විකේතනය කිරීම සහ
563 with monit.section("decoder"):
කොන්කැටෙනේට්
565 z_stack = z.unsqueeze(0).expand(data.shape[0] - 1, -1, -1)
566 inputs = torch.cat([data[:-1], z_stack], 2)
බෙදාහැරීම්මිශ්රණය ලබා ගන්න
568 dist, q_logits, _ = self.decoder(inputs, z, None)
අලාභයගණනය කරන්න
571 with monit.section('loss'):
573 kl_loss = self.kl_div_loss(sigma_hat, mu)
575 reconstruction_loss = self.reconstruction_loss(mask, data[1:], dist, q_logits)
577 loss = reconstruction_loss + self.kl_div_loss_weight * kl_loss
පාඩුලුහුබඳින්න
580 tracker.add("loss.kl.", kl_loss)
581 tracker.add("loss.reconstruction.", reconstruction_loss)
582 tracker.add("loss.total.", loss)
අපපුහුණු තත්වයේ සිටී නම් පමණි
585 if self.mode.is_train:
ධාවනයප්රශස්තකරණය
587 with monit.section('optimize'):
grad
ශුන්යයට සකසන්න
589 self.optimizer.zero_grad()
අනුක්රමිකගණනය
591 loss.backward()
ලොග්ආදර්ශ පරාමිතීන් සහ අනුක්රමික
593 if batch_idx.is_last:
594 tracker.add(encoder=self.encoder, decoder=self.decoder)
ක්ලිප්අනුක්රමික
596 nn.utils.clip_grad_norm_(self.encoder.parameters(), self.grad_clip)
597 nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip)
ප්රශස්තකරන්න
599 self.optimizer.step()
600
601 tracker.save()
603 def sample(self):
අහඹුලෙස වලංගු දත්ත කට්ටලයේ සිට ආකේතකය දක්වා නියැදියක් තෝරන්න
605 data, *_ = self.valid_dataset[np.random.choice(len(self.valid_dataset))]
කණ්ඩායම්මානයන් එක් කර එය උපාංගයට ගෙන යන්න
607 data = data.unsqueeze(1).to(self.device)
නියැදිය
609 self.sampler.sample(data, self.temperature)
612def main():
613 configs = Configs()
614 experiment.create(name="sketch_rnn")
වින්යාසකිරීමේ ශබ්දකෝෂයක් සම්මත කරන්න
617 experiment.configs(configs, {
618 'optimizer.optimizer': 'Adam',
ප්රතිresults ල වේගයෙන් දැකිය හැකි 1e-3
නිසා අපි ඉගෙනුම් අනුපාතයක් භාවිතා කරමු. කඩදාසි යෝජනා කර 1e-4
ඇත.
621 'optimizer.learning_rate': 1e-3,
දත්තසමුදාය නම
623 'dataset_name': 'bicycle',
පුහුණුව, වලංගු කිරීම සහ නියැදීම අතර මාරුවීම සඳහා එපෝච් තුළ අභ්යන්තර පුනරාවර්තන ගණන.
625 'inner_iterations': 10
626 })
627
628 with experiment.start():
අත්හදාබැලීම ක්රියාත්මක කරන්න
630 configs.run()
631
632
633if __name__ == "__main__":
634 main()