න්යෂ්ටිකනියැදීම

මෙයන්යෂ්ටික නියැදීම් ක්රියාත්මක කිරීමක් වන අතර එය කඩදාසි වලින් හඳුන්වා දී ඇත ස්නායු පෙළ පරිහානිය පිළිබඳ කුතුහලය.

කදම්බසෙවීම, පිරිසිදු නියැදීම, උෂ්ණත්ව නියැදීම සහ ඉහළ කේනියැදීමවැනි වෙනත් නියැදික්රමවල ඇති ගැටළු පිළිබඳව කඩදාසි සාකච්ඡා කරයි. මෙම පත්රිකාව න්යෂ්ටික නියැදීම් පිළිබඳ අදහස හඳුන්වා දෙයි, එය පෙළ උත්පාදනය සඳහා වෙනත් නියැදි ක්රමවලට වඩා ප්රායෝගිකව ක්රියා කරයි.

න්යෂ්ටිකනියැදීම පළමුව වචන මාලාවේ උප කුලකයක් තෝරා ගනී , එහිදී කුඩාම ටෝකන කට්ටලයක්

එනම්, ඒවායේ සම්භාවිතාවන්ගේ එකතුව අඩු වන තෙක් අපි ඉහළම සම්භාවිතාව සහිත ටෝකන තෝරා ගනිමු .

ඉන්පසුඅපි තෝරාගත් ටෝකන වලින් සාම්පල ලබා ගනිමු.

මෙන්නමෙම නියැදි ශිල්පීය ක්රම භාවිතා කරන අත්හදා බැලීමක් .

29import torch
30from torch import nn
31
32from labml_nn.sampling import Sampler

න්යෂ්ටිකනියැදි

35class NucleusSampler(Sampler):
  • p යනු තෝරා ගැනීමට ටෝකන වල සම්භාවිතාවන්ගේ එකතුවයි
  • sampler තෝරාගත් ටෝකන සඳහා භාවිතා කළ යුතු නියැදිකරු වේ
  • 39    def __init__(self, p: float, sampler: Sampler):
    44        self.p = p
    45        self.sampler = sampler

    පිවිසුම් වලින් ගණනය කිරීමට සොෆ්ට්මැක්ස්

    47        self.softmax = nn.Softmax(dim=-1)

    න්යෂ්ටිකනියැදීම සමඟ පිවිසුම් වලින් නියැදිය

    49    def __call__(self, logits: torch.Tensor):

    සම්භාවිතාවලබා ගන්න

    55        probs = self.softmax(logits)

    බැසීමේඅනුපිළිවෙලෙහි සම්භාවිතාවන් වර්ග කරන්න

    58        sorted_probs, indices = torch.sort(probs, dim=-1, descending=True)

    වර්ගකළ අනුපිළිවෙලෙහි සමුච්චිත සම්භාවිතාවන්ගේ එකතුව ලබා ගන්න

    60        cum_sum_probs = torch.cumsum(sorted_probs, dim=-1)

    වඩාඅඩු සමුච්චිත මුදලක් සොයා ගන්න .

    62        nucleus = cum_sum_probs < self.p

    සමුච්චිතසම්භාවිතාව අඩු අවම ටෝකන සංඛ්යාවෙන් පසුව අපි එක් ටෝකනයක් එකතු කරන පරිදි ඒවා සකස් කරන්න .

    65        nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1)

    ලොග්සම්භාවිතාව ලබා ගන්න සහ න්යෂ්ටිය නොවන වසං කරන්න

    68        sorted_log_probs = torch.log(sorted_probs)
    69        sorted_log_probs[~nucleus] = float('-inf')

    නියැදියෙන්නියැදිය

    72        sampled_sorted_indexes = self.sampler(sorted_log_probs)

    සත්යදර්ශක ලබා ගන්න

    75        res = indices.gather(-1, sampled_sorted_indexes.unsqueeze(-1))

    78        return res.squeeze(-1)