වර්ගීකරණඅවිනිශ්චිතතාව ප්රමාණ කිරීම සඳහා ගැඹුරු ඉගෙනීම

මෙය වර්ගීකරණ අවිනිශ්චිතතාව ප්රමාණ කිරීම සඳහා පැහැදිලි ගැඹුරු ඉගෙනුම් කඩදාසි ක්රියාත්මක කිරීම PyTorch ක්රියාත්මක කිරීමයි.

Dampster-Shafer Theory of Evidence of Dampster-Shafer Theory of Dampster-Shafer Theory of Evidence සියලු උප කුලවල ස්කන්ධයන්ගේ එකතුව වේ . තනි පුද්ගල පන්ති සම්භාවිතාවන් (plausibilities) මෙම ස්කන්ධයන්ගෙන් ලබා ගත හැකිය.

සියලුමපංතිවල කට්ටලයට ස්කන්ධයක් පැවරීම යන්නෙන් අදහස් කරන්නේ එය ඕනෑම පන්තියක් විය හැකි බවයි; එනම් “මම නොදනිමි” යැයි පැවසීම.

පංති තිබේ නම්, අපි එක් එක් පංතිවලට ස්කන්ධයන් සහ සියලු පංතිවලට සමස්ත අවිනිශ්චිත ස්කන්ධයක් පවරමු.

විශ්වාසජනතාව සහ සාක්ෂි වලින් ගණනය කළ හැකිය , ලෙස සහ කොතැනද . යම් පන්තියකට වර්ගීකරණය කිරීම සඳහා නියැදියකට පක්ෂව දත්ත වලින් එකතු කරන ලද ආධාරක ප්රමාණය මැනීමක් ලෙස කඩදාසි කාලීන සාක්ෂි භාවිතා කරයි.

මෙයපරාමිතීන් සහිත ඩයිරිච්ලට් ව්යාප්තියට අනුරූප වන අතර එය ඩයිරිච්ලට් ශක්තිය ලෙස හැඳින්වේ. ඩයිරිච්ලට් බෙදා හැරීම යනු වර්ගීකරණ බෙදාහැරීමකට වඩා බෙදා හැරීමකි; i.e. ඔබට ඩයිරිච්ලට් බෙදාහැරීමෙන් පන්ති සම්භාවිතාව සාම්පල ලබා ගත හැකිය. පන්තිය සඳහා අපේක්ෂිත සම්භාවිතාව වේ .

දීඇති ආදානයක් සඳහා සාක්ෂි ප්රතිදානය කිරීමේ ආකෘතිය අපට ලැබේ. ලබා ගැනීම සඳහා අපි අවසාන ස්ථරයේ RelU හෝ සොෆ්ට්ප්ලස් වැනි ශ්රිතයක් භාවිතා කරමු .

අපිපහත ක්රියාත්මක කර ඇති ආකෘතිය පුහුණු කිරීම සඳහා පාඩු කාර්යයන් කිහිපයක් පත්රිකාව යෝජනා කරයි.

MNISTදත්ත කට්ටලයේ ආකෘතියක් පුහුණු experiment.py කිරීම සඳහා පුහුණු කේතය මෙන්න.

View Run

54import torch
55
56from labml import tracker
57from labml_helpers.module import Module

IIවර්ගය උපරිම සම්භාවිතාව අඞු කිරීමට

බෙදාහැරීම සම්භාවිතාව මත පෙර වන අතර , සෘණ ලඝු-සටහන ආන්තික සම්භාවිතාව පන්ති සම්භාවිතාව කට ඒකාබද්ධ විසින් ගණනය කරනු ලැබේ .

ඉලක්කසම්භාවිතාවන් (එක්-උණුසුම් ඉලක්ක) දී ඇති නියැදියක් සඳහා නම් අලාභය නම්,

60class MaximumLikelihoodLoss(Module):
  • evidence හැඩය සමඟ ඇත [batch_size, n_classes]
  • target හැඩය සමඟ ඇත [batch_size, n_classes]
85    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

91        alpha = evidence + 1.

93        strength = alpha.sum(dim=-1)

පාඩු

96        loss = (target * (strength.log()[:, None] - alpha.log())).sum(dim=-1)

කණ්ඩායමටවඩා මධ්යන්ය අලාභය

99        return loss.mean()

කුරුසඑන්ට්රොපි අඞු කිරීමට සමග අවදානම් Bayes

බේස්අවදානම යනු වැරදි ඇස්තමේන්තු සකස් කිරීමේ සමස්ත උපරිම පිරිවැයයි. වැරදි තක්සේරුවක් කිරීමේ පිරිවැය පිරිවැය ලබා දෙන පිරිවැය ශ්රිතයක් ගන්නා අතර සම්භාවිතා ව්යාප්තිය මත පදනම්ව හැකි සෑම ප්රති come ලයකටම වඩා එය සාරාංශ කරයි.

මෙන්නපිරිවැය ශ්රිතය හරස් එන්ට්රොපි අලාභයයි, එක්-උණුසුම් කේතනය කර ඇත

අපිමෙම පිරිවැය සියල්ලටම වඩා ඒකාබද්ධ කරමු

ශ්රිතය කොහේද?

