කටුසටහනක් RNN

මෙයකඩදාසි පිළිබඳ විවරණ කරන ලද PyTorch ක්රියාත්මක කිරීම ස්කෙච් ඇඳීම්වල ස්නායුක නිරූපණයකි .

ස්කෙච්ආර්එන්එන් යනු අනුක්රම-සිට-අනුක්රමික විචල්යතා ස්වයංක්රීය ආකේතයකි. එන්කෝඩරය සහ විකේතකය යන දෙකම පුනරාවර්තන ස්නායුක ජාල ආකෘති වේ. ආ roke ාත මාලාවක් පුරෝකථනය කිරීමෙන් ආ roke ාතය පදනම් කරගත් සරල ඇඳීම් ප්රතිනිර්මාණය කිරීමට එය ඉගෙන ගනී. එක් එක් ආ roke ාතය ගවුසියානු මිශ්රණයක් ලෙස විකේතකය පුරෝකථනය කරයි.

දත්තලබා ගැනීම

ඉක්මන් වෙතින් දත්ත බාගන්න, අඳින්න! දත්ත කට්ටලය. Sketch-RNNහි npz ගොනු බාගත කිරීම සඳහා සබැඳියක් ඇත QuickDraw දත්ත සමුදාය කියවීමේ කොටස. බාගත කළ npz ගොනුව (ය) data/sketch ෆෝල්ඩරයේ තබන්න. මෙම කේතය bicycle දත්ත කට්ටලය භාවිතා කිරීමට වින්යාස කර ඇත. ඔබට මෙය වින්යාසයන් තුළ වෙනස් කළ හැකිය.

පිළිගැනීම්

ඇලෙක්සිස් ඩේවිඩ් ජැක් විසින් පයිටෝර්ච් ස්කෙච් ආර්එන්එන් ව්යාපෘතියෙන් උදව් ලබා ගත්තේය

32import math
33from typing import Optional, Tuple, Any
34
35import numpy as np
36import torch
37import torch.nn as nn
38from matplotlib import pyplot as plt
39from torch import optim
40from torch.utils.data import Dataset, DataLoader
41
42import einops
43from labml import lab, experiment, tracker, monit
44from labml_helpers.device import DeviceConfigs
45from labml_helpers.module import Module
46from labml_helpers.optimizer import OptimizerConfigs
47from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex

දත්තකට්ටලය

මෙමපන්තිය දත්ත පැටවීම සහ පූර්ව සැකසීම.

50class StrokesDataset(Dataset):

dataset Seq_len හැඩයේ හිරිවැටීම් අරා ලැයිස්තුවකි, 3. එය ආ ro ාත අනුපිළිවෙලක් වන අතර සෑම පහරක්ම නිඛිල 3 කින් නිරූපණය කෙරේ. පළමු දෙක x සහ y දිගේ විස්ථාපන වේ (, ) සහ අවසාන නිඛිලය පෑනෙහි තත්වය නිරූපණය කරයි, එය කඩදාසි ස්පර්ශ කරන්නේ නම් සහ එසේ නොමැතිනම්.

57    def __init__(self, dataset: np.array, max_seq_length: int, scale: Optional[float] = None):
67        data = []

අපිඑක් එක් අනුපිළිවෙල සහ පෙරීම හරහා නැවත නැවතත්

69        for seq in dataset:

ආro ාත අනුක්රමයේ දිග අපගේ පරාසය තුළ තිබේ නම් පෙරහන් කරන්න

71            if 10 < len(seq) <= max_seq_length:

කලම්ප , කිරීමට

73                seq = np.minimum(seq, 1000)
74                seq = np.maximum(seq, -1000)

පාවෙනලක්ෂ්ය අරාවකට පරිවර්තනය කර එකතු කරන්න data

76                seq = np.array(seq, dtype=np.float32)
77                data.append(seq)

එවිටඅපි (, ) ඒකාබද්ධ සම්මත අපගමනය වන පරිමාණ සාධකය ගණනය. මධ්යන්යය කෙසේ වෙතත් සමීප බැවින් සරල බව සඳහා මධ්යන්යය සකස් කර නොමැති බව කඩදාසි සටහන් කරයි.

