FactoryAggregate - non string keys (#496)

* Improve FactoryAggregate typing stub

* Add implementation, typing stubs, and tests

* Update changelog

* Fix deepcopying

* Add example

* Update docs

* Fix errors formatting for pypy3
This commit is contained in:
Roman Mogylatov 2021-08-25 10:20:45 -04:00 committed by GitHub
parent 6af818102b
commit 14d8ed909b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 4107 additions and 3806 deletions

View File

@ -9,6 +9,8 @@ follows `Semantic versioning`_
Development version Development version
------------------- -------------------
- Add support of non-string keys for ``FactoryAggregate`` provider.
- Improve ``FactoryAggregate`` typing stub.
- Improve resource subclasses typing and make shutdown definition optional - Improve resource subclasses typing and make shutdown definition optional
`PR #492 <https://github.com/ets-labs/python-dependency-injector/pull/492>`_. `PR #492 <https://github.com/ets-labs/python-dependency-injector/pull/492>`_.
Thanks to `@EdwardBlair <https://github.com/EdwardBlair>`_ for suggesting the improvement. Thanks to `@EdwardBlair <https://github.com/EdwardBlair>`_ for suggesting the improvement.

View File

@ -148,13 +148,11 @@ provider with two peculiarities:
Factory aggregate Factory aggregate
----------------- -----------------
:py:class:`FactoryAggregate` provider aggregates multiple factories. When you call the :py:class:`FactoryAggregate` provider aggregates multiple factories.
``FactoryAggregate`` it delegates the call to one of the factories.
The aggregated factories are associated with the string names. When you call the The aggregated factories are associated with the string keys. When you call the
``FactoryAggregate`` you have to provide one of the these names as a first argument. ``FactoryAggregate`` you have to provide one of the these keys as a first argument.
``FactoryAggregate`` looks for the factory with a matching name and delegates it the work. The ``FactoryAggregate`` looks for the factory with a matching key and calls it with the rest of the arguments.
rest of the arguments are passed to the delegated ``Factory``.
.. image:: images/factory_aggregate.png .. image:: images/factory_aggregate.png
:width: 100% :width: 100%
@ -165,12 +163,12 @@ rest of the arguments are passed to the delegated ``Factory``.
:lines: 3- :lines: 3-
:emphasize-lines: 33-37,47 :emphasize-lines: 33-37,47
You can get a dictionary of the aggregated factories using the ``.factories`` attribute of the You can get a dictionary of the aggregated factories using the ``.factories`` attribute.
``FactoryAggregate``. To get a game factories dictionary from the previous example you can use To get a game factories dictionary from the previous example you can use
``game_factory.factories`` attribute. ``game_factory.factories`` attribute.
You can also access an aggregated factory as an attribute. To create the ``Chess`` object from the You can also access an aggregated factory as an attribute. To create the ``Chess`` object from the
previous example you can do ``chess = game_factory.chess('John', 'Jane')``. previous example you can do ``chess = game_factory.chess("John", "Jane")``.
.. note:: .. note::
You can not override the ``FactoryAggregate`` provider. You can not override the ``FactoryAggregate`` provider.
@ -178,4 +176,22 @@ previous example you can do ``chess = game_factory.chess('John', 'Jane')``.
.. note:: .. note::
When you inject the ``FactoryAggregate`` provider it is passed "as is". When you inject the ``FactoryAggregate`` provider it is passed "as is".
To use non-string keys or keys with ``.`` and ``-`` you can provide a dictionary as a positional argument:
.. code-block:: python
providers.FactoryAggregate({
SomeClass: providers.Factory(...),
"key.with.periods": providers.Factory(...),
"key-with-dashes": providers.Factory(...),
})
Example:
.. literalinclude:: ../../examples/providers/factory_aggregate_non_string_keys.py
:language: python
:lines: 3-
:emphasize-lines: 30-33,39-40
.. disqus:: .. disqus::

View File

@ -0,0 +1,45 @@
"""`FactoryAggregate` provider with non-string keys example."""
from dependency_injector import containers, providers
class Command:
...
class CommandA(Command):
...
class CommandB(Command):
...
class Handler:
...
class HandlerA(Handler):
...
class HandlerB(Handler):
...
class Container(containers.DeclarativeContainer):
handler_factory = providers.FactoryAggregate({
CommandA: providers.Factory(HandlerA),
CommandB: providers.Factory(HandlerB),
})
if __name__ == "__main__":
container = Container()
handler_a = container.handler_factory(CommandA)
handler_b = container.handler_factory(CommandB)
assert isinstance(handler_a, HandlerA)
assert isinstance(handler_b, HandlerB)

File diff suppressed because it is too large Load Diff

View File

@ -142,7 +142,7 @@ cdef class FactoryDelegate(Delegate):
cdef class FactoryAggregate(Provider): cdef class FactoryAggregate(Provider):
cdef dict __factories cdef dict __factories
cdef Factory __get_factory(self, str factory_name) cdef Factory __get_factory(self, object factory_name)
# Singleton providers # Singleton providers

View File

@ -282,19 +282,19 @@ class FactoryDelegate(Delegate):
def __init__(self, factory: Factory): ... def __init__(self, factory: Factory): ...
class FactoryAggregate(Provider): class FactoryAggregate(Provider[T]):
def __init__(self, **factories: Factory): ... def __init__(self, dict_: Optional[_Dict[Any, Factory[T]]] = None, **factories: Factory[T]): ...
def __getattr__(self, factory_name: str) -> Factory: ... def __getattr__(self, factory_name: Any) -> Factory[T]: ...
@overload @overload
def __call__(self, factory_name: str, *args: Injection, **kwargs: Injection) -> Any: ... def __call__(self, factory_name: Any, *args: Injection, **kwargs: Injection) -> T: ...
@overload @overload
def __call__(self, factory_name: str, *args: Injection, **kwargs: Injection) -> Awaitable[Any]: ... def __call__(self, factory_name: Any, *args: Injection, **kwargs: Injection) -> Awaitable[T]: ...
def async_(self, factory_name: str, *args: Injection, **kwargs: Injection) -> Awaitable[Any]: ... def async_(self, factory_name: Any, *args: Injection, **kwargs: Injection) -> Awaitable[T]: ...
@property @property
def factories(self) -> _Dict[str, Factory]: ... def factories(self) -> _Dict[Any, Factory[T]]: ...
def set_factories(self, **factories: Factory) -> FactoryAggregate: ... def set_factories(self, dict_: Optional[_Dict[Any, Factory[T]]] = None, **factories: Factory[T]) -> FactoryAggregate[T]: ...
class BaseSingleton(Provider[T]): class BaseSingleton(Provider[T]):

View File

@ -2486,10 +2486,10 @@ cdef class FactoryAggregate(Provider):
__IS_DELEGATED__ = True __IS_DELEGATED__ = True
def __init__(self, **factories): def __init__(self, factories_dict_=None, **factories_kwargs):
"""Initialize provider.""" """Initialize provider."""
self.__factories = {} self.__factories = {}
self.set_factories(**factories) self.set_factories(factories_dict_, **factories_kwargs)
super(FactoryAggregate, self).__init__() super(FactoryAggregate, self).__init__()
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
@ -2499,7 +2499,7 @@ cdef class FactoryAggregate(Provider):
return copied return copied
copied = _memorized_duplicate(self, memo) copied = _memorized_duplicate(self, memo)
copied.set_factories(**deepcopy(self.factories, memo)) copied.set_factories(deepcopy(self.factories, memo))
self._copy_overridings(copied, memo) self._copy_overridings(copied, memo)
@ -2521,13 +2521,23 @@ cdef class FactoryAggregate(Provider):
"""Return dictionary of factories, read-only.""" """Return dictionary of factories, read-only."""
return self.__factories return self.__factories
def set_factories(self, **factories): def set_factories(self, factories_dict_=None, **factories_kwargs):
"""Set factories.""" """Set factories."""
factories = {}
factories.update(factories_kwargs)
if factories_dict_:
factories.update(factories_dict_)
for factory in factories.values(): for factory in factories.values():
if isinstance(factory, Factory) is False: if isinstance(factory, Factory) is False:
raise Error( raise Error(
'{0} can aggregate only instances of {1}, given - {2}' '{0} can aggregate only instances of {1}, given - {2}'.format(
.format(self.__class__, Factory, factory)) self.__class__,
Factory,
factory,
),
)
self.__factories = factories self.__factories = factories
return self return self
@ -2539,8 +2549,7 @@ cdef class FactoryAggregate(Provider):
:return: Overriding context. :return: Overriding context.
:rtype: :py:class:`OverridingContext` :rtype: :py:class:`OverridingContext`
""" """
raise Error( raise Error('{0} providers could not be overridden'.format(self.__class__))
'{0} providers could not be overridden'.format(self.__class__))
@property @property
def related(self): def related(self):
@ -2561,12 +2570,10 @@ cdef class FactoryAggregate(Provider):
return self.__get_factory(factory_name)(*args, **kwargs) return self.__get_factory(factory_name)(*args, **kwargs)
cdef Factory __get_factory(self, str factory_name): cdef Factory __get_factory(self, object factory_key):
if factory_name not in self.__factories: if factory_key not in self.__factories:
raise NoSuchProviderError( raise NoSuchProviderError('{0} does not contain factory with name {1}'.format(self, factory_key))
'{0} does not contain factory with name {1}'.format( return <Factory> self.__factories[factory_key]
self, factory_name))
return <Factory> self.__factories[factory_name]
cdef class BaseSingleton(Provider): cdef class BaseSingleton(Provider):

View File

@ -55,13 +55,26 @@ animal7: Animal = provider7(1, 2, 3, b='1', c=2, e=0.0)
provider8 = providers.FactoryDelegate(providers.Factory(object)) provider8 = providers.FactoryDelegate(providers.Factory(object))
# Test 9: to check FactoryAggregate provider # Test 9: to check FactoryAggregate provider
provider9 = providers.FactoryAggregate( provider9: providers.FactoryAggregate[str] = providers.FactoryAggregate(
a=providers.Factory(object), a=providers.Factory(str, "str1"),
b=providers.Factory(object), b=providers.Factory(str, "str2"),
) )
factory_a_9: providers.Factory = provider9.a factory_a_9: providers.Factory[str] = provider9.a
factory_b_9: providers.Factory = provider9.b factory_b_9: providers.Factory[str] = provider9.b
val9: Any = provider9('a') val9: str = provider9('a')
provider9_set_non_string_keys: providers.FactoryAggregate[str] = providers.FactoryAggregate()
provider9_set_non_string_keys.set_factories({Cat: providers.Factory(str, "str")})
factory_set_non_string_9: providers.Factory[str] = provider9_set_non_string_keys.factories[Cat]
provider9_new_non_string_keys: providers.FactoryAggregate[str] = providers.FactoryAggregate(
{Cat: providers.Factory(str, "str")},
)
factory_new_non_string_9: providers.Factory[str] = provider9_new_non_string_keys.factories[Cat]
provider9_no_explicit_typing = providers.FactoryAggregate(a=providers.Factory(str, "str"))
provider9_no_explicit_typing_factory: providers.Factory[str] = provider9_no_explicit_typing.factories["a"]
provider9_no_explicit_typing_object: str = provider9_no_explicit_typing("a")
# Test 10: to check the explicit typing # Test 10: to check the explicit typing
factory10: providers.Provider[Animal] = providers.Factory(Cat) factory10: providers.Provider[Animal] = providers.Factory(Cat)

View File

@ -5,6 +5,7 @@ import sys
import unittest import unittest
from dependency_injector import ( from dependency_injector import (
containers,
providers, providers,
errors, errors,
) )
@ -498,7 +499,8 @@ class FactoryAggregateTests(unittest.TestCase):
self.example_b_factory = providers.Factory(self.ExampleB) self.example_b_factory = providers.Factory(self.ExampleB)
self.factory_aggregate = providers.FactoryAggregate( self.factory_aggregate = providers.FactoryAggregate(
example_a=self.example_a_factory, example_a=self.example_a_factory,
example_b=self.example_b_factory) example_b=self.example_b_factory,
)
def test_is_provider(self): def test_is_provider(self):
self.assertTrue(providers.is_provider(self.factory_aggregate)) self.assertTrue(providers.is_provider(self.factory_aggregate))
@ -506,6 +508,35 @@ class FactoryAggregateTests(unittest.TestCase):
def test_is_delegated_provider(self): def test_is_delegated_provider(self):
self.assertTrue(providers.is_delegated(self.factory_aggregate)) self.assertTrue(providers.is_delegated(self.factory_aggregate))
def test_init_with_non_string_keys(self):
factory = providers.FactoryAggregate({
self.ExampleA: self.example_a_factory,
self.ExampleB: self.example_b_factory,
})
object_a = factory(self.ExampleA, 1, 2, init_arg3=3, init_arg4=4)
object_b = factory(self.ExampleB, 11, 22, init_arg3=33, init_arg4=44)
self.assertIsInstance(object_a, self.ExampleA)
self.assertEqual(object_a.init_arg1, 1)
self.assertEqual(object_a.init_arg2, 2)
self.assertEqual(object_a.init_arg3, 3)
self.assertEqual(object_a.init_arg4, 4)
self.assertIsInstance(object_b, self.ExampleB)
self.assertEqual(object_b.init_arg1, 11)
self.assertEqual(object_b.init_arg2, 22)
self.assertEqual(object_b.init_arg3, 33)
self.assertEqual(object_b.init_arg4, 44)
self.assertEqual(
factory.factories,
{
self.ExampleA: self.example_a_factory,
self.ExampleB: self.example_b_factory,
},
)
def test_init_with_not_a_factory(self): def test_init_with_not_a_factory(self):
with self.assertRaises(errors.Error): with self.assertRaises(errors.Error):
providers.FactoryAggregate( providers.FactoryAggregate(
@ -528,7 +559,37 @@ class FactoryAggregateTests(unittest.TestCase):
self.assertIsInstance(provider('example_a'), self.ExampleA) self.assertIsInstance(provider('example_a'), self.ExampleA)
self.assertIsInstance(provider('example_b'), self.ExampleB) self.assertIsInstance(provider('example_b'), self.ExampleB)
def test_set_provides_returns_self(self): def test_set_factories_with_non_string_keys(self):
factory = providers.FactoryAggregate()
factory.set_factories({
self.ExampleA: self.example_a_factory,
self.ExampleB: self.example_b_factory,
})
object_a = factory(self.ExampleA, 1, 2, init_arg3=3, init_arg4=4)
object_b = factory(self.ExampleB, 11, 22, init_arg3=33, init_arg4=44)
self.assertIsInstance(object_a, self.ExampleA)
self.assertEqual(object_a.init_arg1, 1)
self.assertEqual(object_a.init_arg2, 2)
self.assertEqual(object_a.init_arg3, 3)
self.assertEqual(object_a.init_arg4, 4)
self.assertIsInstance(object_b, self.ExampleB)
self.assertEqual(object_b.init_arg1, 11)
self.assertEqual(object_b.init_arg2, 22)
self.assertEqual(object_b.init_arg3, 33)
self.assertEqual(object_b.init_arg4, 44)
self.assertEqual(
factory.factories,
{
self.ExampleA: self.example_a_factory,
self.ExampleB: self.example_b_factory,
},
)
def test_set_factories_returns_self(self):
provider = providers.FactoryAggregate() provider = providers.FactoryAggregate()
self.assertIs(provider.set_factories(example_a=self.example_a_factory), provider) self.assertIs(provider.set_factories(example_a=self.example_a_factory), provider)
@ -603,6 +664,24 @@ class FactoryAggregateTests(unittest.TestCase):
self.assertIsInstance(self.factory_aggregate.example_b, type(provider_copy.example_b)) self.assertIsInstance(self.factory_aggregate.example_b, type(provider_copy.example_b))
self.assertIs(self.factory_aggregate.example_b.cls, provider_copy.example_b.cls) self.assertIs(self.factory_aggregate.example_b.cls, provider_copy.example_b.cls)
def test_deepcopy_with_non_string_keys(self):
factory_aggregate = providers.FactoryAggregate({
self.ExampleA: self.example_a_factory,
self.ExampleB: self.example_b_factory,
})
provider_copy = providers.deepcopy(factory_aggregate)
self.assertIsNot(factory_aggregate, provider_copy)
self.assertIsInstance(provider_copy, type(factory_aggregate))
self.assertIsNot(factory_aggregate.factories[self.ExampleA], provider_copy.factories[self.ExampleA])
self.assertIsInstance(factory_aggregate.factories[self.ExampleA], type(provider_copy.factories[self.ExampleA]))
self.assertIs(factory_aggregate.factories[self.ExampleA].cls, provider_copy.factories[self.ExampleA].cls)
self.assertIsNot(factory_aggregate.factories[self.ExampleB], provider_copy.factories[self.ExampleB])
self.assertIsInstance(factory_aggregate.factories[self.ExampleB], type(provider_copy.factories[self.ExampleB]))
self.assertIs(factory_aggregate.factories[self.ExampleB].cls, provider_copy.factories[self.ExampleB].cls)
def test_repr(self): def test_repr(self):
self.assertEqual(repr(self.factory_aggregate), self.assertEqual(repr(self.factory_aggregate),
'<dependency_injector.providers.' '<dependency_injector.providers.'