Skip to content

Commit 7becc1c

Browse files
authored
feat: use rich logging (colors) with hydra (Metta-AI#208)
## Changes - Add support for rich colored logging with millisecond precision - Create utility to override Hydra's default logger with our rich logger - Add time remaining estimation to epoch progress logs (shows in sec/min/hours/days) ## Why Hydra sets up its own logger by default, but our rich logger provides better visibility with colors and precise timestamps. The time remaining estimation helps developers monitor long-running training jobs more effectively. fixes Metta-AI#125
1 parent c88043c commit 7becc1c

File tree

11 files changed

+204
-110
lines changed

11 files changed

+204
-110
lines changed

metta/rl/pufferlib/trainer.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,11 @@ def __init__(
109109
if self._master:
110110
print(policy_record.policy())
111111

112-
if policy_record.metadata["action_names"] != self.vecenv.driver_env.action_names():
112+
action_names = self.vecenv.driver_env.action_names()
113+
if policy_record.metadata["action_names"] != action_names:
113114
raise ValueError(
114115
"Action names do not match between policy and environment: "
115-
f"{policy_record.metadata['action_names']} != {self.vecenv.driver_env.action_names()}"
116+
f"{policy_record.metadata['action_names']} != {action_names}"
116117
)
117118

118119
self._initial_pr = policy_record
@@ -175,13 +176,15 @@ def __init__(
175176

176177
def train(self):
177178
self.train_start = time.time()
179+
self.steps_start = self.agent_step
180+
178181
logger.info("Starting training")
179182

183+
# it doesn't make sense to evaluate more often than checkpointing since we need a saved policy to evaluate
180184
if (
181185
self.trainer_cfg.evaluate_interval != 0
182186
and self.trainer_cfg.evaluate_interval < self.trainer_cfg.checkpoint_interval
183187
):
184-
# it doesn't make sense to evaluate more often than checkpointing since we need a saved policy to evaluate
185188
raise ValueError("evaluate_interval must be at least as large as checkpoint_interval")
186189

187190
logger.info(f"Training on {self.device}")
@@ -195,9 +198,25 @@ def train(self):
195198
# Processing stats
196199
self._process_stats()
197200

201+
# log progress
202+
steps_per_second = (self.agent_step - self.steps_start) / (time.time() - self.train_start)
203+
remaining_steps = self.trainer_cfg.total_timesteps - self.agent_step
204+
remaining_time_sec = remaining_steps / steps_per_second
205+
206+
# Format remaining time in appropriate units
207+
if remaining_time_sec < 60:
208+
time_str = f"{remaining_time_sec:.0f} sec"
209+
elif remaining_time_sec < 3600:
210+
time_str = f"{remaining_time_sec / 60:.1f} min"
211+
elif remaining_time_sec < 86400: # Less than a day
212+
time_str = f"{remaining_time_sec / 3600:.1f} hours"
213+
else:
214+
time_str = f"{remaining_time_sec / 86400:.1f} days"
215+
198216
logger.info(
199-
f"Epoch {self.epoch} - {self.agent_step} "
200-
f"({100.00 * self.agent_step / self.trainer_cfg.total_timesteps:.2f}%)"
217+
f"Epoch {self.epoch} - {self.agent_step} [{steps_per_second:.0f}/sec]"
218+
f" ({100.00 * self.agent_step / self.trainer_cfg.total_timesteps:.2f}%)"
219+
f" - {time_str} remaining"
201220
)
202221

203222
# Checkpointing trainer

metta/sim/simulation_config.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
class SimulationConfig(Config):
1111
"""Configuration for a single simulation run."""
12+
1213
env: str
1314
device: str
1415
num_envs: int
@@ -24,26 +25,23 @@ class SimulationConfig(Config):
2425

2526
class SimulationSuiteConfig(SimulationConfig):
2627
"""A suite of named simulations, with suite-level defaults injected."""
28+
2729
run_dir: str
2830
simulations: Dict[str, SimulationConfig]
2931

30-
# —— don't need env bc all the simulations will specify ——
32+
# —— don't need env bc all the simulations will specify ——
3133
env: Optional[str] = None
3234

3335
@model_validator(mode="before")
34-
def _propagate_defaults(cls, values: dict) -> dict:
36+
def propagate_defaults(cls, values: dict) -> dict:
3537
# collect any suite-level overrides that are present & non-None
3638
suite_defaults = {
37-
k: v for k, v in values.items()
38-
if k in ("env", "device", "num_envs", "num_episodes") and v is not None
39+
k: v for k, v in values.items() if k in ("env", "device", "num_envs", "num_episodes") and v is not None
3940
}
40-
4141
raw_sims = values.get("simulations", {}) or {}
4242
merged: Dict[str, dict] = {}
4343
for name, sim_cfg in raw_sims.items():
4444
# sim_cfg is a dict; override only where sim_cfg provides a key
4545
merged[name] = {**suite_defaults, **sim_cfg}
4646
values["simulations"] = merged
4747
return values
48-
49-

metta/util/logging.py

Lines changed: 68 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import logging
12
import os
23
import sys
4+
from datetime import datetime
35

46
from loguru import logger
7+
from rich.logging import RichHandler
58

69

710
def remap_io(logs_path: str):
@@ -12,14 +15,72 @@ def remap_io(logs_path: str):
1215
stderr = open(stderr_log_path, "a")
1316
sys.stderr = stderr
1417
sys.stdout = stdout
15-
logger.remove() # Remove default handler
16-
logger.remove() # Remove default handler
17-
# logger.add(
18-
# sys.stdout, colorize=True,
19-
# format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | <level>{level: <8}</level> | "
20-
# "<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
21-
# "<level>{message}</level>")
18+
logger.remove()
2219

2320

2421
def restore_io():
2522
sys.stderr = sys.__stderr__
23+
sys.stdout = sys.__stdout__
24+
25+
26+
# Create a custom formatter that supports milliseconds
27+
class MillisecondFormatter(logging.Formatter):
28+
def formatTime(self, record, datefmt=None):
29+
created = datetime.fromtimestamp(record.created)
30+
# Convert microseconds to milliseconds (keep only 3 digits)
31+
msec = created.microsecond // 1000
32+
if datefmt:
33+
# Replace %f with just 3 digits for milliseconds
34+
datefmt = datefmt.replace("%f", f"{msec:03d}")
35+
else:
36+
datefmt = "[%H:%M:%S.%03d]"
37+
return created.strftime(datefmt) % msec
38+
39+
40+
# Create a custom handler that always shows the timestamp
41+
class AlwaysShowTimeRichHandler(RichHandler):
42+
def emit(self, record):
43+
# Force a unique timestamp for each record
44+
record.created = record.created + (record.relativeCreated % 1000) / 1000000
45+
super().emit(record)
46+
47+
48+
def get_log_level(provided_level=None):
49+
"""
50+
Determine log level based on priority:
51+
1. Environment variable LOG_LEVEL
52+
2. Provided level parameter
53+
3. Default to INFO
54+
"""
55+
# Check environment variable first
56+
env_level = os.environ.get("LOG_LEVEL")
57+
if env_level:
58+
return env_level.upper()
59+
60+
# Check provided level next
61+
if provided_level:
62+
return provided_level.upper()
63+
64+
# Default to INFO
65+
return "INFO"
66+
67+
68+
def setup_mettagrid_logger(name: str, level=None) -> logging.Logger:
69+
# Get the appropriate log level based on priority
70+
log_level = get_log_level(level)
71+
72+
# Remove all handlers from the root logger
73+
root_logger = logging.getLogger()
74+
for handler in root_logger.handlers[:]:
75+
root_logger.removeHandler(handler)
76+
77+
# Add back our custom Rich handler
78+
rich_handler = AlwaysShowTimeRichHandler(rich_tracebacks=True)
79+
formatter = MillisecondFormatter("%(message)s", datefmt="[%H:%M:%S.%f]")
80+
rich_handler.setFormatter(formatter)
81+
root_logger.addHandler(rich_handler)
82+
83+
# Set the level
84+
root_logger.setLevel(getattr(logging, log_level))
85+
86+
return logging.getLogger(name)
Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
"""
2-
Unit‑tests for SimulationSuiteConfig ⇄ SimulationConfig behaviour.
3-
2+
Unit‑tests for SimulationSuiteConfig ⇄ SimulationConfig behavior.
43
Covered
54
-------
65
* suite‑level defaults propagate into children
76
* child‑level overrides win
8-
* missing required keys always raise (allow_missing removed)
7+
* missing required keys always raise (allow_missing removed)
98
"""
109

1110
from typing import Dict
@@ -19,21 +18,29 @@
1918
# ---------------------------------------------------------------------------
2019
# constants
2120
# ---------------------------------------------------------------------------
22-
2321
ROOT_ENV, CHILD_A, CHILD_B = "env/root", "env/a", "env/b"
2422
DEVICE, RUN_DIR = "cpu", "./runs/test"
2523

2624

27-
def _build(cfg: Dict):
28-
return SimulationSuiteConfig(OmegaConf.create(cfg))
25+
@pytest.fixture
26+
def build_simulation_suite_config():
27+
def _build(cfg: Dict):
28+
# First create the OmegaConf object
29+
dict_config = OmegaConf.create(cfg)
30+
31+
# Convert to a Python dictionary
32+
regular_dict = OmegaConf.to_container(dict_config, resolve=True)
33+
34+
# Now create the SimulationSuiteConfig using the model_validate method
35+
return SimulationSuiteConfig.model_validate(regular_dict)
36+
37+
return _build
2938

3039

3140
# ---------------------------------------------------------------------------
3241
# propagation & overrides
3342
# ---------------------------------------------------------------------------
34-
35-
36-
def test_propogate_defaults_and_overrides():
43+
def test_propagate_defaults_and_overrides(build_simulation_suite_config):
3744
cfg = {
3845
"env": ROOT_ENV,
3946
"num_envs": 4,
@@ -45,9 +52,8 @@ def test_propogate_defaults_and_overrides():
4552
"b": {"env": CHILD_B, "num_envs": 8}, # overrides num_envs
4653
},
4754
}
48-
suite = _build(cfg)
55+
suite = build_simulation_suite_config(cfg)
4956
a, b = suite.simulations["a"], suite.simulations["b"]
50-
5157
# device and num_envs both propagated, even though num_envs has a default
5258
assert (a.device, a.num_envs) == (DEVICE, 4)
5359
assert (b.device, b.num_envs) == (DEVICE, 8)
@@ -56,20 +62,17 @@ def test_propogate_defaults_and_overrides():
5662
# ---------------------------------------------------------------------------
5763
# allow_extra – child nodes
5864
# ---------------------------------------------------------------------------
59-
60-
6165
@pytest.mark.parametrize(
6266
"has_extra, should_pass",
6367
[
6468
(False, True),
6569
(True, False),
6670
],
6771
)
68-
def test_allow_extra_child_keys(has_extra, should_pass):
72+
def test_allow_extra_child_keys(build_simulation_suite_config, has_extra, should_pass):
6973
child_node = {"env": CHILD_A}
7074
if has_extra:
7175
child_node["foo"] = "bar" # <- unknown key
72-
7376
cfg = {
7477
"env": ROOT_ENV,
7578
"num_envs": 4,
@@ -78,21 +81,18 @@ def test_allow_extra_child_keys(has_extra, should_pass):
7881
"run_dir": RUN_DIR,
7982
"simulations": {"sim": child_node},
8083
}
81-
8284
if should_pass:
83-
suite = _build(cfg)
85+
suite = build_simulation_suite_config(cfg)
8486
assert suite.simulations["sim"].device == DEVICE
8587
else:
8688
with pytest.raises(ValueError):
87-
_build(cfg)
89+
build_simulation_suite_config(cfg)
8890

8991

9092
# ---------------------------------------------------------------------------
9193
# missing required keys should always error
9294
# ---------------------------------------------------------------------------
93-
94-
95-
def test_missing_device_always_errors():
95+
def test_missing_device_always_errors(build_simulation_suite_config):
9696
cfg = {
9797
"env": ROOT_ENV,
9898
"num_envs": 4,
@@ -101,10 +101,10 @@ def test_missing_device_always_errors():
101101
"simulations": {"sim": {}}, # required 'device' omitted
102102
}
103103
with pytest.raises(ValidationError):
104-
_build(cfg)
104+
build_simulation_suite_config(cfg)
105105

106106

107-
def test_missing_suite_env_is_allowed():
107+
def test_missing_suite_env_is_allowed(build_simulation_suite_config):
108108
cfg = {
109109
"run_dir": RUN_DIR,
110110
"device": DEVICE,
@@ -116,5 +116,5 @@ def test_missing_suite_env_is_allowed():
116116
}
117117
},
118118
}
119-
suite = _build(cfg)
119+
suite = build_simulation_suite_config(cfg)
120120
assert suite.simulations["sim"].env == CHILD_A

