diff --git a/docs/main/changelog.rst b/docs/main/changelog.rst index f0d5b9ec..f35b9c44 100644 --- a/docs/main/changelog.rst +++ b/docs/main/changelog.rst @@ -9,6 +9,7 @@ follows `Semantic versioning`_ Development version ------------------- +- Improve ``FactoryAggregate`` typing stub. - Improve resource subclasses typing and make shutdown definition optional `PR #492 `_. Thanks to `@EdwardBlair `_ for suggesting the improvement. diff --git a/src/dependency_injector/providers.pyi b/src/dependency_injector/providers.pyi index 17eee7a8..2733716e 100644 --- a/src/dependency_injector/providers.pyi +++ b/src/dependency_injector/providers.pyi @@ -282,19 +282,19 @@ class FactoryDelegate(Delegate): def __init__(self, factory: Factory): ... -class FactoryAggregate(Provider): - def __init__(self, **factories: Factory): ... - def __getattr__(self, factory_name: str) -> Factory: ... +class FactoryAggregate(Provider[T]): + def __init__(self, **factories: Factory[T]): ... + def __getattr__(self, factory_name: str) -> Factory[T]: ... @overload - def __call__(self, factory_name: str, *args: Injection, **kwargs: Injection) -> Any: ... + def __call__(self, factory_name: str, *args: Injection, **kwargs: Injection) -> T: ... @overload - def __call__(self, factory_name: str, *args: Injection, **kwargs: Injection) -> Awaitable[Any]: ... - def async_(self, factory_name: str, *args: Injection, **kwargs: Injection) -> Awaitable[Any]: ... + def __call__(self, factory_name: str, *args: Injection, **kwargs: Injection) -> Awaitable[T]: ... + def async_(self, factory_name: str, *args: Injection, **kwargs: Injection) -> Awaitable[T]: ... @property - def factories(self) -> _Dict[str, Factory]: ... - def set_factories(self, **factories: Factory) -> FactoryAggregate: ... + def factories(self) -> _Dict[str, Factory[T]]: ... + def set_factories(self, **factories: Factory[T]) -> FactoryAggregate[T]: ... class BaseSingleton(Provider[T]): diff --git a/tests/typing/factory.py b/tests/typing/factory.py index 963eb37c..8e206007 100644 --- a/tests/typing/factory.py +++ b/tests/typing/factory.py @@ -55,13 +55,13 @@ animal7: Animal = provider7(1, 2, 3, b='1', c=2, e=0.0) provider8 = providers.FactoryDelegate(providers.Factory(object)) # Test 9: to check FactoryAggregate provider -provider9 = providers.FactoryAggregate( - a=providers.Factory(object), - b=providers.Factory(object), +provider9: providers.FactoryAggregate[str] = providers.FactoryAggregate( + a=providers.Factory(str, "str1"), + b=providers.Factory(str, "str2"), ) -factory_a_9: providers.Factory = provider9.a -factory_b_9: providers.Factory = provider9.b -val9: Any = provider9('a') +factory_a_9: providers.Factory[str] = provider9.a +factory_b_9: providers.Factory[str] = provider9.b +val9: str = provider9('a') # Test 10: to check the explicit typing factory10: providers.Provider[Animal] = providers.Factory(Cat)