ස්නායුකජාලයක දැනුම ආසවනය කිරීම

මෙය ස්නායුක ජාලයක දැනුම ආසවනය කරන කඩදාසි PyTorch ක්රියාත්මක කිරීම/නිබන්ධනයකි.

එයපුහුණු විශාල ජාලයක දැනුම භාවිතා කරමින් කුඩා ජාලයක් පුහුණු කිරීමේ ක්රමයකි; එනම් විශාල ජාලයෙන් දැනුම ආසවනය කිරීම.

විධිමත්කිරීමක් සහිත විශාල ආකෘතියක් හෝ ආකෘති සමූහයක් (ඩ්රොප් අවුට් භාවිතා කරමින්) දත්ත සහ ලේබල් මත කෙලින්ම පුහුණු කරන විට කුඩා ආකෘතියකට වඩා හොඳ සාමාන්යකරණය කරයි. කෙසේ වෙතත්, විශාල ආකෘතියක් ආධාරයෙන් වඩා හොඳ සාමාන්යකරණය කිරීම සඳහා කුඩා ආකෘතියක් පුහුණු කළ හැකිය. කුඩා ආකෘති නිෂ්පාදනය වඩා හොඳ වේ: වේගවත්, අඩු ගණනය, අඩු මතකය.

පුහුණුආකෘතියක නිමැවුම් සම්භාවිතාවන් ලේබල් වලට වඩා වැඩි තොරතුරු ලබා දෙන්නේ එය ශුන්ය නොවන සම්භාවිතාවන් වැරදි පංතිවලට පවරන බැවිනි. මෙම සම්භාවිතාවන් අපට පවසන්නේ නියැදියකට ඇතැම් පංතිවලට අයත් වීමේ අවස්ථාවක් ඇති බවයි. නිදසුනක් වශයෙන්, ඉලක්කම් වර්ගීකරණය කිරීමේදී, ඉලක්කම් 7 හි රූපයක් ලබා දුන් විට, සාමාන්යකරණය කරන ලද ආකෘතියක් 7ට වැඩි සම්භාවිතාවක් සහ කුඩා නමුත් ශුන්ය නොවන සම්භාවිතාව 2 දක්වා ලබා දෙනු ඇත, අනෙක් ඉලක්කම් වලට පාහේ ශුන්ය සම්භාවිතාව පවරයි. ආසවනය කුඩා ආකෘතියක් වඩා හොඳින් පුහුණු කිරීම සඳහා මෙම තොරතුරු භාවිතා කරයි.

මෘදුඉලක්ක

සම්භාවිතාවන්සාමාන්යයෙන් ගණනය කරනු ලබන්නේ සොෆ්ට්මැක්ස් මෙහෙයුමකින්,

පන්තියසඳහා සම්භාවිතාව කොතැනද සහ පිවිසුම වේ.

හරස්එන්ට්රොපිය හෝ කේඑල් අපසරනය අවම කිරීම සඳහා අපි කුඩා ආකෘතිය පුහුණු කරමු එහි නිමැවුම් සම්භාවිතා ව්යාප්තිය සහ විශාල ජාලයේ නිමැවුම් සම්භාවිතා ව්යාප්තිය (මෘදු ඉලක්ක) අතර.

මෙහිඇති එක් ගැටළුවක් නම් විශාල ජාලය විසින් වැරදි පන්ති සඳහා පවරා ඇති සම්භාවිතාවන් බොහෝ විට ඉතා කුඩා වන අතර අලාභයට දායක නොවීමයි. ඒ නිසා ඔවුන් උෂ්ණත්වය යෙදීමෙන් සම්භාවිතාව මෘදු කරයි ,

සඳහාඉහළ අගයන් මෘදු සම්භාවිතාව ඇති කරනු ඇත.

පුහුණු

කුඩාආකෘතිය පුහුණු කිරීමේදී සැබෑ ලේබල් පුරෝකථනය කිරීම සඳහා දෙවන පාඩු පදය එක් කිරීමට කඩදාසි යෝජනා කරයි. පාඩු පද දෙකේ බර තැබූ එකතුව ලෙස අපි සංයුක්ත අලාභය ගණනය කරමු: මෘදු ඉලක්ක සහ සත්ය ලේබල්.

