Fix typing for wiring marker

This commit is contained in:
ZipFile 2025-05-30 19:28:49 +00:00
parent 99489afa3f
commit b9df88eea7
3 changed files with 100 additions and 35 deletions

View File

@ -41,7 +41,7 @@ class WiringConfiguration:
class Container: class Container:
provider_type: Type[Provider] = Provider provider_type: Type[Provider] = Provider
providers: Dict[str, Provider] providers: Dict[str, Provider]
dependencies: Dict[str, Provider] dependencies: Dict[str, Provider[Any]]
overridden: Tuple[Provider] overridden: Tuple[Provider]
wiring_config: WiringConfiguration wiring_config: WiringConfiguration
auto_load_config: bool = True auto_load_config: bool = True

View File

@ -8,13 +8,14 @@ import pkgutil
import sys import sys
from types import ModuleType from types import ModuleType
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Callable, Callable,
Dict, Dict,
Generic,
Iterable, Iterable,
Iterator, Iterator,
Optional, Optional,
Protocol,
Set, Set,
Tuple, Tuple,
Type, Type,
@ -23,6 +24,7 @@ from typing import (
cast, cast,
) )
from typing_extensions import Self
# Hotfix, see: https://github.com/ets-labs/python-dependency-injector/issues/362 # Hotfix, see: https://github.com/ets-labs/python-dependency-injector/issues/362
if sys.version_info >= (3, 9): if sys.version_info >= (3, 9):
@ -66,7 +68,6 @@ except ImportError:
from . import providers from . import providers
__all__ = ( __all__ = (
"wire", "wire",
"unwire", "unwire",
@ -89,6 +90,10 @@ __all__ = (
T = TypeVar("T") T = TypeVar("T")
F = TypeVar("F", bound=Callable[..., Any]) F = TypeVar("F", bound=Callable[..., Any])
if TYPE_CHECKING:
from .containers import Container
else:
Container = Any Container = Any
@ -777,15 +782,15 @@ class RequiredModifier(Modifier):
def __init__(self) -> None: def __init__(self) -> None:
self.type_modifier = None self.type_modifier = None
def as_int(self) -> "RequiredModifier": def as_int(self) -> Self:
self.type_modifier = TypeModifier(int) self.type_modifier = TypeModifier(int)
return self return self
def as_float(self) -> "RequiredModifier": def as_float(self) -> Self:
self.type_modifier = TypeModifier(float) self.type_modifier = TypeModifier(float)
return self return self
def as_(self, type_: Type) -> "RequiredModifier": def as_(self, type_: Type) -> Self:
self.type_modifier = TypeModifier(type_) self.type_modifier = TypeModifier(type_)
return self return self
@ -833,15 +838,15 @@ class ProvidedInstance(Modifier):
def __init__(self) -> None: def __init__(self) -> None:
self.segments = [] self.segments = []
def __getattr__(self, item): def __getattr__(self, item: str) -> Self:
self.segments.append((self.TYPE_ATTRIBUTE, item)) self.segments.append((self.TYPE_ATTRIBUTE, item))
return self return self
def __getitem__(self, item): def __getitem__(self, item) -> Self:
self.segments.append((self.TYPE_ITEM, item)) self.segments.append((self.TYPE_ITEM, item))
return self return self
def call(self): def call(self) -> Self:
self.segments.append((self.TYPE_CALL, None)) self.segments.append((self.TYPE_CALL, None))
return self return self
@ -866,7 +871,30 @@ def provided() -> ProvidedInstance:
return ProvidedInstance() return ProvidedInstance()
class _Marker(Generic[T]): MarkerItem = Union[
str,
providers.Provider[Any],
Tuple[str, TypeModifier],
Type[Container],
"_Marker",
]
if TYPE_CHECKING:
class _Marker(Protocol):
__IS_MARKER__: bool
def __call__(self) -> Self: ...
def __getattr__(self, item: str) -> Self: ...
def __getitem__(self, item: Any) -> Any: ...
Provide: _Marker
Provider: _Marker
Closing: _Marker
else:
class _Marker:
__IS_MARKER__ = True __IS_MARKER__ = True
@ -880,21 +908,18 @@ class _Marker(Generic[T]):
self.provider = provider self.provider = provider
self.modifier = modifier self.modifier = modifier
def __class_getitem__(cls, item) -> T: def __class_getitem__(cls, item: MarkerItem) -> Self:
if isinstance(item, tuple): if isinstance(item, tuple):
return cls(*item) return cls(*item)
return cls(item) return cls(item)
def __call__(self) -> T: def __call__(self) -> Self:
return self return self
class Provide(_Marker): ... class Provide(_Marker): ...
class Provider(_Marker): ... class Provider(_Marker): ...
class Closing(_Marker): ... class Closing(_Marker): ...
@ -998,8 +1023,8 @@ _inspect_filter = InspectFilter()
_loader = AutoLoader() _loader = AutoLoader()
# Optimizations # Optimizations
from ._cwiring import _sync_inject # noqa
from ._cwiring import _async_inject # noqa from ._cwiring import _async_inject # noqa
from ._cwiring import _sync_inject # noqa
# Wiring uses the following Python wrapper because there is # Wiring uses the following Python wrapper because there is
@ -1028,13 +1053,17 @@ def _get_sync_patched(fn: F, patched: PatchedCallable) -> F:
patched.injections, patched.injections,
patched.closing, patched.closing,
) )
return cast(F, _patched) return cast(F, _patched)
if sys.version_info >= (3, 10): if sys.version_info >= (3, 10):
def _get_annotations(obj: Any) -> Dict[str, Any]: def _get_annotations(obj: Any) -> Dict[str, Any]:
return inspect.get_annotations(obj) return inspect.get_annotations(obj)
else: else:
def _get_annotations(obj: Any) -> Dict[str, Any]: def _get_annotations(obj: Any) -> Dict[str, Any]:
return getattr(obj, "__annotations__", {}) return getattr(obj, "__annotations__", {})

36
tests/typing/wiring.py Normal file
View File

@ -0,0 +1,36 @@
from typing import Iterator
from typing_extensions import Annotated
from dependency_injector.containers import DeclarativeContainer
from dependency_injector.providers import Object, Resource
from dependency_injector.wiring import Closing, Provide, required
def _resource() -> Iterator[int]:
yield 1
class Container(DeclarativeContainer):
value = Object(1)
res = Resource(_resource)
def default_by_ref(value: int = Provide[Container.value]) -> None: ...
def default_by_string(value: int = Provide["value"]) -> None: ...
def default_by_string_with_modifier(
value: int = Provide["value", required().as_int()]
) -> None: ...
def default_container(container: Container = Provide[Container]) -> None: ...
def default_with_closing(value: int = Closing[Provide[Container.res]]) -> None: ...
def annotated_by_ref(value: Annotated[int, Provide[Container.value]]) -> None: ...
def annotated_by_string(value: Annotated[int, Provide["value"]]) -> None: ...
def annotated_by_string_with_modifier(
value: Annotated[int, Provide["value", required().as_int()]],
) -> None: ...
def annotated_container(
container: Annotated[Container, Provide[Container]],
) -> None: ...
def annotated_with_closing(
value: Annotated[int, Closing[Provide[Container.res]]],
) -> None: ...