Skip to content

Commit 18d538f

Browse files
author
Filip Nešťák
committed
Allow annotated in classes
1 parent 0e9905e commit 18d538f

File tree

2 files changed

+55
-1
lines changed

2 files changed

+55
-1
lines changed

injector/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1208,6 +1208,9 @@ def _is_new_union_type(instance: Any) -> bool:
12081208
new_union_type = getattr(types, 'UnionType', None)
12091209
return new_union_type is not None and isinstance(instance, new_union_type)
12101210

1211+
def _is_package_annotation(annotation: Any) -> bool:
1212+
return _is_specialization(annotation, Annotated) and (_inject_marker in annotation.__metadata__ or _noinject_marker in annotation.__metadata__)
1213+
12111214
spec = inspect.getfullargspec(callable)
12121215

12131216
try:
@@ -1238,7 +1241,8 @@ def _is_new_union_type(instance: Any) -> bool:
12381241
bindings.pop(spec.varkw, None)
12391242

12401243
for k, v in list(bindings.items()):
1241-
if _is_specialization(v, Annotated):
1244+
# extract metadata only from Inject and NonInject
1245+
if _is_package_annotation(v):
12421246
v, metadata = v.__origin__, v.__metadata__
12431247
bindings[k] = v
12441248
else:

injector_test.py

+50
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"""Functional tests for the "Injector" dependency injection framework."""
1212

1313
from contextlib import contextmanager
14+
from dataclasses import dataclass
1415
from typing import Any, NewType, Optional, Union
1516
import abc
1617
import sys
@@ -1754,3 +1755,52 @@ def configure(binder):
17541755
injector = Injector([configure])
17551756
assert injector.get(foo) == 123
17561757
assert injector.get(bar) == 456
1758+
1759+
1760+
def test_annotated_integration_with_annotated():
1761+
UserID = Annotated[int, 'user_id']
1762+
1763+
@inject
1764+
class TestClass:
1765+
def __init__(self, user_id: UserID):
1766+
self.user_id = user_id
1767+
1768+
def configure(binder):
1769+
binder.bind(UserID, to=123)
1770+
1771+
injector = Injector([configure])
1772+
1773+
test_class = injector.get(TestClass)
1774+
assert test_class.user_id == 123
1775+
1776+
1777+
def test_newtype_integration_with_annotated():
1778+
UserID = NewType('UserID', int)
1779+
1780+
@inject
1781+
class TestClass:
1782+
def __init__(self, user_id: UserID):
1783+
self.user_id = user_id
1784+
1785+
def configure(binder):
1786+
binder.bind(UserID, to=123)
1787+
1788+
injector = Injector([configure])
1789+
1790+
test_class = injector.get(TestClass)
1791+
assert test_class.user_id == 123
1792+
1793+
def test_dataclass_annotated_parameter():
1794+
Foo = Annotated[int, object()]
1795+
1796+
def configure(binder):
1797+
binder.bind(Foo, to=123)
1798+
1799+
@inject
1800+
@dataclass
1801+
class MyClass:
1802+
foo: Foo
1803+
1804+
injector = Injector([configure])
1805+
instance = injector.get(MyClass)
1806+
assert instance.foo == 123

0 commit comments

Comments
 (0)