අටාරිබ්රේක්අවුට් සමඟ ඩීQN අත්හදා බැලීම

මෙමඅත්හදා බැලීම OpenAI ජිම් හි අටාරි බ්රේක් අවුට් ක්රීඩාව සඳහා ගැඹුරු Q ජාලයක් (DQN) පුහුණු කරයි. කාර්යක්ෂමව සාම්පල ලබා ගැනීම සඳහා එය බහු ක්රියාවලීන්හි ක්රීඩා පරිසරයන් ක්රියාත්මක කරයි.

Open In Colab View Run

16import numpy as np
17import torch
18
19from labml import tracker, experiment, logger, monit
20from labml.internal.configs.dynamic_hyperparam import FloatDynamicHyperParam
21from labml_helpers.schedule import Piecewise
22from labml_nn.rl.dqn import QFuncLoss
23from labml_nn.rl.dqn.model import Model
24from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
25from labml_nn.rl.game import Worker

උපාංගයතෝරන්න

28if torch.cuda.is_available():
29    device = torch.device("cuda:0")
30else:
31    device = torch.device("cpu")

සිට [0, 255] පරිමාණ නිරීක්ෂණ [0, 1]

34def obs_to_torch(obs: np.ndarray) -> torch.Tensor:
36    return torch.tensor(obs, dtype=torch.float32, device=device) / 255.

පුහුණුකරු

39class Trainer:
44    def __init__(self, *,
45                 updates: int, epochs: int,
46                 n_workers: int, worker_steps: int, mini_batch_size: int,
47                 update_target_model: int,
48                 learning_rate: FloatDynamicHyperParam,
49                 ):

කම්කරුවන්සංඛ්යාව

51        self.n_workers = n_workers

එක්එක් යාවත්කාලීනයේ නියැදි පියවර

53        self.worker_steps = worker_steps

පුහුණුපුනරාවර්තන ගණන

55        self.train_epochs = epochs

යාවත්කාලීනගණන

58        self.updates = updates

පුහුණුවසඳහා කුඩා කණ්ඩායමේ ප්රමාණය

60        self.mini_batch_size = mini_batch_size

සෑම250 යාවත්කාලීන ඉලක්ක ජාලය යාවත්කාලීන

63        self.update_target_model = update_target_model

ඉගෙනුම්අනුපාතය

66        self.learning_rate = learning_rate

යාවත්කාලීනකිරීමේ කාර්යයක් ලෙස ගවේෂණය කිරීම

69        self.exploration_coefficient = Piecewise(
70            [
71                (0, 1.0),
72                (25_000, 0.1),
73                (self.updates / 2, 0.01)
74            ], outside_value=0.01)

යාවත්කාලීන කිරීමේ කාර්යයක් ලෙස බෆරය නැවත ධාවනය කිරීම සඳහා

77        self.prioritized_replay_beta = Piecewise(
78            [
79                (0, 0.4),
80                (self.updates, 1)
81            ], outside_value=1)

සමඟබෆරය නැවත ධාවනය කරන්න . නැවත ධාවනය කිරීමේ බෆරයේ ධාරිතාව 2 බලයක් විය යුතුය.

84        self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)

නියැදීම්හා පුහුණු කිරීම සඳහා ආකෘතිය

87        self.model = Model().to(device)

ලබාගැනීමට ඉලක්ක ආකෘතිය

89        self.target_model = Model().to(device)

කම්කරුවන්නිර්මාණය

92        self.workers = [Worker(47 + i) for i in range(self.n_workers)]

නිරීක්ෂණසඳහා ආතතීන් ආරම්භ කරන්න

95        self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)

කම්කරුවන්යළි පිහිටුවන්න

98        for worker in self.workers:
99            worker.child.send(("reset", None))

මූලිකනිරීක්ෂණ ලබා ගන්න

102        for i, worker in enumerate(self.workers):
103            self.obs[i] = worker.child.recv()

පාඩුශ්රිතය

106        self.loss_func = QFuncLoss(0.99)

ප්‍රශස්තකරණය

108        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4)

කෑදර නියැදීම

