Skip to content

Commit 22249e9

Browse files
committed
commit
1 parent 8f3eb8c commit 22249e9

File tree

5 files changed

+414
-0
lines changed

5 files changed

+414
-0
lines changed

deep-learning/Deep-Reinforcement-Learning-Complete-Collection/PyTorch-cpp/gym_server/__init__.py

Whitespace-only changes.
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
"""
2+
Adapted from:
3+
github.com/ikostrikov/pytorch-a2c-ppo-acktr/blob/master/a2c_ppo_acktr/envs.py
4+
5+
Provides utility functions for making Gym environments.
6+
"""
7+
import gym
8+
from gym.spaces import Box
9+
import numpy as np
10+
11+
from baselines.common.vec_env import VecEnvWrapper
12+
from baselines.common.atari_wrappers import make_atari, wrap_deepmind
13+
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
14+
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
15+
from baselines.common.vec_env.vec_normalize import (VecNormalize
16+
as VecNormalize_)
17+
18+
19+
class TransposeImage(gym.ObservationWrapper):
20+
def __init__(self, env=None):
21+
super(TransposeImage, self).__init__(env)
22+
obs_shape = self.observation_space.shape
23+
self.observation_space = Box(
24+
self.observation_space.low[0, 0, 0],
25+
self.observation_space.high[0, 0, 0],
26+
[obs_shape[2], obs_shape[1], obs_shape[0]],
27+
dtype=self.observation_space.dtype)
28+
29+
def observation(self, observation):
30+
return observation.transpose(2, 0, 1)
31+
32+
33+
class VecFrameStack(VecEnvWrapper):
34+
def __init__(self, venv, nstack):
35+
self.venv = venv
36+
self.nstack = nstack
37+
wos = venv.observation_space # wrapped ob space
38+
low = np.repeat(wos.low, self.nstack, axis=0)
39+
high = np.repeat(wos.high, self.nstack, axis=0)
40+
self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype)
41+
observation_space = gym.spaces.Box(
42+
low=low, high=high, dtype=venv.observation_space.dtype)
43+
VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
44+
45+
def step_wait(self):
46+
obs, rews, news, infos = self.venv.step_wait()
47+
self.stackedobs = np.roll(self.stackedobs, shift=-1, axis=0)
48+
for (i, new) in enumerate(news):
49+
if new:
50+
self.stackedobs[i] = 0
51+
self.stackedobs[..., -obs.shape[-1]:] = obs
52+
return self.stackedobs, rews, news, infos
53+
54+
def reset(self):
55+
obs = self.venv.reset()
56+
self.stackedobs[...] = 0
57+
self.stackedobs[-obs.shape[-1]:, ...] = obs
58+
return self.stackedobs
59+
60+
61+
class VecRewardInfo(VecEnvWrapper):
62+
def __init__(self, venv):
63+
self.venv = venv
64+
VecEnvWrapper.__init__(self, venv)
65+
66+
def step_wait(self):
67+
obs, rews, news, infos = self.venv.step_wait()
68+
infos = {'reward': np.expand_dims(rews, -1)}
69+
return obs, rews, news, infos
70+
71+
def reset(self):
72+
obs = self.venv.reset()
73+
return obs
74+
75+
76+
class VecNormalize(VecNormalize_):
77+
def __init__(self, *args, **kwargs):
78+
super(VecNormalize, self).__init__(*args, **kwargs)
79+
self.training = True
80+
81+
def _obfilt(self, obs):
82+
if self.ob_rms:
83+
if self.training:
84+
self.ob_rms.update(obs)
85+
obs = np.clip((obs - self.ob_rms.mean)
86+
/ np.sqrt(self.ob_rms.var + self.epsilon),
87+
-self.clipob, self.clipob)
88+
return obs
89+
90+
def train(self):
91+
self.training = True
92+
93+
def eval(self):
94+
self.training = False
95+
96+
def step_wait(self):
97+
obs, rews, news, infos = self.venv.step_wait()
98+
infos = {'reward': np.expand_dims(rews, -1)}
99+
self.ret = self.ret * self.gamma + rews
100+
obs = self._obfilt(obs)
101+
if self.ret_rms:
102+
self.ret_rms.update(self.ret)
103+
rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon),
104+
-self.cliprew,
105+
self.cliprew)
106+
self.ret[news] = 0.
107+
return obs, rews, news, infos
108+
109+
110+
def make_env(env_id, seed, rank):
111+
def _thunk():
112+
env = gym.make(env_id)
113+
114+
is_atari = hasattr(gym.envs, 'atari') and isinstance(
115+
env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
116+
if is_atari:
117+
env = make_atari(env_id)
118+
119+
env.seed(seed + rank)
120+
121+
obs_shape = env.observation_space.shape
122+
123+
if is_atari:
124+
if len(env.observation_space.shape) == 3:
125+
env = wrap_deepmind(env)
126+
elif len(env.observation_space.shape) == 3:
127+
raise NotImplementedError("CNN models work only for atari,\n"
128+
"please use a custom wrapper for a "
129+
"custom pixel input env.\n See "
130+
"wrap_deepmind for an example.")
131+
132+
# If the input has shape (W,H,3), wrap for PyTorch convolutions
133+
obs_shape = env.observation_space.shape
134+
if len(obs_shape) == 3 and obs_shape[2] in [1, 3]:
135+
env = TransposeImage(env)
136+
137+
return env
138+
return _thunk
139+
140+
141+
def make_vec_envs(env_name, seed, num_processes, gamma, num_frame_stack=None):
142+
envs = [make_env(env_name, seed, i) for i in range(num_processes)]
143+
144+
if len(envs) > 1:
145+
envs = SubprocVecEnv(envs)
146+
else:
147+
envs = DummyVecEnv(envs)
148+
149+
if len(envs.observation_space.shape) == 1:
150+
if gamma is None or gamma == -1:
151+
envs = VecNormalize(envs, ret=False)
152+
else:
153+
envs = VecNormalize(envs, gamma=gamma)
154+
else:
155+
envs = VecRewardInfo(envs)
156+
157+
if num_frame_stack is not None:
158+
envs = VecFrameStack(envs, num_frame_stack)
159+
elif len(envs.observation_space.shape) == 3:
160+
envs = VecFrameStack(envs, 4)
161+
162+
return envs
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"""
2+
Classes for building requests to send to pytorch-cpp-rl.
3+
"""
4+
from abc import ABC, abstractmethod
5+
import numpy as np
6+
import msgpack
7+
8+
9+
class Message(ABC):
10+
"""
11+
Base class for messages.
12+
"""
13+
@abstractmethod
14+
def to_msg(self) -> bytes:
15+
"""
16+
Creates the JSON for the request.
17+
"""
18+
19+
20+
class InfoMessage(Message):
21+
"""
22+
Builds the JSON for returning the result of an info() action.
23+
"""
24+
25+
def __init__(self, action_space_type, action_space_shape,
26+
observation_space_type, observation_space_shape):
27+
self.action_space_type = action_space_type
28+
self.action_space_shape = action_space_shape
29+
self.observation_space_type = observation_space_type
30+
self.observation_space_shape = observation_space_shape
31+
32+
def to_msg(self) -> bytes:
33+
request = {
34+
"action_space_type": self.action_space_type,
35+
"action_space_shape": self.action_space_shape,
36+
"observation_space_type": self.observation_space_type,
37+
"observation_space_shape": self.observation_space_shape
38+
}
39+
return msgpack.packb(request)
40+
41+
42+
class MakeMessage(Message):
43+
"""
44+
Builds the JSON for returning the result of an make_env() action.
45+
"""
46+
47+
def to_msg(self) -> bytes:
48+
request = {
49+
"result": "OK"
50+
}
51+
return msgpack.packb(request)
52+
53+
54+
class ResetMessage(Message):
55+
"""
56+
Builds the JSON for returning the result of an env.reset() action.
57+
"""
58+
59+
def __init__(self, observation: np.ndarray):
60+
self.observation = observation
61+
62+
def to_msg(self) -> bytes:
63+
request = {
64+
"observation": self.observation.tolist()
65+
}
66+
return msgpack.packb(request)
67+
68+
69+
class StepMessage(Message):
70+
"""
71+
Builds the JSON for returning the result of an env.step() action.
72+
"""
73+
74+
def __init__(self,
75+
observation: np.ndarray,
76+
reward: np.ndarray,
77+
done: np.ndarray,
78+
real_reward: np.ndarray):
79+
self.observation = observation
80+
self.reward = reward
81+
self.done = done
82+
self.real_reward = real_reward
83+
84+
def to_msg(self) -> bytes:
85+
request = {
86+
"observation": self.observation.tolist(),
87+
"reward": self.reward.tolist(),
88+
"done": self.done.tolist(),
89+
"real_reward": self.real_reward.tolist()
90+
}
91+
return msgpack.packb(request)
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
"""
2+
Contains a class that trains an agent.
3+
"""
4+
import logging
5+
from typing import Tuple
6+
import numpy as np
7+
import gym
8+
9+
from gym_server.envs import make_vec_envs
10+
from gym_server.messages import (InfoMessage, MakeMessage, ResetMessage,
11+
StepMessage)
12+
from gym_server.zmq_client import ZmqClient
13+
14+
15+
RUNNING_REWARD_HORIZON = 10
16+
17+
18+
class Server:
19+
"""
20+
When `Server.serve()` is called, provides a ZMQ based API for training
21+
RL agents on OpenAI gym environments.
22+
"""
23+
24+
def __init__(self, zmq_client: ZmqClient):
25+
self.zmq_client: ZmqClient = zmq_client
26+
self.env: gym.Env = None
27+
logging.info("Gym server initialized")
28+
29+
def serve(self):
30+
"""
31+
Run the server.
32+
"""
33+
logging.info("Serving")
34+
try:
35+
self.__serve()
36+
except KeyboardInterrupt:
37+
pass
38+
39+
def _serve(self):
40+
while True:
41+
request = self.zmq_client.receive()
42+
method = request['method']
43+
param = request['param']
44+
45+
if method == 'info':
46+
(action_space_type,
47+
action_space_shape,
48+
observation_space_type,
49+
observation_space_shape) = self.__info()
50+
self.zmq_client.send(InfoMessage(action_space_type,
51+
action_space_shape,
52+
observation_space_type,
53+
observation_space_shape))
54+
55+
elif method == 'make':
56+
self.__make(param['env_name'], param['num_envs'],
57+
param['gamma'])
58+
self.zmq_client.send(MakeMessage())
59+
60+
elif method == 'reset':
61+
observation = self.__reset()
62+
self.zmq_client.send(ResetMessage(observation))
63+
64+
elif method == 'step':
65+
if 'render' in param:
66+
result = self.__step(
67+
np.array(param['actions']), param['render'])
68+
else:
69+
result = self.__step(np.array(param['actions']))
70+
self.zmq_client.send(StepMessage(result[0],
71+
result[1],
72+
result[2],
73+
result[3]['reward']))
74+
75+
def info(self):
76+
"""
77+
Return info about the currently loaded environment
78+
"""
79+
action_space_type = self.env.action_space.__class__.__name__
80+
if action_space_type == 'Discrete':
81+
action_space_shape = [self.env.action_space.n]
82+
else:
83+
action_space_shape = self.env.action_space.shape
84+
observation_space_type = self.env.observation_space.__class__.__name__
85+
observation_space_shape = self.env.observation_space.shape
86+
return (action_space_type, action_space_shape, observation_space_type,
87+
observation_space_shape)
88+
89+
def make(self, env_name, num_envs, gamma):
90+
"""
91+
Makes a vectorized environment of the type and number specified.
92+
"""
93+
logging.info("Making %d %ss", num_envs, env_name)
94+
self.env = make_vec_envs(env_name, 0, num_envs, gamma)
95+
96+
def reset(self) -> np.ndarray:
97+
"""
98+
Resets the environments.
99+
"""
100+
logging.info("Resetting environments")
101+
return self.env.reset()
102+
103+
def step(self,
104+
actions: np.ndarray,
105+
render: bool = False) -> Tuple[np.ndarray, np.ndarray,
106+
np.ndarray, np.ndarray]:
107+
"""
108+
Steps the environments.
109+
"""
110+
if isinstance(self.env.action_space, gym.spaces.Discrete):
111+
actions = actions.squeeze(-1)
112+
observation, reward, done, info = self.env.step(actions)
113+
if isinstance(self.env.action_space, gym.spaces.Discrete):
114+
reward = np.expand_dims(reward, -1)
115+
done = np.expand_dims(done, -1)
116+
if render:
117+
self.env.render()
118+
return observation, reward, done, info
119+
120+
__info = info
121+
__make = make
122+
__reset = reset
123+
__serve = _serve
124+
__step = step

0 commit comments

Comments
 (0)