@@ -264,6 +264,23 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
264264 ]
265265 if len (memories ) > 0 :
266266 memories = torch .stack (memories ).unsqueeze (0 )
267+ value_memories = [
268+ ModelUtils .list_to_tensor (batch [BufferKey .CRITIC_MEMORY ][i ])
269+ for i in range (
270+ 0 , len (batch [BufferKey .CRITIC_MEMORY ]), self .policy .sequence_length
271+ )
272+ ]
273+
274+ baseline_memories = [
275+ ModelUtils .list_to_tensor (batch [BufferKey .BASELINE_MEMORY ][i ])
276+ for i in range (
277+ 0 , len (batch [BufferKey .BASELINE_MEMORY ]), self .policy .sequence_length
278+ )
279+ ]
280+
281+ if len (value_memories ) > 0 :
282+ value_memories = torch .stack (value_memories ).unsqueeze (0 )
283+ baseline_memories = torch .stack (baseline_memories ).unsqueeze (0 )
267284
268285 log_probs , entropy = self .policy .evaluate_actions (
269286 current_obs ,
@@ -274,13 +291,15 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
274291 )
275292 all_obs = [current_obs ] + group_obs
276293 values , _ = self .critic .critic_pass (
277- all_obs , memories = memories , sequence_length = self .policy .sequence_length
294+ all_obs ,
295+ memories = value_memories ,
296+ sequence_length = self .policy .sequence_length ,
278297 )
279298 baselines , _ = self .critic .baseline (
280299 [current_obs ],
281300 group_obs ,
282301 group_actions ,
283- memories = memories ,
302+ memories = baseline_memories ,
284303 sequence_length = self .policy .sequence_length ,
285304 )
286305 old_log_probs = ActionLogProbs .from_buffer (batch ).flatten ()
@@ -380,7 +399,7 @@ def _evaluate_by_sequence_team(
380399
381400 for team_obs , team_action in zip (obs , actions ):
382401 seq_obs = []
383- for ( _obs ,) in team_obs :
402+ for _obs in team_obs :
384403 first_seq_obs = _obs [0 :first_seq_len ]
385404 seq_obs .append (first_seq_obs )
386405 team_seq_obs .append (seq_obs )
@@ -534,7 +553,12 @@ def get_trajectory_and_baseline_value_estimates(
534553 _init_value_mem = self .value_memory_dict [agent_id ]
535554 _init_baseline_mem = self .baseline_memory_dict [agent_id ]
536555 else :
537- memory = (
556+ _init_value_mem = (
557+ torch .zeros ((1 , 1 , self .critic .memory_size ))
558+ if self .policy .use_recurrent
559+ else None
560+ )
561+ _init_baseline_mem = (
538562 torch .zeros ((1 , 1 , self .critic .memory_size ))
539563 if self .policy .use_recurrent
540564 else None
@@ -544,19 +568,19 @@ def get_trajectory_and_baseline_value_estimates(
544568 all_next_value_mem : Optional [AgentBufferField ] = None
545569 all_next_baseline_mem : Optional [AgentBufferField ] = None
546570 if self .policy .use_recurrent :
547- value_estimates , baseline_estimates , all_next_value_mem , all_next_baseline_mem , next_value_mem , next_baseline_mem = self .critic . _evaluate_by_sequence_team (
571+ value_estimates , baseline_estimates , all_next_value_mem , all_next_baseline_mem , next_value_mem , next_baseline_mem = self ._evaluate_by_sequence_team (
548572 current_obs , team_obs , team_actions , _init_value_mem , _init_baseline_mem
549573 )
550574 else :
551575 value_estimates , value_mem = self .critic .critic_pass (
552- all_obs , memory , sequence_length = batch .num_experiences
576+ all_obs , _init_value_mem , sequence_length = batch .num_experiences
553577 )
554578
555579 baseline_estimates , baseline_mem = self .critic .baseline (
556580 [current_obs ],
557581 team_obs ,
558582 team_actions ,
559- memory ,
583+ _init_baseline_mem ,
560584 sequence_length = batch .num_experiences ,
561585 )
562586 # Store the memory for the next trajectory
0 commit comments