83        if scale is None:
84            scale = np.std(np.concatenate([np.ravel(s[:, 0:2]) for s in data]))
85        self.scale = scale

සියලුමඅනුපිළිවෙලවල් අතර දිගම අනුක්රමික දිග ලබා ගන්න

88        longest_seq_len = max([len(seq) for seq in data])

ආරම්භකඅනුක්රමය (sos) සහ අවසාන අනුක්රමය (eos) සඳහා අමතර පියවර දෙකක් සමඟ අපි PyTorch දත්ත අරාව ආරම්භ කරමු. සෑම පියවරක්ම දෛශිකයකි . එකක් පමණක් වන අතර අනෙක් ඒවා වේ . ඒවා නියෝජනය කරන්නේ පෑන පහළට, පෑන ඉහළටසහ එම අනුපිළිවෙලට අනුක්රමය අවසන්කිරීමයි. ඊළඟ පියවරේදී පෑන කඩදාසි ස්පර්ශ කරන්නේ නම්. ඊළඟ පියවරේදී පෑන කඩදාසි ස්පර්ශ නොකරන්නේ නම්. එය චිත්රයේ අවසානය නම් වේ.

98        self.data = torch.zeros(len(data), longest_seq_len + 2, 5, dtype=torch.float)

වෙස්අරා අවශ්ය වන්නේ එක් අමතර පියවරක් පමණි, මන්ද එය විකේතනයේ ප්රතිදානයන් සඳහා වන data[:-1] අතර එය ඊළඟ පියවර ගනී.

101        self.mask = torch.zeros(len(data), longest_seq_len + 1)
102
103        for i, seq in enumerate(data):
104            seq = torch.from_numpy(seq)
105            len_seq = len(seq)

පරිමාණයසහ කට්ටලය

107            self.data[i, 1:len_seq + 1, :2] = seq[:, :2] / scale

109            self.data[i, 1:len_seq + 1, 2] = 1 - seq[:, 2]

111            self.data[i, 1:len_seq + 1, 3] = seq[:, 2]

113            self.data[i, len_seq + 1:, 4] = 1

අනුක්රමයඅවසන් වන තෙක් මාස්ක් ක්රියාත්මක වේ

115            self.mask[i, :len_seq + 1] = 1

අනුපිළිවෙලආරම්භ කිරීම

118        self.data[:, 0, 2] = 1

දත්තසමුදාය ප්රමාණය

120    def __len__(self):
122        return len(self.data)

නියැදියක්ලබා ගන්න

124    def __getitem__(self, idx: int):
126        return self.data[idx], self.mask[idx]

ද්වි-විචල්යගවුසියානු මිශ්රණය

මිශ්රණයනියෝජනය වන්නේ සහ . මෙම පන්තිය උෂ්ණත්වය සකස් කරන අතර පරාමිතීන්ගෙන් වර්ගීකරණ හා ගුසියානු බෙදාහැරීම් නිර්මාණය කරයි.

129class BivariateGaussianMixture:
139    def __init__(self, pi_logits: torch.Tensor, mu_x: torch.Tensor, mu_y: torch.Tensor,
140                 sigma_x: torch.Tensor, sigma_y: torch.Tensor, rho_xy: torch.Tensor):
141        self.pi_logits = pi_logits
142        self.mu_x = mu_x
143        self.mu_y = mu_y
144        self.sigma_x = sigma_x
145        self.sigma_y = sigma_y
146        self.rho_xy = rho_xy

මිශ්රණයේබෙදාහැරීම් ගණන,

148    @property
149    def n_distributions(self):
151        return self.pi_logits.shape[-1]

උෂ්ණත්වයඅනුව සකසන්න

153    def set_temperature(self, temperature: float):

158        self.pi_logits /= temperature

160        self.sigma_x *= math.sqrt(temperature)

162        self.sigma_y *= math.sqrt(temperature)
164    def get_distribution(self):

කලම්ප , සහ NaN s ලබා වළක්වා ගැනීමට