tools/analyze.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
"""Analysis tool for MettaGrid evaluation results."""
22

3-
import logging
4-
53
import hydra
64
from omegaconf import DictConfig
75

86
from metta.eval.report import dump_stats, generate_report
7+
from metta.util.logging import setup_mettagrid_logger
98
from metta.util.runtime_configuration import setup_mettagrid_environment
109

1110

1211
@hydra.main(version_base=None, config_path="../configs", config_name="analyze_job")
1312
def main(cfg: DictConfig) -> None:
1413
setup_mettagrid_environment(cfg)
15-
logger = logging.getLogger(__name__)
14+
logger = setup_mettagrid_logger("analyze")
15+
1616
view_type = "latest"
1717
logger.info(f"Generating {view_type} report")
1818
dump_stats(cfg)

tools/play.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import logging
21
import os
32
import signal # Aggressively exit on ctrl+c
43
import sys
@@ -9,31 +8,37 @@
98
from metta.agent.policy_store import PolicyStore
109
from metta.sim.simulation_config import SimulationConfig
1110
from metta.util.config import Config
11+
from metta.util.logging import setup_mettagrid_logger
1212
from metta.util.runtime_configuration import setup_mettagrid_environment
1313
from metta.util.wandb.wandb_context import WandbContext
1414

1515
signal.signal(signal.SIGINT, lambda sig, frame: os._exit(0))
1616

1717

18-
logging.basicConfig(level="INFO")
19-
logger = logging.getLogger(__name__)
20-
21-
2218
class PlayJob(Config):
2319
sim: SimulationConfig
2420
policy_uri: str
2521

22+
def __init__(self, *args, **kwargs):
23+
super().__init__(*args, **kwargs)
24+
2625

2726
@hydra.main(version_base=None, config_path="../configs", config_name="play_job")
28-
def play(cfg):
27+
def main(cfg) -> int:
2928
setup_mettagrid_environment(cfg)
3029

30+
logger = setup_mettagrid_logger("play")
31+
logger.info(f"Playing {cfg.run}")
32+
3133
with WandbContext(cfg) as wandb_run:
3234
policy_store = PolicyStore(cfg, wandb_run)
35+
3336
play_job = PlayJob(cfg.play_job)
3437
policy_record = policy_store.policy(play_job.policy_uri)
3538
metta.sim.simulator.play(play_job.sim, policy_record)
3639

40+
return 0
41+
3742

3843
if __name__ == "__main__":
39-
sys.exit(play())
44+
sys.exit(main())

0 commit comments

Comments
 (0)