ආසවනයසඳහා දත්ත සමුදාය මාරු කට්ටලයලෙස හැඳින්වේ, සහ කඩදාසි එකම පුහුණු දත්ත භාවිතා යෝජනා කරයි.

අපගේඅත්හදා බැලීම

අපිCIFA-10 දත්ත කට්ටලය මත පුහුණු කරමු. ඩ්රොප් අවුට් සමඟ පරාමිතීන් ඇති විශාල ආකෘතියක් අපි පුහුණු කරන අතර එය වලංගු කිරීමේ කට්ටලය මත 85% ක නිරවද්යතාවයක් ලබා දෙයි. පරාමිතීන් සහිත කුඩා ආකෘතියක් 80% ක නිරවද්යතාවයක් ලබා දෙයි.

ඉන්පසුඅපි කුඩා ආකෘතිය විශාල ආකෘතියෙන් ආසවනය සමඟ පුහුණු කරන අතර එය 82% ක නිරවද්යතාවයක් ලබා දෙයි; නිරවද්යතාවයේ 2% වැඩි වීමක්.

View Run

74import torch
75import torch.nn.functional
76from torch import nn
77
78from labml import experiment, tracker
79from labml.configs import option
80from labml_helpers.train_valid import BatchIndex
81from labml_nn.distillation.large import LargeModel
82from labml_nn.distillation.small import SmallModel
83from labml_nn.experiments.cifar10 import CIFAR10Configs

වින්යාසකිරීම්

මෙයසියලු දත්ත සමුදාය ආශ්රිත වින්යාසයන්, ප්රශස්තකරණය සහ පුහුණු ලූපයක් නිර්වචනය කරන සිට CIFAR10Configs විහිදේ.

86class Configs(CIFAR10Configs):

කුඩාආකෘතිය

94    model: SmallModel

විශාලආකෘතිය

96    large: LargeModel

මෘදුඉලක්ක සඳහා KL අපසරනය අලාභය

98    kl_div_loss = nn.KLDivLoss(log_target=True)

සැබෑලේබල් නැතිවීම සඳහා හරස් එන්ට්රොපි අලාභය

100    loss_func = nn.CrossEntropyLoss()

උෂ්ණත්වය,

102    temperature: float = 5.

මෘදුඉලක්ක අහිමි වීම සඳහා බර.

මෘදුඉලක්ක මගින් නිපදවන ලද අනුක්රමික පරිමාණයට ලක් වේ. මේ සඳහා වන්දි ගෙවීම සඳහා කඩදාසි යෝජනා කරන්නේ මෘදු ඉලක්ක අලාභය සාධකයක් මගින් පරිමාණය කිරීමයි

108    soft_targets_weight: float = 100.

සැබෑලේබලය හරස් එන්ට්රොපිය අඞු කිරීමට සඳහා සිරුරේ බර

110    label_loss_weight: float = 0.5

පුහුණුව/වලංගුකිරීමේ පියවර

ආසවනයඇතුළත් කිරීම සඳහා අභිරුචි පුහුණුව/වලංගු කිරීමේ පියවරක් අපි අර්ථ දක්වන්නෙමු

112    def step(self, batch: any, batch_idx: BatchIndex):

කුඩාආකෘතිය සඳහා පුහුණුව/ඇගයීම් මාදිලිය

120        self.model.train(self.mode.is_train)

ඇගයීම්මාදිලියේ විශාල ආකෘතිය

122        self.large.eval()

උපාංගයවෙත දත්ත ගෙනයන්න

125        data, target = batch[0].to(self.device), batch[1].to(self.device)

පුහුණුප්රකාරයේදී ගෝලීය පියවර (සැකසූ සාම්පල ගණන) යාවත්කාලීන කරන්න

128        if self.mode.is_train:
129            tracker.add_global_step(len(data))

නිමැවුම්පිවිසුම් ලබා ගන්න, , විශාල ආකෘතියෙන්

132        with torch.no_grad():
133            large_logits = self.large(data)

ප්රතිදානපිවිසුම් ලබා ගන්න, , කුඩා ආකෘතියෙන්

136        output = self.model(data)

මෘදුඉලක්ක

140        soft_targets = nn.functional.log_softmax(large_logits / self.temperature, dim=-1)

කුඩාආකෘතියේ උෂ්ණත්වය සකස් කළ සම්භාවිතාව

