Skip to content

Commit 64ca0c4

Browse files
committed
Partial conversion to latset mlagents
1 parent c05fd80 commit 64ca0c4

File tree

1 file changed

+23
-32
lines changed

1 file changed

+23
-32
lines changed

obstacle_tower_env.py

Lines changed: 23 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import time
1010
from collections import deque
1111
from gym import error, spaces
12-
from mlagents.envs.environment import UnityEnvironment
12+
from mlagents_envs.environment import UnityEnvironment
1313

1414

1515
class UnityGymException(error.Error):
@@ -24,7 +24,7 @@ class UnityGymException(error.Error):
2424

2525

2626
class ObstacleTowerEnv(gym.Env):
27-
ALLOWED_VERSIONS = ['3.1']
27+
ALLOWED_VERSIONS = ['4.0?team=0']
2828

2929
def __init__(self, environment_filename=None, docker_training=False, worker_id=0, retro=True,
3030
timeout_wait=30, realtime_mode=False, config=None, greyscale=False):
@@ -41,11 +41,12 @@ def __init__(self, environment_filename=None, docker_training=False, worker_id=0
4141
"""
4242
self._env = UnityEnvironment(environment_filename,
4343
worker_id,
44-
docker_training=docker_training,
4544
timeout_wait=timeout_wait)
4645

47-
split_name = self._env.academy_name.split('-v')
48-
if len(split_name) == 2 and split_name[0] == "ObstacleTower":
46+
self._env.reset()
47+
behavior_name = list(self._env.behavior_specs)[0]
48+
split_name = behavior_name.split('-v')
49+
if len(split_name) == 2 and split_name[0] == "ObstacleTowerAgent":
4950
self.name, self.version = split_name
5051
else:
5152
raise UnityGymException(
@@ -80,48 +81,41 @@ def __init__(self, environment_filename=None, docker_training=False, worker_id=0
8081
flatten_branched = self.retro
8182
uint8_visual = self.retro
8283

83-
# Check brain configuration
84-
if len(self._env.brains) != 1:
84+
# Check behavior configuration
85+
if len(self._env.behavior_specs) != 1:
8586
raise UnityGymException(
86-
"There can only be one brain in a UnityEnvironment "
87+
"There can only be one agent in this environment "
8788
"if it is wrapped in a gym.")
88-
self.brain_name = self._env.external_brain_names[0]
89-
brain = self._env.brains[self.brain_name]
89+
self.behavior_name = behavior_name
90+
behavior_spec = self._env.behavior_specs[behavior_name]
9091

91-
if brain.number_visual_observations == 0:
92-
raise UnityGymException("Environment provides no visual observations.")
92+
if len(behavior_spec) < 2:
93+
raise UnityGymException("Environment provides too few observations.")
9394

9495
self.uint8_visual = uint8_visual
9596

96-
if brain.number_visual_observations > 1:
97-
logger.warning("The environment contains more than one visual observation. "
98-
"Please note that only the first will be provided in the observation.")
99-
10097
# Check for number of agents in scene.
101-
initial_info = self._env.reset(train_mode=not self.realtime_mode)[self.brain_name]
102-
self._check_agents(len(initial_info.agents))
98+
initial_info, _ = self._env.get_steps(behavior_name)
99+
self._check_agents(len(initial_info))
103100

104101
# Set observation and action spaces
105-
if len(brain.vector_action_space_size) == 1:
106-
self._action_space = spaces.Discrete(brain.vector_action_space_size[0])
102+
if len(behavior_spec.action_shape) == 1:
103+
self._action_space = spaces.Discrete(behavior_spec.action_shape[0])
107104
else:
108105
if flatten_branched:
109-
self._flattener = ActionFlattener(brain.vector_action_space_size)
106+
self._flattener = ActionFlattener(behavior_spec.action_shape)
110107
self._action_space = self._flattener.action_space
111108
else:
112-
self._action_space = spaces.MultiDiscrete(brain.vector_action_space_size)
113-
114-
high = np.array([np.inf] * brain.vector_observation_space_size)
115-
self.action_meanings = brain.vector_action_descriptions
109+
self._action_space = spaces.MultiDiscrete(behavior_spec.action_shape)
116110

117111
if self._greyscale:
118112
depth = 1
119113
else:
120114
depth = 3
121115
image_space_max = 1.0
122116
image_space_dtype = np.float32
123-
camera_height = brain.camera_resolutions[0]["height"]
124-
camera_width = brain.camera_resolutions[0]["width"]
117+
camera_height = behavior_spec.observation_shapes[0][0]
118+
camera_width = behavior_spec.observation_shapes[0][1]
125119
if self.retro:
126120
image_space_max = 255
127121
image_space_dtype = np.uint8
@@ -163,7 +157,7 @@ def reset(self, config=None):
163157

164158
self.reset_params = self._env.reset_parameters
165159
info = self._env.reset(config=reset_params,
166-
train_mode=not self.realtime_mode)[self.brain_name]
160+
train_mode=not self.realtime_mode)[self.behavior_name]
167161
n_agents = len(info.agents)
168162
self._check_agents(n_agents)
169163
self.game_over = False
@@ -191,7 +185,7 @@ def step(self, action):
191185
# Translate action into list
192186
action = self._flattener.lookup_action(action)
193187

194-
info = self._env.step(action)[self.brain_name]
188+
info = self._env.step(action)[self.behavior_name]
195189
n_agents = len(info.agents)
196190
self._check_agents(n_agents)
197191
self._current_state = info
@@ -246,9 +240,6 @@ def close(self):
246240
"""
247241
self._env.close()
248242

249-
def get_action_meanings(self):
250-
return self.action_meanings
251-
252243
def seed(self, seed=None):
253244
"""Sets a fixed seed for this env's random number generator(s).
254245
The valid range for seeds is [0, 99999). By default a random seed

0 commit comments

Comments
 (0)