166        sigma_x = torch.clamp_min(self.sigma_x, 1e-5)
167        sigma_y = torch.clamp_min(self.sigma_y, 1e-5)
168        rho_xy = torch.clamp(self.rho_xy, -1 + 1e-5, 1 - 1e-5)

මාධ්යයන්ලබා ගන්න

171        mean = torch.stack([self.mu_x, self.mu_y], -1)

කෝවිචියන්ස්අනුකෘතිය ලබා ගන්න

173        cov = torch.stack([
174            sigma_x * sigma_x, rho_xy * sigma_x * sigma_y,
175            rho_xy * sigma_x * sigma_y, sigma_y * sigma_y
176        ], -1)
177        cov = cov.view(*sigma_y.shape, 2, 2)

ද්වි-විචල්යසාමාන්ය ව්යාප්තියක් සාදන්න.

📝එය [[a, 0], [b, c]] කොහෙද ලෙස scale_tril අනුකෘතිය කාර්යක්ෂම වනු ඇත. නමුත් සරල බව සඳහා අපි සම-විචල්යතා අනුකෘතිය භාවිතා කරමු. ද්වි-විචල්ය බෙදාහැරීම්, ඒවායේ සම-විචල්යතා අනුකෘතිය සහසම්භාවිතා dens නත්ව ක්රියාකාරිත්වය ගැන වැඩිදුර කියවීමට ඔබට අවශ්ය නම් මෙය හොඳ සම්පතකි .

188        multi_dist = torch.distributions.MultivariateNormal(mean, covariance_matrix=cov)

පිවිසුම් වලින් වර්ගීකරණ බෙදාහැරීමක් සාදන්න

191        cat_dist = torch.distributions.Categorical(logits=self.pi_logits)

194        return cat_dist, multi_dist

එන්කෝඩර්මොඩියුලය

මෙයද්විපාර්ශ්වික LSTM කින් සමන්විත වේ

197class EncoderRNN(Module):
204    def __init__(self, d_z: int, enc_hidden_size: int):
205        super().__init__()

ආදානය ලෙස අනුක්රමයක් ගනිමින් ද්විපාර්ශ්වික LSTM සාදන්න.

208        self.lstm = nn.LSTM(5, enc_hidden_size, bidirectional=True)

ලබාගැනීමට ප්රධානියා

210        self.mu_head = nn.Linear(2 * enc_hidden_size, d_z)

ලබාගැනීමට ප්රධානියා

212        self.sigma_head = nn.Linear(2 * enc_hidden_size, d_z)
214    def forward(self, inputs: torch.Tensor, state=None):

ද්විපාර්ශ්විකඑල්එස්ටීඑම් හි සැඟවුණු තත්වය යනු අවසාන ටෝකනයේ ප්රතිදානය ඉදිරි දිශාවට සහ ප්රතිලෝම දිශාවට පළමු ටෝකනය සංයුක්ත කිරීමයි, එය අපට අවශ්ය දෙයයි.

221        _, (hidden, cell) = self.lstm(inputs.float(), state)

රාජ්යයටහැඩය ඇත [2, batch_size, hidden_size] , එහිදී පළමු මානය දිශාව වේ. ලබා ගැනීම සඳහා අපි එය නැවත සකස්

කරමු
225        hidden = einops.rearrange(hidden, 'fb b h -> b (fb h)')

228        mu = self.mu_head(hidden)

230        sigma_hat = self.sigma_head(hidden)

232        sigma = torch.exp(sigma_hat / 2.)

නියැදිය

235        z = mu + sigma * torch.normal(mu.new_zeros(mu.shape), mu.new_ones(mu.shape))

238        return z, mu, sigma_hat

විකේතකමොඩියුලය

මෙයLSTM කින් සමන්විත වේ

241class DecoderRNN(Module):
248    def __init__(self, d_z: int, dec_hidden_size: int, n_distributions: int):
249        super().__init__()

LSTMආදානය ලෙස ගනී

251        self.lstm = nn.LSTM(d_z + 5, dec_hidden_size)

LSTMහි ආරම්භක තත්වය වේ . init_state මේ සඳහා රේඛීය පරිවර්තනයයි

255        self.init_state = nn.Linear(d_z, 2 * dec_hidden_size)

