This commit is contained in:
Roman Mogylatov 2020-09-18 17:17:05 -04:00
parent 0c41e2671f
commit ae4477c5ab

View File

@ -2,10 +2,18 @@
import functools import functools
import inspect import inspect
import sys
import pkgutil import pkgutil
from types import ModuleType from types import ModuleType
from typing import Optional, Iterable, Callable, Any, Type, Dict, Generic, TypeVar from typing import Optional, Iterable, Callable, Any, Type, Dict, Generic, TypeVar
# Spike for Python 3.6
if sys.version_info < (3, 7):
from typing import GenericMeta
else:
class GenericMeta(type):
...
from . import providers from . import providers
AnyContainer = Any AnyContainer = Any
@ -73,14 +81,28 @@ def _resolve_injections(fn: Callable[..., Any], container: AnyContainer) -> Dict
continue continue
marker = parameter.default marker = parameter.default
provider = None
provider_name = container.resolve_provider_name(marker.provider) provider_name = container.resolve_provider_name(marker.provider)
if provider_name: if provider_name:
provider = container.providers[provider_name] provider = container.providers[provider_name]
elif config and isinstance(marker.provider, providers.ConfigurationOption):
provider = _prepare_config_injection(marker.provider, parameter, config)
else:
continue
if config and isinstance(marker.provider, providers.ConfigurationOption): if isinstance(marker, Provide):
full_option_name = marker.provider.get_name() injections[parameter_name] = provider
elif isinstance(marker, Provider):
injections[parameter_name] = provider.provider
return injections
def _prepare_config_injection(
option: providers.ConfigurationOption,
parameter: inspect.Parameter,
config: providers.Configuration,
) -> providers.Provider:
full_option_name = option.get_name()
_, *parts = full_option_name.split('.') _, *parts = full_option_name.split('.')
relative_option_name = '.'.join(parts) relative_option_name = '.'.join(parts)
provider = config.get_option_provider(relative_option_name) provider = config.get_option_provider(relative_option_name)
@ -90,16 +112,7 @@ def _resolve_injections(fn: Callable[..., Any], container: AnyContainer) -> Dict
provider = provider.as_float() provider = provider.as_float()
elif parameter.annotation is not inspect.Parameter.empty: elif parameter.annotation is not inspect.Parameter.empty:
provider = provider.as_(parameter.annotation) provider = provider.as_(parameter.annotation)
return provider
if provider is None:
continue
if isinstance(marker, Provide):
injections[parameter_name] = provider
elif isinstance(marker, Provider):
injections[parameter_name] = provider.provider
return injections
def _resolve_container_config(container: AnyContainer) -> Optional[providers.Configuration]: def _resolve_container_config(container: AnyContainer) -> Optional[providers.Configuration]:
@ -141,7 +154,11 @@ class ClassGetItemMeta(type):
return cls(item) return cls(item)
class _Marker(Generic[T], metaclass=ClassGetItemMeta): class GenericClassGetItemMeta(GenericMeta, ClassGetItemMeta):
pass
class _Marker(Generic[T], metaclass=GenericClassGetItemMeta):
def __init__(self, provider: providers.Provider) -> None: def __init__(self, provider: providers.Provider) -> None:
self.provider = provider self.provider = provider