Skip to content

Commit a36a9e3

Browse files
committed
fixed bugs in action space, interactive works on simple.py
1 parent 0ff36f2 commit a36a9e3

File tree

3 files changed

+20
-21
lines changed

3 files changed

+20
-21
lines changed

bin/interactive.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
# create world
1919
world = scenario.make_world()
2020
# create multiagent environment
21-
env = MultiAgentEnv(world, scenario.reset_world, scenario.reward, scenario.observation, shared_viewer = False)
21+
env = MultiAgentEnv(world, scenario.reset_world, scenario.reward, scenario.observation, info_callback=None, shared_viewer = False)
2222
# render call to create viewer window (necessary only for interactive policies)
2323
env.render()
2424
# create interactive policies for each agent

multiagent/environment.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@
44
import numpy as np
55
import tensorflow as tf
66

7-
# TODO: make description of class?
8-
97
# environment for all agents in the multiagent world
10-
# TODO: currently code assumes that no agents will be created/destroyed at runtime!
8+
# currently code assumes that no agents will be created/destroyed at runtime!
119
class MultiAgentEnv(gym.Env):
1210
metadata = {
1311
'render.modes' : ['human', 'rgb_array']
@@ -94,6 +92,7 @@ def _step(self, action_n):
9492
obs_n.append(self._get_obs(agent))
9593
reward_n.append(self._get_reward(agent))
9694
done_n.append(False)
95+
9796
info_n['n'].append(self._get_info(agent))
9897

9998
# all agents get total reward in cooperative case
@@ -146,13 +145,12 @@ def _set_action(self, action, agent, action_space, time=None):
146145
act.append(action[index:(index+s)])
147146
index += s
148147
action = act
149-
#else:
150-
# action = [action] # TODO: why is this necessary??
148+
else:
149+
action = [action] # TODO: why is this necessary??
151150

152151
if agent.movable:
153152
# physical action
154153
if self.discrete_action_input:
155-
print(action)
156154
agent.action.u = np.zeros(self.world.dim_p)
157155
# process discrete action
158156
if action[0] == 1: agent.action.u[0] = -1.0
@@ -190,7 +188,6 @@ def _reset_render(self):
190188

191189
# render environment
192190
def _render(self, mode='human', close=True):
193-
# TODO: render text in viewer instead
194191
if mode == 'human':
195192
alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
196193
message = ''
@@ -231,10 +228,7 @@ def _render(self, mode='human', close=True):
231228
for entity in self.world.entities:
232229
geom = rendering.make_circle(entity.size)
233230
xform = rendering.Transform()
234-
if 'forest' in entity.name:
235-
print(entity.color)
236-
geom.set_color(*entity.color, alpha=0.5)
237-
elif 'agent' in entity.name:
231+
if 'agent' in entity.name:
238232
geom.set_color(*entity.color, alpha=0.5)
239233
else:
240234
geom.set_color(*entity.color)
@@ -252,7 +246,7 @@ def _render(self, mode='human', close=True):
252246
for i in range(len(self.viewers)):
253247
from multiagent import rendering
254248
# update bounds to center around agent
255-
cam_range = 1.2
249+
cam_range = 1
256250
if self.shared_viewer:
257251
pos = np.zeros(self.world.dim_p)
258252
else:

multiagent/policy.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,20 @@ def action(self, obs):
2929
if self.move[2]: u = 4
3030
if self.move[3]: u = 3
3131
else:
32-
u = np.array([0.0,0.0,0.0,0.0])
33-
if self.move[0]: u[0] += 1.0
34-
if self.move[1]: u[1] += 1.0
35-
if self.move[2]: u[2] += 1.0
32+
u = np.array([0.0,0.0,0.0,0.0,0.0]) # 5-d because of no-move action
33+
if self.move[0]: u[1] += 1.0
34+
if self.move[1]: u[2] += 1.0
3635
if self.move[3]: u[3] += 1.0
37-
c = 0
38-
for i in range(len(self.comm)):
39-
if self.comm[i]: c = i+1
40-
return [u, c]
36+
if self.move[2]: u[4] += 1.0
37+
if True not in self.move:
38+
u[0] += 1.0
39+
if self.env.world.dim_c == 0:
40+
return u
41+
else:
42+
c = 0
43+
for i in range(len(self.comm)):
44+
if self.comm[i]: c = i+1
45+
return [u, c]
4146

4247
# keyborad event callbacks
4348
def key_press(self, k, mod):

0 commit comments

Comments
 (0)