Skip to content

Commit 5905680

Browse files
committed
ignoring precommit, grabbing baseline/critic mems from buffer in trainer
1 parent 4b7db51 commit 5905680

File tree

4 files changed

+39
-8
lines changed

4 files changed

+39
-8
lines changed

ml-agents/mlagents/trainers/buffer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class BufferKey(enum.Enum):
3535
MASKS = "masks"
3636
MEMORY = "memory"
3737
CRITIC_MEMORY = "critic_memory"
38+
BASELINE_MEMORY = "coma_baseline_memory"
3839
PREV_ACTION = "prev_action"
3940

4041
ADVANTAGES = "advantages"

ml-agents/mlagents/trainers/coma/optimizer_torch.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,23 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
264264
]
265265
if len(memories) > 0:
266266
memories = torch.stack(memories).unsqueeze(0)
267+
value_memories = [
268+
ModelUtils.list_to_tensor(batch[BufferKey.CRITIC_MEMORY][i])
269+
for i in range(
270+
0, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length
271+
)
272+
]
273+
274+
baseline_memories = [
275+
ModelUtils.list_to_tensor(batch[BufferKey.BASELINE_MEMORY][i])
276+
for i in range(
277+
0, len(batch[BufferKey.BASELINE_MEMORY]), self.policy.sequence_length
278+
)
279+
]
280+
281+
if len(value_memories) > 0:
282+
value_memories = torch.stack(value_memories).unsqueeze(0)
283+
baseline_memories = torch.stack(baseline_memories).unsqueeze(0)
267284

268285
log_probs, entropy = self.policy.evaluate_actions(
269286
current_obs,
@@ -274,13 +291,15 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
274291
)
275292
all_obs = [current_obs] + group_obs
276293
values, _ = self.critic.critic_pass(
277-
all_obs, memories=memories, sequence_length=self.policy.sequence_length
294+
all_obs,
295+
memories=value_memories,
296+
sequence_length=self.policy.sequence_length,
278297
)
279298
baselines, _ = self.critic.baseline(
280299
[current_obs],
281300
group_obs,
282301
group_actions,
283-
memories=memories,
302+
memories=baseline_memories,
284303
sequence_length=self.policy.sequence_length,
285304
)
286305
old_log_probs = ActionLogProbs.from_buffer(batch).flatten()
@@ -380,7 +399,7 @@ def _evaluate_by_sequence_team(
380399

381400
for team_obs, team_action in zip(obs, actions):
382401
seq_obs = []
383-
for (_obs,) in team_obs:
402+
for _obs in team_obs:
384403
first_seq_obs = _obs[0:first_seq_len]
385404
seq_obs.append(first_seq_obs)
386405
team_seq_obs.append(seq_obs)
@@ -534,7 +553,12 @@ def get_trajectory_and_baseline_value_estimates(
534553
_init_value_mem = self.value_memory_dict[agent_id]
535554
_init_baseline_mem = self.baseline_memory_dict[agent_id]
536555
else:
537-
memory = (
556+
_init_value_mem = (
557+
torch.zeros((1, 1, self.critic.memory_size))
558+
if self.policy.use_recurrent
559+
else None
560+
)
561+
_init_baseline_mem = (
538562
torch.zeros((1, 1, self.critic.memory_size))
539563
if self.policy.use_recurrent
540564
else None
@@ -544,19 +568,19 @@ def get_trajectory_and_baseline_value_estimates(
544568
all_next_value_mem: Optional[AgentBufferField] = None
545569
all_next_baseline_mem: Optional[AgentBufferField] = None
546570
if self.policy.use_recurrent:
547-
value_estimates, baseline_estimates, all_next_value_mem, all_next_baseline_mem, next_value_mem, next_baseline_mem = self.critic._evaluate_by_sequence_team(
571+
value_estimates, baseline_estimates, all_next_value_mem, all_next_baseline_mem, next_value_mem, next_baseline_mem = self._evaluate_by_sequence_team(
548572
current_obs, team_obs, team_actions, _init_value_mem, _init_baseline_mem
549573
)
550574
else:
551575
value_estimates, value_mem = self.critic.critic_pass(
552-
all_obs, memory, sequence_length=batch.num_experiences
576+
all_obs, _init_value_mem, sequence_length=batch.num_experiences
553577
)
554578

555579
baseline_estimates, baseline_mem = self.critic.baseline(
556580
[current_obs],
557581
team_obs,
558582
team_actions,
559-
memory,
583+
_init_baseline_mem,
560584
sequence_length=batch.num_experiences,
561585
)
562586
# Store the memory for the next trajectory

ml-agents/mlagents/trainers/coma/trainer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ def _process_trajectory(self, trajectory: Trajectory) -> None:
8080
value_estimates,
8181
baseline_estimates,
8282
value_next,
83+
value_memories,
84+
baseline_memories,
8385
) = self.optimizer.get_trajectory_and_baseline_value_estimates(
8486
agent_buffer_trajectory,
8587
trajectory.next_obs,
@@ -89,6 +91,10 @@ def _process_trajectory(self, trajectory: Trajectory) -> None:
8991
and not trajectory.interrupted,
9092
)
9193

94+
if value_memories is not None:
95+
agent_buffer_trajectory[BufferKey.CRITIC_MEMORY].set(value_memories)
96+
agent_buffer_trajectory[BufferKey.BASELINE_MEMORY].set(baseline_memories)
97+
9298
for name, v in value_estimates.items():
9399
agent_buffer_trajectory[RewardSignalUtil.value_estimates_key(name)].extend(
94100
v

ml-agents/mlagents/trainers/optimizer/torch_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def create_reward_signals(self, reward_signal_configs):
5656
)
5757

5858
def _evaluate_by_sequence(
59-
self, tensor_obs: List[torch.Tensor], initial_memory: np.ndarray
59+
self, tensor_obs: List[torch.Tensor], initial_memory: torch.Tensor
6060
) -> Tuple[Dict[str, torch.Tensor], AgentBufferField, torch.Tensor]:
6161
"""
6262
Evaluate a trajectory sequence-by-sequence, assembling the result. This enables us to get the

0 commit comments

Comments
 (0)