මෙමස්තරය එක් එක් සඳහා ප්රතිදානයන් නිෂ්පාදනය කරයි n_distributions . සෑම ව්යාප්තියකටම පරාමිතීන් හයක් අවශ්ය

වේ
260        self.mixtures = nn.Linear(dec_hidden_size, 6 * n_distributions)

මෙමහිස පිවිසුම් සඳහා වේ

263        self.q_head = nn.Linear(dec_hidden_size, 3)

මෙය කොතැනද යන්න ගණනය කිරීමයි

266        self.q_log_softmax = nn.LogSoftmax(-1)

මෙමපරාමිතීන් අනාගත යොමු කිරීම සඳහා ගබඩා කර ඇත

269        self.n_distributions = n_distributions
270        self.dec_hidden_size = dec_hidden_size
272    def forward(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]):

ආරම්භකතත්වය ගණනය කරන්න

274        if state is None:

276            h, c = torch.split(torch.tanh(self.init_state(z)), self.dec_hidden_size, 1)

h සහ c හැඩයන් [batch_size, lstm_size] ඇත. LSTM හි භාවිතා කරන හැඩය [1, batch_size, lstm_size] නිසා අපට ඒවා හැඩගස්වා ගැනීමට අවශ්යය.

279            state = (h.unsqueeze(0).contiguous(), c.unsqueeze(0).contiguous())

LSTMධාවනය කරන්න

282        outputs, state = self.lstm(x, state)

ලබාගන්න

285        q_logits = self.q_log_softmax(self.q_head(outputs))

ලබාගන්න . torch.split ප්රතිදානය මානයක් self.n_distribution හරහා ප්රමාණය tensors 6 බවට splits 2 .

291        pi_logits, mu_x, mu_y, sigma_x, sigma_y, rho_xy = \
292            torch.split(self.mixtures(outputs), self.n_distributions, 2)

ද්වි-විචල්යGaussian මිශ්රණයක් සාදන්න සහ කොහේද සහ

මිශ්රණයෙන් බෙදා හැරීම තෝරා ගැනීමේ වර්ගීකරණ සම්භාවිතාවන් වේ.

305        dist = BivariateGaussianMixture(pi_logits, mu_x, mu_y,
306                                        torch.exp(sigma_x), torch.exp(sigma_y), torch.tanh(rho_xy))

309        return dist, q_logits, state

ප්රතිසංස්කරණඅලාභය

312class ReconstructionLoss(Module):
317    def forward(self, mask: torch.Tensor, target: torch.Tensor,
318                 dist: 'BivariateGaussianMixture', q_logits: torch.Tensor):

ලබා ගන්න

320        pi, mix = dist.get_distribution()

target අවසාන මානය ලක්ෂණ [seq_len, batch_size, 5] වන හැඩය ඇත . අපට අවශ්ය වන්නේ y ලබා ගැනීමට සහ මිශ්රණයේ ඇති එක් එක් බෙදාහැරීම් වලින් සම්භාවිතාවන් ලබා ගැනීමයි .

xy හැඩය ඇත [seq_len, batch_size, n_distributions, 2]

327        xy = target[:, :, 0:2].unsqueeze(-2).expand(-1, -1, dist.n_distributions, -1)

සම්භාවිතාවගණනය කරන්න

333        probs = torch.sum(pi.probs * torch.exp(mix.log_prob(xy)), 2)

(longest_seq_len ) මූලද්රව්ය probs තිබුණද, එකතුව ගනු ලබන්නේ ඉතිරිය වන බැවිනි වෙස් වලාගෙන.

අපඑකතුව ගෙන බෙදිය යුතු යැයි හැඟෙනු ඇත , නමුත් මෙය කෙටි අනුපිළිවෙලින් තනි අනාවැකි සඳහා වැඩි බරක් ලබා දෙනු ඇත. අපි බෙදෙන විට අපි එක් එක් අනාවැකිය සමාන බර

දෙන්න
342        loss_stroke = -torch.mean(mask * torch.log(1e-5 + probs))

345        loss_pen = -torch.mean(target[:, :, 2:] * q_logits)

348        return loss_stroke + loss_pen

එල්. එල්-අපසරනය අහිමි

