Implement Provide[foo.provided.bar.baz.call()]

This commit is contained in:
Roman Mogylatov 2020-09-26 01:07:32 -04:00
parent 6d92df32aa
commit 95db0eddc9
5 changed files with 1352 additions and 607 deletions

File diff suppressed because it is too large Load Diff

View File

@ -2703,6 +2703,11 @@ cdef class ProvidedInstance(Provider):
def __getitem__(self, item):
return ItemGetter(self, item)
@property
def provides(self):
"""Return provider."""
return self.__provider
def call(self, *args, **kwargs):
return MethodCaller(self, *args, **kwargs)
@ -2742,6 +2747,16 @@ cdef class AttributeGetter(Provider):
def __getitem__(self, item):
return ItemGetter(self, item)
@property
def provides(self):
"""Return provider."""
return self.__provider
@property
def name(self):
"""Return name of the attribute."""
return self.__attribute
def call(self, *args, **kwargs):
return MethodCaller(self, *args, **kwargs)
@ -2782,6 +2797,16 @@ cdef class ItemGetter(Provider):
def __getitem__(self, item):
return ItemGetter(self, item)
@property
def provides(self):
"""Return provider."""
return self.__provider
@property
def name(self):
"""Return name of the item."""
return self.__item
def call(self, *args, **kwargs):
return MethodCaller(self, *args, **kwargs)
@ -2833,6 +2858,37 @@ cdef class MethodCaller(Provider):
def __getitem__(self, item):
return ItemGetter(self, item)
@property
def provides(self):
"""Return provider."""
return self.__provider
@property
def args(self):
"""Return positional argument injections."""
cdef int index
cdef PositionalInjection arg
cdef list args
args = list()
for index in range(self.__args_len):
arg = self.__args[index]
args.append(arg.__value)
return tuple(args)
@property
def kwargs(self):
"""Return keyword argument injections."""
cdef int index
cdef NamedInjection kwarg
cdef dict kwargs
kwargs = dict()
for index in range(self.__kwargs_len):
kwarg = self.__kwargs[index]
kwargs[kwarg.__name] = kwarg.__value
return kwargs
def call(self, *args, **kwargs):
return MethodCaller(self, *args, **kwargs)

View File

@ -130,6 +130,16 @@ def _resolve_injections(fn: Callable[..., Any], container: AnyContainer) -> Dict
elif isinstance(marker.provider, providers.Delegate):
provider_name = container.resolve_provider_name(marker.provider.provides)
provider = container.providers[provider_name]
elif isinstance(marker.provider, (
providers.ProvidedInstance,
providers.AttributeGetter,
providers.ItemGetter,
providers.MethodCaller,
)):
provider = _prepare_provided_instance_injection(
marker.provider,
container,
)
elif isinstance(marker.provider, providers.Provider):
provider_name = container.resolve_provider_name(marker.provider)
if not provider_name:
@ -159,6 +169,40 @@ def _prepare_config_injection(
return provider
def _prepare_provided_instance_injection(
current_provider: providers.Provider,
container: AnyContainer,
) -> providers.Provider:
provided_instance_markers = []
instance_provider_marker = current_provider
while isinstance(instance_provider_marker, (
providers.ProvidedInstance,
providers.AttributeGetter,
providers.ItemGetter,
providers.MethodCaller,
)):
provided_instance_markers.insert(0, instance_provider_marker)
instance_provider_marker = instance_provider_marker.provides
provider_name = container.resolve_provider_name(instance_provider_marker)
provider = container.providers[provider_name]
for provided_instance in provided_instance_markers:
if isinstance(provided_instance, providers.ProvidedInstance):
provider = provider.provided
elif isinstance(provided_instance, providers.AttributeGetter):
provider = getattr(provider, provided_instance.name)
elif isinstance(provided_instance, providers.ItemGetter):
provider = provider[provided_instance.name]
elif isinstance(provided_instance, providers.MethodCaller):
provider = provider.call(
*provided_instance.args,
**provided_instance.kwargs,
)
return provider
def _resolve_container_config(container: AnyContainer) -> Optional[providers.Configuration]:
for provider in container.providers.values():
if isinstance(provider, providers.Configuration):

View File

@ -35,3 +35,7 @@ def test_config_value(
def test_provide_provider(service_provider: Callable[..., Service] = Provider[Container.service.provider]):
service = service_provider()
return service
def test_provided_instance(some_value: int = Provide[Container.service.provided.foo['bar'].call()]):
return some_value

View File

@ -62,3 +62,13 @@ class WiringTest(unittest.TestCase):
def test_provide_provider(self):
service = module.test_provide_provider()
self.assertIsInstance(service, Service)
def test_provided_instance(self):
class TestService:
foo = {
'bar': lambda: 10,
}
with self.container.service.override(TestService()):
some_value = module.test_provided_instance()
self.assertEqual(some_value, 10)