Fix typing for wiring marker

This commit is contained in:
ZipFile 2025-05-30 19:28:49 +00:00
parent 99489afa3f
commit b9df88eea7
3 changed files with 100 additions and 35 deletions

View File

@ -41,7 +41,7 @@ class WiringConfiguration:
class Container:
provider_type: Type[Provider] = Provider
providers: Dict[str, Provider]
dependencies: Dict[str, Provider]
dependencies: Dict[str, Provider[Any]]
overridden: Tuple[Provider]
wiring_config: WiringConfiguration
auto_load_config: bool = True

View File

@ -8,13 +8,14 @@ import pkgutil
import sys
from types import ModuleType
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generic,
Iterable,
Iterator,
Optional,
Protocol,
Set,
Tuple,
Type,
@ -23,6 +24,7 @@ from typing import (
cast,
)
from typing_extensions import Self
# Hotfix, see: https://github.com/ets-labs/python-dependency-injector/issues/362
if sys.version_info >= (3, 9):
@ -66,7 +68,6 @@ except ImportError:
from . import providers
__all__ = (
"wire",
"unwire",
@ -89,6 +90,10 @@ __all__ = (
T = TypeVar("T")
F = TypeVar("F", bound=Callable[..., Any])
if TYPE_CHECKING:
from .containers import Container
else:
Container = Any
@ -777,15 +782,15 @@ class RequiredModifier(Modifier):
def __init__(self) -> None:
self.type_modifier = None
def as_int(self) -> "RequiredModifier":
def as_int(self) -> Self:
self.type_modifier = TypeModifier(int)
return self
def as_float(self) -> "RequiredModifier":
def as_float(self) -> Self:
self.type_modifier = TypeModifier(float)
return self
def as_(self, type_: Type) -> "RequiredModifier":
def as_(self, type_: Type) -> Self:
self.type_modifier = TypeModifier(type_)
return self
@ -833,15 +838,15 @@ class ProvidedInstance(Modifier):
def __init__(self) -> None:
self.segments = []
def __getattr__(self, item):
def __getattr__(self, item: str) -> Self:
self.segments.append((self.TYPE_ATTRIBUTE, item))
return self
def __getitem__(self, item):
def __getitem__(self, item) -> Self:
self.segments.append((self.TYPE_ITEM, item))
return self
def call(self):
def call(self) -> Self:
self.segments.append((self.TYPE_CALL, None))
return self
@ -866,7 +871,30 @@ def provided() -> ProvidedInstance:
return ProvidedInstance()
class _Marker(Generic[T]):
MarkerItem = Union[
str,
providers.Provider[Any],
Tuple[str, TypeModifier],
Type[Container],
"_Marker",
]
if TYPE_CHECKING:
class _Marker(Protocol):
__IS_MARKER__: bool
def __call__(self) -> Self: ...
def __getattr__(self, item: str) -> Self: ...
def __getitem__(self, item: Any) -> Any: ...
Provide: _Marker
Provider: _Marker
Closing: _Marker
else:
class _Marker:
__IS_MARKER__ = True
@ -880,21 +908,18 @@ class _Marker(Generic[T]):
self.provider = provider
self.modifier = modifier
def __class_getitem__(cls, item) -> T:
def __class_getitem__(cls, item: MarkerItem) -> Self:
if isinstance(item, tuple):
return cls(*item)
return cls(item)
def __call__(self) -> T:
def __call__(self) -> Self:
return self
class Provide(_Marker): ...
class Provider(_Marker): ...
class Closing(_Marker): ...
@ -998,8 +1023,8 @@ _inspect_filter = InspectFilter()
_loader = AutoLoader()
# Optimizations
from ._cwiring import _sync_inject # noqa
from ._cwiring import _async_inject # noqa
from ._cwiring import _sync_inject # noqa
# Wiring uses the following Python wrapper because there is
@ -1028,13 +1053,17 @@ def _get_sync_patched(fn: F, patched: PatchedCallable) -> F:
patched.injections,
patched.closing,
)
return cast(F, _patched)
if sys.version_info >= (3, 10):
def _get_annotations(obj: Any) -> Dict[str, Any]:
return inspect.get_annotations(obj)
else:
def _get_annotations(obj: Any) -> Dict[str, Any]:
return getattr(obj, "__annotations__", {})

36
tests/typing/wiring.py Normal file
View File

@ -0,0 +1,36 @@
from typing import Iterator
from typing_extensions import Annotated
from dependency_injector.containers import DeclarativeContainer
from dependency_injector.providers import Object, Resource
from dependency_injector.wiring import Closing, Provide, required
def _resource() -> Iterator[int]:
yield 1
class Container(DeclarativeContainer):
value = Object(1)
res = Resource(_resource)
def default_by_ref(value: int = Provide[Container.value]) -> None: ...
def default_by_string(value: int = Provide["value"]) -> None: ...
def default_by_string_with_modifier(
value: int = Provide["value", required().as_int()]
) -> None: ...
def default_container(container: Container = Provide[Container]) -> None: ...
def default_with_closing(value: int = Closing[Provide[Container.res]]) -> None: ...
def annotated_by_ref(value: Annotated[int, Provide[Container.value]]) -> None: ...
def annotated_by_string(value: Annotated[int, Provide["value"]]) -> None: ...
def annotated_by_string_with_modifier(
value: Annotated[int, Provide["value", required().as_int()]],
) -> None: ...
def annotated_container(
container: Annotated[Container, Provide[Container]],
) -> None: ...
def annotated_with_closing(
value: Annotated[int, Closing[Provide[Container.res]]],
) -> None: ...