මෙයලබා දී ඇති සාමාන්ය බෙදාහැරීමක් අතර KL අපසරනය ගණනය කරයි

351class KLDivLoss(Module):
358    def forward(self, sigma_hat: torch.Tensor, mu: torch.Tensor):

360        return -0.5 * torch.mean(1 + sigma_hat - mu ** 2 - torch.exp(sigma_hat))

නියැදිකරු

මෙයවිකේතකයෙන් රූප සටහනක් සාම්පල කර එය ගොඩබෑමට ලක් කරයි

363class Sampler:
370    def __init__(self, encoder: EncoderRNN, decoder: DecoderRNN):
371        self.decoder = decoder
372        self.encoder = encoder
374    def sample(self, data: torch.Tensor, temperature: float):

376        longest_seq_len = len(data)

එන්කෝඩරයෙන් ලබා ගන්න

379        z, _, _ = self.encoder(data)

ආරම්භකඅනුක්රමය ආ roke ාතය වේ

382        s = data.new_tensor([0, 0, 1, 0, 0])
383        seq = [s]

ආරම්භකවිකේතකය වේ None . විකේතකය එය ආරම්භ කරනු ඇත

386        state = None

අපටඅනුක්රමික අවශ්ය නොවේ

389        with torch.no_grad():

නියැදි පහරවල්

391            for i in range(longest_seq_len):

යනු විකේතකයට ආදානය කිරීමයි

393                data = torch.cat([s.view(1, 1, -1), z.unsqueeze(0)], 2)

විකේතකයෙන් ඊළඟ තත්වය ලබා ගන්න

396                dist, q_logits, state = self.decoder(data, z, state)

ආඝාතයසාම්පලයක්

398                s = self._sample_step(dist, q_logits, temperature)

ආro ාත අනුපිළිවෙලට නව ආ roke ාතය එක් කරන්න

400                seq.append(s)

නියැදීමනවත්වන්න නම් . මෙයින් ඇඟවෙන්නේ ස්කීච් කිරීම නතර වී ඇති

බවයි
402                if s[4] == 1:
403                    break

ආro ාත අනුපිළිවෙලෙහි පයිටෝච් ටෙන්සරයක් සාදන්න

406        seq = torch.stack(seq)

ආro ාත අනුපිළිවෙල සැලසුම් කරන්න

409        self.plot(seq)
411    @staticmethod
412    def _sample_step(dist: 'BivariateGaussianMixture', q_logits: torch.Tensor, temperature: float):

නියැදීම් සඳහා උෂ්ණත්වය සකසන්න. මෙය පන්තියේ ක්රියාත්මක වේ BivariateGaussianMixture .

414        dist.set_temperature(temperature)

උෂ්ණත්වයසකස් කර ලබා ගන්න

416        pi, mix = dist.get_distribution()

මිශ්රණයෙන්භාවිතා කිරීම සඳහා බෙදා හැරීමේ දර්ශකයෙන් නියැදිය

418        idx = pi.sample()[0, 0]

ලොග්-සම්භාවිතාවන් q_logits සමඟ වර්ගීකරණ බෙදාහැරීමක් සාදන්න

421        q = torch.distributions.Categorical(logits=q_logits / temperature)

වෙතින්නියැදිය

423        q_idx = q.sample()[0, 0]

මිශ්රණයේසාමාන්ය බෙදාහැරීම් වලින් නියැදිය සහ සුචිගත කරන ලද එක තෝරන්න idx

426        xy = mix.sample()[0, 0, idx]

හිස්ආඝාතය සාදන්න

429        stroke = q_logits.new_zeros(5)

සකසන්න

431        stroke[:2] = xy

සකසන්න

433        stroke[q_idx + 2] = 1

435        return stroke
437    @staticmethod
438    def plot(seq: torch.Tensor):

ලබාගැනීම සඳහා සමුච්චිත සාරාංශ ගන්න

440        seq[:, 0:2] = torch.cumsum(seq[:, 0:2], dim=0)

පෝරමයේනව අංකුර අරා සාදන්න

442        seq[:, 2] = seq[:, 3]
443        seq = seq[:, 0:3].detach().cpu().numpy()