නියැදීම්ක්රියා කිරීමේදී අපි කෑදර උපාය මාර්ගයක් භාවිතා කරමු, එහිදී අපි සම්භාවිතාව සමඟ කෑදර ක්රියාමාර්ගයක් ගන්නා අතර අහඹු ක්රියාමාර්ගයක් ගනිමු . අපි හඳුන්වන්නේ exploration_coefficient .

110    def _sample_action(self, q_value: torch.Tensor, exploration_coefficient: float):

නියැදීම්වලට අනුක්රමික අවශ්ය නොවේ

120        with torch.no_grad():

ඉහළමQ- අගය සමඟ ක්රියාව සාම්පල කරන්න. මෙය කෑදර ක්රියාවකි.

122            greedy_action = torch.argmax(q_value, dim=-1)

ඒකාකාරවනියැදිය සහ ක්රියාව

124            random_action = torch.randint(q_value.shape[-1], greedy_action.shape, device=q_value.device)

කෑදරක්රියාව හෝ අහඹු ක්රියාව තෝරා ගත යුතුද යන්න

126            is_choose_rand = torch.rand(greedy_action.shape, device=q_value.device) < exploration_coefficient

මතපදනම්ව ක්රියාව තෝරන්න is_choose_rand

128            return torch.where(is_choose_rand, random_action, greedy_action).cpu().numpy()

නියැදිදත්ත

130    def sample(self, exploration_coefficient: float):

මේසඳහා අනුක්රමික අවශ්ය නොවේ

134        with torch.no_grad():

නියැදිය worker_steps

136            for t in range(self.worker_steps):

වත්මන්නිරීක්ෂණ සඳහා Q_අගයන් ලබා ගන්න

138                q_value = self.model(obs_to_torch(self.obs))

නියැදික්රියා

140                actions = self._sample_action(q_value, exploration_coefficient)

එක්එක් සේවකයා මත නියැදි ක්රියා ක්රියාත්මක

143                for w, worker in enumerate(self.workers):
144                    worker.child.send(("step", actions[w]))

එක්එක් සේවකයාගෙන් තොරතුරු රැස් කරන්න

147                for w, worker in enumerate(self.workers):

ක්රියාවන්ක්රියාත්මක කිරීමෙන් පසු ප්රති results ල ලබා ගන්න

149                    next_obs, reward, done, info = worker.child.recv()

බෆරයනැවත ධාවනය කිරීම සඳහා සංක්රාන්තිය එක් කරන්න

152                    self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done)

කථාංගතොරතුරු යාවත්කාලීන කරන්න. කථාංග තොරතුරු එකතු කරන්න, කථාංගයක් අවසන් වුවහොත් ලබා ගත හැකිය; මෙයට කථාංගයේ සම්පූර්ණ විපාකය සහ දිග ඇතුළත් වේ - එය ක්රියාත්මක වන ආකාරය Game බැලීමට බලන්න.

158                    if info:
159                        tracker.add('reward', info['reward'])
160                        tracker.add('length', info['length'])

වත්මන්නිරීක්ෂණ යාවත්කාලීන කරන්න

163                    self.obs[w] = next_obs

ආකෘතියපුහුණු කරන්න

165    def train(self, beta: float):
169        for _ in range(self.train_epochs):

ප්රමුඛතානැවත ධාවනය කිරීමේ බෆරයෙන් නියැදිය

171            samples = self.replay_buffer.sample(self.mini_batch_size, beta)

පුරෝකථනයකරන ලද Q-අගය ලබා ගන්න

173            q_value = self.model(obs_to_torch(samples['obs']))

ද්විත්ව Q- ඉගෙනීම සඳහා ඊළඟ තත්වයේ Q-අගයන්ලබා ගන්න. මේවා සඳහා අනුක්රමික ප්රචාරණය නොකළ යුතුය

177            with torch.no_grad():

ලබාගන්න

179                double_q_value = self.model(obs_to_torch(samples['next_obs']))

ලබාගන්න

181                target_q_value = self.target_model(obs_to_torch(samples['next_obs']))

තාවකාලිකවෙනස ගණනය කරන්න (TD) දෝෂ , සහ අලාභය, .

184            td_errors, loss = self.loss_func(q_value,
185                                             q_value.new_tensor(samples['action']),
186                                             double_q_value, target_q_value,
187                                             q_value.new_tensor(samples['done']),
188                                             q_value.new_tensor(samples['reward']),
189                                             q_value.new_tensor(samples['weights']))

