ගැඹුරුQ ජාල (DQN)

මෙය PyTorch ක්රියාත්මක කිරීමකි Atari කඩදාසි සෙල්ලම් කිරීම ගැඹුරු ශක්තිමත් කිරීමේ ඉගෙනීම සහ ඩුවලිං ජාලය , ප්රමුඛතා නැවත ධාවනය සහ ද්විත්ව Q ජාලය සමඟ.

මෙන්න අත්හදා බැලීම සහ ආදර්ශ ක්රියාත්මක කිරීම.

Open In Colab View Run

25from typing import Tuple
26
27import torch
28from torch import nn
29
30from labml import tracker
31from labml_helpers.module import Module
32from labml_nn.rl.dqn.replay_buffer import ReplayBuffer

ආකෘතිය පුහුණු කරන්න

ප්රශස්ත ක්රියාකාරී අගය ශ්රිතය සොයා ගැනීමට අපට අවශ්යය.

ඉලක්ක ජාලය 🎯

ස්ථාවරත්වය වැඩි දියුණු කිරීම සඳහා අපි පෙර අත්දැකීම් වලින් අහඹු ලෙස නියැදිය අත්දැකීම් නැවත ධාවනය භාවිතා කරමු. ඉලක්කය ගණනය කිරීම සඳහා වෙනම පරාමිතීන් සමූහයක් සහිත Q ජාලයක් ද අපි භාවිතා කරමු. වරින් වර යාවත්කාලීන වේ. මෙය ගැඹුරු ශක්තිමත් කිරීමේ ඉගෙනීම තුළින් කඩදාසි මානව මට්ටම් පාලනයට අනුව ය.

එබැවින් පාඩු ශ්රිතය වන්නේ,

ද්විත්ව-ඉගෙනුම්

ඉහත ගණනය කිරීමේ උපරිම ක්රියාකරු හොඳම ක්රියාව තෝරා ගැනීම සහ වටිනාකම ඇගයීම සඳහා එකම ජාලයක් භාවිතා කරයි. එනම්, අපි ද්විත්ව Q- ඉගෙනීම භාවිතා කරමු, එය ලබා ගන්නේ කොතැනින්ද සහ වටිනාකම ලබා ගනී.

පාඩු ශ්රිතය බවට පත්වේ,

35class QFuncLoss(Module):
103    def __init__(self, gamma: float):
104        super().__init__()
105        self.gamma = gamma
106        self.huber_loss = nn.SmoothL1Loss(reduction='none')
  • q -
  • action -
  • double_q -
  • target_q -
  • done - පියවර ගැනීමෙන් පසු ක්රීඩාව අවසන් වූවාද යන්න
  • reward -
  • weights - ප්රමුඛතාවය පළපුරුදු නැවත ධාවනය සිට සාම්පල බර
108    def forward(self, q: torch.Tensor, action: torch.Tensor, double_q: torch.Tensor,
109                target_q: torch.Tensor, done: torch.Tensor, reward: torch.Tensor,
110                weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

122        q_sampled_action = q.gather(-1, action.to(torch.long).unsqueeze(-1)).squeeze(-1)
123        tracker.add('q_sampled_action', q_sampled_action)

අනුක්රමිකඅනුක්රමික ප්රචාරය නොකළ යුතුය

131        with torch.no_grad():

රාජ්යයෙන්හොඳම ක්රියාව ලබා ගන්න

135            best_next_action = torch.argmax(double_q, -1)

රාජ්යහොඳම ක්රියාව සඳහා ඉලක්ක ජාලයෙන් q අගය ලබා ගන්න

141            best_next_q_value = target_q.gather(-1, best_next_action.unsqueeze(-1)).squeeze(-1)

අපේක්ෂිතQ අගය ගණනය කරන්න. ක්රීඩාව අවසන් වූයේ නම් ඊළඟ රාජ්ය Q අගයන් ශුන්ය (1 - done) කිරීමට අපි ගුණ කරමු.

152            q_update = reward + self.gamma * best_next_q_value * (1 - done)
153            tracker.add('q_update', q_update)

නැවතධාවනය කිරීමේ බෆරයේ සාම්පල කිරා මැන බැලීමට තාවකාලික වෙනස දෝෂය භාවිතා කරයි

156            td_error = q_sampled_action - q_update
157            tracker.add('td_error', td_error)

එයoutliers අඩු සංවේදී නිසා අපි ඒ වෙනුවට මධ්යන්ය වර්ග දෝෂයක් අහිමි Huber අහිමි ගත

161        losses = self.huber_loss(q_sampled_action, q_update)

බරතැබූ ක්රම ලබා ගන්න

163        loss = torch.mean(weights * losses)
164        tracker.add('loss', loss)
165
166        return td_error, loss