Implement lazy initialization and improve copying for Callable, Factory, Singleton, and Coroutine providers

This commit is contained in:
Roman Mogylatov 2021-03-07 10:17:15 -05:00
parent dbc6393561
commit 0fab16db94
7 changed files with 6468 additions and 6106 deletions

File diff suppressed because it is too large Load Diff

View File

@ -134,7 +134,7 @@ class Callable(Provider[T]):
def __init__(self, provides: Optional[_Callable[..., T]] = None, *args: Injection, **kwargs: Injection) -> None: ... def __init__(self, provides: Optional[_Callable[..., T]] = None, *args: Injection, **kwargs: Injection) -> None: ...
@property @property
def provides(self) -> Optional[T]: ... def provides(self) -> Optional[T]: ...
def set_provides(self, provides: Optional[_Callable[..., T]]) -> None: ... def set_provides(self, provides: Optional[_Callable[..., T]]) -> Callable[T]: ...
@property @property
def args(self) -> Tuple[Injection]: ... def args(self) -> Tuple[Injection]: ...
def add_args(self, *args: Injection) -> Callable[T]: ... def add_args(self, *args: Injection) -> Callable[T]: ...
@ -227,7 +227,7 @@ class Factory(Provider[T]):
def cls(self) -> T: ... def cls(self) -> T: ...
@property @property
def provides(self) -> T: ... def provides(self) -> T: ...
def set_provides(self, provides: Optional[_Callable[..., T]]) -> None: ... def set_provides(self, provides: Optional[_Callable[..., T]]) -> Factory[T]: ...
@property @property
def args(self) -> Tuple[Injection]: ... def args(self) -> Tuple[Injection]: ...
def add_args(self, *args: Injection) -> Factory[T]: ... def add_args(self, *args: Injection) -> Factory[T]: ...
@ -277,7 +277,7 @@ class BaseSingleton(Provider[T]):
def cls(self) -> T: ... def cls(self) -> T: ...
@property @property
def provides(self) -> T: ... def provides(self) -> T: ...
def set_provides(self, provides: Optional[_Callable[..., T]]) -> None: ... def set_provides(self, provides: Optional[_Callable[..., T]]) -> BaseSingleton[T]: ...
@property @property
def args(self) -> Tuple[Injection]: ... def args(self) -> Tuple[Injection]: ...
def add_args(self, *args: Injection) -> BaseSingleton[T]: ... def add_args(self, *args: Injection) -> BaseSingleton[T]: ...

View File

