diff --git a/labml_nn/rl/dqn/__init__.py b/labml_nn/rl/dqn/__init__.py index e69de29b..b23f9d54 100644 --- a/labml_nn/rl/dqn/__init__.py +++ b/labml_nn/rl/dqn/__init__.py @@ -0,0 +1,50 @@ +""" +This is a Deep Q Learning implementation with: +* Double Q Network +* Dueling Network +* Prioritized Replay +""" + +from typing import Tuple + +import torch + +from labml import tracker +from labml_helpers.module import Module +from labml_nn.rl.dqn.replay_buffer import ReplayBuffer + + +class QFuncLoss(Module): + def __init__(self, gamma: float): + super().__init__() + self.gamma = gamma + + def __call__(self, q: torch.Tensor, + action: torch.Tensor, + double_q: torch.Tensor, + target_q: torch.Tensor, + done: torch.Tensor, + reward: torch.Tensor, + weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + q_sampled_action = q.gather(-1, action.to(torch.long).unsqueeze(-1)).squeeze(-1) + tracker.add('q_sampled_action', q_sampled_action) + + with torch.no_grad(): + best_next_action = torch.argmax(double_q, -1) + best_next_q_value = target_q.gather(-1, best_next_action.unsqueeze(-1)).squeeze(-1) + + best_next_q_value *= (1 - done) + + q_update = reward + self.gamma * best_next_q_value + tracker.add('q_update', q_update) + + td_error = q_sampled_action - q_update + tracker.add('td_error', td_error) + + # Huber loss + losses = torch.nn.functional.smooth_l1_loss(q_sampled_action, q_update, reduction='none') + loss = torch.mean(weights * losses) + tracker.add('loss', loss) + + return td_error, loss + diff --git a/labml_nn/rl/dqn/experiment.py b/labml_nn/rl/dqn/experiment.py new file mode 100644 index 00000000..817f7cf0 --- /dev/null +++ b/labml_nn/rl/dqn/experiment.py @@ -0,0 +1,385 @@ +""" +\( + \def\hl1#1{{\color{orange}{#1}}} + \def\blue#1{{\color{cyan}{#1}}} + \def\green#1{{\color{yellowgreen}{#1}}} +\) +""" + +import numpy as np +import torch +from torch import nn + +from labml import tracker, experiment, logger, monit +from labml_helpers.module import Module +from labml_helpers.schedule import Piecewise +from labml_nn.rl.dqn import QFuncLoss +from labml_nn.rl.dqn.replay_buffer import ReplayBuffer +from labml_nn.rl.game import Worker + +if torch.cuda.is_available(): + device = torch.device("cuda:0") +else: + device = torch.device("cpu") + + +class Model(Module): + """ + ## Neural Network Model for $Q$ Values + + #### Dueling Network ⚔️ + We are using a [dueling network](https://arxiv.org/abs/1511.06581) + to calculate Q-values. + Intuition behind dueling network architure is that in most states + the action doesn't matter, + and in some states the action is significant. Dueling network allows + this to be represented very well. + + \begin{align} + Q^\pi(s,a) &= V^\pi(s) + A^\pi(s, a) + \\ + \mathop{\mathbb{E}}_{a \sim \pi(s)} + \Big[ + A^\pi(s, a) + \Big] + &= 0 + \end{align} + + So we create two networks for $V$ and $A$ and get $Q$ from them. + $$ + Q(s, a) = V(s) + + \Big( + A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a') + \Big) + $$ + We share the initial layers of the $V$ and $A$ networks. + """ + + def __init__(self): + """ + ### Initialize + + We need `scope` because we need multiple copies of variables + for target network and training network. + """ + + super().__init__() + self.conv = nn.Sequential( + # The first convolution layer takes a + # 84x84 frame and produces a 20x20 frame + nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4), + nn.ReLU(), + + # The second convolution layer takes a + # 20x20 frame and produces a 9x9 frame + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2), + nn.ReLU(), + + # The third convolution layer takes a + # 9x9 frame and produces a 7x7 frame + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1), + nn.ReLU(), + ) + + # A fully connected layer takes the flattened + # frame from third convolution layer, and outputs + # 512 features + self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512) + + self.state_score = nn.Sequential( + nn.Linear(in_features=512, out_features=256), + nn.ReLU(), + nn.Linear(in_features=256, out_features=1), + ) + self.action_score = nn.Sequential( + nn.Linear(in_features=512, out_features=256), + nn.ReLU(), + nn.Linear(in_features=256, out_features=4), + ) + + # + self.activation = nn.ReLU() + + def __call__(self, obs: torch.Tensor): + h = self.conv(obs) + h = h.reshape((-1, 7 * 7 * 64)) + + h = self.activation(self.lin(h)) + + action_score = self.action_score(h) + state_score = self.state_score(h) + + # $Q(s, a) =V(s) + \Big(A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a' \in \mathcal{A}} A(s, a')\Big)$ + action_score_centered = action_score - action_score.mean(dim=-1, keepdim=True) + q = state_score + action_score_centered + + return q + + +def obs_to_torch(obs: np.ndarray) -> torch.Tensor: + """Scale observations from `[0, 255]` to `[0, 1]`""" + return torch.tensor(obs, dtype=torch.float32, device=device) / 255. + + +class Main(object): + """ + ## Main class + This class runs the training loop. + It initializes TensorFlow, handles logging and monitoring, + and runs workers as multiple processes. + """ + + def __init__(self): + """ + ### Initialize + """ + + # #### Configurations + + # number of workers + self.n_workers = 8 + # steps sampled on each update + self.worker_steps = 4 + # number of training iterations + self.train_epochs = 8 + + # number of updates + self.updates = 1_000_000 + # size of mini batch for training + self.mini_batch_size = 32 + + # exploration as a function of time step + self.exploration_coefficient = Piecewise( + [ + (0, 1.0), + (25_000, 0.1), + (self.updates / 2, 0.01) + ], outside_value=0.01) + + # update target network every 250 update + self.update_target_model = 250 + + # $\beta$ for replay buffer as a function of time steps + self.prioritized_replay_beta = Piecewise( + [ + (0, 0.4), + (self.updates, 1) + ], outside_value=1) + + # replay buffer + self.replay_buffer = ReplayBuffer(2 ** 14, 0.6) + + self.model = Model().to(device) + self.target_model = Model().to(device) + + # last observation for each worker + # create workers + self.workers = [Worker(47 + i) for i in range(self.n_workers)] + + # initialize tensors for observations + self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8) + for worker in self.workers: + worker.child.send(("reset", None)) + for i, worker in enumerate(self.workers): + self.obs[i] = worker.child.recv() + + self.loss_func = QFuncLoss(0.99) + # optimizer + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4) + + def _sample_action(self, q_value: torch.Tensor, exploration_coefficient: float): + """ + #### $\epsilon$-greedy Sampling + When sampling actions we use a $\epsilon$-greedy strategy, where we + take a greedy action with probabiliy $1 - \epsilon$ and + take a random action with probability $\epsilon$. + We refer to $\epsilon$ as *exploration*. + """ + + with torch.no_grad(): + greedy_action = torch.argmax(q_value, dim=-1) + random_action = torch.randint(q_value.shape[-1], greedy_action.shape, device=q_value.device) + + is_choose_rand = torch.rand(greedy_action.shape, device=q_value.device) < exploration_coefficient + + return torch.where(is_choose_rand, random_action, greedy_action).cpu().numpy() + + def sample(self, exploration_coefficient: float): + """### Sample data""" + + with torch.no_grad(): + # sample `SAMPLE_STEPS` + for t in range(self.worker_steps): + # sample actions + q_value = self.model(obs_to_torch(self.obs)) + actions = self._sample_action(q_value, exploration_coefficient) + + # run sampled actions on each worker + for w, worker in enumerate(self.workers): + worker.child.send(("step", actions[w])) + + # collect information from each worker + for w, worker in enumerate(self.workers): + # get results after executing the actions + next_obs, reward, done, info = worker.child.recv() + + # add transition to replay buffer + self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done) + + # update episode information + # collect episode info, which is available if an episode finished; + # this includes total reward and length of the episode - + # look at `Game` to see how it works. + # We also add a game frame to it for monitoring. + if info: + tracker.add('reward', info['reward']) + tracker.add('length', info['length']) + + # update current observation + self.obs[w] = next_obs + + def train(self, beta: float): + """ + ### Train the model + + We want to find optimal action-value function. + + \begin{align} + Q^*(s,a) &= \max_\pi \mathbb{E} \Big[ + r_t + \gamma r_{t + 1} + \gamma^2 r_{t + 2} + ... | s_t = s, a_t = a, \pi + \Big] + \\ + Q^*(s,a) &= \mathop{\mathbb{E}}_{s' \sim \large{\varepsilon}} \Big[ + r + \gamma \max_{a'} Q^* (s', a') | s, a + \Big] + \end{align} + + #### Target network 🎯 + In order to improve stability we use experience replay that randomly sample + from previous experience $U(D)$. We also use a Q network + with a separate set of paramters $\hl1{\theta_i^{-}}$ to calculate the target. + $\hl1{\theta_i^{-}}$ is updated periodically. + This is according to the [paper by DeepMind](https://deepmind.com/research/dqn/). + + So the loss function is, + $$ + \mathcal{L}_i(\theta_i) = \mathop{\mathbb{E}}_{(s,a,r,s') \sim U(D)} + \bigg[ + \Big( + r + \gamma \max_{a'} Q(s', a'; \hl1{\theta_i^{-}}) - Q(s,a;\theta_i) + \Big) ^ 2 + \bigg] + $$ + + #### Double $Q$-Learning + The max operator in the above calculation uses same network for both + selecting the best action and for evaluating the value. + That is, + $$ + \max_{a'} Q(s', a'; \theta) = \blue{Q} + \Big( + s', \mathop{\operatorname{argmax}}_{a'} + \blue{Q}(s', a'; \blue{\theta}); \blue{\theta} + \Big) + $$ + We use [double Q-learning](https://arxiv.org/abs/1509.06461), where + the $\operatorname{argmax}$ is taken from $\theta_i$ and + the value is taken from $\theta_i^{-}$. + + And the loss function becomes, + \begin{align} + \mathcal{L}_i(\theta_i) = \mathop{\mathbb{E}}_{(s,a,r,s') \sim U(D)} + \Bigg[ + \bigg( + &r + \gamma \blue{Q} + \Big( + s', + \mathop{\operatorname{argmax}}_{a'} + \green{Q}(s', a'; \green{\theta_i}); \blue{\theta_i^{-}} + \Big) + \\ + - &Q(s,a;\theta_i) + \bigg) ^ 2 + \Bigg] + \end{align} + """ + + for _ in range(self.train_epochs): + # sample from priority replay buffer + samples = self.replay_buffer.sample(self.mini_batch_size, beta) + # train network + q_value = self.model(obs_to_torch(samples['obs'])) + + with torch.no_grad(): + double_q_value = self.model(obs_to_torch(samples['next_obs'])) + target_q_value = self.target_model(obs_to_torch(samples['next_obs'])) + + td_errors, loss = self.loss_func(q_value, + q_value.new_tensor(samples['action']), + double_q_value, target_q_value, + q_value.new_tensor(samples['done']), + q_value.new_tensor(samples['reward']), + q_value.new_tensor(samples['weights'])) + + # $p_i = |\delta_i| + \epsilon$ + new_priorities = np.abs(td_errors.cpu().numpy()) + 1e-6 + # update replay buffer + self.replay_buffer.update_priorities(samples['indexes'], new_priorities) + + # compute gradients + self.optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5) + self.optimizer.step() + + def run_training_loop(self): + """ + ### Run training loop + """ + + # copy to target network initially + self.target_model.load_state_dict(self.model.state_dict()) + + # last 100 episode information + tracker.set_queue('reward', 100, True) + tracker.set_queue('length', 100, True) + + for update in monit.loop(self.updates): + # $\epsilon$, exploration fraction + exploration = self.exploration_coefficient(update) + tracker.add('exploration', exploration) + # $\beta$ for priority replay + beta = self.prioritized_replay_beta(update) + tracker.add('beta', beta) + + # sample with current policy + self.sample(exploration) + + if self.replay_buffer.is_full(): + # train the model + self.train(beta) + + # periodically update target network + if update % self.update_target_model == 0: + self.target_model.load_state_dict(self.model.state_dict()) + + tracker.save() + if (update + 1) % 1_000 == 0: + logger.log() + + def destroy(self): + """ + ### Destroy + Stop the workers + """ + for worker in self.workers: + worker.child.send(("close", None)) + + +# ## Run it +if __name__ == "__main__": + experiment.create(name='dqn') + m = Main() + with experiment.start(): + m.run_training_loop() + m.destroy() diff --git a/labml_nn/rl/dqn/replay_buffer.py b/labml_nn/rl/dqn/replay_buffer.py new file mode 100644 index 00000000..48d79f44 --- /dev/null +++ b/labml_nn/rl/dqn/replay_buffer.py @@ -0,0 +1,236 @@ +import numpy as np +import random + + +class ReplayBuffer: + """ + ## Buffer for Prioritized Experience Replay + [Prioritized experience replay](https://arxiv.org/abs/1511.05952) + samples important transitions more frequently. + The transitions are prioritized by the Temporal Difference error. + + We sample transition $i$ with probability, + $$P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}$$ + where $\alpha$ is a hyper-parameter that determines how much + prioritization is used, with $\alpha = 0$ corresponding to uniform case. + + We use proportional prioritization $p_i = |\delta_i| + \epsilon$ where + $\delta_i$ is the temporal difference for transition $i$. + + We correct the bias introduced by prioritized replay by + importance-sampling (IS) weights + $$w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$$ + that fully compensates for when $\beta = 1$. + We normalize weights by $1/\max_i w_i$ for stability. + Unbiased nature is most important towards the convergence at end of training. + Therefore we increase $\beta$ towards end of training. + + ### Binary Segment Trees + We use binary segment trees to efficiently calculate + $\sum_k^i p_k^\alpha$, the cumulative probability, + which is needed to sample. + We also use a binary segment tree to find $\min p_i^\alpha$, + which is needed for $1/\max_i w_i$. + We can also use a min-heap for this. + + This is how a binary segment tree works for sum; + it is similar for minimum. + Let $x_i$ be the list of $N$ values we want to represent. + Let $b_{i,j}$ be the $j^{\mathop{th}}$ node of the $i^{\mathop{th}}$ row + in the binary tree. + That is two children of node $b_{i,j}$ are $b_{i+1,2j}$ and $b_{i+1,2j + 1}$. + + The leaf nodes on row $D = \left\lceil {1 + \log_2 N} \right\rceil$ + will have values of $x$. + Every node keeps the sum of the two child nodes. + So the root node keeps the sum of the entire array of values. + The two children of the root node keep + the sum of the first half of the array and + the sum of the second half of the array, and so on. + + $$b_{i,j} = \sum_{k = (j -1) * 2^{D - i} + 1}^{j * 2^{D - i}} x_k$$ + + Number of nodes in row $i$, + $$N_i = \left\lceil{\frac{N}{D - i + i}} \right\rceil$$ + This is equal to the sum of nodes in all rows above $i$. + So we can use a single array $a$ to store the tree, where, + $$b_{i,j} = a_{N_1 + j}$$ + + Then child nodes of $a_i$ are $a_{2i}$ and $a_{2i + 1}$. + That is, + $$a_i = a_{2i} + a_{2i + 1}$$ + + This way of maintaining binary trees is very easy to program. + *Note that we are indexing from 1*. + """ + + def __init__(self, capacity, alpha): + """ + ### Initialize + """ + # we use a power of 2 for capacity to make it easy to debug + self.capacity = capacity + # we refill the queue once it reaches capacity + self.next_idx = 0 + # $\alpha$ + self.alpha = alpha + + # maintain segment binary trees to take sum and find minimum over a range + self.priority_sum = [0 for _ in range(2 * self.capacity)] + self.priority_min = [float('inf') for _ in range(2 * self.capacity)] + + # current max priority, $p$, to be assigned to new transitions + self.max_priority = 1. + + # arrays for buffer + self.data = { + 'obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8), + 'action': np.zeros(shape=capacity, dtype=np.int32), + 'reward': np.zeros(shape=capacity, dtype=np.float32), + 'next_obs': np.zeros(shape=(capacity, 4, 84, 84), dtype=np.uint8), + 'done': np.zeros(shape=capacity, dtype=np.bool) + } + + # size of the buffer + self.size = 0 + + def add(self, obs, action, reward, next_obs, done): + """ + ### Add sample to queue + """ + + idx = self.next_idx + + # store in the queue + self.data['obs'][idx] = obs + self.data['action'][idx] = action + self.data['reward'][idx] = reward + self.data['next_obs'][idx] = next_obs + self.data['done'][idx] = done + + # increment head of the queue and calculate the size + self.next_idx = (idx + 1) % self.capacity + self.size = min(self.capacity, self.size + 1) + + # $p_i^\alpha$, new samples get `max_priority` + priority_alpha = self.max_priority ** self.alpha + self._set_priority_min(idx, priority_alpha) + self._set_priority_sum(idx, priority_alpha) + + def _set_priority_min(self, idx, priority_alpha): + """ + #### Set priority in binary segment tree for minimum + """ + + # leaf of the binary tree + idx += self.capacity + self.priority_min[idx] = priority_alpha + + # update tree, by traversing along ancestors + while idx >= 2: + idx //= 2 + self.priority_min[idx] = min(self.priority_min[2 * idx], + self.priority_min[2 * idx + 1]) + + def _set_priority_sum(self, idx, priority): + """ + #### Set priority in binary segment tree for sum + """ + + # leaf of the binary tree + idx += self.capacity + self.priority_sum[idx] = priority + + # update tree, by traversing along ancestors + while idx >= 2: + idx //= 2 + self.priority_sum[idx] = self.priority_sum[2 * idx] + self.priority_sum[2 * idx + 1] + + def _sum(self): + """ + #### $\sum_k p_k^\alpha$ + """ + return self.priority_sum[1] + + def _min(self): + """ + #### $\min_k p_k^\alpha$ + """ + return self.priority_min[1] + + def find_prefix_sum_idx(self, prefix_sum): + """ + #### Find largest $i$ such that $\sum_{k=1}^{i} p_k^\alpha \le P$ + """ + + # start from the root + idx = 1 + while idx < self.capacity: + # if the sum of the left branch is higher than required sum + if self.priority_sum[idx * 2] > prefix_sum: + # go to left branch if the tree if the + idx = 2 * idx + else: + # otherwise go to right branch and reduce the sum of left + # branch from required sum + prefix_sum -= self.priority_sum[idx * 2] + idx = 2 * idx + 1 + + return idx - self.capacity + + def sample(self, batch_size, beta): + """ + ### Sample from buffer + """ + + samples = { + 'weights': np.zeros(shape=batch_size, dtype=np.float32), + 'indexes': np.zeros(shape=batch_size, dtype=np.int32) + } + + # get samples + for i in range(batch_size): + p = random.random() * self._sum() + idx = self.find_prefix_sum_idx(p) + samples['indexes'][i] = idx + + # $\min_i P(i) = \frac{\min_i p_i^\alpha}{\sum_k p_k^\alpha}$ + prob_min = self._min() / self._sum() + # $\max_i w_i = \bigg(\frac{1}{N} \frac{1}{\min_i P(i)}\bigg)^\beta$ + max_weight = (prob_min * self.size) ** (-beta) + + for i in range(batch_size): + idx = samples['indexes'][i] + # $P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}$ + prob = self.priority_sum[idx + self.capacity] / self._sum() + # $w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$ + weight = (prob * self.size) ** (-beta) + # normalize by $\frac{1}{\max_i w_i}$, + # which also cancels off the $\frac{1}/{N}$ term + samples['weights'][i] = weight / max_weight + + # get samples data + for k, v in self.data.items(): + samples[k] = v[samples['indexes']] + + return samples + + def update_priorities(self, indexes, priorities): + """ + ### Update priorities + """ + for idx, priority in zip(indexes, priorities): + self.max_priority = max(self.max_priority, priority) + + # $p_i^\alpha$ + priority_alpha = priority ** self.alpha + self._set_priority_min(idx, priority_alpha) + self._set_priority_sum(idx, priority_alpha) + + def is_full(self): + """ + ### Is the buffer full + + We only start sampling afte the buffer is full. + """ + return self.capacity == self.size diff --git a/labml_nn/rl/ppo/experiment.py b/labml_nn/rl/ppo/experiment.py index 3f034959..e8d40b55 100644 --- a/labml_nn/rl/ppo/experiment.py +++ b/labml_nn/rl/ppo/experiment.py @@ -128,7 +128,7 @@ class Trainer: # Value Loss self.value_loss = ClippedValueFunctionLoss() - def sample(self) -> (Dict[str, np.ndarray], List): + def sample(self) -> Dict[str, torch.Tensor]: """ ### Sample data with current policy """