වොසර්ස්ටයින්GAN (WGAN)

මෙය වොසර්ස්ටයින් GANක්රියාත්මක කිරීමයි.

මුල්GAN අලාභය පදනම් වී ඇත්තේ සැබෑ බෙදා හැරීම සහ ජනනය කරන ලද බෙදා හැරීම අතර ජෙන්සන්-ෂැනන් (JS) අපසරනය මත ය. වොසර්ස්ටයින් GAN පදනම් වී ඇත්තේ මෙම බෙදාහැරීම් අතර පෘථිවි මූවර් දුර මත ය.

යනු සියළුම ඒකාබද්ධ බෙදාහැරීම් වල කට්ටලය වන අතර ඒවායේ ආන්තික සම්භාවිතාවන් වේ .

දී ඇති ඒකාබද්ධ ව්යාප්තියක් සඳහා පෘථිවි චලනය දුර වේ ( සහ සම්භාවිතාවන් වේ).

සැබෑබෙදා හැරීම සහ ජනනය කරන ලද බෙදා හැරීම අතර ඕනෑම ඒකාබද්ධ බෙදාහැරීමක් සඳහා අවම වශයෙන් පෘථිවි මන්තීට දුරට සමාන වේ.

කඩදාසිපෙන්වන්නේ ජෙන්සන්-ෂැනන් (ජේඑස්) අපසරනය සහ සම්භාවිතා බෙදාහැරීම් දෙකක් අතර වෙනස සඳහා වෙනත් ක්රියාමාර්ග සුමට නොවන බවයි. එබැවින් අපි එක් සම්භාවිතා බෙදාහැරීමක් මත (පරාමිතිකරණය කර ඇති) ශ්රේණියේ සම්භවයක් කරන්නේ නම් එය අභිසාරී නොවනු ඇත.

කන්තෝරෝවිච්-රුබින්ස්ටයින්ද්විත්ව භාවය මත පදනම්ව,

සියලු 1-Lipschitz කාර්යයන් කොහෙද.

එනම්, එය සියලු 1-Lipschitz කාර්යයන් අතර විශාලතම වෙනසට සමාන වේ.

-ලිප්ස්චිට්ස් කාර්යයන් සඳහා,

සියලු -Lipschitz ශ්රිතයන් පරාමිතිකරණය කර ඇති ස්ථානය ලෙස නිරූපණය කළ හැකි නම් ,

උත්පාදකයන්ත්රයකින් නිරූපණය වන්නේ නම් සහ දන්නා බෙදාහැරීමක් නම් ,

දැන් සමඟ අභිසාරී වීමට ඉහත සූත්රය අවම කිරීම සඳහා අපට අනුක්රමික සම්භවය ලබා ගත හැකිය.

ඒහා සමානව අපට නැගී සිටීමෙන් සොයාගත හැකිය. මායිම් තබා ගත හැකි එක් ක්රමයක් නම් පරාසයක් තුළ ක්ලිප් කර ඇති නිර්වචනය කරන ස්නායුක ජාලයේ සියලුම බර ක්ලිප් කිරීමයි.

සරල MNIST පරම්පරාවේ අත්හදා බැලීමකදී මෙය අත්හදා බැලීමටකේතය මෙන්න.

Open In Colab

87import torch.utils.data
88from torch.nn import functional as F
89
90from labml_helpers.module import Module

වෙනස්කම්කරන්නාගේ පාඩුව

අපිඋපරිම කිරීමට සොයා ගැනීමට අවශ්ය , ඒ නිසා අපි අවම,

93class DiscriminatorLoss(Module):
  • f_real වේ
  • f_fake වේ
  • මෙයපාඩු සහිත ටියුපල් නැවත ලබා දෙන අතර ඒවා පසුව එකතු කරනු ලැබේ. ඒවා ලොග් වීම සඳහා වෙනම තබා ඇත.

    104    def forward(self, f_real: torch.Tensor, f_fake: torch.Tensor):

    පරාසය තබා ගැනීම සඳහා අලාභය ක්ලිප් කිරීමට අපි RelUs භාවිතා කරමු.

    115        return F.relu(1 - f_real).mean(), F.relu(1 + f_fake).mean()

    උත්පාදකනැතිවීම

    අවමකිරීම සඳහා අපට සොයා ගැනීමට අවශ්යයි පළමු සංරචකය ස්වාධීන වේ , එබැවින් අපි අවම කරමු,

    118class GeneratorLoss(Module):
    • f_fake වේ
    130    def forward(self, f_fake: torch.Tensor):
    134        return -f_fake.mean()