Skip to content

Commit a3ccb43

Browse files
Merge branch 'master' of github.com:shariqiqbal2810/multiagent-particle-envs
2 parents fbf8ead + 81311a1 commit a3ccb43

File tree

2 files changed

+49
-7
lines changed

2 files changed

+49
-7
lines changed

multiagent/core.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import seaborn as sns
23

34
# physical/external base state of all entites
45
class EntityState(object):
@@ -164,6 +165,20 @@ def calculate_distances(self):
164165
self.cached_dist_mag = np.linalg.norm(self.cached_dist_vect, axis=2)
165166
self.cached_collisions = (self.cached_dist_mag <= self.min_dists)
166167

168+
def assign_agent_colors(self):
169+
n_dummies = 0
170+
if hasattr(self.agents[0], 'dummy'):
171+
n_dummies = len([a for a in self.agents if a.dummy])
172+
n_adversaries = 0
173+
if hasattr(self.agents[0], 'adversary'):
174+
n_adversaries = len([a for a in self.agents if a.adversary])
175+
n_good_agents = len(self.agents) - n_adversaries - n_dummies
176+
dummy_colors = [(0, 0, 0)] * n_dummies
177+
adv_colors = sns.color_palette("OrRd_d", n_adversaries)
178+
good_colors = sns.color_palette("GnBu_d", n_good_agents)
179+
colors = dummy_colors + adv_colors + good_colors
180+
for color, agent in zip(colors, self.agents):
181+
agent.color = color
167182

168183
# update state of the world
169184
def step(self):
@@ -192,7 +207,7 @@ def apply_action_force(self, p_force):
192207
for i,agent in enumerate(self.agents):
193208
if agent.movable:
194209
noise = np.random.randn(*agent.action.u.shape) * agent.u_noise if agent.u_noise else 0.0
195-
p_force[i] = agent.action.u + noise
210+
p_force[i] = (agent.mass * agent.accel if agent.accel is not None else agent.mass) * agent.action.u + noise
196211
return p_force
197212

198213
# gather physical forces acting on entities

multiagent/environment.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ class MultiAgentEnv(gym.Env):
1212

1313
def __init__(self, world, reset_callback=None, reward_callback=None,
1414
observation_callback=None, info_callback=None,
15-
done_callback=None, shared_viewer=True, discrete_action=False):
15+
done_callback=None, post_step_callback=None,
16+
shared_viewer=True, discrete_action=False):
1617

1718
self.world = world
1819
self.agents = self.world.policy_agents
@@ -24,6 +25,7 @@ def __init__(self, world, reset_callback=None, reward_callback=None,
2425
self.observation_callback = observation_callback
2526
self.info_callback = info_callback
2627
self.done_callback = done_callback
28+
self.post_step_callback = post_step_callback
2729
# environment parameters
2830
self.discrete_action_space = discrete_action
2931
# if true, action is a number 0...N, otherwise action is a one-hot N-dimensional vector
@@ -47,10 +49,7 @@ def __init__(self, world, reset_callback=None, reward_callback=None,
4749
if agent.movable:
4850
total_action_space.append(u_action_space)
4951
# communication action space
50-
if self.discrete_action_space:
51-
c_action_space = spaces.Discrete(world.dim_c)
52-
else:
53-
c_action_space = spaces.Box(low=0.0, high=1.0, shape=(world.dim_c,))
52+
c_action_space = spaces.Discrete(world.dim_c)
5453
if not agent.silent:
5554
total_action_space.append(c_action_space)
5655
# total action space
@@ -105,7 +104,8 @@ def _step(self, action_n):
105104
reward = np.sum(reward_n)
106105
if self.shared_reward:
107106
reward_n = [reward] * self.n
108-
107+
if self.post_step_callback is not None:
108+
self.post_step_callback(self.world)
109109
return obs_n, reward_n, done_n, info_n
110110

111111
def _reset(self):
@@ -240,16 +240,32 @@ def _render(self, mode='human', close=True):
240240
from multiagent import rendering
241241
self.render_geoms = []
242242
self.render_geoms_xform = []
243+
self.comm_geoms = []
243244
for entity in self.world.entities:
244245
geom = rendering.make_circle(entity.size)
245246
xform = rendering.Transform()
247+
entity_comm_geoms = []
246248
if 'agent' in entity.name:
247249
geom.set_color(*entity.color, alpha=0.5)
250+
if not entity.silent:
251+
dim_c = self.world.dim_c
252+
# make circles to represent communication
253+
for ci in range(dim_c):
254+
comm = rendering.make_circle(entity.size / dim_c)
255+
comm.set_color(1, 1, 1)
256+
comm.add_attr(xform)
257+
offset = rendering.Transform()
258+
comm_size = (entity.size / dim_c)
259+
offset.set_translation(ci * comm_size * 2 -
260+
entity.size + comm_size, 0)
261+
comm.add_attr(offset)
262+
entity_comm_geoms.append(comm)
248263
else:
249264
geom.set_color(*entity.color)
250265
geom.add_attr(xform)
251266
self.render_geoms.append(geom)
252267
self.render_geoms_xform.append(xform)
268+
self.comm_geoms.append(entity_comm_geoms)
253269
for wall in self.world.walls:
254270
corners = ((wall.axis_pos - 0.5 * wall.width, wall.endpoints[0]),
255271
(wall.axis_pos - 0.5 * wall.width, wall.endpoints[1]),
@@ -269,6 +285,9 @@ def _render(self, mode='human', close=True):
269285
viewer.geoms = []
270286
for geom in self.render_geoms:
271287
viewer.add_geom(geom)
288+
for entity_comm_geoms in self.comm_geoms:
289+
for geom in entity_comm_geoms:
290+
viewer.add_geom(geom)
272291

273292
results = []
274293
for i in range(len(self.viewers)):
@@ -283,6 +302,14 @@ def _render(self, mode='human', close=True):
283302
# update geometry positions
284303
for e, entity in enumerate(self.world.entities):
285304
self.render_geoms_xform[e].set_translation(*entity.state.p_pos)
305+
if 'agent' in entity.name:
306+
self.render_geoms[e].set_color(*entity.color, alpha=0.5)
307+
if not entity.silent:
308+
for ci in range(self.world.dim_c):
309+
color = 1 - entity.state.c[ci]
310+
self.comm_geoms[e][ci].set_color(color, color, color)
311+
else:
312+
self.render_geoms[e].set_color(*entity.color)
286313
# render to display or array
287314
results.append(self.viewers[i].render(return_rgb_array = mode=='rgb_array'))
288315

0 commit comments

Comments
 (0)