143        soft_prob = nn.functional.log_softmax(output / self.temperature, dim=-1)

මෘදුඉලක්ක අලාභය ගණනය කරන්න

146        soft_targets_loss = self.kl_div_loss(soft_prob, soft_targets)

සැබෑලේබල් අලාභය ගණනය කරන්න

148        label_loss = self.loss_func(output, target)

පාඩුදෙකේ බර තැබූ එකතුව

150        loss = self.soft_targets_weight * soft_targets_loss + self.label_loss_weight * label_loss

පාඩුලොග් කරන්න

152        tracker.add({"loss.kl_div.": soft_targets_loss,
153                     "loss.nll": label_loss,
154                     "loss.": loss})

ගණනයකිරීම සහ ලොග් කිරීමේ නිරවද්යතාවය

157        self.accuracy(output, target)
158        self.accuracy.track()

ආකෘතියපුහුණු කරන්න

161        if self.mode.is_train:

අනුක්රමිකගණනය කරන්න

163            loss.backward()

ප්රශස්තිකරණපියවර ගන්න

165            self.optimizer.step()

සෑමයුගලයකම අවසාන කණ්ඩායමේ ආදර්ශ පරාමිතීන් සහ අනුක්රමික ලොග් කරන්න

167            if batch_idx.is_last:
168                tracker.add('model', self.model)

අනුක්රමිකඉවත්

170            self.optimizer.zero_grad()

ලුහුබැඳඇති ප්රමිතික සුරකින්න

173        tracker.save()

විශාලආකෘතියක් සාදන්න

176@option(Configs.large)
177def _large_model(c: Configs):
181    return LargeModel().to(c.device)

කුඩාආකෘතියක් සාදන්න

184@option(Configs.model)
185def _small_student_model(c: Configs):
189    return SmallModel().to(c.device)
192def get_saved_model(run_uuid: str, checkpoint: int):
197    from labml_nn.distillation.large import Configs as LargeConfigs

ඇගයීම්ප්රකාරයේදී (පටිගත කිරීමක් නැත)

200    experiment.evaluate()

විශාලආකෘති පුහුණු අත්හදා බැලීමේ වින්යාස ආරම්භ කරන්න

202    conf = LargeConfigs()

සුරකිනලද වින්යාස පැටවීම

204    experiment.configs(conf, experiment.load_configs(run_uuid))

ඉතිරිකිරීම/පැටවීම සඳහා ආකෘති සකසන්න

206    experiment.add_pytorch_models({'model': conf.model})

පූරණයකිරීමට කුමන ධාවනය සහ මුරපොලක් සකසන්න

208    experiment.load(run_uuid, checkpoint)

අත්හදාබැලීම ආරම්භ කරන්න - මෙය ආකෘතිය පටවනු ඇත, සියල්ල සූදානම් කරන්න

210    experiment.start()

ආකෘතියආපසු දෙන්න

213    return conf.model

ආසවනයසමඟ කුඩා ආකෘතියක් පුහුණු කරන්න

216def main(run_uuid: str, checkpoint: int):

සුරකිනලද ආකෘතිය පැටවීම

221    large_model = get_saved_model(run_uuid, checkpoint)

අත්හදාබැලීම සාදන්න

223    experiment.create(name='distillation', comment='cifar10')

වින්යාසයන්සාදන්න

225    conf = Configs()

පටවනලද විශාල ආකෘතිය සකසන්න

227    conf.large = large_model

වින්යාසයන්පූරණය කරන්න

229    experiment.configs(conf, {
230        'optimizer.optimizer': 'Adam',
231        'optimizer.learning_rate': 2.5e-4,
232        'model': '_small_student_model',
233    })

ඉතිරිකිරීම/පැටවීම සඳහා ආකෘතිය සකසන්න

235    experiment.add_pytorch_models({'model': conf.model})

මුලසිටම අත්හදා බැලීම ආරම්භ කරන්න

237    experiment.load(None, None)

අත්හදාබැලීම ආරම්භ කර පුහුණු ලූපය ක්රියාත්මක කරන්න

239    with experiment.start():
240        conf.run()

244if __name__ == '__main__':
245    main('d46cd53edaec11eb93c38d6538aee7d6', 1_000_000)