この実験では、ディープQネットワーク(DQN)にOpenAI Gymでアタリブレイクアウトゲームをプレイするようにトレーニングします。ゲーム環境を複数のプロセスで実行して効率的にサンプリングします。
15import numpy as np
16import torch
17
18from labml import tracker, experiment, logger, monit
19from labml.internal.configs.dynamic_hyperparam import FloatDynamicHyperParam
20from labml_helpers.schedule import Piecewise
21from labml_nn.rl.dqn import QFuncLoss
22from labml_nn.rl.dqn.model import Model
23from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
24from labml_nn.rl.game import Workerデバイスを選択
27if torch.cuda.is_available():
28 device = torch.device("cuda:0")
29else:
30 device = torch.device("cpu")[0, 255]
観測値をからにスケーリング [0, 1]
33def obs_to_torch(obs: np.ndarray) -> torch.Tensor:35 return torch.tensor(obs, dtype=torch.float32, device=device) / 255.38class Trainer:43 def __init__(self, *,
44 updates: int, epochs: int,
45 n_workers: int, worker_steps: int, mini_batch_size: int,
46 update_target_model: int,
47 learning_rate: FloatDynamicHyperParam,
48 ):労働者の数
50 self.n_workers = n_workers更新のたびにサンプリングされるステップ
52 self.worker_steps = worker_stepsトレーニングの反復回数
54 self.train_epochs = epochs更新回数
57 self.updates = updatesトレーニング用ミニバッチのサイズ
59 self.mini_batch_size = mini_batch_size250 回の更新ごとにターゲットネットワークを更新
62 self.update_target_model = update_target_model学習率
65 self.learning_rate = learning_rate更新機能としての探索
68 self.exploration_coefficient = Piecewise(
69 [
70 (0, 1.0),
71 (25_000, 0.1),
72 (self.updates / 2, 0.01)
73 ], outside_value=0.01)更新機能としての再生バッファ用
76 self.prioritized_replay_beta = Piecewise(
77 [
78 (0, 0.4),
79 (self.updates, 1)
80 ], outside_value=1)リプレイバッファは.再生バッファの容量は 2 の累乗でなければなりません
。83 self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)サンプリングとトレーニング用のモデル
86 self.model = Model().to(device)取得する対象モデル
88 self.target_model = Model().to(device)ワーカーを作成
91 self.workers = [Worker(47 + i) for i in range(self.n_workers)]観測用のテンソルを初期化
94 self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)ワーカーをリセット
97 for worker in self.workers:
98 worker.child.send(("reset", None))初期観測値を取得
101 for i, worker in enumerate(self.workers):
102 self.obs[i] = worker.child.recv()損失関数
105 self.loss_func = QFuncLoss(0.99)オプティマイザー
107 self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4)アクションをサンプリングするときは、-greedy ストラテジーを使用します。つまり、確率のある貪欲なアクションを実行し、確率のあるランダムなアクションを実行します。と呼びますexploration_coefficient
。
109 def _sample_action(self, q_value: torch.Tensor, exploration_coefficient: float):サンプリングにはグラデーションは必要ありません
119 with torch.no_grad():Q値が最も高いアクションをサンプリングします。これは貪欲な行動です
。121 greedy_action = torch.argmax(q_value, dim=-1)サンプルとアクションを均一に
123 random_action = torch.randint(q_value.shape[-1], greedy_action.shape, device=q_value.device)欲張りアクションとランダムアクションのどちらを選ぶか
125 is_choose_rand = torch.rand(greedy_action.shape, device=q_value.device) < exploration_coefficient以下に基づいてアクションを選択してください is_choose_rand
127 return torch.where(is_choose_rand, random_action, greedy_action).cpu().numpy()129 def sample(self, exploration_coefficient: float):これにはグラデーションは必要ありません
133 with torch.no_grad():[サンプル] worker_steps
135 for t in range(self.worker_steps):現在の観測値の Q_value を取得
137 q_value = self.model(obs_to_torch(self.obs))サンプルアクション
139 actions = self._sample_action(q_value, exploration_coefficient)各ワーカーでサンプルアクションを実行
142 for w, worker in enumerate(self.workers):
143 worker.child.send(("step", actions[w]))各作業者から情報を収集する
146 for w, worker in enumerate(self.workers):アクションを実行した後に結果を取得
148 next_obs, reward, done, info = worker.child.recv()再生バッファにトランジションを追加
151 self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done)エピソード情報を更新します。エピソードが終了した場合に利用できるエピソード情報を収集します。これには、合計報酬とエピソードの長さが含まれます。仕組みを確認してみてください。Game
157 if info:
158 tracker.add('reward', info['reward'])
159 tracker.add('length', info['length'])現在の観測値を更新
162 self.obs[w] = next_obs164 def train(self, beta: float):168 for _ in range(self.train_epochs):プライオリティ・リプレイ・バッファからのサンプル
170 samples = self.replay_buffer.sample(self.mini_batch_size, beta)予測された Q 値の取得
172 q_value = self.model(obs_to_torch(samples['obs']))二重Q学習の次の状態のQ値を取得します。これらの場合、グラデーションは伝播しないはずです
176 with torch.no_grad():取得
178 double_q_value = self.model(obs_to_torch(samples['next_obs']))取得
180 target_q_value = self.target_model(obs_to_torch(samples['next_obs']))時差 (TD) 誤差、および損失を計算します。
183 td_errors, loss = self.loss_func(q_value,
184 q_value.new_tensor(samples['action']),
185 double_q_value, target_q_value,
186 q_value.new_tensor(samples['done']),
187 q_value.new_tensor(samples['reward']),
188 q_value.new_tensor(samples['weights']))再生バッファの優先度を計算
191 new_priorities = np.abs(td_errors.cpu().numpy()) + 1e-6リプレイバッファの優先順位を更新
193 self.replay_buffer.update_priorities(samples['indexes'], new_priorities)学習率を設定
196 for pg in self.optimizer.param_groups:
197 pg['lr'] = self.learning_rate()以前に計算したグラデーションをゼロにします
199 self.optimizer.zero_grad()勾配の計算
201 loss.backward()クリップグラデーション
203 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)グラデーションに基づいてパラメータを更新
205 self.optimizer.step()207 def run_training_loop(self):最新100話の情報
213 tracker.set_queue('reward', 100, True)
214 tracker.set_queue('length', 100, True)最初にターゲットネットワークにコピー
217 self.target_model.load_state_dict(self.model.state_dict())
218
219 for update in monit.loop(self.updates):、探査フラクション
221 exploration = self.exploration_coefficient(update)
222 tracker.add('exploration', exploration)優先再生用
224 beta = self.prioritized_replay_beta(update)
225 tracker.add('beta', beta)現在のポリシーを含むサンプル
228 self.sample(exploration)バッファーがいっぱいになったらトレーニングを開始する
231 if self.replay_buffer.is_full():モデルのトレーニング
233 self.train(beta)ターゲットネットワークを定期的に更新
236 if update % self.update_target_model == 0:
237 self.target_model.load_state_dict(self.model.state_dict())追跡指標を保存します。
240 tracker.save()画面に定期的に新しい行を追加してください
242 if (update + 1) % 1_000 == 0:
243 logger.log()245 def destroy(self):250 for worker in self.workers:
251 worker.child.send(("close", None))254def main():実験を作成
256 experiment.create(name='dqn')コンフィギュレーション
259 configs = {更新回数
261 'updates': 1_000_000,サンプルデータを使用してモデルをトレーニングするエポックの数。
263 'epochs': 8,ワーカープロセスの数
265 'n_workers': 8,1 回の更新で各プロセスで実行するステップの数
267 'worker_steps': 4,ミニバッチサイズ
269 'mini_batch_size': 32,対象モデルの更新間隔
271 'update_target_model': 250,学習率。
273 'learning_rate': FloatDynamicHyperParam(1e-4, (0, 1e-3)),
274 }コンフィギュレーション
277 experiment.configs(configs)トレーナーを初期化
280 m = Trainer(**configs)実験の実行と監視
282 with experiment.start():
283 m.run_training_loop()労働者を止めろ
285 m.destroy()289if __name__ == "__main__":
290 main()