diff --git a/src/dependency_injector/containers.pyi b/src/dependency_injector/containers.pyi index 4b40fbba..ec41ea8e 100644 --- a/src/dependency_injector/containers.pyi +++ b/src/dependency_injector/containers.pyi @@ -41,7 +41,7 @@ class WiringConfiguration: class Container: provider_type: Type[Provider] = Provider providers: Dict[str, Provider] - dependencies: Dict[str, Provider] + dependencies: Dict[str, Provider[Any]] overridden: Tuple[Provider] wiring_config: WiringConfiguration auto_load_config: bool = True diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index 6829d53b..b8534ee5 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -8,13 +8,14 @@ import pkgutil import sys from types import ModuleType from typing import ( + TYPE_CHECKING, Any, Callable, Dict, - Generic, Iterable, Iterator, Optional, + Protocol, Set, Tuple, Type, @@ -23,6 +24,7 @@ from typing import ( cast, ) +from typing_extensions import Self # Hotfix, see: https://github.com/ets-labs/python-dependency-injector/issues/362 if sys.version_info >= (3, 9): @@ -66,7 +68,6 @@ except ImportError: from . import providers - __all__ = ( "wire", "unwire", @@ -89,7 +90,11 @@ __all__ = ( T = TypeVar("T") F = TypeVar("F", bound=Callable[..., Any]) -Container = Any + +if TYPE_CHECKING: + from .containers import Container +else: + Container = Any class PatchedRegistry: @@ -777,15 +782,15 @@ class RequiredModifier(Modifier): def __init__(self) -> None: self.type_modifier = None - def as_int(self) -> "RequiredModifier": + def as_int(self) -> Self: self.type_modifier = TypeModifier(int) return self - def as_float(self) -> "RequiredModifier": + def as_float(self) -> Self: self.type_modifier = TypeModifier(float) return self - def as_(self, type_: Type) -> "RequiredModifier": + def as_(self, type_: Type) -> Self: self.type_modifier = TypeModifier(type_) return self @@ -833,15 +838,15 @@ class ProvidedInstance(Modifier): def __init__(self) -> None: self.segments = [] - def __getattr__(self, item): + def __getattr__(self, item: str) -> Self: self.segments.append((self.TYPE_ATTRIBUTE, item)) return self - def __getitem__(self, item): + def __getitem__(self, item) -> Self: self.segments.append((self.TYPE_ITEM, item)) return self - def call(self): + def call(self) -> Self: self.segments.append((self.TYPE_CALL, None)) return self @@ -866,36 +871,56 @@ def provided() -> ProvidedInstance: return ProvidedInstance() -class _Marker(Generic[T]): - - __IS_MARKER__ = True - - def __init__( - self, - provider: Union[providers.Provider, Container, str], - modifier: Optional[Modifier] = None, - ) -> None: - if _is_declarative_container(provider): - provider = provider.__self__ - self.provider = provider - self.modifier = modifier - - def __class_getitem__(cls, item) -> T: - if isinstance(item, tuple): - return cls(*item) - return cls(item) - - def __call__(self) -> T: - return self +MarkerItem = Union[ + str, + providers.Provider[Any], + Tuple[str, TypeModifier], + Type[Container], + "_Marker", +] -class Provide(_Marker): ... +if TYPE_CHECKING: + class _Marker(Protocol): + __IS_MARKER__: bool -class Provider(_Marker): ... + def __call__(self) -> Self: ... + def __getattr__(self, item: str) -> Self: ... + def __getitem__(self, item: Any) -> Any: ... + Provide: _Marker + Provider: _Marker + Closing: _Marker +else: -class Closing(_Marker): ... + class _Marker: + + __IS_MARKER__ = True + + def __init__( + self, + provider: Union[providers.Provider, Container, str], + modifier: Optional[Modifier] = None, + ) -> None: + if _is_declarative_container(provider): + provider = provider.__self__ + self.provider = provider + self.modifier = modifier + + def __class_getitem__(cls, item: MarkerItem) -> Self: + if isinstance(item, tuple): + return cls(*item) + return cls(item) + + def __call__(self) -> Self: + return self + + class Provide(_Marker): ... + + class Provider(_Marker): ... + + class Closing(_Marker): ... class AutoLoader: @@ -998,8 +1023,8 @@ _inspect_filter = InspectFilter() _loader = AutoLoader() # Optimizations -from ._cwiring import _sync_inject # noqa from ._cwiring import _async_inject # noqa +from ._cwiring import _sync_inject # noqa # Wiring uses the following Python wrapper because there is @@ -1028,13 +1053,17 @@ def _get_sync_patched(fn: F, patched: PatchedCallable) -> F: patched.injections, patched.closing, ) + return cast(F, _patched) if sys.version_info >= (3, 10): + def _get_annotations(obj: Any) -> Dict[str, Any]: return inspect.get_annotations(obj) + else: + def _get_annotations(obj: Any) -> Dict[str, Any]: return getattr(obj, "__annotations__", {}) diff --git a/tests/typing/wiring.py b/tests/typing/wiring.py new file mode 100644 index 00000000..078fe3cf --- /dev/null +++ b/tests/typing/wiring.py @@ -0,0 +1,36 @@ +from typing import Iterator + +from typing_extensions import Annotated + +from dependency_injector.containers import DeclarativeContainer +from dependency_injector.providers import Object, Resource +from dependency_injector.wiring import Closing, Provide, required + + +def _resource() -> Iterator[int]: + yield 1 + + +class Container(DeclarativeContainer): + value = Object(1) + res = Resource(_resource) + + +def default_by_ref(value: int = Provide[Container.value]) -> None: ... +def default_by_string(value: int = Provide["value"]) -> None: ... +def default_by_string_with_modifier( + value: int = Provide["value", required().as_int()] +) -> None: ... +def default_container(container: Container = Provide[Container]) -> None: ... +def default_with_closing(value: int = Closing[Provide[Container.res]]) -> None: ... +def annotated_by_ref(value: Annotated[int, Provide[Container.value]]) -> None: ... +def annotated_by_string(value: Annotated[int, Provide["value"]]) -> None: ... +def annotated_by_string_with_modifier( + value: Annotated[int, Provide["value", required().as_int()]], +) -> None: ... +def annotated_container( + container: Annotated[Container, Provide[Container]], +) -> None: ... +def annotated_with_closing( + value: Annotated[int, Closing[Provide[Container.res]]], +) -> None: ...