කොහෙද ලකුණු දී අරාව බෙදන්න . i.e. පෑන කඩදාසි සිට ඔසවා එහිදී ලකුණු දී පිලි අරාව බෙදී. මෙය ආ ro ාත අනුපිළිවෙල ලැයිස්තුවක් ලබා දෙයි.

448        strokes = np.split(seq, np.where(seq[:, 2] > 0)[0] + 1)

පහරවල්එක් එක් අනුපිළිවෙල සැලසුම් කරන්න

450        for s in strokes:
451            plt.plot(s[:, 0], -s[:, 1])

අක්ෂපෙන්වන්න එපා

453        plt.axis('off')

කුමන්ත්රණයපෙන්වන්න

455        plt.show()

වින්යාසකිරීම්

මේවාපෙරනිමි වින්යාසයන් වන අතර ඒවා පසුව a පසුකර යාමෙන් සකස් කළ හැකිය dict .

458class Configs(TrainValidConfigs):

අත්හදාබැලීම ක්රියාත්මක කිරීම සඳහා උපාංගය තෝරා ගැනීමට උපාංග වින්යාසයන්

466    device: torch.device = DeviceConfigs()

468    encoder: EncoderRNN
469    decoder: DecoderRNN
470    optimizer: optim.Adam
471    sampler: Sampler
472
473    dataset_name: str
474    train_loader: DataLoader
475    valid_loader: DataLoader
476    train_dataset: StrokesDataset
477    valid_dataset: StrokesDataset

එන්කෝඩර්සහ විකේතක ප්රමාණ

480    enc_hidden_size = 256
481    dec_hidden_size = 512

කණ්ඩායම්ප්රමාණය

484    batch_size = 100

විශේෂාංගගණන

487    d_z = 128

මිශ්රණයේබෙදාහැරීම් ගණන,

489    n_distributions = 20

KLඅපසරනය අඞු කිරීමට බර,

492    kl_div_loss_weight = 0.5

ශ්රේණියේක්ලිපින්

494    grad_clip = 1.

නියැදීම සඳහා උෂ්ණත්වය

496    temperature = 0.4

වඩාදිගු ආ roke ාත අනුපිළිවෙල පෙරහන් කරන්න

499    max_seq_length = 200
500
501    epochs = 100
502
503    kl_div_loss = KLDivLoss()
504    reconstruction_loss = ReconstructionLoss()
506    def init(self):

එන්කෝඩරයසහ විකේතකය ආරම්භ කරන්න

508        self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
509        self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)

ප්රශස්තකරණයසකසන්න. ප්රශස්තිකරණ වර්ගය සහ ඉගෙනුම් අනුපාතය වැනි දේවල් වින්යාසගත කළ හැකිය

512        optimizer = OptimizerConfigs()
513        optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
514        self.optimizer = optimizer

නියැදියසාදන්න

517        self.sampler = Sampler(self.encoder, self.decoder)

npz ගොනු මාර්ගය වේ data/sketch/[DATASET NAME].npz

520        path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz'

අංකිතගොනුව පූරණය කරන්න

522        dataset = np.load(str(path), encoding='latin1', allow_pickle=True)

පුහුණුදත්ත සමුදාය සාදන්න

525        self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length)

වලංගුදත්ත සමුදාය සාදන්න

527        self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale)

පුහුණුදත්ත පැටවුම සාදන්න

530        self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True)

වලංගුදත්ත පැටවුම සාදන්න

532        self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)

ටෙන්සෝර්බෝඩ්හි ස්ථර ප්රතිදානයන් නිරීක්ෂණය කිරීම සඳහා කොකු එක් කරන්න

535        hook_model_outputs(self.mode, self.encoder, 'encoder')
536        hook_model_outputs(self.mode, self.decoder, 'decoder')

සම්පූර්ණදුම්රිය/වලංගු කිරීමේ අලාභය මුද්රණය කිරීම සඳහා ට්රැකර් සකසන්න

539        tracker.set_scalar("loss.total.*", True)
540
541        self.state_modules = []
543    def step(self, batch: Any, batch_idx: BatchIndex):
544        self.encoder.train(self.mode.is_train)
545        self.decoder.train(self.mode.is_train)

