Skip to content

Commit e1ea0c8

Browse files
committed
added support for done_callback
1 parent 5d9336b commit e1ea0c8

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

multiagent/environment.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class MultiAgentEnv(gym.Env):
1212

1313
def __init__(self, world, reset_callback=None, reward_callback=None,
1414
observation_callback=None, info_callback=None,
15-
shared_viewer=True):
15+
done_callback=None, shared_viewer=True):
1616

1717
self.world = world
1818
self.agents = self.world.policy_agents
@@ -23,6 +23,7 @@ def __init__(self, world, reset_callback=None, reward_callback=None,
2323
self.reward_callback = reward_callback
2424
self.observation_callback = observation_callback
2525
self.info_callback = info_callback
26+
self.done_callback = done_callback
2627
# environment parameters
2728
self.discrete_action_space = True
2829
# if true, action is a number 0...N, otherwise action is a one-hot N-dimensional vector
@@ -90,7 +91,7 @@ def _step(self, action_n):
9091
for agent in self.agents:
9192
obs_n.append(self._get_obs(agent))
9293
reward_n.append(self._get_reward(agent))
93-
done_n.append(False)
94+
done_n.append(self._get_done(agent))
9495

9596
info_n['n'].append(self._get_info(agent))
9697

@@ -125,6 +126,12 @@ def _get_obs(self, agent):
125126
return np.zeros(0)
126127
return self.observation_callback(agent, self.world)
127128

129+
# get dones for a particular agent
130+
def _get_done(self, agent):
131+
if self.done_callback is None:
132+
return False
133+
return self.observation_callback(agent, self.world)
134+
128135
# get reward for a particular agent
129136
def _get_reward(self, agent):
130137
if self.reward_callback is None:

0 commit comments

Comments
 (0)