@@ -12,7 +12,7 @@ class MultiAgentEnv(gym.Env):
1212
1313 def __init__ (self , world , reset_callback = None , reward_callback = None ,
1414 observation_callback = None , info_callback = None ,
15- shared_viewer = True ):
15+ done_callback = None , shared_viewer = True ):
1616
1717 self .world = world
1818 self .agents = self .world .policy_agents
@@ -23,6 +23,7 @@ def __init__(self, world, reset_callback=None, reward_callback=None,
2323 self .reward_callback = reward_callback
2424 self .observation_callback = observation_callback
2525 self .info_callback = info_callback
26+ self .done_callback = done_callback
2627 # environment parameters
2728 self .discrete_action_space = True
2829 # if true, action is a number 0...N, otherwise action is a one-hot N-dimensional vector
@@ -90,7 +91,7 @@ def _step(self, action_n):
9091 for agent in self .agents :
9192 obs_n .append (self ._get_obs (agent ))
9293 reward_n .append (self ._get_reward (agent ))
93- done_n .append (False )
94+ done_n .append (self . _get_done ( agent ) )
9495
9596 info_n ['n' ].append (self ._get_info (agent ))
9697
@@ -125,6 +126,12 @@ def _get_obs(self, agent):
125126 return np .zeros (0 )
126127 return self .observation_callback (agent , self .world )
127128
129+ # get dones for a particular agent
130+ def _get_done (self , agent ):
131+ if self .done_callback is None :
132+ return False
133+ return self .observation_callback (agent , self .world )
134+
128135 # get reward for a particular agent
129136 def _get_reward (self , agent ):
130137 if self .reward_callback is None :
0 commit comments