උපාංගය mask වෙත ගෙන data ගොස් අනුක්රමය සහ කණ්ඩායම් මානයන් මාරු කරන්න. data හැඩය ඇති [seq_len, batch_size, 5] අතර හැඩය mask ඇත [seq_len, batch_size] .

550        data = batch[0].to(self.device).transpose(0, 1)
551        mask = batch[1].to(self.device).transpose(0, 1)

පුහුණුමාදිලියේ වර්ධක පියවර

554        if self.mode.is_train:
555            tracker.add_global_step(len(data))

ආro ාත අනුපිළිවෙල කේතනය කරන්න

558        with monit.section("encoder"):

ලබාගන්න , සහ

560            z, mu, sigma_hat = self.encoder(data)

බෙදාහැරීම්මිශ්රණය විකේතනය කිරීම සහ

563        with monit.section("decoder"):

කොන්කැටෙනේට්

565            z_stack = z.unsqueeze(0).expand(data.shape[0] - 1, -1, -1)
566            inputs = torch.cat([data[:-1], z_stack], 2)

බෙදාහැරීම්මිශ්රණය ලබා ගන්න

568            dist, q_logits, _ = self.decoder(inputs, z, None)

අලාභයගණනය කරන්න

571        with monit.section('loss'):

573            kl_loss = self.kl_div_loss(sigma_hat, mu)

575            reconstruction_loss = self.reconstruction_loss(mask, data[1:], dist, q_logits)

577            loss = reconstruction_loss + self.kl_div_loss_weight * kl_loss

පාඩුලුහුබඳින්න

580            tracker.add("loss.kl.", kl_loss)
581            tracker.add("loss.reconstruction.", reconstruction_loss)
582            tracker.add("loss.total.", loss)

අපපුහුණු තත්වයේ සිටී නම් පමණි

585        if self.mode.is_train:

ධාවනයප්රශස්තකරණය

587            with monit.section('optimize'):

grad ශුන්යයට සකසන්න

589                self.optimizer.zero_grad()

අනුක්රමිකගණනය

591                loss.backward()

ලොග්ආදර්ශ පරාමිතීන් සහ අනුක්රමික

593                if batch_idx.is_last:
594                    tracker.add(encoder=self.encoder, decoder=self.decoder)

ක්ලිප්අනුක්රමික

596                nn.utils.clip_grad_norm_(self.encoder.parameters(), self.grad_clip)
597                nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip)

ප්රශස්තකරන්න

599                self.optimizer.step()
600
601        tracker.save()
603    def sample(self):

අහඹුලෙස වලංගු දත්ත කට්ටලයේ සිට ආකේතකය දක්වා නියැදියක් තෝරන්න

605        data, *_ = self.valid_dataset[np.random.choice(len(self.valid_dataset))]

කණ්ඩායම්මානයන් එක් කර එය උපාංගයට ගෙන යන්න

607        data = data.unsqueeze(1).to(self.device)

නියැදිය

609        self.sampler.sample(data, self.temperature)
612def main():
613    configs = Configs()
614    experiment.create(name="sketch_rnn")

වින්යාසකිරීමේ ශබ්දකෝෂයක් සම්මත කරන්න

617    experiment.configs(configs, {
618        'optimizer.optimizer': 'Adam',

ප්රතිresults ල වේගයෙන් දැකිය හැකි 1e-3 නිසා අපි ඉගෙනුම් අනුපාතයක් භාවිතා කරමු. කඩදාසි යෝජනා කර 1e-4 ඇත.

621        'optimizer.learning_rate': 1e-3,

දත්තසමුදාය නම

623        'dataset_name': 'bicycle',

පුහුණුව, වලංගු කිරීම සහ නියැදීම අතර මාරුවීම සඳහා එපෝච් තුළ අභ්යන්තර පුනරාවර්තන ගණන.

625        'inner_iterations': 10
626    })
627
628    with experiment.start():

අත්හදාබැලීම ක්රියාත්මක කරන්න

630        configs.run()
631
632
633if __name__ == "__main__":
634    main()