102class CrossEntropyBayesRisk(Module):
  • evidence හැඩය සමඟ ඇත [batch_size, n_classes]
  • target හැඩය සමඟ ඇත [batch_size, n_classes]
132    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

138        alpha = evidence + 1.

140        strength = alpha.sum(dim=-1)

පාඩු

143        loss = (target * (torch.digamma(strength)[:, None] - torch.digamma(alpha))).sum(dim=-1)

කණ්ඩායමටවඩා මධ්යන්ය අලාභය

146        return loss.mean()

චතුරස්රාකාරදෝෂ නැතිවීම සමඟ බේස් අවදානම

මෙහිපිරිවැය ශ්රිතය කොටු දෝෂයකි,

අපිමෙම පිරිවැය සියල්ලටම වඩා ඒකාබද්ධ කරමු

ඩයිරිච්ලට්ව්යාප්තියෙන් නියැදි කළ විට අපේක්ෂිත සම්භාවිතාව කොතැනද සහ විචලතාව කොහේද?

මෙයලබා දෙයි,

සමීකරණයේමෙම පළමු කොටස දෝෂ පදය වන අතර දෙවන කොටස විචලනය වේ.

149class SquaredErrorBayesRisk(Module):
  • evidence හැඩය සමඟ ඇත [batch_size, n_classes]
  • target හැඩය සමඟ ඇත [batch_size, n_classes]
195    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

201        alpha = evidence + 1.

203        strength = alpha.sum(dim=-1)

205        p = alpha / strength[:, None]

දෝෂයකි

208        err = (target - p) ** 2

විචලතාව

210        var = p * (1 - p) / (strength[:, None] + 1)

ඒවායේඑකතුව

213        loss = (err + var).sum(dim=-1)

කණ්ඩායමටවඩා මධ්යන්ය අලාභය

216        return loss.mean()

KLඅපසරනය නියාමනය කිරීමේ අලාභය

නියැදියනිවැරදිව වර්ගීකරණය කළ නොහැකි නම් සම්පූර්ණ සාක්ෂි ශුන්යයට හැකිලීමට මෙය උත්සාහ කරයි.

පළමුවඅපි නිවැරදි සාක්ෂි ඉවත් කිරීමෙන් පසු ඩයිරිච්ලට් පරාමිතීන් ගණනය කරමු.

ගැමා ශ්රිතය කොහෙද, ශ්රිතය සහ

219class KLDivergenceLoss(Module):
  • evidence හැඩය සමඟ ඇත [batch_size, n_classes]
  • target හැඩය සමඟ ඇත [batch_size, n_classes]
243    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

249        alpha = evidence + 1.

පන්තිගණන

251        n_classes = evidence.shape[-1]

නොමඟයවන සාක්ෂි ඉවත් කරන්න

254        alpha_tilde = target + (1 - target) * alpha

256        strength_tilde = alpha_tilde.sum(dim=-1)

පළමුපදය

267        first = (torch.lgamma(alpha_tilde.sum(dim=-1))
268                 - torch.lgamma(alpha_tilde.new_tensor(float(n_classes)))
269                 - (torch.lgamma(alpha_tilde)).sum(dim=-1))

දෙවනවාරය

274        second = (
275                (alpha_tilde - 1) *
276                (torch.digamma(alpha_tilde) - torch.digamma(strength_tilde)[:, None])
277        ).sum(dim=-1)

කොන්දේසිඑකතුව

280        loss = first + second

කණ්ඩායමටවඩා මධ්යන්ය අලාභය

283        return loss.mean()

සංඛ්යාලේඛනනිරීක්ෂණය කරන්න

මෙමමොඩියුලය සංඛ්යාලේඛන ගණනය කර ඒවා ලැබ්මිලිසමඟ නිරීක්ෂණය කරයි tracker .

286class TrackStatistics(Module):
294    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

පන්තිගණන

296        n_classes = evidence.shape[-1]

ඉලක්කයසමඟ නිවැරදිව ගැලපෙන අනාවැකි (වැඩිම සම්භාවිතාව මත පදනම්ව කෑදර නියැදීම්)

298        match = evidence.argmax(dim=-1).eq(target.argmax(dim=-1))

ලුහුබැඳීමේනිරවද්යතාවය

300        tracker.add('accuracy.', match.sum() / match.shape[0])

303        alpha = evidence + 1.

305        strength = alpha.sum(dim=-1)

308        expected_probability = alpha / strength[:, None]

තෝරාගත්(කෑදර හයිසෙට් සම්භාවිතාව) පන්තියේ අපේක්ෂිත සම්භාවිතාව

310        expected_probability, _ = expected_probability.max(dim=-1)

අවිනිශ්චිතතාස්කන්ධය

313        uncertainty_mass = n_classes / strength

නිවැරදිවඅනාවැකි සඳහා ලුහුබඳින්න

316        tracker.add('u.succ.', uncertainty_mass.masked_select(match))

වැරදිඅනාවැකි සඳහා ලුහුබඳින්න

318        tracker.add('u.fail.', uncertainty_mass.masked_select(~match))

නිවැරදිවඅනාවැකි සඳහා ලුහුබඳින්න

320        tracker.add('prob.succ.', expected_probability.masked_select(match))

වැරදිඅනාවැකි සඳහා ලුහුබඳින්න

322        tracker.add('prob.fail.', expected_probability.masked_select(~match))