Refactor FactoryAggregate

This commit is contained in:
Roman Mogylatov 2022-01-08 19:43:56 -05:00
parent ec78d03b9b
commit 03487d7945
5 changed files with 5077 additions and 7321 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -148,10 +148,8 @@ cdef class FactoryDelegate(Delegate):
pass pass
cdef class FactoryAggregate(Provider): cdef class FactoryAggregate(Aggregate):
cdef dict __providers pass
cdef Provider __get_provider(self, object provider_name)
# Singleton providers # Singleton providers

View File

@ -317,20 +317,8 @@ class FactoryDelegate(Delegate):
def __init__(self, factory: Factory): ... def __init__(self, factory: Factory): ...
class FactoryAggregate(Provider[T]): class FactoryAggregate(Aggregate[T]):
def __init__(self, provider_dict: Optional[_Dict[Any, Provider[T]]] = None, **provider_kwargs: Provider[T]): ...
def __getattr__(self, provider_name: Any) -> Factory[T]: ... def __getattr__(self, provider_name: Any) -> Factory[T]: ...
@overload
def __call__(self, provider_name: Optional[Any] = None, *args: Injection, **kwargs: Injection) -> T: ...
@overload
def __call__(self, provider_name: Optional[Any] = None, *args: Injection, **kwargs: Injection) -> Awaitable[T]: ...
def async_(self, provider_name: Optional[Any] = None, *args: Injection, **kwargs: Injection) -> Awaitable[T]: ...
@property
def providers(self) -> _Dict[Any, Provider[T]]: ...
def set_providers(self, provider_dict: Optional[_Dict[Any, Provider[T]]] = None, **provider_kwargs: Provider[T]) -> FactoryAggregate[T]: ...
@property @property
def factories(self) -> _Dict[Any, Factory[T]]: ... def factories(self) -> _Dict[Any, Factory[T]]: ...
def set_factories(self, provider_dict: Optional[_Dict[Any, Factory[T]]] = None, **provider_kwargs: Factory[T]) -> FactoryAggregate[T]: ... def set_factories(self, provider_dict: Optional[_Dict[Any, Factory[T]]] = None, **provider_kwargs: Factory[T]) -> FactoryAggregate[T]: ...
@ -502,7 +490,9 @@ class MethodCaller(Provider, ProvidedInstanceFluentInterface):
class OverridingContext(Generic[T]): class OverridingContext(Generic[T]):
def __init__(self, overridden: Provider, overriding: Provider): ... def __init__(self, overridden: Provider, overriding: Provider): ...
def __enter__(self) -> T: ... def __enter__(self) -> T: ...
def __exit__(self, *_: Any) -> None: ... def __exit__(self, *_: Any) -> None:
pass
...
class BaseSingletonResetContext(Generic[T]): class BaseSingletonResetContext(Generic[T]):

View File

@ -2670,7 +2670,7 @@ cdef class FactoryDelegate(Delegate):
super(FactoryDelegate, self).__init__(factory) super(FactoryDelegate, self).__init__(factory)
cdef class FactoryAggregate(Provider): cdef class FactoryAggregate(Aggregate):
"""Factory providers aggregate. """Factory providers aggregate.
:py:class:`FactoryAggregate` is an aggregate of :py:class:`Factory` :py:class:`FactoryAggregate` is an aggregate of :py:class:`Factory`
@ -2684,69 +2684,6 @@ cdef class FactoryAggregate(Provider):
:py:class:`FactoryAggregate`. :py:class:`FactoryAggregate`.
""" """
__IS_DELEGATED__ = True
def __init__(self, provider_dict=None, **provider_kwargs):
"""Initialize provider."""
self.__providers = {}
self.set_providers(provider_dict, **provider_kwargs)
super(FactoryAggregate, self).__init__()
def __deepcopy__(self, memo):
"""Create and return full copy of provider."""
copied = memo.get(id(self))
if copied is not None:
return copied
copied = _memorized_duplicate(self, memo)
copied.set_providers(deepcopy(self.providers, memo))
self._copy_overridings(copied, memo)
return copied
def __getattr__(self, factory_name):
"""Return aggregated provider."""
return self.__get_provider(factory_name)
def __str__(self):
"""Return string representation of provider.
:rtype: str
"""
return represent_provider(provider=self, provides=self.providers)
@property
def providers(self):
"""Return dictionary of providers, read-only.
Alias for ``.factories`` attribute.
"""
return dict(self.__providers)
def set_providers(self, provider_dict=None, **provider_kwargs):
"""Set providers.
Alias for ``.set_factories()`` method.
"""
providers = {}
providers.update(provider_kwargs)
if provider_dict:
providers.update(provider_dict)
for provider in providers.values():
if not is_provider(provider):
raise Error(
'{0} can aggregate only instances of {1}, given - {2}'.format(
self.__class__,
Provider,
provider,
),
)
self.__providers = providers
return self
@property @property
def factories(self): def factories(self):
"""Return dictionary of factories, read-only. """Return dictionary of factories, read-only.
@ -2762,40 +2699,6 @@ cdef class FactoryAggregate(Provider):
""" """
return self.set_providers(factory_dict, **factory_kwargs) return self.set_providers(factory_dict, **factory_kwargs)
def override(self, _):
"""Override provider with another provider.
:raise: :py:exc:`dependency_injector.errors.Error`
:return: Overriding context.
:rtype: :py:class:`OverridingContext`
"""
raise Error('{0} providers could not be overridden'.format(self.__class__))
@property
def related(self):
"""Return related providers generator."""
yield from self.__providers.values()
yield from super().related
cpdef object _provide(self, tuple args, dict kwargs):
try:
provider_name = args[0]
except IndexError:
try:
provider_name = kwargs.pop("factory_name")
except KeyError:
raise TypeError("Missing 1st required positional argument: \"provider_name\"")
else:
args = args[1:]
return self.__get_provider(provider_name)(*args, **kwargs)
cdef Provider __get_provider(self, object provider_name):
if provider_name not in self.__providers:
raise NoSuchProviderError("{0} does not contain provider with name {1}".format(self, provider_name))
return <Provider> self.__providers[provider_name]
cdef class BaseSingleton(Provider): cdef class BaseSingleton(Provider):
"""Base class of singleton providers.""" """Base class of singleton providers."""