Aggregate provider (#544)

* Add implementation and typing stubs

* Add tests

* Add typing tests

* Refactor FactoryAggregate

* Update changelog

* Add Aggregate provider docs and example

* Update cross links between Aggregate, Selector, and FactoryAggregate docs

* Add wording improvements to the docs
This commit is contained in:
Roman Mogylatov 2022-01-09 21:45:20 -05:00 committed by GitHub
parent cfadd8c3fa
commit 742e73af1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 12685 additions and 11230 deletions

View File

@ -9,8 +9,14 @@ follows `Semantic versioning`_
Development version
-------------------
- Add new provider ``Aggregate``. It is a generalized version of ``FactoryAggregate`` that
can contain providers of any type, not only ``Factory``. See issue
`#530 <https://github.com/ets-labs/python-dependency-injector/issues/530>`_. Thanks to
`@zerlok (Danil Troshnev) <https://github.com/zerlok>`_ for suggesting the feature.
- Add argument ``as_`` to the ``config.from_env()`` method for the explicit type casting
of an environment variable value, e.g.: ``config.timeout.from_env("TIMEOUT", as_=int)``.
See issue `#533 <https://github.com/ets-labs/python-dependency-injector/issues/533>`_. Thanks to
`@gtors (Andrey Torsunov) <https://github.com/gtors>`_ for suggesting the feature.
- Add ``.providers`` attribute to the ``FactoryAggregate`` provider. It is an alias for
``FactoryAggregate.factories`` attribute.
- Add ``.set_providers()`` method to the ``FactoryAggregate`` provider. It is an alias for

View File

@ -0,0 +1,72 @@
.. _aggregate-provider:
Aggregate provider
==================
.. meta::
:keywords: Python,DI,Dependency injection,IoC,Inversion of Control,Configuration,Injection,
Aggregate,Polymorphism,Environment Variable,Flexibility
:description: Aggregate provider aggregates other providers.
This page demonstrates how to implement the polymorphism and increase the
flexibility of your application using the Aggregate provider.
:py:class:`Aggregate` provider aggregates a group of other providers.
.. currentmodule:: dependency_injector.providers
.. literalinclude:: ../../examples/providers/aggregate.py
:language: python
:lines: 3-
:emphasize-lines: 24-27
Each provider in the ``Aggregate`` is associated with a key. You can call aggregated providers by providing
their key as a first argument. All positional and keyword arguments following the key will be forwarded to
the called provider:
.. code-block:: python
yaml_reader = container.config_readers("yaml", "./config.yml", foo=...)
You can also retrieve an aggregated provider by providing its key as an attribute name:
.. code-block:: python
yaml_reader = container.config_readers.yaml("./config.yml", foo=...)
To retrieve a dictionary of aggregated providers, use ``.providers`` attribute:
.. code-block:: python
container.config_readers.providers == {
"yaml": <YAML provider>,
"json": <JSON provider>,
}
.. note::
You can not override the ``Aggregate`` provider.
.. note::
When you inject the ``Aggregate`` provider, it is passed "as is".
To use non-string keys or string keys with ``.`` and ``-``, provide a dictionary as a positional argument:
.. code-block:: python
aggregate = providers.Aggregate({
SomeClass: providers.Factory(...),
"key.with.periods": providers.Factory(...),
"key-with-dashes": providers.Factory(...),
})
.. seealso::
:ref:`selector-provider` to make injections based on a configuration value, environment variable, or a result of a callable.
``Aggregate`` provider is different from the :ref:`selector-provider`. ``Aggregate`` provider doesn't select which provider
to inject and doesn't have a selector. It is a group of providers and is always injected "as is". The rest of the interface
of both providers is similar.
.. note::
``Aggregate`` provider is a successor of :ref:`factory-aggregate-provider` provider. ``Aggregate`` provider doesn't have
a restriction on the provider type, while ``FactoryAggregate`` aggregates only ``Factory`` providers.
.. disqus::

View File

@ -145,11 +145,17 @@ provider with two peculiarities:
:lines: 3-
:emphasize-lines: 34
.. _factory-aggregate-provider:
Factory aggregate
-----------------
:py:class:`FactoryAggregate` provider aggregates multiple factories.
.. seealso::
:ref:`aggregate-provider` it's a successor of ``FactoryAggregate`` provider that can aggregate
any type of provider, not only ``Factory``.
The aggregated factories are associated with the string keys. When you call the
``FactoryAggregate`` you have to provide one of the these keys as a first argument.
``FactoryAggregate`` looks for the factory with a matching key and calls it with the rest of the arguments.

View File

@ -46,6 +46,7 @@ Providers module API docs - :py:mod:`dependency_injector.providers`
dict
configuration
resource
aggregate
selector
dependency
overriding

View File

@ -30,4 +30,7 @@ When a ``Selector`` provider is called, it gets a ``selector`` value and delegat
the provider with a matching name. The ``selector`` callable works as a switch: when the returned
value is changed the ``Selector`` provider will delegate the work to another provider.
.. seealso::
:ref:`aggregate-provider` to inject a group of providers.
.. disqus::

View File

@ -0,0 +1,39 @@
"""`Aggregate` provider example."""
from dependency_injector import containers, providers
class ConfigReader:
def __init__(self, path):
self._path = path
def read(self):
print(f"Parsing {self._path} with {self.__class__.__name__}")
...
class YamlReader(ConfigReader):
...
class JsonReader(ConfigReader):
...
class Container(containers.DeclarativeContainer):
config_readers = providers.Aggregate(
yaml=providers.Factory(YamlReader),
json=providers.Factory(JsonReader),
)
if __name__ == "__main__":
container = Container()
yaml_reader = container.config_readers("yaml", "./config.yml")
yaml_reader.read() # Parsing ./config.yml with YamlReader
json_reader = container.config_readers("json", "./config.json")
json_reader.read() # Parsing ./config.json with JsonReader

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -38,6 +38,12 @@ cdef class Delegate(Provider):
cpdef object _provide(self, tuple args, dict kwargs)
cdef class Aggregate(Provider):
cdef dict __providers
cdef Provider __get_provider(self, object provider_name)
cdef class Dependency(Provider):
cdef object __instance_of
cdef object __default
@ -142,10 +148,8 @@ cdef class FactoryDelegate(Delegate):
pass
cdef class FactoryAggregate(Provider):
cdef dict __providers
cdef Provider __get_provider(self, object provider_name)
cdef class FactoryAggregate(Aggregate):
pass
# Singleton providers

View File

@ -104,6 +104,21 @@ class Delegate(Provider[Provider]):
def set_provides(self, provides: Optional[Provider]) -> Delegate: ...
class Aggregate(Provider[T]):
def __init__(self, provider_dict: Optional[_Dict[Any, Provider[T]]] = None, **provider_kwargs: Provider[T]): ...
def __getattr__(self, provider_name: Any) -> Provider[T]: ...
@overload
def __call__(self, provider_name: Optional[Any] = None, *args: Injection, **kwargs: Injection) -> T: ...
@overload
def __call__(self, provider_name: Optional[Any] = None, *args: Injection, **kwargs: Injection) -> Awaitable[T]: ...
def async_(self, provider_name: Optional[Any] = None, *args: Injection, **kwargs: Injection) -> Awaitable[T]: ...
@property
def providers(self) -> _Dict[Any, Provider[T]]: ...
def set_providers(self, provider_dict: Optional[_Dict[Any, Provider[T]]] = None, **provider_kwargs: Provider[T]) -> Aggregate[T]: ...
class Dependency(Provider[T]):
def __init__(self, instance_of: Type[T] = object, default: Optional[Union[Provider, Any]] = None) -> None: ...
def __getattr__(self, name: str) -> Any: ...
@ -302,20 +317,8 @@ class FactoryDelegate(Delegate):
def __init__(self, factory: Factory): ...
class FactoryAggregate(Provider[T]):
def __init__(self, provider_dict: Optional[_Dict[Any, Provider[T]]] = None, **provider_kwargs: Provider[T]): ...
class FactoryAggregate(Aggregate[T]):
def __getattr__(self, provider_name: Any) -> Factory[T]: ...
@overload
def __call__(self, provider_name: Optional[Any] = None, *args: Injection, **kwargs: Injection) -> T: ...
@overload
def __call__(self, provider_name: Optional[Any] = None, *args: Injection, **kwargs: Injection) -> Awaitable[T]: ...
def async_(self, provider_name: Optional[Any] = None, *args: Injection, **kwargs: Injection) -> Awaitable[T]: ...
@property
def providers(self) -> _Dict[Any, Provider[T]]: ...
def set_providers(self, provider_dict: Optional[_Dict[Any, Provider[T]]] = None, **provider_kwargs: Provider[T]) -> FactoryAggregate[T]: ...
@property
def factories(self) -> _Dict[Any, Factory[T]]: ...
def set_factories(self, provider_dict: Optional[_Dict[Any, Factory[T]]] = None, **provider_kwargs: Factory[T]) -> FactoryAggregate[T]: ...
@ -487,7 +490,9 @@ class MethodCaller(Provider, ProvidedInstanceFluentInterface):
class OverridingContext(Generic[T]):
def __init__(self, overridden: Provider, overriding: Provider): ...
def __enter__(self) -> T: ...
def __exit__(self, *_: Any) -> None: ...
def __exit__(self, *_: Any) -> None:
pass
...
class BaseSingletonResetContext(Generic[T]):

View File

@ -632,6 +632,115 @@ cdef class Delegate(Provider):
return self.__provides
cdef class Aggregate(Provider):
"""Providers aggregate.
:py:class:`Aggregate` is a delegated provider, meaning that it is
injected "as is".
All aggregated providers can be retrieved as a read-only
dictionary :py:attr:`Aggregate.providers` or as an attribute of
:py:class:`Aggregate`, e.g. ``aggregate.provider``.
"""
__IS_DELEGATED__ = True
def __init__(self, provider_dict=None, **provider_kwargs):
"""Initialize provider."""
self.__providers = {}
self.set_providers(provider_dict, **provider_kwargs)
super().__init__()
def __deepcopy__(self, memo):
"""Create and return full copy of provider."""
copied = memo.get(id(self))
if copied is not None:
return copied
copied = _memorized_duplicate(self, memo)
copied.set_providers(deepcopy(self.providers, memo))
self._copy_overridings(copied, memo)
return copied
def __getattr__(self, factory_name):
"""Return aggregated provider."""
return self.__get_provider(factory_name)
def __str__(self):
"""Return string representation of provider.
:rtype: str
"""
return represent_provider(provider=self, provides=self.providers)
@property
def providers(self):
"""Return dictionary of providers, read-only.
Alias for ``.factories`` attribute.
"""
return dict(self.__providers)
def set_providers(self, provider_dict=None, **provider_kwargs):
"""Set providers.
Alias for ``.set_factories()`` method.
"""
providers = {}
providers.update(provider_kwargs)
if provider_dict:
providers.update(provider_dict)
for provider in providers.values():
if not is_provider(provider):
raise Error(
'{0} can aggregate only instances of {1}, given - {2}'.format(
self.__class__,
Provider,
provider,
),
)
self.__providers = providers
return self
def override(self, _):
"""Override provider with another provider.
:raise: :py:exc:`dependency_injector.errors.Error`
:return: Overriding context.
:rtype: :py:class:`OverridingContext`
"""
raise Error('{0} providers could not be overridden'.format(self.__class__))
@property
def related(self):
"""Return related providers generator."""
yield from self.__providers.values()
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
try:
provider_name = args[0]
except IndexError:
try:
provider_name = kwargs.pop("factory_name")
except KeyError:
raise TypeError("Missing 1st required positional argument: \"provider_name\"")
else:
args = args[1:]
return self.__get_provider(provider_name)(*args, **kwargs)
cdef Provider __get_provider(self, object provider_name):
if provider_name not in self.__providers:
raise NoSuchProviderError("{0} does not contain provider with name {1}".format(self, provider_name))
return <Provider> self.__providers[provider_name]
cdef class Dependency(Provider):
""":py:class:`Dependency` provider describes dependency interface.
@ -2561,7 +2670,7 @@ cdef class FactoryDelegate(Delegate):
super(FactoryDelegate, self).__init__(factory)
cdef class FactoryAggregate(Provider):
cdef class FactoryAggregate(Aggregate):
"""Factory providers aggregate.
:py:class:`FactoryAggregate` is an aggregate of :py:class:`Factory`
@ -2575,69 +2684,6 @@ cdef class FactoryAggregate(Provider):
:py:class:`FactoryAggregate`.
"""
__IS_DELEGATED__ = True
def __init__(self, provider_dict=None, **provider_kwargs):
"""Initialize provider."""
self.__providers = {}
self.set_providers(provider_dict, **provider_kwargs)
super(FactoryAggregate, self).__init__()
def __deepcopy__(self, memo):
"""Create and return full copy of provider."""
copied = memo.get(id(self))
if copied is not None:
return copied
copied = _memorized_duplicate(self, memo)
copied.set_providers(deepcopy(self.providers, memo))
self._copy_overridings(copied, memo)
return copied
def __getattr__(self, factory_name):
"""Return aggregated provider."""
return self.__get_provider(factory_name)
def __str__(self):
"""Return string representation of provider.
:rtype: str
"""
return represent_provider(provider=self, provides=self.providers)
@property
def providers(self):
"""Return dictionary of providers, read-only.
Alias for ``.factories`` attribute.
"""
return dict(self.__providers)
def set_providers(self, provider_dict=None, **provider_kwargs):
"""Set providers.
Alias for ``.set_factories()`` method.
"""
providers = {}
providers.update(provider_kwargs)
if provider_dict:
providers.update(provider_dict)
for provider in providers.values():
if not is_provider(provider):
raise Error(
'{0} can aggregate only instances of {1}, given - {2}'.format(
self.__class__,
Provider,
provider,
),
)
self.__providers = providers
return self
@property
def factories(self):
"""Return dictionary of factories, read-only.
@ -2653,40 +2699,6 @@ cdef class FactoryAggregate(Provider):
"""
return self.set_providers(factory_dict, **factory_kwargs)
def override(self, _):
"""Override provider with another provider.
:raise: :py:exc:`dependency_injector.errors.Error`
:return: Overriding context.
:rtype: :py:class:`OverridingContext`
"""
raise Error('{0} providers could not be overridden'.format(self.__class__))
@property
def related(self):
"""Return related providers generator."""
yield from self.__providers.values()
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
try:
provider_name = args[0]
except IndexError:
try:
provider_name = kwargs.pop("factory_name")
except KeyError:
raise TypeError("Missing 1st required positional argument: \"provider_name\"")
else:
args = args[1:]
return self.__get_provider(provider_name)(*args, **kwargs)
cdef Provider __get_provider(self, object provider_name):
if provider_name not in self.__providers:
raise NoSuchProviderError("{0} does not contain provider with name {1}".format(self, provider_name))
return <Provider> self.__providers[provider_name]
cdef class BaseSingleton(Provider):
"""Base class of singleton providers."""

32
tests/typing/aggregate.py Normal file
View File

@ -0,0 +1,32 @@
from dependency_injector import providers
class Animal:
...
class Cat(Animal):
...
# Test 1: to check Aggregate provider
provider1: providers.Aggregate[str] = providers.Aggregate(
a=providers.Object("str1"),
b=providers.Object("str2"),
)
provider_a_1: providers.Provider[str] = provider1.a
provider_b_1: providers.Provider[str] = provider1.b
val1: str = provider1("a")
provider1_set_non_string_keys: providers.Aggregate[str] = providers.Aggregate()
provider1_set_non_string_keys.set_providers({Cat: providers.Object("str")})
provider_set_non_string_1: providers.Provider[str] = provider1_set_non_string_keys.providers[Cat]
provider1_new_non_string_keys: providers.Aggregate[str] = providers.Aggregate(
{Cat: providers.Object("str")},
)
factory_new_non_string_1: providers.Provider[str] = provider1_new_non_string_keys.providers[Cat]
provider1_no_explicit_typing = providers.Aggregate(a=providers.Object("str"))
provider1_no_explicit_typing_factory: providers.Provider[str] = provider1_no_explicit_typing.providers["a"]
provider1_no_explicit_typing_object: str = provider1_no_explicit_typing("a")

View File

@ -0,0 +1,292 @@
"""Aggregate provider tests."""
from dependency_injector import providers, errors
from pytest import fixture, mark, raises
class Example:
def __init__(self, init_arg1=None, init_arg2=None, init_arg3=None, init_arg4=None):
self.init_arg1 = init_arg1
self.init_arg2 = init_arg2
self.init_arg3 = init_arg3
self.init_arg4 = init_arg4
self.attribute1 = None
self.attribute2 = None
class ExampleA(Example):
pass
class ExampleB(Example):
pass
@fixture
def factory_a():
return providers.Factory(ExampleA)
@fixture
def factory_b():
return providers.Factory(ExampleB)
@fixture
def aggregate_type():
return "default"
@fixture
def aggregate(aggregate_type, factory_a, factory_b):
if aggregate_type == "empty":
return providers.Aggregate()
elif aggregate_type == "non-string-keys":
return providers.Aggregate({
ExampleA: factory_a,
ExampleB: factory_b,
})
elif aggregate_type == "default":
return providers.Aggregate(
example_a=factory_a,
example_b=factory_b,
)
else:
raise ValueError("Unknown factory type \"{0}\"".format(aggregate_type))
def test_is_provider(aggregate):
assert providers.is_provider(aggregate) is True
def test_is_delegated_provider(aggregate):
assert providers.is_delegated(aggregate) is True
@mark.parametrize("aggregate_type", ["non-string-keys"])
def test_init_with_non_string_keys(aggregate, factory_a, factory_b):
object_a = aggregate(ExampleA, 1, 2, init_arg3=3, init_arg4=4)
object_b = aggregate(ExampleB, 11, 22, init_arg3=33, init_arg4=44)
assert isinstance(object_a, ExampleA)
assert object_a.init_arg1 == 1
assert object_a.init_arg2 == 2
assert object_a.init_arg3 == 3
assert object_a.init_arg4 == 4
assert isinstance(object_b, ExampleB)
assert object_b.init_arg1 == 11
assert object_b.init_arg2 == 22
assert object_b.init_arg3 == 33
assert object_b.init_arg4 == 44
assert aggregate.providers == {
ExampleA: factory_a,
ExampleB: factory_b,
}
def test_init_with_not_a_factory():
with raises(errors.Error):
providers.Aggregate(
example_a=providers.Factory(ExampleA),
example_b=object(),
)
@mark.parametrize("aggregate_type", ["empty"])
def test_init_optional_providers(aggregate, factory_a, factory_b):
aggregate.set_providers(
example_a=factory_a,
example_b=factory_b,
)
assert aggregate.providers == {
"example_a": factory_a,
"example_b": factory_b,
}
assert isinstance(aggregate("example_a"), ExampleA)
assert isinstance(aggregate("example_b"), ExampleB)
@mark.parametrize("aggregate_type", ["non-string-keys"])
def test_set_providers_with_non_string_keys(aggregate, factory_a, factory_b):
aggregate.set_providers({
ExampleA: factory_a,
ExampleB: factory_b,
})
object_a = aggregate(ExampleA, 1, 2, init_arg3=3, init_arg4=4)
object_b = aggregate(ExampleB, 11, 22, init_arg3=33, init_arg4=44)
assert isinstance(object_a, ExampleA)
assert object_a.init_arg1 == 1
assert object_a.init_arg2 == 2
assert object_a.init_arg3 == 3
assert object_a.init_arg4 == 4
assert isinstance(object_b, ExampleB)
assert object_b.init_arg1 == 11
assert object_b.init_arg2 == 22
assert object_b.init_arg3 == 33
assert object_b.init_arg4 == 44
assert aggregate.providers == {
ExampleA: factory_a,
ExampleB: factory_b,
}
def test_set_providers_returns_self(aggregate, factory_a):
assert aggregate.set_providers(example_a=factory_a) is aggregate
@mark.parametrize("aggregate_type", ["empty"])
def test_init_optional_providers(aggregate, factory_a, factory_b):
aggregate.set_providers(
example_a=factory_a,
example_b=factory_b,
)
assert aggregate.providers == {
"example_a": factory_a,
"example_b": factory_b,
}
assert isinstance(aggregate("example_a"), ExampleA)
assert isinstance(aggregate("example_b"), ExampleB)
@mark.parametrize("aggregate_type", ["non-string-keys"])
def test_set_providers_with_non_string_keys(aggregate, factory_a, factory_b):
aggregate.set_providers({
ExampleA: factory_a,
ExampleB: factory_b,
})
object_a = aggregate(ExampleA, 1, 2, init_arg3=3, init_arg4=4)
object_b = aggregate(ExampleB, 11, 22, init_arg3=33, init_arg4=44)
assert isinstance(object_a, ExampleA)
assert object_a.init_arg1 == 1
assert object_a.init_arg2 == 2
assert object_a.init_arg3 == 3
assert object_a.init_arg4 == 4
assert isinstance(object_b, ExampleB)
assert object_b.init_arg1 == 11
assert object_b.init_arg2 == 22
assert object_b.init_arg3 == 33
assert object_b.init_arg4 == 44
assert aggregate.providers == {
ExampleA: factory_a,
ExampleB: factory_b,
}
def test_set_providers_returns_self(aggregate, factory_a):
assert aggregate.set_providers(example_a=factory_a) is aggregate
def test_call(aggregate):
object_a = aggregate("example_a", 1, 2, init_arg3=3, init_arg4=4)
object_b = aggregate("example_b", 11, 22, init_arg3=33, init_arg4=44)
assert isinstance(object_a, ExampleA)
assert object_a.init_arg1 == 1
assert object_a.init_arg2 == 2
assert object_a.init_arg3 == 3
assert object_a.init_arg4 == 4
assert isinstance(object_b, ExampleB)
assert object_b.init_arg1 == 11
assert object_b.init_arg2 == 22
assert object_b.init_arg3 == 33
assert object_b.init_arg4 == 44
def test_call_factory_name_as_kwarg(aggregate):
object_a = aggregate(
factory_name="example_a",
init_arg1=1,
init_arg2=2,
init_arg3=3,
init_arg4=4,
)
assert isinstance(object_a, ExampleA)
assert object_a.init_arg1 == 1
assert object_a.init_arg2 == 2
assert object_a.init_arg3 == 3
assert object_a.init_arg4 == 4
def test_call_no_factory_name(aggregate):
with raises(TypeError):
aggregate()
def test_call_no_such_provider(aggregate):
with raises(errors.NoSuchProviderError):
aggregate("unknown")
def test_overridden(aggregate):
with raises(errors.Error):
aggregate.override(providers.Object(object()))
def test_getattr(aggregate, factory_a, factory_b):
assert aggregate.example_a is factory_a
assert aggregate.example_b is factory_b
def test_getattr_no_such_provider(aggregate):
with raises(errors.NoSuchProviderError):
aggregate.unknown
def test_providers(aggregate, factory_a, factory_b):
assert aggregate.providers == dict(
example_a=factory_a,
example_b=factory_b,
)
def test_deepcopy(aggregate):
provider_copy = providers.deepcopy(aggregate)
assert aggregate is not provider_copy
assert isinstance(provider_copy, type(aggregate))
assert aggregate.example_a is not provider_copy.example_a
assert isinstance(aggregate.example_a, type(provider_copy.example_a))
assert aggregate.example_a.cls is provider_copy.example_a.cls
assert aggregate.example_b is not provider_copy.example_b
assert isinstance(aggregate.example_b, type(provider_copy.example_b))
assert aggregate.example_b.cls is provider_copy.example_b.cls
@mark.parametrize("aggregate_type", ["non-string-keys"])
def test_deepcopy_with_non_string_keys(aggregate):
provider_copy = providers.deepcopy(aggregate)
assert aggregate is not provider_copy
assert isinstance(provider_copy, type(aggregate))
assert aggregate.providers[ExampleA] is not provider_copy.providers[ExampleA]
assert isinstance(aggregate.providers[ExampleA], type(provider_copy.providers[ExampleA]))
assert aggregate.providers[ExampleA].provides is provider_copy.providers[ExampleA].provides
assert aggregate.providers[ExampleB] is not provider_copy.providers[ExampleB]
assert isinstance(aggregate.providers[ExampleB], type(provider_copy.providers[ExampleB]))
assert aggregate.providers[ExampleB].provides is provider_copy.providers[ExampleB].provides
def test_repr(aggregate):
assert repr(aggregate) == (
"<dependency_injector.providers."
"Aggregate({0}) at {1}>".format(
repr(aggregate.providers),
hex(id(aggregate)),
)
)