Skip to content

Commit 03c5e94

Browse files
Experimental omegaconf+ parser mode that supports interpolation across configs and arguments (#765)
--------- Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
1 parent 9fce2b9 commit 03c5e94

File tree

12 files changed

+289
-77
lines changed

12 files changed

+289
-77
lines changed

CHANGELOG.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ Added
2323
<https://github.com/omni-us/jsonargparse/pull/758>`__).
2424
- New ``ActionFail`` for arguments that should fail parsing with a given error
2525
message (`#759 <https://github.com/omni-us/jsonargparse/pull/759>`__).
26+
- Experimental ``omegaconf+`` parser mode that supports variable interpolation
27+
and resolving across configs and command line arguments. Depending on
28+
community feedback, in v5.0.0 this new mode could replace the current
29+
``omegaconf`` mode, introducing a breaking change (`#765
30+
<https://github.com/omni-us/jsonargparse/pull/765>`__).
2631

2732
Fixed
2833
^^^^^
@@ -34,6 +39,8 @@ Fixed
3439
- Environment variable names not shown in help for positional arguments when
3540
``default_env`` is true (`#763
3641
<https://github.com/omni-us/jsonargparse/pull/763>`__).
42+
- ``parse_object`` not parsing correctly configs (`#765
43+
<https://github.com/omni-us/jsonargparse/pull/765>`__).
3744

3845
Changed
3946
^^^^^^^

DOCUMENTATION.rst

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2398,8 +2398,8 @@ instantiates :class:`Data` first, then use the ``num_classes`` attribute to
23982398
instantiate :class:`Model`.
23992399

24002400

2401-
Variable interpolation
2402-
======================
2401+
OmegaConf variable interpolation
2402+
================================
24032403

24042404
One of the possible reasons to add a parser mode (see :ref:`custom-loaders`) can
24052405
be to have support for variable interpolation in yaml files. Any library could
@@ -2463,10 +2463,22 @@ This yaml could be parsed as follows:
24632463

24642464
.. note::
24652465

2466-
The ``parser_mode='omegaconf'`` provides support for `OmegaConf's
2467-
<https://omegaconf.readthedocs.io/>`__ variable interpolation in a single
2468-
yaml file. It is not possible to do interpolation across multiple yaml files
2469-
or in an isolated individual command line argument.
2466+
The ``parser_mode="omegaconf"`` provides support for `OmegaConf's resolvers
2467+
<https://omegaconf.readthedocs.io/en/latest/custom_resolvers.html/>`__ in a
2468+
single YAML file. It is not possible to do interpolation across multiple
2469+
YAML files or in an isolated individual command line argument.
2470+
2471+
Experimental ``omegaconf+`` mode
2472+
--------------------------------
2473+
2474+
There is a new experimental ``omegaconf+`` parser mode that doesn't suffer from
2475+
the limitations of ``omegaconf`` mentioned above. Instead of applying OmegaConf
2476+
resolvers for each YAML config, the resolving is applied once at the end of
2477+
parsing. Because of this, in nested subconfigs, references to config nodes need
2478+
to be relative to work correctly.
2479+
2480+
Depending on feedback from the community, this mode might become the default
2481+
``omegaconf`` mode in v5.0.0.
24702482

24712483

24722484
.. _environment-variables:

jsonargparse/_common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def __call__(self, class_type: Type[ClassType], *args, **kwargs) -> ClassType:
6363
class_instantiators: ContextVar[Optional[InstantiatorsDictType]] = ContextVar("class_instantiators", default=None)
6464
nested_links: ContextVar[List[dict]] = ContextVar("nested_links", default=[])
6565
applied_instantiation_links: ContextVar[Optional[set]] = ContextVar("applied_instantiation_links", default=None)
66+
path_dump_preserve_relative: ContextVar[bool] = ContextVar("path_dump_preserve_relative", default=False)
6667

6768

6869
parser_context_vars = {
@@ -74,6 +75,7 @@ def __call__(self, class_type: Type[ClassType], *args, **kwargs) -> ClassType:
7475
"class_instantiators": class_instantiators,
7576
"nested_links": nested_links,
7677
"applied_instantiation_links": applied_instantiation_links,
78+
"path_dump_preserve_relative": path_dump_preserve_relative,
7779
}
7880

7981

jsonargparse/_core.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
fsspec_support,
8383
import_fsspec,
8484
import_jsonnet,
85+
omegaconf_apply,
8586
pyyaml_available,
8687
)
8788
from ._parameter_resolvers import UnknownDefault
@@ -378,9 +379,12 @@ def _parse_common(
378379
with parser_context(lenient_check=True):
379380
ActionTypeHint.add_sub_defaults(self, cfg)
380381

381-
_ActionPrintConfig.print_config_if_requested(self, cfg)
382-
383382
with parser_context(parent_parser=self):
383+
if not lenient_check.get() and self.parser_mode == "omegaconf+":
384+
cfg = omegaconf_apply(self, cfg)
385+
386+
_ActionPrintConfig.print_config_if_requested(self, cfg)
387+
384388
try:
385389
ActionLink.apply_parsing_links(self, cfg)
386390
except Exception as ex:
@@ -1401,6 +1405,13 @@ def _apply_actions(
14011405
if isinstance(action, _ActionConfigLoad):
14021406
config_keys.add(action_dest)
14031407
keys.append(action_dest)
1408+
elif isinstance(action, ActionConfigFile):
1409+
if isinstance(value, str):
1410+
cfg.pop(action_dest)
1411+
preserve = Namespace({k: cfg[k] for k in keys[num:]})
1412+
ActionConfigFile.apply_config(self, cfg, action_dest, value)
1413+
cfg.update(preserve)
1414+
continue
14041415
elif getattr(action, "jsonnet_ext_vars", False):
14051416
prev_cfg[action_dest] = value
14061417
cfg[action_dest] = value
@@ -1450,7 +1461,7 @@ def _check_value_key(
14501461
value = action.check_type(value, self)
14511462
elif hasattr(action, "_check_type"):
14521463
with parser_context(parent_parser=self):
1453-
value = action._check_type_(value, cfg=cfg, append=append) # type: ignore[attr-defined]
1464+
value = action._check_type_(value, cfg=cfg, append=append, mode=self.parser_mode) # type: ignore[attr-defined]
14541465
elif action.type is not None:
14551466
try:
14561467
if action.nargs in {None, "?"} or action.nargs == 0:
@@ -1593,8 +1604,8 @@ def parser_mode(self) -> str:
15931604

15941605
@parser_mode.setter
15951606
def parser_mode(self, parser_mode: str):
1596-
if parser_mode == "omegaconf":
1597-
set_omegaconf_loader()
1607+
if parser_mode in {"omegaconf", "omegaconf+"}:
1608+
set_omegaconf_loader(parser_mode)
15981609
if parser_mode not in loaders:
15991610
raise ValueError(f"The only accepted values for parser_mode are {set(loaders)}.")
16001611
if parser_mode == "jsonnet":

jsonargparse/_loaders_dumpers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -335,11 +335,12 @@ def set_dumper(format_name: str, dumper_fn: Callable[[Any], str]):
335335
dumpers[format_name] = dumper_fn
336336

337337

338-
def set_omegaconf_loader():
339-
if omegaconf_support and "omegaconf" not in loaders:
338+
def set_omegaconf_loader(mode="omegaconf"):
339+
if omegaconf_support and mode not in loaders:
340340
from ._optionals import get_omegaconf_loader
341341

342-
set_loader("omegaconf", get_omegaconf_loader(), get_loader_exceptions("yaml"))
342+
loader = yaml_load if mode == "omegaconf+" else get_omegaconf_loader()
343+
set_loader(mode, loader, get_loader_exceptions("yaml"))
343344

344345

345346
set_loader("jsonnet", jsonnet_load, get_loader_exceptions("jsonnet"))

jsonargparse/_optionals.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,24 @@ def omegaconf_load(value):
286286
return omegaconf_load
287287

288288

289+
def omegaconf_apply(parser, cfg):
290+
if "${" not in str(cfg):
291+
return cfg
292+
293+
with missing_package_raise("omegaconf", "omegaconf_apply"):
294+
from omegaconf import OmegaConf
295+
296+
from ._common import parser_context
297+
298+
with parser_context(path_dump_preserve_relative=True):
299+
cfg_dict = parser.dump(
300+
cfg, format="json_compact", skip_validation=True, skip_none=False, skip_link_targets=False
301+
)
302+
cfg_omegaconf = OmegaConf.create(cfg_dict)
303+
cfg_dict = OmegaConf.to_container(cfg_omegaconf, resolve=True)
304+
return parser._apply_actions(cfg_dict)
305+
306+
289307
annotated_alias = typing_extensions_import("_AnnotatedAlias")
290308

291309

jsonargparse/_typehints.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
get_unaliased_type,
5555
is_dataclass_like,
5656
is_subclass,
57+
lenient_check,
5758
nested_links,
5859
parent_parser,
5960
parser_context,
@@ -524,7 +525,7 @@ def __call__(self, *args, **kwargs):
524525
if "nargs" in kwargs and kwargs["nargs"] == 0:
525526
raise ValueError("ActionTypeHint does not allow nargs=0.")
526527
return ActionTypeHint(**kwargs)
527-
cfg, val, opt_str = args[1:]
528+
parser, cfg, val, opt_str = args
528529
if not (self.nargs == "?" and val is None):
529530
if isinstance(opt_str, str) and opt_str.startswith(f"--{self.dest}."):
530531
if opt_str.startswith(f"--{self.dest}.init_args."):
@@ -533,7 +534,7 @@ def __call__(self, *args, **kwargs):
533534
sub_opt = opt_str[len(f"--{self.dest}.") :]
534535
val = NestedArg(key=sub_opt, val=val)
535536
append = opt_str == f"--{self.dest}+"
536-
val = self._check_type_(val, append=append, cfg=cfg)
537+
val = self._check_type_(val, append=append, cfg=cfg, mode=parser.parser_mode)
537538
if is_subclass_spec(val):
538539
prev_val = cfg.get(self.dest)
539540
if is_subclass_spec(prev_val) and "init_args" in prev_val:
@@ -545,7 +546,7 @@ def __call__(self, *args, **kwargs):
545546
cfg.update(val, self.dest)
546547
return None
547548

548-
def _check_type(self, value, append=False, cfg=None):
549+
def _check_type(self, value, append=False, cfg=None, mode=None):
549550
islist = _is_action_value_list(self)
550551
if not islist:
551552
value = [value]
@@ -584,7 +585,14 @@ def _check_type(self, value, append=False, cfg=None):
584585
val = adapt_typehints(orig_val, self._typehint, default=self.default, **kwargs)
585586
ex = None
586587
except ValueError:
587-
if self._enable_path and config_path is None and isinstance(orig_val, str):
588+
if (
589+
lenient_check.get()
590+
and mode == "omegaconf+"
591+
and isinstance(orig_val, str)
592+
and "${" in orig_val
593+
):
594+
ex = None
595+
elif self._enable_path and config_path is None and isinstance(orig_val, str):
588596
msg = f"\n- Expected a config path but {orig_val} either not accessible or invalid\n- "
589597
raise type(ex)(msg + str(ex)) from ex
590598
if ex:

jsonargparse/_util.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,11 +282,13 @@ def get_typehint_origin(typehint):
282282

283283

284284
@contextmanager
285-
def change_to_path_dir(path: Optional["Path"]) -> Iterator[Optional[str]]:
285+
def change_to_path_dir(path: Optional[Union["Path", str]]) -> Iterator[Optional[str]]:
286286
"""A context manager for running code in the directory of a path."""
287287
path_dir = current_path_dir.get()
288288
chdir: Union[bool, str] = False
289289
if path is not None:
290+
if isinstance(path, str):
291+
path = Path(path, mode="d")
290292
if path._url_data and (path.is_url or path.is_fsspec):
291293
scheme = path._url_data.scheme
292294
path_dir = path._url_data.url_path

jsonargparse/typing.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
else:
1414
_TypeAlias = type
1515

16-
from ._common import is_final_class
16+
from ._common import is_final_class, path_dump_preserve_relative
1717
from ._optionals import final, pydantic_support
18-
from ._util import Path, get_import_path, get_private_kwargs, import_object
18+
from ._util import Path, change_to_path_dir, get_import_path, get_private_kwargs, import_object
1919

2020
__all__ = [
2121
"final",
@@ -219,6 +219,15 @@ def _is_path_type(value, type_class):
219219
return isinstance(value, Path)
220220

221221

222+
def _serialize_path(path: Path):
223+
if path_dump_preserve_relative.get() and path.relative != path.absolute:
224+
return {
225+
"relative": path._relative,
226+
"cwd": path._cwd,
227+
}
228+
return str(path)
229+
230+
222231
def path_type(mode: str, docstring: Optional[str] = None, **kwargs) -> _TypeAlias:
223232
"""Creates or returns an already registered path type class.
224233
@@ -249,10 +258,14 @@ class PathType(Path):
249258
_expression = name
250259
_mode = mode
251260
_skip_check = skip_check
252-
_type = str
261+
_type = _serialize_path
253262

254263
def __init__(self, v, **k):
255-
super().__init__(v, mode=self._mode, skip_check=self._skip_check, **k)
264+
if isinstance(v, dict) and set(v) == {"cwd", "relative"}:
265+
with change_to_path_dir(v["cwd"]):
266+
super().__init__(v["relative"], mode=self._mode, skip_check=self._skip_check, **k)
267+
else:
268+
super().__init__(v, mode=self._mode, skip_check=self._skip_check, **k)
256269

257270
restricted_type = type(name, (PathType,), {"__doc__": docstring})
258271
add_type(restricted_type, register_key, type_check=_is_path_type)

jsonargparse_tests/test_core.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,18 @@ def test_parse_object_simple(parser):
191191
pytest.raises(ArgumentError, lambda: parser.parse_object({"undefined": True}))
192192

193193

194+
def test_parse_object_config(parser):
195+
parser.add_argument("--cfg", action="config")
196+
parser.add_argument("--a", type=int)
197+
parser.add_argument("--b", type=int)
198+
path = Path("config.json")
199+
path.write_text('{"a": 1, "b": 2}')
200+
cfg = parser.parse_object({"b": 0, "cfg": str(path), "a": 3})
201+
popped_cfg = cfg.pop("cfg")
202+
assert popped_cfg[0].relative == "config.json"
203+
assert cfg == Namespace(a=3, b=2)
204+
205+
194206
def test_parse_object_nested(parser):
195207
parser.add_argument("--l1.l2.op", type=float)
196208
assert parser.parse_object({"l1": {"l2": {"op": 2.1}}}).l1.l2.op == 2.1

0 commit comments

Comments
 (0)