diff --git a/src/dependency_injector/containers.pyi b/src/dependency_injector/containers.pyi index 8b7a3e9e..f1f118b5 100644 --- a/src/dependency_injector/containers.pyi +++ b/src/dependency_injector/containers.pyi @@ -1,5 +1,5 @@ from types import ModuleType -from typing import Type, Dict, Tuple, Optional, Any, Union, ClassVar, Callable as _Callable, Iterable +from typing import Type, Dict, Tuple, Optional, Any, Union, ClassVar, Callable as _Callable, Iterable, TypeVar from .providers import Provider @@ -31,9 +31,14 @@ class DeclarativeContainer(Container): def __init__(self, **overriding_providers: Union[Provider, Any]) -> None: ... -def override(container: Container) -> _Callable[[Container], Container]: ... +C = TypeVar('C', bound=DeclarativeContainer) +C_Overriding = TypeVar('C_Overriding', bound=DeclarativeContainer) -def copy(container: Container) -> _Callable[[Container], Container]: ... +def override(container: Type[C]) -> _Callable[[Type[C_Overriding]], Type[C_Overriding]]: ... + + +def copy(container: Type[C]) -> _Callable[[Type[C_Overriding]], Type[C_Overriding]]: ... + def is_container(instance: Any) -> bool: ... diff --git a/tests/typing/declarative_container.py b/tests/typing/declarative_container.py index 46db4080..70ae08b9 100644 --- a/tests/typing/declarative_container.py +++ b/tests/typing/declarative_container.py @@ -10,3 +10,23 @@ container1 = Container1() container1_type: containers.Container = Container1() provider1: providers.Provider = container1.provider val1: int = container1.provider(3) + + +# Test 2: to check @override decorator +class Container21(containers.DeclarativeContainer): + provider = providers.Factory(int) + + +@containers.override(Container21) +class Container22(containers.DeclarativeContainer): + ... + + +# Test 3: to check @copy decorator +class Container31(containers.DeclarativeContainer): + provider = providers.Factory(int) + + +@containers.copy(Container31) +class Container32(containers.DeclarativeContainer): + ...