නැවතධාවනය කිරීමේ බෆරය සඳහා ප්රමුඛතා ගණනය කරන්න

192            new_priorities = np.abs(td_errors.cpu().numpy()) + 1e-6

නැවතධාවනය කිරීමේ ස්වාරක්ෂක ප්රමුඛතා යාවත්කාලීන කරන්න

194            self.replay_buffer.update_priorities(samples['indexes'], new_priorities)

ඉගෙනුම්අනුපාතය සකසන්න

197            for pg in self.optimizer.param_groups:
198                pg['lr'] = self.learning_rate()

කලින්ගණනය කරන ලද අනුක්රමික ශුන්ය කිරීම

200            self.optimizer.zero_grad()

අනුක්රමිකගණනය කරන්න

202            loss.backward()

ක්ලිප්අනුක්රමික

204            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)

අනුක්රමිකමත පදනම්ව පරාමිතීන් යාවත්කාලීන කරන්න

206            self.optimizer.step()

පුහුණුලූපය ධාවනය කරන්න

208    def run_training_loop(self):

අවසන්100 කථාංග තොරතුරු

214        tracker.set_queue('reward', 100, True)
215        tracker.set_queue('length', 100, True)

මුලින්ඉලක්කගත ජාලයට පිටපත් කරන්න

218        self.target_model.load_state_dict(self.model.state_dict())
219
220        for update in monit.loop(self.updates):

, ගවේෂණ භාගය

222            exploration = self.exploration_coefficient(update)
223            tracker.add('exploration', exploration)

ප්රමුඛතා නැවත ධාවනය සඳහා

225            beta = self.prioritized_replay_beta(update)
226            tracker.add('beta', beta)

වත්මන්ප්රතිපත්තිය සමඟ නියැදිය

229            self.sample(exploration)

බෆරයපිරී ගිය පසු පුහුණුව ආරම්භ කරන්න

232            if self.replay_buffer.is_full():

ආකෘතියපුහුණු කරන්න

234                self.train(beta)

ඉලක්කජාලය වරින් වර යාවත්කාලීන කරන්න

237                if update % self.update_target_model == 0:
238                    self.target_model.load_state_dict(self.model.state_dict())

ලුහුබැඳඇති දර්ශක සුරකින්න.

241            tracker.save()

වරින්වර තිරයට නව රේඛාවක් එක් කරන්න

243            if (update + 1) % 1_000 == 0:
244                logger.log()

විනාශකරන්න

කම්කරුවන්නවත්වන්න

246    def destroy(self):
251        for worker in self.workers:
252            worker.child.send(("close", None))
255def main():

අත්හදාබැලීම සාදන්න

257    experiment.create(name='dqn')

වින්යාසකිරීම්

260    configs = {

යාවත්කාලීනගණන

262        'updates': 1_000_000,

නියැදිදත්ත සමඟ ආකෘතිය පුහුණු කිරීම සඳහා එපොච් ගණන.

264        'epochs': 8,

සේවකක්රියාවලි ගණන

266        'n_workers': 8,

තනියාවත්කාලීන කිරීම සඳහා එක් එක් ක්රියාවලිය මත ක්රියාත්මක කිරීමට පියවර ගණන

268        'worker_steps': 4,

කුඩාකණ්ඩායම් ප්රමාණය

270        'mini_batch_size': 32,

ඉලක්කගතආකෘති යාවත්කාලීන කිරීමේ පරතරය

272        'update_target_model': 250,

ඉගෙනුම්අනුපාතය.

274        'learning_rate': FloatDynamicHyperParam(1e-4, (0, 1e-3)),
275    }

වින්යාසකිරීම්

278    experiment.configs(configs)

පුහුණුකරුආරම්භ කරන්න

281    m = Trainer(**configs)

අත්හදාබැලීම ධාවනය කර අධීක්ෂණය කරන්න

283    with experiment.start():
284        m.run_training_loop()

කම්කරුවන්නවත්වන්න

286    m.destroy()

එයක්රියාත්මක කරන්න

290if __name__ == "__main__":
291    main()