diff --git a/src/dependency_injector/containers.pyi b/src/dependency_injector/containers.pyi index f1f118b5..071a4071 100644 --- a/src/dependency_injector/containers.pyi +++ b/src/dependency_injector/containers.pyi @@ -4,6 +4,11 @@ from typing import Type, Dict, Tuple, Optional, Any, Union, ClassVar, Callable a from .providers import Provider +C_Base = TypeVar('C_Base', bound='Container') +C = TypeVar('C', bound='DeclarativeContainer') +C_Overriding = TypeVar('C_Overriding', bound='DeclarativeContainer') + + class Container: provider_type: Type[Provider] = Provider providers: Dict[str, Provider] @@ -13,7 +18,7 @@ class Container: def __setattr__(self, name: str, value: Union[Provider, Any]) -> None: ... def __delattr__(self, name: str) -> None: ... def set_providers(self, **providers: Provider): ... - def override(self, overriding: DynamicContainer) -> None: ... + def override(self, overriding: C_Base) -> None: ... def override_providers(self, **overriding_providers: Provider) -> None: ... def reset_last_overriding(self) -> None: ... def reset_override(self) -> None: ... @@ -31,9 +36,6 @@ class DeclarativeContainer(Container): def __init__(self, **overriding_providers: Union[Provider, Any]) -> None: ... -C = TypeVar('C', bound=DeclarativeContainer) -C_Overriding = TypeVar('C_Overriding', bound=DeclarativeContainer) - def override(container: Type[C]) -> _Callable[[Type[C_Overriding]], Type[C_Overriding]]: ... diff --git a/tests/typing/declarative_container.py b/tests/typing/declarative_container.py index 70ae08b9..53df3144 100644 --- a/tests/typing/declarative_container.py +++ b/tests/typing/declarative_container.py @@ -30,3 +30,11 @@ class Container31(containers.DeclarativeContainer): @containers.copy(Container31) class Container32(containers.DeclarativeContainer): ... + + +# Test 4: to override() +class Container4(containers.DeclarativeContainer): + provider = providers.Factory(int) + +container4 = Container4() +container4.override(Container4())