Skip to content

Commit bc38784

Browse files
committed
Support constructor injection based on tags.
1 parent 106fce6 commit bc38784

File tree

5 files changed

+76
-4
lines changed

5 files changed

+76
-4
lines changed

src/dependency_injection/container.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
from typing import Any, Callable, Dict, List, Optional, TypeVar, Type
55

6+
from dependency_injection.tags.all_tagged import AllTagged
7+
from dependency_injection.tags.any_tagged import AnyTagged
8+
from dependency_injection.tags.tagged import Tagged
69
from dependency_injection.registration import Registration
710
from dependency_injection.scope import DEFAULT_SCOPE_NAME, Scope
811
from dependency_injection.utils.singleton_meta import SingletonMeta
@@ -225,12 +228,47 @@ def _inject_dependencies(
225228
# **kwargs parameter
226229
pass
227230
else:
228-
# Check if constructor_args has an argument with the same name
231+
# Priority 1: Check if constructor_args has argument with same name
229232
if constructor_args and param_name in constructor_args:
230233
dependencies[param_name] = constructor_args[param_name]
231234
else:
232-
dependencies[param_name] = self.resolve(
233-
param_info.annotation, scope_name=scope_name
234-
)
235+
# Priority 2: Handle List[Tagged], List[AnyTagged[...]], ...
236+
tagged_dependencies = []
237+
if (
238+
hasattr(param_info.annotation, "__origin__")
239+
and param_info.annotation.__origin__ is list
240+
):
241+
inner_type = param_info.annotation.__args__[0]
242+
243+
if isinstance(inner_type, Tagged):
244+
tagged_dependencies = self.resolve_all(
245+
tags={inner_type.tag}
246+
)
247+
248+
elif isinstance(inner_type, AnyTagged):
249+
tagged_dependencies = self.resolve_all(
250+
tags=inner_type.tags, match_all_tags=False
251+
)
252+
253+
elif isinstance(inner_type, AllTagged):
254+
tagged_dependencies = self.resolve_all(
255+
tags=inner_type.tags, match_all_tags=True
256+
)
257+
258+
dependencies[param_name] = tagged_dependencies
259+
260+
else:
261+
# Priority 3: Regular type resolution
262+
try:
263+
dependencies[param_name] = self.resolve(
264+
param_info.annotation, scope_name=scope_name
265+
)
266+
except KeyError:
267+
raise ValueError(
268+
f"Cannot resolve dependency for parameter "
269+
f"'{param_name}' of type "
270+
f"'{param_info.annotation}' in class "
271+
f"'{implementation.__name__}'."
272+
)
235273

236274
return implementation(**dependencies)

src/dependency_injection/tags/__init__.py

Whitespace-only changes.
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from typing import Type, Tuple, Set, Union
2+
3+
4+
class AllTagged:
5+
def __init__(self, tags: Tuple[Type, ...]):
6+
self.tags: Set[Type] = set(tags)
7+
8+
@classmethod
9+
def __class_getitem__(cls, item: Union[Type, Tuple[Type, ...]]) -> "AllTagged":
10+
if not isinstance(item, tuple):
11+
item = (item,)
12+
return cls(item)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from typing import Type, Tuple, Set, Union
2+
3+
4+
class AnyTagged:
5+
def __init__(self, tags: Union[Tuple[Type, ...], Type]):
6+
if not isinstance(tags, tuple):
7+
tags = (tags,)
8+
self.tags: Set[Type] = set(tags)
9+
10+
@classmethod
11+
def __class_getitem__(cls, item: Union[Type, Tuple[Type, ...]]) -> "AnyTagged":
12+
return cls(item)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from typing import Type
2+
3+
4+
class Tagged:
5+
def __init__(self, tag: Type):
6+
self.tag = tag
7+
8+
@classmethod
9+
def __class_getitem__(cls, item: Type) -> "Tagged":
10+
return cls(item)

0 commit comments

Comments
 (0)