මෙමඅත්හදා බැලීම OpenAI ජිම් හි අටාරි බ්රේක් අවුට් ක්රීඩාව සඳහා ගැඹුරු Q ජාලයක් (DQN) පුහුණු කරයි. කාර්යක්ෂමව සාම්පල ලබා ගැනීම සඳහා එය බහු ක්රියාවලීන්හි ක්රීඩා පරිසරයන් ක්රියාත්මක කරයි.
16import numpy as np
17import torch
18
19from labml import tracker, experiment, logger, monit
20from labml.internal.configs.dynamic_hyperparam import FloatDynamicHyperParam
21from labml_helpers.schedule import Piecewise
22from labml_nn.rl.dqn import QFuncLoss
23from labml_nn.rl.dqn.model import Model
24from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
25from labml_nn.rl.game import Workerඋපාංගයතෝරන්න
28if torch.cuda.is_available():
29 device = torch.device("cuda:0")
30else:
31 device = torch.device("cpu")සිට [0, 255]
පරිමාණ නිරීක්ෂණ [0, 1]
34def obs_to_torch(obs: np.ndarray) -> torch.Tensor:36 return torch.tensor(obs, dtype=torch.float32, device=device) / 255.39class Trainer:44 def __init__(self, *,
45 updates: int, epochs: int,
46 n_workers: int, worker_steps: int, mini_batch_size: int,
47 update_target_model: int,
48 learning_rate: FloatDynamicHyperParam,
49 ):කම්කරුවන්සංඛ්යාව
51 self.n_workers = n_workersඑක්එක් යාවත්කාලීනයේ නියැදි පියවර
53 self.worker_steps = worker_stepsපුහුණුපුනරාවර්තන ගණන
55 self.train_epochs = epochsයාවත්කාලීනගණන
58 self.updates = updatesපුහුණුවසඳහා කුඩා කණ්ඩායමේ ප්රමාණය
60 self.mini_batch_size = mini_batch_sizeසෑම250 යාවත්කාලීන ඉලක්ක ජාලය යාවත්කාලීන
63 self.update_target_model = update_target_modelඉගෙනුම්අනුපාතය
66 self.learning_rate = learning_rateයාවත්කාලීනකිරීමේ කාර්යයක් ලෙස ගවේෂණය කිරීම
69 self.exploration_coefficient = Piecewise(
70 [
71 (0, 1.0),
72 (25_000, 0.1),
73 (self.updates / 2, 0.01)
74 ], outside_value=0.01)යාවත්කාලීන කිරීමේ කාර්යයක් ලෙස බෆරය නැවත ධාවනය කිරීම සඳහා
77 self.prioritized_replay_beta = Piecewise(
78 [
79 (0, 0.4),
80 (self.updates, 1)
81 ], outside_value=1)සමඟබෆරය නැවත ධාවනය කරන්න . නැවත ධාවනය කිරීමේ බෆරයේ ධාරිතාව 2 බලයක් විය යුතුය.
84 self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)නියැදීම්හා පුහුණු කිරීම සඳහා ආකෘතිය
87 self.model = Model().to(device)ලබාගැනීමට ඉලක්ක ආකෘතිය
89 self.target_model = Model().to(device)කම්කරුවන්නිර්මාණය
92 self.workers = [Worker(47 + i) for i in range(self.n_workers)]නිරීක්ෂණසඳහා ආතතීන් ආරම්භ කරන්න
95 self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)කම්කරුවන්යළි පිහිටුවන්න
98 for worker in self.workers:
99 worker.child.send(("reset", None))මූලිකනිරීක්ෂණ ලබා ගන්න
102 for i, worker in enumerate(self.workers):
103 self.obs[i] = worker.child.recv()පාඩුශ්රිතය
106 self.loss_func = QFuncLoss(0.99)ප්රශස්තකරණය
108 self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4)නියැදීම්ක්රියා කිරීමේදී අපි කෑදර උපාය මාර්ගයක් භාවිතා කරමු, එහිදී අපි සම්භාවිතාව සමඟ කෑදර ක්රියාමාර්ගයක් ගන්නා අතර අහඹු ක්රියාමාර්ගයක් ගනිමු . අපි හඳුන්වන්නේ exploration_coefficient
.
110 def _sample_action(self, q_value: torch.Tensor, exploration_coefficient: float):නියැදීම්වලට අනුක්රමික අවශ්ය නොවේ
120 with torch.no_grad():ඉහළමQ- අගය සමඟ ක්රියාව සාම්පල කරන්න. මෙය කෑදර ක්රියාවකි.
122 greedy_action = torch.argmax(q_value, dim=-1)ඒකාකාරවනියැදිය සහ ක්රියාව
124 random_action = torch.randint(q_value.shape[-1], greedy_action.shape, device=q_value.device)කෑදරක්රියාව හෝ අහඹු ක්රියාව තෝරා ගත යුතුද යන්න
126 is_choose_rand = torch.rand(greedy_action.shape, device=q_value.device) < exploration_coefficientමතපදනම්ව ක්රියාව තෝරන්න is_choose_rand
128 return torch.where(is_choose_rand, random_action, greedy_action).cpu().numpy()130 def sample(self, exploration_coefficient: float):මේසඳහා අනුක්රමික අවශ්ය නොවේ
134 with torch.no_grad():නියැදිය worker_steps
136 for t in range(self.worker_steps):වත්මන්නිරීක්ෂණ සඳහා Q_අගයන් ලබා ගන්න
138 q_value = self.model(obs_to_torch(self.obs))නියැදික්රියා
140 actions = self._sample_action(q_value, exploration_coefficient)එක්එක් සේවකයා මත නියැදි ක්රියා ක්රියාත්මක
143 for w, worker in enumerate(self.workers):
144 worker.child.send(("step", actions[w]))එක්එක් සේවකයාගෙන් තොරතුරු රැස් කරන්න
147 for w, worker in enumerate(self.workers):ක්රියාවන්ක්රියාත්මක කිරීමෙන් පසු ප්රති results ල ලබා ගන්න
149 next_obs, reward, done, info = worker.child.recv()බෆරයනැවත ධාවනය කිරීම සඳහා සංක්රාන්තිය එක් කරන්න
152 self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done)කථාංගතොරතුරු යාවත්කාලීන කරන්න. කථාංග තොරතුරු එකතු කරන්න, කථාංගයක් අවසන් වුවහොත් ලබා ගත හැකිය; මෙයට කථාංගයේ සම්පූර්ණ විපාකය සහ දිග ඇතුළත් වේ - එය ක්රියාත්මක වන ආකාරය Game
බැලීමට බලන්න.
158 if info:
159 tracker.add('reward', info['reward'])
160 tracker.add('length', info['length'])වත්මන්නිරීක්ෂණ යාවත්කාලීන කරන්න
163 self.obs[w] = next_obs165 def train(self, beta: float):169 for _ in range(self.train_epochs):ප්රමුඛතානැවත ධාවනය කිරීමේ බෆරයෙන් නියැදිය
171 samples = self.replay_buffer.sample(self.mini_batch_size, beta)පුරෝකථනයකරන ලද Q-අගය ලබා ගන්න
173 q_value = self.model(obs_to_torch(samples['obs']))ද්විත්ව Q- ඉගෙනීම සඳහා ඊළඟ තත්වයේ Q-අගයන්ලබා ගන්න. මේවා සඳහා අනුක්රමික ප්රචාරණය නොකළ යුතුය
177 with torch.no_grad():ලබාගන්න
179 double_q_value = self.model(obs_to_torch(samples['next_obs']))ලබාගන්න
181 target_q_value = self.target_model(obs_to_torch(samples['next_obs']))තාවකාලිකවෙනස ගණනය කරන්න (TD) දෝෂ , සහ අලාභය, .
184 td_errors, loss = self.loss_func(q_value,
185 q_value.new_tensor(samples['action']),
186 double_q_value, target_q_value,
187 q_value.new_tensor(samples['done']),
188 q_value.new_tensor(samples['reward']),
189 q_value.new_tensor(samples['weights']))නැවතධාවනය කිරීමේ බෆරය සඳහා ප්රමුඛතා ගණනය කරන්න
192 new_priorities = np.abs(td_errors.cpu().numpy()) + 1e-6නැවතධාවනය කිරීමේ ස්වාරක්ෂක ප්රමුඛතා යාවත්කාලීන කරන්න
194 self.replay_buffer.update_priorities(samples['indexes'], new_priorities)ඉගෙනුම්අනුපාතය සකසන්න
197 for pg in self.optimizer.param_groups:
198 pg['lr'] = self.learning_rate()කලින්ගණනය කරන ලද අනුක්රමික ශුන්ය කිරීම
200 self.optimizer.zero_grad()අනුක්රමිකගණනය කරන්න
202 loss.backward()ක්ලිප්අනුක්රමික
204 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)අනුක්රමිකමත පදනම්ව පරාමිතීන් යාවත්කාලීන කරන්න
206 self.optimizer.step()208 def run_training_loop(self):අවසන්100 කථාංග තොරතුරු
214 tracker.set_queue('reward', 100, True)
215 tracker.set_queue('length', 100, True)මුලින්ඉලක්කගත ජාලයට පිටපත් කරන්න
218 self.target_model.load_state_dict(self.model.state_dict())
219
220 for update in monit.loop(self.updates):, ගවේෂණ භාගය
222 exploration = self.exploration_coefficient(update)
223 tracker.add('exploration', exploration)ප්රමුඛතා නැවත ධාවනය සඳහා
225 beta = self.prioritized_replay_beta(update)
226 tracker.add('beta', beta)වත්මන්ප්රතිපත්තිය සමඟ නියැදිය
229 self.sample(exploration)බෆරයපිරී ගිය පසු පුහුණුව ආරම්භ කරන්න
232 if self.replay_buffer.is_full():ආකෘතියපුහුණු කරන්න
234 self.train(beta)ඉලක්කජාලය වරින් වර යාවත්කාලීන කරන්න
237 if update % self.update_target_model == 0:
238 self.target_model.load_state_dict(self.model.state_dict())ලුහුබැඳඇති දර්ශක සුරකින්න.
241 tracker.save()වරින්වර තිරයට නව රේඛාවක් එක් කරන්න
243 if (update + 1) % 1_000 == 0:
244 logger.log()246 def destroy(self):251 for worker in self.workers:
252 worker.child.send(("close", None))255def main():අත්හදාබැලීම සාදන්න
257 experiment.create(name='dqn')වින්යාසකිරීම්
260 configs = {යාවත්කාලීනගණන
262 'updates': 1_000_000,නියැදිදත්ත සමඟ ආකෘතිය පුහුණු කිරීම සඳහා එපොච් ගණන.
264 'epochs': 8,සේවකක්රියාවලි ගණන
266 'n_workers': 8,තනියාවත්කාලීන කිරීම සඳහා එක් එක් ක්රියාවලිය මත ක්රියාත්මක කිරීමට පියවර ගණන
268 'worker_steps': 4,කුඩාකණ්ඩායම් ප්රමාණය
270 'mini_batch_size': 32,ඉලක්කගතආකෘති යාවත්කාලීන කිරීමේ පරතරය
272 'update_target_model': 250,ඉගෙනුම්අනුපාතය.
274 'learning_rate': FloatDynamicHyperParam(1e-4, (0, 1e-3)),
275 }වින්යාසකිරීම්
278 experiment.configs(configs)පුහුණුකරුආරම්භ කරන්න
281 m = Trainer(**configs)අත්හදාබැලීම ධාවනය කර අධීක්ෂණය කරන්න
283 with experiment.start():
284 m.run_training_loop()කම්කරුවන්නවත්වන්න
286 m.destroy()290if __name__ == "__main__":
291 main()