@ -1047,9 +1047,10 @@ cdef class Callable(Provider):
if isinstance(provides, Provider): if isinstance(provides, Provider):
provides = deepcopy(provides, memo) provides = deepcopy(provides, memo)
copied = self.__class__(provides, copied = _memorized_duplicate(self, memo)
*deepcopy(self.args, memo), copied.set_provides(provides)
**deepcopy(self.kwargs, memo)) copied.set_args(*deepcopy(self.args, memo))
copied.set_kwargs(**deepcopy(self.kwargs, memo))
self._copy_overridings(copied, memo) self._copy_overridings(copied, memo)
@ -1075,6 +1076,7 @@ cdef class Callable(Provider):
f'got {provides} instead' f'got {provides} instead'
) )
self.__provides = provides self.__provides = provides
return self
@property @property
def args(self): def args(self):
@ -1271,22 +1273,14 @@ cdef class Coroutine(Callable):
_is_coroutine = _is_coroutine_marker _is_coroutine = _is_coroutine_marker
def __init__(self, provides, *args, **kwargs): def set_provides(self, provides):
"""Initializer. """Set provider's provides."""
:param provides: Wrapped callable.
:type provides: callable
"""
if not asyncio: if not asyncio:
raise Error('Package asyncio is not available') raise Error('Package asyncio is not available')
if provides and not asyncio.iscoroutinefunction(provides):
if not asyncio.iscoroutinefunction(provides): raise Error(f'Provider {_class_qualname(self)} expected to get coroutine function, '
raise Error('Provider {0} expected to get coroutine function, ' f'got {provides} instead')
'got {1}'.format('.'.join((self.__class__.__module__, return super().set_provides(provides)
self.__class__.__name__)),
provides))
super(Coroutine, self).__init__(provides, *args, **kwargs)
cdef class DelegatedCoroutine(Coroutine): cdef class DelegatedCoroutine(Coroutine):
@ -2106,13 +2100,14 @@ cdef class Factory(Provider):
if copied is not None: if copied is not None:
return copied return copied
cls = self.cls provides = self.provides
if isinstance(cls, Provider): if isinstance(provides, Provider):
cls = deepcopy(cls, memo) provides = deepcopy(provides, memo)
copied = self.__class__(cls, copied = _memorized_duplicate(self, memo)
*deepcopy(self.args, memo), copied.set_provides(provides)
**deepcopy(self.kwargs, memo)) copied.set_args(*deepcopy(self.args, memo))
copied.set_kwargs(**deepcopy(self.kwargs, memo))
copied.set_attributes(**deepcopy(self.attributes, memo)) copied.set_attributes(**deepcopy(self.attributes, memo))
self._copy_overridings(copied, memo) self._copy_overridings(copied, memo)
@ -2147,6 +2142,7 @@ cdef class Factory(Provider):
f'{self.__class__.provided_type} instances' f'{self.__class__.provided_type} instances'
) )
self.__instantiator.set_provides(provides) self.__instantiator.set_provides(provides)
return self
@property @property
def args(self): def args(self):
@ -2477,13 +2473,14 @@ cdef class BaseSingleton(Provider):
if copied is not None: if copied is not None:
return copied return copied
cls = self.cls provides = self.provides
if isinstance(cls, Provider): if isinstance(provides, Provider):
cls = deepcopy(cls, memo) provides = deepcopy(provides, memo)
copied = self.__class__(cls, copied = _memorized_duplicate(self, memo)
*deepcopy(self.args, memo), copied.set_provides(provides)
**deepcopy(self.kwargs, memo)) copied.set_args(*deepcopy(self.args, memo))
copied.set_kwargs(**deepcopy(self.kwargs, memo))
copied.set_attributes(**deepcopy(self.attributes, memo)) copied.set_attributes(**deepcopy(self.attributes, memo))
self._copy_overridings(copied, memo) self._copy_overridings(copied, memo)
@ -2510,6 +2507,7 @@ cdef class BaseSingleton(Provider):
f'{self.__class__.provided_type} instances' f'{self.__class__.provided_type} instances'
) )
self.__instantiator.set_provides(provides) self.__instantiator.set_provides(provides)
return self
@property @property
def args(self): def args(self):
@ -4459,3 +4457,16 @@ cpdef _copy_parent(object from_, object to, dict memo):
else from_.parent else from_.parent
) )
to.assign_parent(copied_parent) to.assign_parent(copied_parent)
cpdef object _memorized_duplicate(object instance, dict memo):
copied = instance.__class__()
memo[id(instance)] = copied
return copied
cpdef str _class_qualname(object instance):
name = getattr(instance.__class__, '__qualname__', None)
if not name:
name = '.'.join((instance.__class__.__module__, instance.__class__.__name__))
return name

View File

@ -22,6 +22,16 @@ class CallableTests(unittest.TestCase):
def test_init_with_not_callable(self): def test_init_with_not_callable(self):
self.assertRaises(errors.Error, providers.Callable, 123) self.assertRaises(errors.Error, providers.Callable, 123)
def test_init_optional_provides(self):
provider = providers.Callable()
provider.set_provides(object)
self.assertIs(provider.provides, object)
self.assertIsInstance(provider(), object)
def test_set_provides_returns_self(self):
provider = providers.Callable()
self.assertIs(provider.set_provides(object), provider)
def test_provided_instance_provider(self): def test_provided_instance_provider(self):
provider = providers.Callable(_example) provider = providers.Callable(_example)
self.assertIsInstance(provider.provided, providers.ProvidedInstance) self.assertIsInstance(provider.provided, providers.ProvidedInstance)

View File

@ -43,6 +43,16 @@ class CoroutineTests(AsyncTestCase):
def test_init_with_not_coroutine(self): def test_init_with_not_coroutine(self):
self.assertRaises(errors.Error, providers.Coroutine, lambda: None) self.assertRaises(errors.Error, providers.Coroutine, lambda: None)
def test_init_optional_provides(self):
provider = providers.Coroutine()
provider.set_provides(_example)
self.assertIs(provider.provides, _example)
self.assertEqual(run(provider(1, 2, 3, 4)), (1, 2, 3, 4))
def test_set_provides_returns_self(self):
provider = providers.Coroutine()
self.assertIs(provider.set_provides(_example), provider)
def test_call_with_positional_args(self): def test_call_with_positional_args(self):
provider = providers.Coroutine(_example, 1, 2, 3, 4) provider = providers.Coroutine(_example, 1, 2, 3, 4)
self.assertTupleEqual(self._run(provider()), (1, 2, 3, 4)) self.assertTupleEqual(self._run(provider()), (1, 2, 3, 4))

View File

@ -34,6 +34,16 @@ class FactoryTests(unittest.TestCase):
def test_init_with_not_callable(self): def test_init_with_not_callable(self):
self.assertRaises(errors.Error, providers.Factory, 123) self.assertRaises(errors.Error, providers.Factory, 123)
def test_init_optional_provides(self):
provider = providers.Factory()
provider.set_provides(object)
self.assertIs(provider.provides, object)
self.assertIsInstance(provider(), object)
def test_set_provides_returns_self(self):
provider = providers.Factory()
self.assertIs(provider.set_provides(object), provider)
def test_init_with_valid_provided_type(self): def test_init_with_valid_provided_type(self):
class ExampleProvider(providers.Factory): class ExampleProvider(providers.Factory):
provided_type = Example provided_type = Example

View File

@ -36,6 +36,16 @@ class _BaseSingletonTestCase(object):
def test_init_with_not_callable(self): def test_init_with_not_callable(self):
self.assertRaises(errors.Error, self.singleton_cls, 123) self.assertRaises(errors.Error, self.singleton_cls, 123)
def test_init_optional_provides(self):
provider = self.singleton_cls()
provider.set_provides(object)
self.assertIs(provider.provides, object)
self.assertIsInstance(provider(), object)
def test_set_provides_returns_self(self):
provider = self.singleton_cls()
self.assertIs(provider.set_provides(object), provider)
def test_init_with_valid_provided_type(self): def test_init_with_valid_provided_type(self):
class ExampleProvider(self.singleton_cls): class ExampleProvider(self.singleton_cls):
provided_type = Example provided_type = Example