Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions timm/models/_builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import dataclasses
import logging
import os
import inspect
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
Expand Down Expand Up @@ -303,6 +304,19 @@ def _filter_kwargs(kwargs: Dict[str, Any], names: List[str]) -> None:
for n in names:
kwargs.pop(n, None)

def _ignore_kwargs(func, kwargs):
""" Filter kwargs to those that func accepts.
"""
sig = inspect.signature(func)
if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()):
return kwargs
filter_keys = [p.name for p in sig.parameters.values() if p.kind in (p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY)]
filtered_kwargs = {k: v for k, v in kwargs.items() if k in filter_keys}
ignored_keys = set(kwargs.keys()) - set(filtered_kwargs.keys())
if ignored_keys:
_logger.warning(
f'Ignored attempt to pass arguments ({", ".join(ignored_keys)}) to function {func}.')
return filtered_kwargs

def _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter) -> None:
""" Update the default_cfg and kwargs before passing to model
Expand Down Expand Up @@ -441,6 +455,7 @@ def build_model_with_cfg(
feature_cfg['feature_cls'] = kwargs.pop('feature_cls')

# Instantiate the model
kwargs = _ignore_kwargs(model_cls.__init__, kwargs)
if model_cfg is None:
model = model_cls(**kwargs)
else:
Expand Down