This commit is contained in:
Roman Mogylatov 2020-09-18 17:17:05 -04:00
parent 0c41e2671f
commit ae4477c5ab

View File

@ -2,10 +2,18 @@
import functools import functools
import inspect import inspect
import sys
import pkgutil import pkgutil
from types import ModuleType from types import ModuleType
from typing import Optional, Iterable, Callable, Any, Type, Dict, Generic, TypeVar from typing import Optional, Iterable, Callable, Any, Type, Dict, Generic, TypeVar
# Spike for Python 3.6
if sys.version_info < (3, 7):
from typing import GenericMeta
else:
class GenericMeta(type):
...
from . import providers from . import providers
AnyContainer = Any AnyContainer = Any
@ -73,25 +81,12 @@ def _resolve_injections(fn: Callable[..., Any], container: AnyContainer) -> Dict
continue continue
marker = parameter.default marker = parameter.default
provider = None
provider_name = container.resolve_provider_name(marker.provider) provider_name = container.resolve_provider_name(marker.provider)
if provider_name: if provider_name:
provider = container.providers[provider_name] provider = container.providers[provider_name]
elif config and isinstance(marker.provider, providers.ConfigurationOption):
if config and isinstance(marker.provider, providers.ConfigurationOption): provider = _prepare_config_injection(marker.provider, parameter, config)
full_option_name = marker.provider.get_name() else:
_, *parts = full_option_name.split('.')
relative_option_name = '.'.join(parts)
provider = config.get_option_provider(relative_option_name)
if parameter.annotation is int:
provider = provider.as_int()
elif parameter.annotation is float:
provider = provider.as_float()
elif parameter.annotation is not inspect.Parameter.empty:
provider = provider.as_(parameter.annotation)
if provider is None:
continue continue
if isinstance(marker, Provide): if isinstance(marker, Provide):
@ -102,6 +97,24 @@ def _resolve_injections(fn: Callable[..., Any], container: AnyContainer) -> Dict
return injections return injections
def _prepare_config_injection(
option: providers.ConfigurationOption,
parameter: inspect.Parameter,
config: providers.Configuration,
) -> providers.Provider:
full_option_name = option.get_name()
_, *parts = full_option_name.split('.')
relative_option_name = '.'.join(parts)
provider = config.get_option_provider(relative_option_name)
if parameter.annotation is int:
provider = provider.as_int()
elif parameter.annotation is float:
provider = provider.as_float()
elif parameter.annotation is not inspect.Parameter.empty:
provider = provider.as_(parameter.annotation)
return provider
def _resolve_container_config(container: AnyContainer) -> Optional[providers.Configuration]: def _resolve_container_config(container: AnyContainer) -> Optional[providers.Configuration]:
for provider in container.providers.values(): for provider in container.providers.values():
if isinstance(provider, providers.Configuration): if isinstance(provider, providers.Configuration):
@ -141,7 +154,11 @@ class ClassGetItemMeta(type):
return cls(item) return cls(item)
class _Marker(Generic[T], metaclass=ClassGetItemMeta): class GenericClassGetItemMeta(GenericMeta, ClassGetItemMeta):
pass
class _Marker(Generic[T], metaclass=GenericClassGetItemMeta):
def __init__(self, provider: providers.Provider) -> None: def __init__(self, provider: providers.Provider) -> None:
self.provider = provider self.provider = provider