Files
Varuna Jayasiri bb84c7ad9e anchor
2021-10-21 15:05:47 +05:30

170 lines
4.6 KiB
Python

"""
---
title: Atari wrapper with multi-processing
summary: This implements the Atari games with multi-processing.
---
# Atari wrapper with multi-processing
"""
import multiprocessing
import multiprocessing.connection
import cv2
import gym
import numpy as np
class Game:
"""
<a id="GameEnvironment"></a>
## Game environment
This is a wrapper for OpenAI gym game environment.
We do a few things here:
1. Apply the same action on four frames and get the last frame
2. Convert observation frames to gray and scale it to (84, 84)
3. Stack four frames of the last four actions
4. Add episode information (total reward for the entire episode) for monitoring
5. Restrict an episode to a single life (game has 5 lives, we reset after every single life)
#### Observation format
Observation is tensor of size (4, 84, 84). It is four frames
(images of the game screen) stacked on first axis.
i.e, each channel is a frame.
"""
def __init__(self, seed: int):
# create environment
self.env = gym.make('BreakoutNoFrameskip-v4')
self.env.seed(seed)
# tensor for a stack of 4 frames
self.obs_4 = np.zeros((4, 84, 84))
# buffer to keep the maximum of last 2 frames
self.obs_2_max = np.zeros((2, 84, 84))
# keep track of the episode rewards
self.rewards = []
# and number of lives left
self.lives = 0
def step(self, action):
"""
### Step
Executes `action` for 4 time steps and
returns a tuple of (observation, reward, done, episode_info).
* `observation`: stacked 4 frames (this frame and frames for last 3 actions)
* `reward`: total reward while the action was executed
* `done`: whether the episode finished (a life lost)
* `episode_info`: episode information if completed
"""
reward = 0.
done = None
# run for 4 steps
for i in range(4):
# execute the action in the OpenAI Gym environment
obs, r, done, info = self.env.step(action)
if i >= 2:
self.obs_2_max[i % 2] = self._process_obs(obs)
reward += r
# get number of lives left
lives = self.env.unwrapped.ale.lives()
# reset if a life is lost
if lives < self.lives:
done = True
break
# maintain rewards for each step
self.rewards.append(reward)
if done:
# if finished, set episode information if episode is over, and reset
episode_info = {"reward": sum(self.rewards), "length": len(self.rewards)}
self.reset()
else:
episode_info = None
# get the max of last two frames
obs = self.obs_2_max.max(axis=0)
# push it to the stack of 4 frames
self.obs_4 = np.roll(self.obs_4, shift=-1, axis=0)
self.obs_4[-1] = obs
return self.obs_4, reward, done, episode_info
def reset(self):
"""
### Reset environment
Clean up episode info and 4 frame stack
"""
# reset OpenAI Gym environment
obs = self.env.reset()
# reset caches
obs = self._process_obs(obs)
for i in range(4):
self.obs_4[i] = obs
self.rewards = []
self.lives = self.env.unwrapped.ale.lives()
return self.obs_4
@staticmethod
def _process_obs(obs):
"""
#### Process game frames
Convert game frames to gray and rescale to 84x84
"""
obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
obs = cv2.resize(obs, (84, 84), interpolation=cv2.INTER_AREA)
return obs
def worker_process(remote: multiprocessing.connection.Connection, seed: int):
"""
##Worker Process
Each worker process runs this method
"""
# create game
game = Game(seed)
# wait for instructions from the connection and execute them
while True:
cmd, data = remote.recv()
if cmd == "step":
remote.send(game.step(data))
elif cmd == "reset":
remote.send(game.reset())
elif cmd == "close":
remote.close()
break
else:
raise NotImplementedError
class Worker:
"""
Creates a new worker and runs it in a separate process.
"""
def __init__(self, seed):
self.child, parent = multiprocessing.Pipe()
self.process = multiprocessing.Process(target=worker_process, args=(parent, seed))
self.process.start()