Introduce concept with annotations

This commit is contained in:
Roman Mogylatov 2020-09-18 16:55:41 -04:00
parent 9b4761fce5
commit 0c41e2671f
9 changed files with 1875 additions and 1012 deletions

File diff suppressed because it is too large Load Diff

View File

@ -17,6 +17,7 @@ class Container:
def override_providers(self, **overriding_providers: Provider) -> None: ...
def reset_last_overriding(self) -> None: ...
def reset_override(self) -> None: ...
def resolve_provider_name(self, provider_to_resolve: Provider) -> Optional[str]: ...
def wire(self, modules: Optional[Iterable[ModuleType]] = None, packages: Optional[Iterable[ModuleType]] = None) -> None: ...

View File

@ -57,6 +57,7 @@ class DynamicContainer(object):
self.provider_type = Provider
self.providers = dict()
self.overridden = tuple()
self.declarative_parent = None
super(DynamicContainer, self).__init__()
def __deepcopy__(self, memo):
@ -68,6 +69,7 @@ class DynamicContainer(object):
copied = self.__class__()
copied.provider_type = Provider
copied.overridden = deepcopy(self.overridden, memo)
copied.declarative_parent = self.declarative_parent
for name, provider in deepcopy(self.providers, memo).items():
setattr(copied, name, provider)
@ -179,6 +181,20 @@ class DynamicContainer(object):
for provider in six.itervalues(self.providers):
provider.reset_override()
def resolve_provider_name(self, provider_to_resolve):
"""Try to resolve provider name by its instance."""
if self.declarative_parent:
provider_name = self.declarative_parent.resolve_provider_name(provider_to_resolve)
if provider_name:
return provider_name
for provider_name, container_provider in self.providers.items():
if container_provider is provider_to_resolve:
return provider_name
else:
return None
def wire(self, modules=None, packages=None):
"""Wire container providers with provided packages and modules by name.
@ -329,6 +345,7 @@ class DeclarativeContainer(object):
"""
container = cls.instance_type()
container.provider_type = cls.provider_type
container.declarative_parent = cls
container.set_providers(**deepcopy(cls.providers))
container.override_providers(**overriding_providers)
return container
@ -382,6 +399,15 @@ class DeclarativeContainer(object):
for provider in six.itervalues(cls.providers):
provider.reset_override()
@classmethod
def resolve_provider_name(cls, provider_to_resolve):
"""Try to resolve provider name by its instance."""
for provider_name, container_provider in cls.providers.items():
if container_provider is provider_to_resolve:
return provider_name
else:
return None
@classmethod
def wire(cls, modules=None, packages=None):
"""Wire container providers with provided packages and modules by name.

View File

@ -4,11 +4,12 @@ import functools
import inspect
import pkgutil
from types import ModuleType
from typing import Optional, Iterable, Callable, Any, Type, Dict
from typing import Optional, Iterable, Callable, Any, Type, Dict, Generic, TypeVar
from . import providers
AnyContainer = Any
T = TypeVar('T')
def wire(
@ -62,26 +63,41 @@ def _patch_fn(
def _resolve_injections(fn: Callable[..., Any], container: AnyContainer) -> Dict[str, Any]:
signature = inspect.signature(fn)
config = _resolve_container_config(container)
signature = inspect.signature(fn)
injections = {}
for parameter_name, parameter in signature.parameters.items():
if parameter_name in container.providers:
injections[parameter_name] = container.providers[parameter_name]
if not isinstance(parameter.default, _Marker):
continue
marker = parameter.default
if parameter_name.endswith('_provider'):
provider_name = parameter_name[:-9]
if provider_name in container.providers:
injections[parameter_name] = container.providers[provider_name].provider
provider = None
if config and isinstance(parameter.default, ConfigurationOption):
option_provider = config.get_option_provider(parameter.default.selector)
if parameter.annotation:
injections[parameter_name] = option_provider.as_(parameter.annotation)
else:
injections[parameter_name] = option_provider
provider_name = container.resolve_provider_name(marker.provider)
if provider_name:
provider = container.providers[provider_name]
if config and isinstance(marker.provider, providers.ConfigurationOption):
full_option_name = marker.provider.get_name()
_, *parts = full_option_name.split('.')
relative_option_name = '.'.join(parts)
provider = config.get_option_provider(relative_option_name)
if parameter.annotation is int:
provider = provider.as_int()
elif parameter.annotation is float:
provider = provider.as_float()
elif parameter.annotation is not inspect.Parameter.empty:
provider = provider.as_(parameter.annotation)
if provider is None:
continue
if isinstance(marker, Provide):
injections[parameter_name] = provider
elif isinstance(marker, Provider):
injections[parameter_name] = provider.provider
return injections
@ -118,18 +134,24 @@ def _patch_with_injections(fn, injections):
return _patched
class ConfigurationOptionMeta(type):
class ClassGetItemMeta(type):
def __getitem__(cls, item):
# Spike for Python 3.6
return cls(item)
class ConfigurationOption(metaclass=ConfigurationOptionMeta):
"""Configuration option marker."""
class _Marker(Generic[T], metaclass=ClassGetItemMeta):
def __init__(self, provider: providers.Provider) -> None:
self.provider = provider
def __init__(self, selector: str):
self.selector = selector
def __class_getitem__(cls, item):
def __class_getitem__(cls, item) -> T:
return cls(item)
class Provide(_Marker):
...
class Provider(_Marker):
...

View File

@ -0,0 +1,10 @@
from dependency_injector import containers, providers
from .service import Service
class Container(containers.DeclarativeContainer):
config = providers.Configuration()
service = providers.Factory(Service)

View File

@ -1,24 +1,30 @@
"""Test module for wiring."""
from dependency_injector.wiring import ConfigurationOption
from typing import Callable
from dependency_injector.wiring import Provide, Provider
from .container import Container
from .service import Service
class TestClass:
def __init__(self, service):
def __init__(self, service: Service = Provide[Container.service]):
self.service = service
def test_function(service):
def test_function(service: Service = Provide[Container.service]):
return service
def test_function_provider(service_provider):
return service_provider()
def test_function_provider(service_provider: Callable[..., Service] = Provider[Container.service]):
service = service_provider()
return service
def test_config_value(
some_value_int: int = ConfigurationOption['a.b.c'],
some_value_str: str = ConfigurationOption['a.b.c'],
some_value_int: int = Provide[Container.config.a.b.c],
some_value_str: str = Provide[Container.config.a.b.c],
):
return some_value_int, some_value_str

View File

@ -1,2 +1,8 @@
def test_function(service):
from dependency_injector.wiring import Provide
from ...container import Container
from ...service import Service
def test_function(service: Service = Provide[Container.service]):
return service

View File

@ -0,0 +1,2 @@
class Service:
service_attr: int

View File

@ -1,19 +1,9 @@
import unittest
from dependency_injector import containers, providers
from . import module, package
class Service:
...
class Container(containers.DeclarativeContainer):
config = providers.Configuration()
service = providers.Factory(Service)
from .service import Service
from .container import Container
class WiringTest(unittest.TestCase):