මෙය PyTorch ක්රියාත්මක කිරීමකි Atari කඩදාසි සෙල්ලම් කිරීම ගැඹුරු ශක්තිමත් කිරීමේ ඉගෙනීම සහ ඩුවලිං ජාලය , ප්රමුඛතා නැවත ධාවනය සහ ද්විත්ව Q ජාලය සමඟ.
25from typing import Tuple
26
27import torch
28from torch import nn
29
30from labml import tracker
31from labml_helpers.module import Module
32from labml_nn.rl.dqn.replay_buffer import ReplayBufferප්රශස්ත ක්රියාකාරී අගය ශ්රිතය සොයා ගැනීමට අපට අවශ්යය.
ස්ථාවරත්වය වැඩි දියුණු කිරීම සඳහා අපි පෙර අත්දැකීම් වලින් අහඹු ලෙස නියැදිය අත්දැකීම් නැවත ධාවනය භාවිතා කරමු. ඉලක්කය ගණනය කිරීම සඳහා වෙනම පරාමිතීන් සමූහයක් සහිත Q ජාලයක් ද අපි භාවිතා කරමු. වරින් වර යාවත්කාලීන වේ. මෙය ගැඹුරු ශක්තිමත් කිරීමේ ඉගෙනීම තුළින් කඩදාසි මානව මට්ටම් පාලනයට අනුව ය.
එබැවින් පාඩු ශ්රිතය වන්නේ,
ඉහත ගණනය කිරීමේ උපරිම ක්රියාකරු හොඳම ක්රියාව තෝරා ගැනීම සහ වටිනාකම ඇගයීම සඳහා එකම ජාලයක් භාවිතා කරයි. එනම්, අපි ද්විත්ව Q- ඉගෙනීම භාවිතා කරමු, එය ලබා ගන්නේ කොතැනින්ද සහ වටිනාකම ලබා ගනී.
පාඩු ශ්රිතය බවට පත්වේ,
35class QFuncLoss(Module):103 def __init__(self, gamma: float):
104 super().__init__()
105 self.gamma = gamma
106 self.huber_loss = nn.SmoothL1Loss(reduction='none')q
- action
- double_q
- target_q
- done
- පියවර ගැනීමෙන් පසු ක්රීඩාව අවසන් වූවාද යන්න reward
- weights
- ප්රමුඛතාවය පළපුරුදු නැවත ධාවනය සිට සාම්පල බර108 def forward(self, q: torch.Tensor, action: torch.Tensor, double_q: torch.Tensor,
109 target_q: torch.Tensor, done: torch.Tensor, reward: torch.Tensor,
110 weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:122 q_sampled_action = q.gather(-1, action.to(torch.long).unsqueeze(-1)).squeeze(-1)
123 tracker.add('q_sampled_action', q_sampled_action)අනුක්රමිකඅනුක්රමික ප්රචාරය නොකළ යුතුය
131 with torch.no_grad():රාජ්යයෙන්හොඳම ක්රියාව ලබා ගන්න
135 best_next_action = torch.argmax(double_q, -1)රාජ්යහොඳම ක්රියාව සඳහා ඉලක්ක ජාලයෙන් q අගය ලබා ගන්න
141 best_next_q_value = target_q.gather(-1, best_next_action.unsqueeze(-1)).squeeze(-1)අපේක්ෂිතQ අගය ගණනය කරන්න. ක්රීඩාව අවසන් වූයේ නම් ඊළඟ රාජ්ය Q අගයන් ශුන්ය (1 - done)
කිරීමට අපි ගුණ කරමු.
152 q_update = reward + self.gamma * best_next_q_value * (1 - done)
153 tracker.add('q_update', q_update)නැවතධාවනය කිරීමේ බෆරයේ සාම්පල කිරා මැන බැලීමට තාවකාලික වෙනස දෝෂය භාවිතා කරයි
156 td_error = q_sampled_action - q_update
157 tracker.add('td_error', td_error)161 losses = self.huber_loss(q_sampled_action, q_update)බරතැබූ ක්රම ලබා ගන්න
163 loss = torch.mean(weights * losses)
164 tracker.add('loss', loss)
165
166 return td_error, loss