Skip to content

Commit aa1803e

Browse files
authored
Add support for loading PufferLib policies (Metta-AI#233)
# Add support for loading PufferLib policies This PR adds support for loading PufferLib policies from local files using the `puffer://` URI scheme. It includes: - New `PolicyState` tensorclass to manage policy state - Support for loading PufferLib policies in the policy store - Implementation of PufferLib policy and recurrent wrapper classes - Updates to the simulator to work with PufferLib policies - Refactoring of the play tool to separate policy loading from simulation - Simplifying eval configs to be shorter by removing redundant fields - Removes e3b for now, we have a plan to add it in later, cleaner The changes allow for seamless integration of policies trained with PufferLib into the Metta environment.
1 parent 2b70fb5 commit aa1803e

21 files changed

+294
-194
lines changed

configs/replay_job.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ defaults:
22
- common
33
- wandb: metta_research
44
- sim: simple
5-
- /sim/simple@replay_job.sim
65
- _self_
76

87
cmd: play
@@ -13,4 +12,5 @@ torch_deterministic: true
1312
eval_db_uri: null
1413

1514
replay_job:
15+
sim: ${sim}
1616
policy_uri: ${policy_uri}

configs/sim/all.yaml

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,93 +2,61 @@ defaults:
22
- sim
33
- _self_
44

5-
num_envs: 1
6-
num_episodes: 1
7-
run_dir: ${run_dir}
8-
95
simulations:
106
navigation/emptyspace_withinsight:
117
env: env/mettagrid/navigation/evals/emptyspace_withinsight
12-
policy_agents_pct: 1.0
138
navigation/emptyspace_outofsight:
149
env: env/mettagrid/navigation/evals/emptyspace_outofsight
15-
policy_agents_pct: 1.0
1610
navigation/emptyspace_sparse:
1711
env: env/mettagrid/navigation/evals/emptyspace_sparse
18-
policy_agents_pct: 1.0
1912
navigation/walls_withinsight:
2013
env: env/mettagrid/navigation/evals/walls_withinsight
21-
policy_agents_pct: 1.0
2214
navigation/walls_outofsight:
2315
env: env/mettagrid/navigation/evals/walls_outofsight
24-
policy_agents_pct: 1.0
2516
navigation/walls_sparse:
2617
env: env/mettagrid/navigation/evals/walls_sparse
27-
policy_agents_pct: 1.0
2818
navigation/cylinder:
2919
env: env/mettagrid/navigation/evals/cylinder
30-
policy_agents_pct: 1.0
3120
navigation/obstacles0:
3221
env: env/mettagrid/navigation/evals/obstacles0
33-
policy_agents_pct: 1.0
3422
navigation/obstacles1:
3523
env: env/mettagrid/navigation/evals/obstacles1
36-
policy_agents_pct: 1.0
3724
navigation/obstacles2:
3825
env: env/mettagrid/navigation/evals/obstacles2
39-
policy_agents_pct: 1.0
4026
navigation/obstacles3:
4127
env: env/mettagrid/navigation/evals/obstacles3
42-
policy_agents_pct: 1.0
4328
navigation/corridors:
4429
env: env/mettagrid/navigation/evals/corridors
45-
policy_agents_pct: 1.0
4630
navigation/labyrinth:
4731
env: env/mettagrid/navigation/evals/labyrinth
48-
policy_agents_pct: 1.0
4932
navigation/radialmaze:
5033
env: env/mettagrid/navigation/evals/radialmaze
51-
policy_agents_pct: 1.0
5234
object_use/altar_use_free:
5335
env: env/mettagrid/object_use/evals/altar_use_free
54-
policy_agents_pct: 1.0
5536
object_use/altar_use:
5637
env: env/mettagrid/object_use/evals/altar_use
57-
policy_agents_pct: 1.0
5838
object_use/armory_use_free:
5939
env: env/mettagrid/object_use/evals/armory_use_free
60-
policy_agents_pct: 1.0
6140
object_use/armory_use:
6241
env: env/mettagrid/object_use/evals/armory_use
63-
policy_agents_pct: 1.0
6442
object_use/generator_use_free:
6543
env: env/mettagrid/object_use/evals/generator_use_free
66-
policy_agents_pct: 1.0
67-
object_use/generator_uses:
44+
object_use/generator_use:
6845
env: env/mettagrid/object_use/evals/generator_use
69-
policy_agents_pct: 1.0
7046
object_use/lasery_use_free:
7147
env: env/mettagrid/object_use/evals/lasery_use_free
72-
policy_agents_pct: 1.0
7348
object_use/lasery_use:
7449
env: env/mettagrid/object_use/evals/lasery_use
75-
policy_agents_pct: 1.0
7650
object_use/mine_use:
7751
env: env/mettagrid/object_use/evals/mine_use
78-
policy_agents_pct: 1.0
7952
object_use/shoot_out:
8053
env: env/mettagrid/object_use/evals/shoot_out
81-
policy_agents_pct: 1.0
8254
object_use/swap_in:
8355
env: env/mettagrid/object_use/evals/swap_in
84-
policy_agents_pct: 1.0
8556
object_use/swap_out:
8657
env: env/mettagrid/object_use/evals/swap_out
87-
policy_agents_pct: 1.0
8858
object_use/temple_use_free:
8959
env: env/mettagrid/object_use/evals/temple_use_free
90-
policy_agents_pct: 1.0
9160
simple_npc:
9261
env: env/mettagrid/simple
93-
policy_agents_pct: 0.5
9462
npc_policy_uri: wandb://run/b.daveey.t.8.rdr9.3

configs/sim/memory.yaml

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,10 @@ defaults:
22
- sim
33
- _self_
44

5-
num_envs: 1
6-
num_episodes: 1
7-
run_dir: ${run_dir}
8-
95
simulations:
106
navigation/memory_easy:
117
env: env/mettagrid/memory/evals/memory_easy
12-
policy_agents_pct: 1.0
138
navigation/memory_medium:
149
env: env/mettagrid/memory/evals/memory_medium
15-
policy_agents_pct: 1.0
1610
navigation/memory_hard:
1711
env: env/mettagrid/memory/evals/memory_hard
18-
policy_agents_pct: 1.0

configs/sim/navigation.yaml

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,53 +2,35 @@ defaults:
22
- sim
33
- _self_
44

5-
num_envs: 1
6-
num_episodes: 1
7-
run_dir: ${run_dir}
8-
95
simulations:
106
navigation/emptyspace_withinsight:
117
env: env/mettagrid/navigation/evals/emptyspace_withinsight
12-
policy_agents_pct: 1.0
138
navigation/emptyspace_outofsight:
149
env: env/mettagrid/navigation/evals/emptyspace_outofsight
15-
policy_agents_pct: 1.0
1610
navigation/emptyspace_sparse:
1711
env: env/mettagrid/navigation/evals/emptyspace_sparse
18-
policy_agents_pct: 1.0
1912
navigation/walls_withinsight:
2013
env: env/mettagrid/navigation/evals/walls_withinsight
21-
policy_agents_pct: 1.0
2214
navigation/walls_outofsight:
2315
env: env/mettagrid/navigation/evals/walls_outofsight
24-
policy_agents_pct: 1.0
2516
navigation/walls_sparse:
2617
env: env/mettagrid/navigation/evals/walls_sparse
27-
policy_agents_pct: 1.0
2818
navigation/cylinder:
2919
env: env/mettagrid/navigation/evals/cylinder
30-
policy_agents_pct: 1.0
3120
navigation/obstacles0:
3221
env: env/mettagrid/navigation/evals/obstacles0
33-
policy_agents_pct: 1.0
3422
navigation/obstacles1:
3523
env: env/mettagrid/navigation/evals/obstacles1
36-
policy_agents_pct: 1.0
3724
navigation/obstacles2:
3825
env: env/mettagrid/navigation/evals/obstacles2
39-
policy_agents_pct: 1.0
4026
navigation/obstacles3:
4127
env: env/mettagrid/navigation/evals/obstacles3
42-
policy_agents_pct: 1.0
4328
navigation/corridors:
4429
env: env/mettagrid/navigation/evals/corridors
45-
policy_agents_pct: 1.0
4630
navigation/labyrinth:
4731
env: env/mettagrid/navigation/evals/labyrinth
48-
policy_agents_pct: 1.0
4932
navigation/radialmaze:
5033
env: env/mettagrid/navigation/evals/radialmaze
51-
policy_agents_pct: 1.0
5234
navigation/cylinder_easy:
5335
env: env/mettagrid/navigation/evals/cylinder_easy
5436
navigation/honeypot:

configs/sim/object_use.yaml

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,47 +2,30 @@ defaults:
22
- sim
33
- _self_
44

5-
num_envs: 10
6-
num_episodes: 10
7-
run_dir: ${run_dir}
8-
95
simulations:
106
object_use/altar_use_free:
117
env: env/mettagrid/object_use/evals/altar_use_free
12-
policy_agents_pct: 1.0
138
object_use/altar_use:
149
env: env/mettagrid/object_use/evals/altar_use
15-
policy_agents_pct: 1.0
1610
object_use/armory_use_free:
1711
env: env/mettagrid/object_use/evals/armory_use_free
18-
policy_agents_pct: 1.0
1912
object_use/armory_use:
2013
env: env/mettagrid/object_use/evals/armory_use
21-
policy_agents_pct: 1.0
2214
object_use/generator_use_free:
2315
env: env/mettagrid/object_use/evals/generator_use_free
24-
policy_agents_pct: 1.0
2516
object_use/generator_use:
2617
env: env/mettagrid/object_use/evals/generator_use
27-
policy_agents_pct: 1.0
2818
object_use/lasery_use_free:
2919
env: env/mettagrid/object_use/evals/lasery_use_free
30-
policy_agents_pct: 1.0
3120
object_use/lasery_use:
3221
env: env/mettagrid/object_use/evals/lasery_use
33-
policy_agents_pct: 1.0
3422
object_use/mine_use:
3523
env: env/mettagrid/object_use/evals/mine_use
36-
policy_agents_pct: 1.0
3724
object_use/shoot_out:
3825
env: env/mettagrid/object_use/evals/shoot_out
39-
policy_agents_pct: 1.0
4026
object_use/swap_in:
4127
env: env/mettagrid/object_use/evals/swap_in
42-
policy_agents_pct: 1.0
4328
object_use/swap_out:
4429
env: env/mettagrid/object_use/evals/swap_out
45-
policy_agents_pct: 1.0
4630
object_use/temple_use_free:
4731
env: env/mettagrid/object_use/evals/temple_use_free
48-
policy_agents_pct: 1.0

configs/sim/sim.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,6 @@ device: ${device}
22
vectorization: ${vectorization}
33
eval_db_uri: ${eval_db_uri}
44
env: env/mettagrid/simple
5+
run_dir: ${run_dir}
6+
num_envs: 1
7+
num_episodes: 1

configs/sim/simple.yaml

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@ defaults:
22
- sim
33
- _self_
44

5-
# This is a single simulation, not a suite.
6-
env: env/mettagrid/simple
7-
policy_agents_pct: 1.0
8-
num_envs: 1
9-
num_episodes: 1
5+
simulations:
6+
simple:
7+
env: env/mettagrid/simple

configs/sim/smoke_test.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,7 @@ defaults:
44
- sim
55
- _self_
66

7-
num_envs: 1
8-
num_episodes: 1
9-
run_dir: ${run_dir}
107

118
simulations:
129
emptyspace_withinsight:
1310
env: env/mettagrid/navigation/evals/emptyspace_withinsight
14-
policy_agents_pct: 1.0

configs/user/daveey.yaml

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# @package __global__
22

33
defaults:
4-
- override /agent: simple
5-
- override /eval: simple_solo
6-
- override /analyzer: eval_analyzer
4+
- override /sim: simple
5+
# - override /eval: simple_solo
6+
# - override /analyzer: eval_analyzer
77
- _self_
88

99
trainer:
@@ -23,35 +23,34 @@ trainer:
2323
# policy_uri: wandb://run/b.daveey.t.1.bl
2424
# policy_uri: wandb://run/b.daveey.t.16.dr0
2525

26-
# policy_uri: wandb://run/b.daveey.t.4.lra.muon
27-
policy_uri: wandb://run/b.daveey.t.1.lra.dr.muon
28-
# policy_uri: /tmp/puffer_metta_serial.pt
26+
# policy_uri: wandb://run/b.daveey.dr9.muon.latest
27+
# policy_uri: wandb://run/b.daveey.t.1.lra.dr.muon
28+
policy_uri: puffer:///tmp/puffer_metta.pt
2929

3030
npc_policy_uri: ${policy_uri}
3131
# npc_policy_uri: wandb://run/b.daveey.t.16.dr0
3232
# policy_uri: ${trained_policy_uri}
3333
# npc_policy_uri: ${trained_policy_uri}
3434
# eval_db_uri: wandb://artifacts/daveey_eval_testing
3535

36+
eval_db_uri: ${run_dir}/eval_stats
37+
3638
analyzer:
3739
policy_uri: ${..policy_uri}
38-
eval_stats_uri: ${run_dir}/eval_stats
40+
view_type: latest
3941
analysis:
4042
metrics:
4143
- metric: episode_reward
4244
- metric: "heart.get"
4345

44-
eval:
46+
47+
sim:
4548
env: /env/mettagrid/puffer
46-
eval_db_uri: ${run_dir}/eval_stats
47-
num_envs: 10
48-
num_episodes: 10
49+
num_envs: 1
50+
num_episodes: 1
4951
max_time_s: 600
5052
# policy_agents_pct: 1
5153

52-
policy_uri: ${..policy_uri}
53-
npc_policy_uri: ${..npc_policy_uri}
54-
# eval_db_uri: ${..eval_db_uri} #file://daphne/sweep_stats
5554
# env: /env/mettagrid/reward_dr
5655
# env_overrides:
5756
# # sampling: 0.7
@@ -66,7 +65,7 @@ eval:
6665
wandb:
6766
checkpoint_interval: 1
6867

69-
run_id: 16
68+
run_id: 17
7069
run: ${oc.env:USER}.local.${run_id}
7170
trained_policy_uri: ${run_dir}/checkpoints
7271

metta.code-workspace

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,40 @@
197197
}
198198
},
199199
"terminal.integrated.cwd": "${workspaceFolder}",
200-
"terminal.integrated.splitCwd": "workspaceRoot"
200+
"terminal.integrated.splitCwd": "workspaceRoot",
201+
"cSpell.words": [
202+
"bptt",
203+
"clipfrac",
204+
"coef",
205+
"cooldown",
206+
"dones",
207+
"heavyball",
208+
"kickstarter",
209+
"Kruskal",
210+
"lasery",
211+
"lexsort",
212+
"logratio",
213+
"metta",
214+
"mettagrid",
215+
"minibatch",
216+
"minibatches",
217+
"newlogprob",
218+
"nvec",
219+
"pufferlib",
220+
"pytest",
221+
"PYTHONPATH",
222+
"raylib",
223+
"relu",
224+
"tensorclass",
225+
"tensordict",
226+
"timestep",
227+
"timesteps",
228+
"truncateds",
229+
"unclipped",
230+
"vecenv",
231+
"venv",
232+
"vloss",
233+
"wandb"
234+
]
201235
}
202236
}

0 commit comments

Comments
 (0)