Skip to content

Commit c806015

Browse files
author
Chris Elion
authored
Fix mypy errors in trainer code. (#3135)
1 parent f504908 commit c806015

File tree

4 files changed

+10
-10
lines changed

4 files changed

+10
-10
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,7 @@ repos:
1010
)$
1111
1212
- repo: https://github.com/pre-commit/mirrors-mypy
13-
rev: v0.750
14-
# Currently mypy may assert after logging one message. To get all the messages at once, change repo and rev to
15-
# repo: https://github.com/chriselion/mypy
16-
# rev: 3d0b6164a9487a6c5cf9d144110b86600fd85e25
17-
# This is a fork with the assert disabled, although precommit has trouble installing it sometimes.
13+
rev: v0.761
1814
hooks:
1915
- id: mypy
2016
name: mypy-ml-agents

ml-agents/mlagents/trainers/ppo/trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(
7070
self.load = load
7171
self.multi_gpu = multi_gpu
7272
self.seed = seed
73-
self.policy: TFPolicy = None
73+
self.policy: PPOPolicy = None # type: ignore
7474

7575
def process_trajectory(self, trajectory: Trajectory) -> None:
7676
"""
@@ -255,6 +255,8 @@ def add_policy(self, name_behavior_id: str, policy: TFPolicy) -> None:
255255
self.__class__.__name__
256256
)
257257
)
258+
if not isinstance(policy, PPOPolicy):
259+
raise RuntimeError("Non-PPOPolicy passed to PPOTrainer.add_policy()")
258260
self.policy = policy
259261

260262
def get_policy(self, name_behavior_id: str) -> TFPolicy:

ml-agents/mlagents/trainers/sac/policy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Dict, Any, Optional
2+
from typing import Dict, Any, Optional, Mapping
33
import numpy as np
44
from mlagents.tf_utils import tf
55

@@ -206,7 +206,7 @@ def update(
206206
return update_stats
207207

208208
def update_reward_signals(
209-
self, reward_signal_minibatches: Dict[str, Dict], num_sequences: int
209+
self, reward_signal_minibatches: Mapping[str, Dict], num_sequences: int
210210
) -> Dict[str, float]:
211211
"""
212212
Only update the reward signals.
@@ -236,7 +236,7 @@ def add_reward_signal_dicts(
236236
feed_dict: Dict[tf.Tensor, Any],
237237
update_dict: Dict[str, tf.Tensor],
238238
stats_needed: Dict[str, str],
239-
reward_signal_minibatches: Dict[str, Dict],
239+
reward_signal_minibatches: Mapping[str, Dict],
240240
num_sequences: int,
241241
) -> None:
242242
"""

ml-agents/mlagents/trainers/sac/trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __init__(
7777
self.check_param_keys()
7878
self.load = load
7979
self.seed = seed
80-
self.policy: TFPolicy = None
80+
self.policy: SACPolicy = None # type: ignore
8181

8282
self.step = 0
8383
self.train_interval = (
@@ -337,6 +337,8 @@ def add_policy(self, name_behavior_id: str, policy: TFPolicy) -> None:
337337
self.__class__.__name__
338338
)
339339
)
340+
if not isinstance(policy, SACPolicy):
341+
raise RuntimeError("Non-SACPolicy passed to SACTrainer.add_policy()")
340342
self.policy = policy
341343

342344
def get_policy(self, name_behavior_id: str) -> TFPolicy:

0 commit comments

Comments
 (0)