Implement lazy initialization and improve copying for AttributeGetter provider

This commit is contained in:
Roman Mogylatov 2021-03-09 18:00:35 -05:00
parent cec342c1e4
commit b8078f904e
6 changed files with 1659 additions and 1510 deletions

View File

@ -1402,13 +1402,13 @@ struct __pyx_obj_19dependency_injector_9providers_ProvidedInstance {
* *
* *
* cdef class AttributeGetter(Provider): # <<<<<<<<<<<<<< * cdef class AttributeGetter(Provider): # <<<<<<<<<<<<<<
* cdef Provider __provider * cdef object __provides
* cdef object __attribute * cdef object __name
*/ */
struct __pyx_obj_19dependency_injector_9providers_AttributeGetter { struct __pyx_obj_19dependency_injector_9providers_AttributeGetter {
struct __pyx_obj_19dependency_injector_9providers_Provider __pyx_base; struct __pyx_obj_19dependency_injector_9providers_Provider __pyx_base;
struct __pyx_obj_19dependency_injector_9providers_Provider *__pyx___provider; PyObject *__pyx___provides;
PyObject *__pyx___attribute; PyObject *__pyx___name;
}; };
@ -2158,8 +2158,8 @@ static struct __pyx_vtabstruct_19dependency_injector_9providers_ProvidedInstance
* *
* *
* cdef class AttributeGetter(Provider): # <<<<<<<<<<<<<< * cdef class AttributeGetter(Provider): # <<<<<<<<<<<<<<
* cdef Provider __provider * cdef object __provides
* cdef object __attribute * cdef object __name
*/ */
struct __pyx_vtabstruct_19dependency_injector_9providers_AttributeGetter { struct __pyx_vtabstruct_19dependency_injector_9providers_AttributeGetter {

File diff suppressed because it is too large Load Diff

View File

@ -241,8 +241,8 @@ cdef class ProvidedInstance(Provider):
cdef class AttributeGetter(Provider): cdef class AttributeGetter(Provider):
cdef Provider __provider cdef object __provides
cdef object __attribute cdef object __name
cpdef object _provide(self, tuple args, dict kwargs) cpdef object _provide(self, tuple args, dict kwargs)

View File

@ -226,7 +226,7 @@ class Configuration(Object[Any]):
def set_strict(self, strict: bool) -> Configuration: ... def set_strict(self, strict: bool) -> Configuration: ...
def get_children(self) -> _Dict[str, ConfigurationOption]: ... def get_children(self) -> _Dict[str, ConfigurationOption]: ...
def set_children(self, children: _Dict[str, ConfigurationOption) -> Configuration: ... def set_children(self, children: _Dict[str, ConfigurationOption]) -> Configuration: ...
def get(self, selector: str) -> Any: ... def get(self, selector: str) -> Any: ...
def set(self, selector: str, value: Any) -> OverridingContext[P]: ... def set(self, selector: str, value: Any) -> OverridingContext[P]: ...
@ -433,7 +433,10 @@ class ProvidedInstance(Provider, ProvidedInstanceFluentInterface):
class AttributeGetter(Provider, ProvidedInstanceFluentInterface): class AttributeGetter(Provider, ProvidedInstanceFluentInterface):
def __init__(self, provider: Provider, attribute: str) -> None: ... def __init__(self, provides: Optional[Provider] = None, attribute: Optional[str] = None) -> None: ...
@property
def name(self) -> Optional[str]: ...
def set_name(self, name: Optional[str]) -> ProvidedInstanceFluentInterface: ...
class ItemGetter(Provider, ProvidedInstanceFluentInterface): class ItemGetter(Provider, ProvidedInstanceFluentInterface):

View File

@ -3894,7 +3894,7 @@ cdef class ProvidedInstance(Provider):
def __repr__(self): def __repr__(self):
return f'{self.__class__.__name__}(\'{self.__provides}\')' return f'{self.__class__.__name__}(\'{self.__provides}\')'
def __deepcopy__(self, memo=None): def __deepcopy__(self, memo):
copied = memo.get(id(self)) copied = memo.get(id(self))
if copied is not None: if copied is not None:
return copied return copied
@ -3939,25 +3939,26 @@ cdef class AttributeGetter(Provider):
You should not create this provider directly. See :py:class:`ProvidedInstance` instead. You should not create this provider directly. See :py:class:`ProvidedInstance` instead.
""" """
def __init__(self, provider, attribute): def __init__(self, provides=None, name=None):
self.__provider = provider self.__provides = None
self.__attribute = attribute self.set_provides(provides)
self.__name = None
self.set_name(name)
super().__init__() super().__init__()
def __repr__(self): def __repr__(self):
return f'{self.__class__.__name__}(\'{self.__attribute}\')' return f'{self.__class__.__name__}(\'{self.__name}\')'
def __deepcopy__(self, memo=None):
cdef AttributeGetter copied
def __deepcopy__(self, memo):
copied = memo.get(id(self)) copied = memo.get(id(self))
if copied is not None: if copied is not None:
return copied return copied
return self.__class__( copied = _memorized_duplicate(self, memo)
deepcopy(self.__provider, memo), copied.set_provides(_copy_if_provider(self.provides, memo))
self.__attribute, copied.set_name(self.name)
) return copied
def __getattr__(self, item): def __getattr__(self, item):
return AttributeGetter(self, item) return AttributeGetter(self, item)
@ -3967,13 +3968,23 @@ cdef class AttributeGetter(Provider):
@property @property
def provides(self): def provides(self):
"""Return provider.""" """Return provider's provides."""
return self.__provider return self.__provides
def set_provides(self, provides):
"""Set provider's provides."""
self.__provides = provides
return self
@property @property
def name(self): def name(self):
"""Return name of the attribute.""" """Return name of the attribute."""
return self.__attribute return self.__name
def set_name(self, name):
"""Set name of the attribute."""
self.__name = name
return self
def call(self, *args, **kwargs): def call(self, *args, **kwargs):
return MethodCaller(self, *args, **kwargs) return MethodCaller(self, *args, **kwargs)
@ -3981,22 +3992,23 @@ cdef class AttributeGetter(Provider):
@property @property
def related(self): def related(self):
"""Return related providers generator.""" """Return related providers generator."""
yield self.__provider if is_provider(self.provides):
yield self.provides
yield from super().related yield from super().related
cpdef object _provide(self, tuple args, dict kwargs): cpdef object _provide(self, tuple args, dict kwargs):
provided = self.__provider(*args, **kwargs) provided = self.provides(*args, **kwargs)
if __is_future_or_coroutine(provided): if __is_future_or_coroutine(provided):
future_result = asyncio.Future() future_result = asyncio.Future()
provided = asyncio.ensure_future(provided) provided = asyncio.ensure_future(provided)
provided.add_done_callback(functools.partial(self._async_provide, future_result)) provided.add_done_callback(functools.partial(self._async_provide, future_result))
return future_result return future_result
return getattr(provided, self.__attribute) return getattr(provided, self.name)
def _async_provide(self, future_result, future): def _async_provide(self, future_result, future):
try: try:
provided = future.result() provided = future.result()
result = getattr(provided, self.__attribute) result = getattr(provided, self.name)
except Exception: except Exception:
pass pass
else: else:

View File

@ -69,16 +69,6 @@ class ProvidedInstanceTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.container = Container() self.container = Container()
def test_lazy_init(self):
provides = providers.Object(object())
provider = providers.ProvidedInstance()
provider.set_provides(provides)
self.assertIs(provider.provides, provides)
def test_set_provides_returns_self(self):
provider = providers.ProvidedInstance()
self.assertIs(provider.set_provides(providers.Provider()), provider)
def test_is_provider(self): def test_is_provider(self):
self.assertTrue(providers.is_provider(self.container.service.provided)) self.assertTrue(providers.is_provider(self.container.service.provided))
@ -136,6 +126,26 @@ class ProvidedInstanceTests(unittest.TestCase):
) )
class LazyInitTests(unittest.TestCase):
def test_provided_instance(self):
provides = providers.Object(object())
provider = providers.ProvidedInstance()
provider.set_provides(provides)
self.assertIs(provider.provides, provides)
self.assertIs(provider.set_provides(providers.Provider()), provider)
def test_attribute_getter(self):
provides = providers.Object(object())
provider = providers.AttributeGetter()
provider.set_provides(provides)
provider.set_name('__dict__')
self.assertIs(provider.provides, provides)
self.assertEqual(provider.name, '__dict__')
self.assertIs(provider.set_provides(providers.Provider()), provider)
self.assertIs(provider.set_name('__dict__'), provider)
class ProvidedInstancePuzzleTests(unittest.TestCase): class ProvidedInstancePuzzleTests(unittest.TestCase):
def test_puzzled(self): def test_puzzled(self):