@@ -12,7 +12,8 @@ class MultiAgentEnv(gym.Env):
1212
1313 def __init__ (self , world , reset_callback = None , reward_callback = None ,
1414 observation_callback = None , info_callback = None ,
15- done_callback = None , shared_viewer = True , discrete_action = False ):
15+ done_callback = None , post_step_callback = None ,
16+ shared_viewer = True , discrete_action = False ):
1617
1718 self .world = world
1819 self .agents = self .world .policy_agents
@@ -24,6 +25,7 @@ def __init__(self, world, reset_callback=None, reward_callback=None,
2425 self .observation_callback = observation_callback
2526 self .info_callback = info_callback
2627 self .done_callback = done_callback
28+ self .post_step_callback = post_step_callback
2729 # environment parameters
2830 self .discrete_action_space = discrete_action
2931 # if true, action is a number 0...N, otherwise action is a one-hot N-dimensional vector
@@ -47,10 +49,7 @@ def __init__(self, world, reset_callback=None, reward_callback=None,
4749 if agent .movable :
4850 total_action_space .append (u_action_space )
4951 # communication action space
50- if self .discrete_action_space :
51- c_action_space = spaces .Discrete (world .dim_c )
52- else :
53- c_action_space = spaces .Box (low = 0.0 , high = 1.0 , shape = (world .dim_c ,))
52+ c_action_space = spaces .Discrete (world .dim_c )
5453 if not agent .silent :
5554 total_action_space .append (c_action_space )
5655 # total action space
@@ -105,7 +104,8 @@ def _step(self, action_n):
105104 reward = np .sum (reward_n )
106105 if self .shared_reward :
107106 reward_n = [reward ] * self .n
108-
107+ if self .post_step_callback is not None :
108+ self .post_step_callback (self .world )
109109 return obs_n , reward_n , done_n , info_n
110110
111111 def _reset (self ):
@@ -240,16 +240,32 @@ def _render(self, mode='human', close=True):
240240 from multiagent import rendering
241241 self .render_geoms = []
242242 self .render_geoms_xform = []
243+ self .comm_geoms = []
243244 for entity in self .world .entities :
244245 geom = rendering .make_circle (entity .size )
245246 xform = rendering .Transform ()
247+ entity_comm_geoms = []
246248 if 'agent' in entity .name :
247249 geom .set_color (* entity .color , alpha = 0.5 )
250+ if not entity .silent :
251+ dim_c = self .world .dim_c
252+ # make circles to represent communication
253+ for ci in range (dim_c ):
254+ comm = rendering .make_circle (entity .size / dim_c )
255+ comm .set_color (1 , 1 , 1 )
256+ comm .add_attr (xform )
257+ offset = rendering .Transform ()
258+ comm_size = (entity .size / dim_c )
259+ offset .set_translation (ci * comm_size * 2 -
260+ entity .size + comm_size , 0 )
261+ comm .add_attr (offset )
262+ entity_comm_geoms .append (comm )
248263 else :
249264 geom .set_color (* entity .color )
250265 geom .add_attr (xform )
251266 self .render_geoms .append (geom )
252267 self .render_geoms_xform .append (xform )
268+ self .comm_geoms .append (entity_comm_geoms )
253269 for wall in self .world .walls :
254270 corners = ((wall .axis_pos - 0.5 * wall .width , wall .endpoints [0 ]),
255271 (wall .axis_pos - 0.5 * wall .width , wall .endpoints [1 ]),
@@ -269,6 +285,9 @@ def _render(self, mode='human', close=True):
269285 viewer .geoms = []
270286 for geom in self .render_geoms :
271287 viewer .add_geom (geom )
288+ for entity_comm_geoms in self .comm_geoms :
289+ for geom in entity_comm_geoms :
290+ viewer .add_geom (geom )
272291
273292 results = []
274293 for i in range (len (self .viewers )):
@@ -283,6 +302,14 @@ def _render(self, mode='human', close=True):
283302 # update geometry positions
284303 for e , entity in enumerate (self .world .entities ):
285304 self .render_geoms_xform [e ].set_translation (* entity .state .p_pos )
305+ if 'agent' in entity .name :
306+ self .render_geoms [e ].set_color (* entity .color , alpha = 0.5 )
307+ if not entity .silent :
308+ for ci in range (self .world .dim_c ):
309+ color = 1 - entity .state .c [ci ]
310+ self .comm_geoms [e ][ci ].set_color (color , color , color )
311+ else :
312+ self .render_geoms [e ].set_color (* entity .color )
286313 # render to display or array
287314 results .append (self .viewers [i ].render (return_rgb_array = mode == 'rgb_array' ))
288315
0 commit comments