* Add single container prototype

* Add multiple containers prototype

* Add integration tests

* Implement from_*() methods and add tests

* Prototype inline injections

* Add integration test for inline providers

* Refactor integration tests

* Add integration test for reordered schema

* Remove unused imports from tests

* Refactor schema module

* Update tests to match latest schemas

* Add mypy_boto3_s3 to the test requirements

* Add boto3 to the test requirements

* Add set_provides for Callable, Factory, and Singleton providers

* Fix warnings in tests

* Add typing stubs for Callable, Factory, and Singleton .set_provides() attributes

* Fix singleton children to have optional provides

* Implement provider to provider resolving

* Fix pypy3 tests

* Implement boto3 session use case and add tests

* Implement lazy initialization and improve copying for Callable, Factory, Singleton, and Coroutine providers

* Fix Python 2 tests

* Add region name for boto3 integration example

* Remove f-strings from set_provides()

* Fix schema flake8 errors

* Implement lazy initialization and improve copying for Delegate provider

* Implement lazy initialization and improve copying for Object provider

* Speed up wiring tests

* Implement lazy initialization and improve copying for FactoryAggregate provider

* Implement lazy initialization and improve copying for Selector provider

* Implement lazy initialization and improve copying for Dependency provider

* Implement lazy initialization and improve copying for Resource provider

* Implement lazy initialization and improve copying for Configuration provider

* Implement lazy initialization and improve copying for ProvidedInstance provider

* Implement lazy initialization and improve copying for AttributeGetter provider

* Implement lazy initialization and improve copying for ItemGetter provider

* Implement lazy initialization and improve copying for MethodCaller provder

* Update changelog

* Fix typing in wiring module

* Fix wiring module loader uninstallation issue

* Fix provided instance providers error handing in asynchronous mode

Co-authored-by: Roman Mogylatov <rmk@Romans-MacBook-Pro.local>
This commit is contained in:
Roman Mogylatov 2021-03-20 13:16:51 -04:00 committed by GitHub
parent 8cad8c6b65
commit f961ff536a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 21101 additions and 14879 deletions

View File

@ -7,6 +7,14 @@ that were made in every particular version.
From version 0.7.6 *Dependency Injector* framework strictly From version 0.7.6 *Dependency Injector* framework strictly
follows `Semantic versioning`_ follows `Semantic versioning`_
Development version
-------------------
- Implement providers' lazy initialization.
- Improve providers' copying.
- Improve typing in wiring module.
- Fix wiring module loader uninstallation issue.
- Fix provided instance providers error handing in asynchronous mode.
4.30.0 4.30.0
------ ------
- Remove restriction to wire a dynamic container. - Remove restriction to wire a dynamic container.

View File

@ -12,5 +12,7 @@ fastapi
pydantic pydantic
numpy numpy
scipy scipy
boto3
mypy_boto3_s3
-r requirements-ext.txt -r requirements-ext.txt

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +1,4 @@
from pathlib import Path
from typing import ( from typing import (
Generic, Generic,
Type, Type,
@ -34,6 +35,7 @@ class Container:
def __init__(self) -> None: ... def __init__(self) -> None: ...
def __deepcopy__(self, memo: Optional[Dict[str, Any]]) -> Provider: ... def __deepcopy__(self, memo: Optional[Dict[str, Any]]) -> Provider: ...
def __setattr__(self, name: str, value: Union[Provider, Any]) -> None: ... def __setattr__(self, name: str, value: Union[Provider, Any]) -> None: ...
def __getattr__(self, name: str) -> Provider: ...
def __delattr__(self, name: str) -> None: ... def __delattr__(self, name: str) -> None: ...
def set_providers(self, **providers: Provider): ... def set_providers(self, **providers: Provider): ...
def set_provider(self, name: str, provider: Provider) -> None: ... def set_provider(self, name: str, provider: Provider) -> None: ...
@ -48,6 +50,9 @@ class Container:
def apply_container_providers_overridings(self) -> None: ... def apply_container_providers_overridings(self) -> None: ...
def reset_singletons(self) -> SingletonResetContext[C_Base]: ... def reset_singletons(self) -> SingletonResetContext[C_Base]: ...
def check_dependencies(self) -> None: ... def check_dependencies(self) -> None: ...
def from_schema(self, schema: Dict[Any, Any]) -> None: ...
def from_yaml_schema(self, filepath: Union[Path, str], loader: Optional[Any]=None) -> None: ...
def from_json_schema(self, filepath: Union[Path, str]) -> None: ...
@overload @overload
def resolve_provider_name(self, provider: Provider) -> str: ... def resolve_provider_name(self, provider: Provider) -> str: ...
@classmethod @classmethod

View File

@ -1,5 +1,6 @@
"""Containers module.""" """Containers module."""
import json
import sys import sys
try: try:
@ -7,6 +8,11 @@ try:
except ImportError: except ImportError:
asyncio = None asyncio = None
try:
import yaml
except ImportError:
yaml = None
import six import six
from . import providers, errors from . import providers, errors
@ -330,6 +336,39 @@ class DynamicContainer(Container):
f'{", ".join(undefined_names)}', f'{", ".join(undefined_names)}',
) )
def from_schema(self, schema):
"""Build container providers from schema."""
from .schema import build_schema
for name, provider in build_schema(schema).items():
self.set_provider(name, provider)
def from_yaml_schema(self, filepath, loader=None):
"""Build container providers from YAML schema.
You can specify type of loader as a second argument. By default, method
uses ``SafeLoader``.
"""
if yaml is None:
raise errors.Error(
'Unable to load yaml schema - PyYAML is not installed. '
'Install PyYAML or install Dependency Injector with yaml extras: '
'"pip install dependency-injector[yaml]"'
)
if loader is None:
loader = yaml.SafeLoader
with open(filepath) as file:
schema = yaml.load(file, loader)
self.from_schema(schema)
def from_json_schema(self, filepath):
"""Build container providers from JSON schema."""
with open(filepath) as file:
schema = json.load(file)
self.from_schema(schema)
def resolve_provider_name(self, provider): def resolve_provider_name(self, provider):
"""Try to resolve provider name.""" """Try to resolve provider name."""
for provider_name, container_provider in self.providers.items(): for provider_name, container_provider in self.providers.items():

File diff suppressed because it is too large Load Diff

View File

@ -203,7 +203,7 @@ cdef class Dict(Provider):
cdef class Resource(Provider): cdef class Resource(Provider):
cdef object __initializer cdef object __provides
cdef bint __initialized cdef bint __initialized
cdef object __shutdowner cdef object __shutdowner
cdef object __resource cdef object __resource
@ -235,27 +235,27 @@ cdef class Selector(Provider):
# Provided instance # Provided instance
cdef class ProvidedInstance(Provider): cdef class ProvidedInstance(Provider):
cdef Provider __provider cdef object __provides
cpdef object _provide(self, tuple args, dict kwargs) cpdef object _provide(self, tuple args, dict kwargs)
cdef class AttributeGetter(Provider): cdef class AttributeGetter(Provider):
cdef Provider __provider cdef object __provides
cdef object __attribute cdef object __name
cpdef object _provide(self, tuple args, dict kwargs) cpdef object _provide(self, tuple args, dict kwargs)
cdef class ItemGetter(Provider): cdef class ItemGetter(Provider):
cdef Provider __provider cdef object __provides
cdef object __item cdef object __name
cpdef object _provide(self, tuple args, dict kwargs) cpdef object _provide(self, tuple args, dict kwargs)
cdef class MethodCaller(Provider): cdef class MethodCaller(Provider):
cdef Provider __provider cdef object __provides
cdef tuple __args cdef tuple __args
cdef int __args_len cdef int __args_len
cdef tuple __kwargs cdef tuple __kwargs

View File

@ -79,7 +79,9 @@ class Provider(Generic[T]):
class Object(Provider[T]): class Object(Provider[T]):
def __init__(self, provides: T) -> None: ... def __init__(self, provides: Optional[T] = None) -> None: ...
def provides(self) -> Optional[T]: ...
def set_provides(self, provides: Optional[T]) -> Object: ...
class Self(Provider[T]): class Self(Provider[T]):
@ -94,16 +96,21 @@ class Delegate(Provider[Provider]):
def __init__(self, provides: Optional[Provider] = None) -> None: ... def __init__(self, provides: Optional[Provider] = None) -> None: ...
@property @property
def provides(self) -> Optional[Provider]: ... def provides(self) -> Optional[Provider]: ...
def set_provides(self, provides: Optional[Provider]): ... def set_provides(self, provides: Optional[Provider]) -> Delegate: ...
class Dependency(Provider[T]): class Dependency(Provider[T]):
def __init__(self, instance_of: Type[T] = object, default: Optional[Union[Provider, Any]] = None) -> None: ... def __init__(self, instance_of: Type[T] = object, default: Optional[Union[Provider, Any]] = None) -> None: ...
def __getattr__(self, name: str) -> Any: ... def __getattr__(self, name: str) -> Any: ...
@property @property
def instance_of(self) -> Type[T]: ... def instance_of(self) -> Type[T]: ...
def set_instance_of(self, instance_of: Type[T]) -> Dependency[T]: ...
@property @property
def default(self) -> Provider[T]: ... def default(self) -> Provider[T]: ...
def set_default(self, default: Optional[Union[Provider, Any]]) -> Dependency[T]: ...
@property @property
def is_defined(self) -> bool: ... def is_defined(self) -> bool: ...
def provided_by(self, provider: Provider) -> OverridingContext[P]: ... def provided_by(self, provider: Provider) -> OverridingContext[P]: ...
@ -131,9 +138,10 @@ class DependenciesContainer(Object):
class Callable(Provider[T]): class Callable(Provider[T]):
def __init__(self, provides: _Callable[..., T], *args: Injection, **kwargs: Injection) -> None: ... def __init__(self, provides: Optional[_Callable[..., T]] = None, *args: Injection, **kwargs: Injection) -> None: ...
@property @property
def provides(self) -> T: ... def provides(self) -> Optional[T]: ...
def set_provides(self, provides: Optional[_Callable[..., T]]) -> Callable[T]: ...
@property @property
def args(self) -> Tuple[Injection]: ... def args(self) -> Tuple[Injection]: ...
def add_args(self, *args: Injection) -> Callable[T]: ... def add_args(self, *args: Injection) -> Callable[T]: ...
@ -207,7 +215,19 @@ class Configuration(Object[Any]):
def __exit__(self, *exc_info: Any) -> None: ... def __exit__(self, *exc_info: Any) -> None: ...
def __getattr__(self, item: str) -> ConfigurationOption: ... def __getattr__(self, item: str) -> ConfigurationOption: ...
def __getitem__(self, item: Union[str, Provider]) -> ConfigurationOption: ... def __getitem__(self, item: Union[str, Provider]) -> ConfigurationOption: ...
def get_name(self) -> str: ... def get_name(self) -> str: ...
def set_name(self, name: str) -> Configuration: ...
def get_default(self) -> _Dict[Any, Any]: ...
def set_default(self, default: _Dict[Any, Any]): ...
def get_strict(self) -> bool: ...
def set_strict(self, strict: bool) -> Configuration: ...
def get_children(self) -> _Dict[str, ConfigurationOption]: ...
def set_children(self, children: _Dict[str, ConfigurationOption]) -> Configuration: ...
def get(self, selector: str) -> Any: ... def get(self, selector: str) -> Any: ...
def set(self, selector: str, value: Any) -> OverridingContext[P]: ... def set(self, selector: str, value: Any) -> OverridingContext[P]: ...
def reset_cache(self) -> None: ... def reset_cache(self) -> None: ...
@ -221,11 +241,12 @@ class Configuration(Object[Any]):
class Factory(Provider[T]): class Factory(Provider[T]):
provided_type: Optional[Type] provided_type: Optional[Type]
def __init__(self, provides: _Callable[..., T], *args: Injection, **kwargs: Injection) -> None: ... def __init__(self, provides: Optional[_Callable[..., T]] = None, *args: Injection, **kwargs: Injection) -> None: ...
@property @property
def cls(self) -> T: ... def cls(self) -> T: ...
@property @property
def provides(self) -> T: ... def provides(self) -> T: ...
def set_provides(self, provides: Optional[_Callable[..., T]]) -> Factory[T]: ...
@property @property
def args(self) -> Tuple[Injection]: ... def args(self) -> Tuple[Injection]: ...
def add_args(self, *args: Injection) -> Factory[T]: ... def add_args(self, *args: Injection) -> Factory[T]: ...
@ -266,15 +287,17 @@ class FactoryAggregate(Provider):
@property @property
def factories(self) -> _Dict[str, Factory]: ... def factories(self) -> _Dict[str, Factory]: ...
def set_factories(self, **factories: Factory) -> FactoryAggregate: ...
class BaseSingleton(Provider[T]): class BaseSingleton(Provider[T]):
provided_type = Optional[Type] provided_type = Optional[Type]
def __init__(self, provides: _Callable[..., T], *args: Injection, **kwargs: Injection) -> None: ... def __init__(self, provides: Optional[_Callable[..., T]] = None, *args: Injection, **kwargs: Injection) -> None: ...
@property @property
def cls(self) -> T: ... def cls(self) -> T: ...
@property @property
def provides(self) -> T: ... def provides(self) -> T: ...
def set_provides(self, provides: Optional[_Callable[..., T]]) -> BaseSingleton[T]: ...
@property @property
def args(self) -> Tuple[Injection]: ... def args(self) -> Tuple[Injection]: ...
def add_args(self, *args: Injection) -> BaseSingleton[T]: ... def add_args(self, *args: Injection) -> BaseSingleton[T]: ...
@ -340,19 +363,20 @@ class Dict(Provider[_Dict]):
class Resource(Provider[T]): class Resource(Provider[T]):
@overload @overload
def __init__(self, initializer: Type[resources.Resource[T]], *args: Injection, **kwargs: Injection) -> None: ... def __init__(self, provides: Optional[Type[resources.Resource[T]]] = None, *args: Injection, **kwargs: Injection) -> None: ...
@overload @overload
def __init__(self, initializer: Type[resources.AsyncResource[T]], *args: Injection, **kwargs: Injection) -> None: ... def __init__(self, provides: Optional[Type[resources.AsyncResource[T]]] = None, *args: Injection, **kwargs: Injection) -> None: ...
@overload @overload
def __init__(self, initializer: _Callable[..., _Iterator[T]], *args: Injection, **kwargs: Injection) -> None: ... def __init__(self, provides: Optional[_Callable[..., _Iterator[T]]] = None, *args: Injection, **kwargs: Injection) -> None: ...
@overload @overload
def __init__(self, initializer: _Callable[..., _AsyncIterator[T]], *args: Injection, **kwargs: Injection) -> None: ... def __init__(self, provides: Optional[_Callable[..., _AsyncIterator[T]]] = None, *args: Injection, **kwargs: Injection) -> None: ...
@overload @overload
def __init__(self, initializer: _Callable[..., _Coroutine[Injection, Injection, T]], *args: Injection, **kwargs: Injection) -> None: ... def __init__(self, provides: Optional[_Callable[..., _Coroutine[Injection, Injection, T]]] = None, *args: Injection, **kwargs: Injection) -> None: ...
@overload @overload
def __init__(self, initializer: _Callable[..., T], *args: Injection, **kwargs: Injection) -> None: ... def __init__(self, provides: Optional[_Callable[..., T]] = None, *args: Injection, **kwargs: Injection) -> None: ...
@property @property
def initializer(self) -> _Callable[..., Any]: ... def provides(self) -> Optional[_Callable[..., Any]]: ...
def set_provides(self, provides: Optional[Any]) -> Resource[T]: ...
@property @property
def args(self) -> Tuple[Injection]: ... def args(self) -> Tuple[Injection]: ...
def add_args(self, *args: Injection) -> Resource[T]: ... def add_args(self, *args: Injection) -> Resource[T]: ...
@ -383,32 +407,47 @@ class Container(Provider[T]):
class Selector(Provider[Any]): class Selector(Provider[Any]):
def __init__(self, selector: _Callable[..., Any], **providers: Provider): ... def __init__(self, selector: Optional[_Callable[..., Any]] = None, **providers: Provider): ...
def __getattr__(self, name: str) -> Provider: ... def __getattr__(self, name: str) -> Provider: ...
@property
def selector(self) -> Optional[_Callable[..., Any]]: ...
def set_selector(self, selector: Optional[_Callable[..., Any]]) -> Selector: ...
@property @property
def providers(self) -> _Dict[str, Provider]: ... def providers(self) -> _Dict[str, Provider]: ...
def set_providers(self, **providers: Provider) -> Selector: ...
class ProvidedInstanceFluentInterface: class ProvidedInstanceFluentInterface:
def __getattr__(self, item: Any) -> AttributeGetter: ... def __getattr__(self, item: Any) -> AttributeGetter: ...
def __getitem__(self, item: Any) -> ItemGetter: ... def __getitem__(self, item: Any) -> ItemGetter: ...
def call(self, *args: Injection, **kwargs: Injection) -> MethodCaller: ... def call(self, *args: Injection, **kwargs: Injection) -> MethodCaller: ...
@property
def provides(self) -> Optional[Provider]: ...
def set_provides(self, provides: Optional[Provider]) -> ProvidedInstanceFluentInterface: ...
class ProvidedInstance(Provider, ProvidedInstanceFluentInterface): class ProvidedInstance(Provider, ProvidedInstanceFluentInterface):
def __init__(self, provider: Provider) -> None: ... def __init__(self, provides: Optional[Provider] = None) -> None: ...
class AttributeGetter(Provider, ProvidedInstanceFluentInterface): class AttributeGetter(Provider, ProvidedInstanceFluentInterface):
def __init__(self, provider: Provider, attribute: str) -> None: ... def __init__(self, provides: Optional[Provider] = None, name: Optional[str] = None) -> None: ...
@property
def name(self) -> Optional[str]: ...
def set_name(self, name: Optional[str]) -> ProvidedInstanceFluentInterface: ...
class ItemGetter(Provider, ProvidedInstanceFluentInterface): class ItemGetter(Provider, ProvidedInstanceFluentInterface):
def __init__(self, provider: Provider, item: str) -> None: ... def __init__(self, provides: Optional[Provider] = None, name: Optional[str] = None) -> None: ...
@property
def name(self) -> Optional[str]: ...
def set_name(self, name: Optional[str]) -> ProvidedInstanceFluentInterface: ...
class MethodCaller(Provider, ProvidedInstanceFluentInterface): class MethodCaller(Provider, ProvidedInstanceFluentInterface):
def __init__(self, provider: Provider, *args: Injection, **kwargs: Injection) -> None: ... def __init__(self, provides: Optional[Provider] = None, *args: Injection, **kwargs: Injection) -> None: ...
class OverridingContext(Generic[T]): class OverridingContext(Generic[T]):

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,231 @@
"""Schema module."""
import builtins
import importlib
from typing import Dict, Any, Type, Optional
from . import containers, providers
ContainerSchema = Dict[Any, Any]
ProviderSchema = Dict[Any, Any]
class SchemaProcessorV1:
def __init__(self, schema: ContainerSchema) -> None:
self._schema = schema
self._container = containers.DynamicContainer()
def process(self):
"""Process schema."""
self._create_providers(self._schema['container'])
self._setup_injections(self._schema['container'])
def get_providers(self):
"""Return providers."""
return self._container.providers
def _create_providers(
self,
provider_schema: ProviderSchema,
container: Optional[containers.Container] = None,
) -> None:
if container is None:
container = self._container
for provider_name, data in provider_schema.items():
provider = None
if 'provider' in data:
provider_type = _get_provider_cls(data['provider'])
args = []
# provides = data.get('provides')
# if provides:
# provides = _import_string(provides)
# if provides:
# args.append(provides)
provider = provider_type(*args)
if provider is None:
provider = providers.Container(containers.DynamicContainer)
container.set_provider(provider_name, provider)
if isinstance(provider, providers.Container):
self._create_providers(provider_schema=data, container=provider)
def _setup_injections( # noqa: C901
self,
provider_schema: ProviderSchema,
container: Optional[containers.Container] = None,
) -> None:
if container is None:
container = self._container
for provider_name, data in provider_schema.items():
provider = getattr(container, provider_name)
args = []
kwargs = {}
provides = data.get('provides')
if provides:
if isinstance(provides, str) and provides.startswith('container.'):
provides = self._resolve_provider(provides[len('container.'):])
else:
provides = _import_string(provides)
provider.set_provides(provides)
arg_injections = data.get('args')
if arg_injections:
for arg in arg_injections:
injection = None
if isinstance(arg, str) and arg.startswith('container.'):
injection = self._resolve_provider(arg[len('container.'):])
# TODO: refactoring
if isinstance(arg, dict):
provider_args = []
provider_type = _get_provider_cls(arg.get('provider'))
provides = arg.get('provides')
if provides:
if isinstance(provides, str) and provides.startswith('container.'):
provides = self._resolve_provider(provides[len('container.'):])
else:
provides = _import_string(provides)
provider_args.append(provides)
for provider_arg in arg.get('args', []):
if isinstance(provider_arg, str) \
and provider_arg.startswith('container.'):
provider_args.append(
self._resolve_provider(provider_arg[len('container.'):]),
)
injection = provider_type(*provider_args)
if not injection:
injection = arg
args.append(injection)
if args:
provider.add_args(*args)
kwarg_injections = data.get('kwargs')
if kwarg_injections:
for name, arg in kwarg_injections.items():
injection = None
if isinstance(arg, str) and arg.startswith('container.'):
injection = self._resolve_provider(arg[len('container.'):])
# TODO: refactoring
if isinstance(arg, dict):
provider_args = []
provider_type = _get_provider_cls(arg.get('provider'))
provides = arg.get('provides')
if provides:
if isinstance(provides, str) and provides.startswith('container.'):
provides = self._resolve_provider(provides[len('container.'):])
else:
provides = _import_string(provides)
provider_args.append(provides)
for provider_arg in arg.get('args', []):
if isinstance(provider_arg, str) \
and provider_arg.startswith('container.'):
provider_args.append(
self._resolve_provider(provider_arg[len('container.'):]),
)
injection = provider_type(*provider_args)
if not injection:
injection = arg
kwargs[name] = injection
if kwargs:
provider.add_kwargs(**kwargs)
if isinstance(provider, providers.Container):
self._setup_injections(provider_schema=data, container=provider)
def _resolve_provider(self, name: str) -> Optional[providers.Provider]:
segments = name.split('.')
try:
provider = getattr(self._container, segments[0])
except AttributeError:
return None
for segment in segments[1:]:
parentheses = ''
if '(' in segment and ')' in segment:
parentheses = segment[segment.find('('):segment.rfind(')')+1]
segment = segment.replace(parentheses, '')
try:
provider = getattr(provider, segment)
except AttributeError:
# TODO
return None
if parentheses:
# TODO
provider = provider()
return provider
def build_schema(schema: ContainerSchema) -> Dict[str, providers.Provider]:
"""Build provider schema."""
schema_processor = SchemaProcessorV1(schema)
schema_processor.process()
return schema_processor.get_providers()
def _get_provider_cls(provider_cls_name: str) -> Type[providers.Provider]:
std_provider_type = _fetch_provider_cls_from_std(provider_cls_name)
if std_provider_type:
return std_provider_type
custom_provider_type = _import_provider_cls(provider_cls_name)
if custom_provider_type:
return custom_provider_type
raise SchemaError(f'Undefined provider class "{provider_cls_name}"')
def _fetch_provider_cls_from_std(provider_cls_name: str) -> Optional[Type[providers.Provider]]:
return getattr(providers, provider_cls_name, None)
def _import_provider_cls(provider_cls_name: str) -> Optional[Type[providers.Provider]]:
try:
cls = _import_string(provider_cls_name)
except (ImportError, ValueError) as exception:
raise SchemaError(f'Can not import provider "{provider_cls_name}"') from exception
except AttributeError:
return None
else:
if isinstance(cls, type) and not issubclass(cls, providers.Provider):
raise SchemaError(f'Provider class "{cls}" is not a subclass of providers base class')
return cls
def _import_string(string_name: str) -> Optional[object]:
segments = string_name.split('.')
if len(segments) == 1:
member = getattr(builtins, segments[0], None)
if member:
return member
module_name = '.'.join(segments[:-1])
if not module_name:
return None
member = segments[-1]
module = importlib.import_module(module_name)
return getattr(module, member, None)
class SchemaError(Exception):
"""Schema-related error."""

View File

@ -771,7 +771,7 @@ class ProvidedInstance(Modifier):
def modify( def modify(
self, self,
provider: providers.ConfigurationOption, provider: providers.Provider,
providers_map: ProvidersMap, providers_map: ProvidersMap,
) -> providers.Provider: ) -> providers.Provider:
provider = provider.provided provider = provider.provided
@ -860,7 +860,7 @@ class AutoLoader:
@property @property
def installed(self): def installed(self):
return self._path_hook is not None return self._path_hook in sys.path_hooks
def install(self): def install(self):
if self.installed: if self.installed:

View File

@ -412,7 +412,7 @@ class DeclarativeContainerTests(unittest.TestCase):
class Services(containers.DeclarativeContainer): class Services(containers.DeclarativeContainer):
a = providers.Dependency() a = providers.Dependency()
c = providers.Factory(C, a=a) c = providers.Factory(C, a=a)
b = providers.Factory(B, fa=a.delegate()) b = providers.Factory(B, fa=a.provider)
a = providers.Factory(A) a = providers.Factory(A)
assert isinstance(Services(a=a).c().a, A) # ok assert isinstance(Services(a=a).c().a, A) # ok

View File

@ -655,6 +655,32 @@ class ProvidedInstanceTests(AsyncTestCase):
self.assertIs(instance2.resource, RESOURCE1) self.assertIs(instance2.resource, RESOURCE1)
self.assertIs(instance1.resource, instance2.resource) self.assertIs(instance1.resource, instance2.resource)
def test_provided_attribute_error(self):
async def raise_exception():
raise RuntimeError()
class TestContainer(containers.DeclarativeContainer):
client = providers.Factory(raise_exception)
container = TestContainer()
with self.assertRaises(RuntimeError):
self._run(container.client.provided.attr())
def test_provided_attribute_undefined_attribute(self):
class TestClient:
def __init__(self, resource):
self.resource = resource
class TestContainer(containers.DeclarativeContainer):
resource = providers.Resource(init_resource, providers.Object(RESOURCE1))
client = providers.Factory(TestClient, resource=resource)
container = TestContainer()
with self.assertRaises(AttributeError):
self._run(container.client.provided.attr())
def test_provided_item(self): def test_provided_item(self):
class TestClient: class TestClient:
def __init__(self, resource): def __init__(self, resource):
@ -685,6 +711,28 @@ class ProvidedInstanceTests(AsyncTestCase):
self.assertIs(instance2.resource, RESOURCE1) self.assertIs(instance2.resource, RESOURCE1)
self.assertIs(instance1.resource, instance2.resource) self.assertIs(instance1.resource, instance2.resource)
def test_provided_item_error(self):
async def raise_exception():
raise RuntimeError()
class TestContainer(containers.DeclarativeContainer):
client = providers.Factory(raise_exception)
container = TestContainer()
with self.assertRaises(RuntimeError):
self._run(container.client.provided['item']())
def test_provided_item_undefined_item(self):
class TestContainer(containers.DeclarativeContainer):
resource = providers.Resource(init_resource, providers.Object(RESOURCE1))
client = providers.Factory(dict, resource=resource)
container = TestContainer()
with self.assertRaises(KeyError):
self._run(container.client.provided['item']())
def test_provided_method_call(self): def test_provided_method_call(self):
class TestClient: class TestClient:
def __init__(self, resource): def __init__(self, resource):
@ -715,6 +763,31 @@ class ProvidedInstanceTests(AsyncTestCase):
self.assertIs(instance2.resource, RESOURCE1) self.assertIs(instance2.resource, RESOURCE1)
self.assertIs(instance1.resource, instance2.resource) self.assertIs(instance1.resource, instance2.resource)
def test_provided_method_call_parent_error(self):
async def raise_exception():
raise RuntimeError()
class TestContainer(containers.DeclarativeContainer):
client = providers.Factory(raise_exception)
container = TestContainer()
with self.assertRaises(RuntimeError):
self._run(container.client.provided.method.call()())
def test_provided_method_call_error(self):
class TestClient:
def method(self):
raise RuntimeError()
class TestContainer(containers.DeclarativeContainer):
client = providers.Factory(TestClient)
container = TestContainer()
with self.assertRaises(RuntimeError):
self._run(container.client.provided.method.call()())
class DependencyTests(AsyncTestCase): class DependencyTests(AsyncTestCase):
@ -996,7 +1069,7 @@ class AsyncProvidersWithAsyncDependenciesTests(AsyncTestCase):
container = Container() container = Container()
service = self._run(container.service()) service = self._run(container.service())
self.assertEquals(service, {'service': 'ok', 'db': {'db': 'ok'}}) self.assertEqual(service, {'service': 'ok', 'db': {'db': 'ok'}})
class AsyncProviderWithAwaitableObjectTests(AsyncTestCase): class AsyncProviderWithAwaitableObjectTests(AsyncTestCase):

View File

@ -1,6 +1,7 @@
"""Dependency injector base providers unit tests.""" """Dependency injector base providers unit tests."""
import unittest import unittest
import warnings
from dependency_injector import ( from dependency_injector import (
containers, containers,
@ -21,13 +22,14 @@ class ProviderTests(unittest.TestCase):
self.assertRaises(NotImplementedError, self.provider.__call__) self.assertRaises(NotImplementedError, self.provider.__call__)
def test_delegate(self): def test_delegate(self):
with warnings.catch_warnings():
warnings.simplefilter('ignore')
delegate1 = self.provider.delegate() delegate1 = self.provider.delegate()
delegate2 = self.provider.delegate()
self.assertIsInstance(delegate1, providers.Delegate) self.assertIsInstance(delegate1, providers.Delegate)
self.assertIs(delegate1(), self.provider) self.assertIs(delegate1(), self.provider)
delegate2 = self.provider.delegate()
self.assertIsInstance(delegate2, providers.Delegate) self.assertIsInstance(delegate2, providers.Delegate)
self.assertIs(delegate2(), self.provider) self.assertIs(delegate2(), self.provider)
@ -150,6 +152,17 @@ class ObjectProviderTests(unittest.TestCase):
def test_is_provider(self): def test_is_provider(self):
self.assertTrue(providers.is_provider(providers.Object(object()))) self.assertTrue(providers.is_provider(providers.Object(object())))
def test_init_optional_provides(self):
instance = object()
provider = providers.Object()
provider.set_provides(instance)
self.assertIs(provider.provides, instance)
self.assertIs(provider(), instance)
def test_set_provides_returns_self(self):
provider = providers.Object()
self.assertIs(provider.set_provides(object()), provider)
def test_provided_instance_provider(self): def test_provided_instance_provider(self):
provider = providers.Object(object()) provider = providers.Object(object())
self.assertIsInstance(provider.provided, providers.ProvidedInstance) self.assertIsInstance(provider.provided, providers.ProvidedInstance)
@ -289,6 +302,16 @@ class DelegateTests(unittest.TestCase):
def test_is_provider(self): def test_is_provider(self):
self.assertTrue(providers.is_provider(self.delegate)) self.assertTrue(providers.is_provider(self.delegate))
def test_init_optional_provides(self):
provider = providers.Delegate()
provider.set_provides(self.delegated)
self.assertIs(provider.provides, self.delegated)
self.assertIs(provider(), self.delegated)
def test_set_provides_returns_self(self):
provider = providers.Delegate()
self.assertIs(provider.set_provides(self.delegated), provider)
def test_init_with_not_provider(self): def test_init_with_not_provider(self):
self.assertRaises(errors.Error, providers.Delegate, object()) self.assertRaises(errors.Error, providers.Delegate, object())
@ -312,6 +335,24 @@ class DependencyTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.provider = providers.Dependency(instance_of=list) self.provider = providers.Dependency(instance_of=list)
def test_init_optional(self):
list_provider = providers.List(1, 2, 3)
provider = providers.Dependency()
provider.set_instance_of(list)
provider.set_default(list_provider)
self.assertIs(provider.instance_of, list)
self.assertIs(provider.default, list_provider)
self.assertEqual(provider(), [1, 2, 3])
def test_set_instance_of_returns_self(self):
provider = providers.Dependency()
self.assertIs(provider.set_instance_of(list), provider)
def test_set_default_returns_self(self):
provider = providers.Dependency()
self.assertIs(provider.set_default(providers.Provider()), provider)
def test_init_with_not_class(self): def test_init_with_not_class(self):
self.assertRaises(TypeError, providers.Dependency, object()) self.assertRaises(TypeError, providers.Dependency, object())

View File

@ -22,6 +22,16 @@ class CallableTests(unittest.TestCase):
def test_init_with_not_callable(self): def test_init_with_not_callable(self):
self.assertRaises(errors.Error, providers.Callable, 123) self.assertRaises(errors.Error, providers.Callable, 123)
def test_init_optional_provides(self):
provider = providers.Callable()
provider.set_provides(object)
self.assertIs(provider.provides, object)
self.assertIsInstance(provider(), object)
def test_set_provides_returns_self(self):
provider = providers.Callable()
self.assertIs(provider.set_provides(object), provider)
def test_provided_instance_provider(self): def test_provided_instance_provider(self):
provider = providers.Callable(_example) provider = providers.Callable(_example)
self.assertIsInstance(provider.provided, providers.ProvidedInstance) self.assertIsInstance(provider.provided, providers.ProvidedInstance)

View File

@ -28,6 +28,28 @@ class ConfigTests(unittest.TestCase):
def tearDown(self): def tearDown(self):
del self.config del self.config
def test_init_optional(self):
provider = providers.Configuration()
provider.set_name('myconfig')
provider.set_default({'foo': 'bar'})
provider.set_strict(True)
self.assertEqual(provider.get_name(), 'myconfig')
self.assertEqual(provider.get_default(), {'foo': 'bar'})
self.assertTrue(provider.get_strict())
def test_set_name_returns_self(self):
provider = providers.Configuration()
self.assertIs(provider.set_name('myconfig'), provider)
def test_set_default_returns_self(self):
provider = providers.Configuration()
self.assertIs(provider.set_default({}), provider)
def test_set_strict_returns_self(self):
provider = providers.Configuration()
self.assertIs(provider.set_strict(True), provider)
def test_default_name(self): def test_default_name(self):
config = providers.Configuration() config = providers.Configuration()
self.assertEqual(config.get_name(), 'config') self.assertEqual(config.get_name(), 'config')

View File

@ -1,8 +1,8 @@
"""Dependency injector coroutine providers unit tests.""" """Dependency injector coroutine providers unit tests."""
import asyncio import asyncio
import unittest import unittest
import warnings
from dependency_injector import ( from dependency_injector import (
providers, providers,
@ -43,6 +43,16 @@ class CoroutineTests(AsyncTestCase):
def test_init_with_not_coroutine(self): def test_init_with_not_coroutine(self):
self.assertRaises(errors.Error, providers.Coroutine, lambda: None) self.assertRaises(errors.Error, providers.Coroutine, lambda: None)
def test_init_optional_provides(self):
provider = providers.Coroutine()
provider.set_provides(_example)
self.assertIs(provider.provides, _example)
self.assertEqual(run(provider(1, 2, 3, 4)), (1, 2, 3, 4))
def test_set_provides_returns_self(self):
provider = providers.Coroutine()
self.assertIs(provider.set_provides(_example), provider)
def test_call_with_positional_args(self): def test_call_with_positional_args(self):
provider = providers.Coroutine(_example, 1, 2, 3, 4) provider = providers.Coroutine(_example, 1, 2, 3, 4)
self.assertTupleEqual(self._run(provider()), (1, 2, 3, 4)) self.assertTupleEqual(self._run(provider()), (1, 2, 3, 4))
@ -232,6 +242,9 @@ class AbstractCoroutineTests(AsyncTestCase):
providers.Coroutine) providers.Coroutine)
def test_call_overridden_by_coroutine(self): def test_call_overridden_by_coroutine(self):
with warnings.catch_warnings():
warnings.simplefilter('ignore')
@asyncio.coroutine @asyncio.coroutine
def _abstract_example(): def _abstract_example():
raise RuntimeError('Should not be raised') raise RuntimeError('Should not be raised')
@ -242,6 +255,9 @@ class AbstractCoroutineTests(AsyncTestCase):
self.assertTrue(self._run(provider(1, 2, 3, 4)), (1, 2, 3, 4)) self.assertTrue(self._run(provider(1, 2, 3, 4)), (1, 2, 3, 4))
def test_call_overridden_by_delegated_coroutine(self): def test_call_overridden_by_delegated_coroutine(self):
with warnings.catch_warnings():
warnings.simplefilter('ignore')
@asyncio.coroutine @asyncio.coroutine
def _abstract_example(): def _abstract_example():
raise RuntimeError('Should not be raised') raise RuntimeError('Should not be raised')

View File

@ -34,6 +34,16 @@ class FactoryTests(unittest.TestCase):
def test_init_with_not_callable(self): def test_init_with_not_callable(self):
self.assertRaises(errors.Error, providers.Factory, 123) self.assertRaises(errors.Error, providers.Factory, 123)
def test_init_optional_provides(self):
provider = providers.Factory()
provider.set_provides(object)
self.assertIs(provider.provides, object)
self.assertIsInstance(provider(), object)
def test_set_provides_returns_self(self):
provider = providers.Factory()
self.assertIs(provider.set_provides(object), provider)
def test_init_with_valid_provided_type(self): def test_init_with_valid_provided_type(self):
class ExampleProvider(providers.Factory): class ExampleProvider(providers.Factory):
provided_type = Example provided_type = Example
@ -502,6 +512,26 @@ class FactoryAggregateTests(unittest.TestCase):
example_a=providers.Factory(self.ExampleA), example_a=providers.Factory(self.ExampleA),
example_b=object()) example_b=object())
def test_init_optional_factories(self):
provider = providers.FactoryAggregate()
provider.set_factories(
example_a=self.example_a_factory,
example_b=self.example_b_factory,
)
self.assertEqual(
provider.factories,
{
'example_a': self.example_a_factory,
'example_b': self.example_b_factory,
},
)
self.assertIsInstance(provider('example_a'), self.ExampleA)
self.assertIsInstance(provider('example_b'), self.ExampleB)
def test_set_provides_returns_self(self):
provider = providers.FactoryAggregate()
self.assertIs(provider.set_factories(example_a=self.example_a_factory), provider)
def test_call(self): def test_call(self):
object_a = self.factory_aggregate('example_a', object_a = self.factory_aggregate('example_a',
1, 2, init_arg3=3, init_arg4=4) 1, 2, init_arg3=3, init_arg4=4)

View File

@ -126,6 +126,44 @@ class ProvidedInstanceTests(unittest.TestCase):
) )
class LazyInitTests(unittest.TestCase):
def test_provided_instance(self):
provides = providers.Object(object())
provider = providers.ProvidedInstance()
provider.set_provides(provides)
self.assertIs(provider.provides, provides)
self.assertIs(provider.set_provides(providers.Provider()), provider)
def test_attribute_getter(self):
provides = providers.Object(object())
provider = providers.AttributeGetter()
provider.set_provides(provides)
provider.set_name('__dict__')
self.assertIs(provider.provides, provides)
self.assertEqual(provider.name, '__dict__')
self.assertIs(provider.set_provides(providers.Provider()), provider)
self.assertIs(provider.set_name('__dict__'), provider)
def test_item_getter(self):
provides = providers.Object({'foo': 'bar'})
provider = providers.ItemGetter()
provider.set_provides(provides)
provider.set_name('foo')
self.assertIs(provider.provides, provides)
self.assertEqual(provider.name, 'foo')
self.assertIs(provider.set_provides(providers.Provider()), provider)
self.assertIs(provider.set_name('foo'), provider)
def test_method_caller(self):
provides = providers.Object(lambda: 42)
provider = providers.MethodCaller()
provider.set_provides(provides)
self.assertIs(provider.provides, provides)
self.assertEqual(provider(), 42)
self.assertIs(provider.set_provides(providers.Provider()), provider)
class ProvidedInstancePuzzleTests(unittest.TestCase): class ProvidedInstancePuzzleTests(unittest.TestCase):
def test_puzzled(self): def test_puzzled(self):

View File

@ -29,6 +29,16 @@ class ResourceTests(unittest.TestCase):
def test_is_provider(self): def test_is_provider(self):
self.assertTrue(providers.is_provider(providers.Resource(init_fn))) self.assertTrue(providers.is_provider(providers.Resource(init_fn)))
def test_init_optional_provides(self):
provider = providers.Resource()
provider.set_provides(init_fn)
self.assertIs(provider.provides, init_fn)
self.assertEqual(provider(), (tuple(), dict()))
def test_set_provides_returns_self(self):
provider = providers.Resource()
self.assertIs(provider.set_provides(init_fn), provider)
def test_provided_instance_provider(self): def test_provided_instance_provider(self):
provider = providers.Resource(init_fn) provider = providers.Resource(init_fn)
self.assertIsInstance(provider.provided, providers.ProvidedInstance) self.assertIsInstance(provider.provided, providers.ProvidedInstance)

View File

@ -16,6 +16,28 @@ class SelectorTests(unittest.TestCase):
def test_is_provider(self): def test_is_provider(self):
self.assertTrue(providers.is_provider(providers.Selector(self.selector))) self.assertTrue(providers.is_provider(providers.Selector(self.selector)))
def test_init_optional(self):
one = providers.Object(1)
two = providers.Object(2)
provider = providers.Selector()
provider.set_selector(self.selector)
provider.set_providers(one=one, two=two)
self.assertEqual(provider.providers, {'one': one, 'two': two})
with self.selector.override('one'):
self.assertEqual(provider(), one())
with self.selector.override('two'):
self.assertEqual(provider(), two())
def test_set_selector_returns_self(self):
provider = providers.Selector()
self.assertIs(provider.set_selector(self.selector), provider)
def test_set_providers_returns_self(self):
provider = providers.Selector()
self.assertIs(provider.set_providers(one=providers.Provider()), provider)
def test_provided_instance_provider(self): def test_provided_instance_provider(self):
provider = providers.Selector(self.selector) provider = providers.Selector(self.selector)
self.assertIsInstance(provider.provided, providers.ProvidedInstance) self.assertIsInstance(provider.provided, providers.ProvidedInstance)

View File

@ -36,6 +36,16 @@ class _BaseSingletonTestCase(object):
def test_init_with_not_callable(self): def test_init_with_not_callable(self):
self.assertRaises(errors.Error, self.singleton_cls, 123) self.assertRaises(errors.Error, self.singleton_cls, 123)
def test_init_optional_provides(self):
provider = self.singleton_cls()
provider.set_provides(object)
self.assertIs(provider.provides, object)
self.assertIsInstance(provider(), object)
def test_set_provides_returns_self(self):
provider = self.singleton_cls()
self.assertIs(provider.set_provides(object), provider)
def test_init_with_valid_provided_type(self): def test_init_with_valid_provided_type(self):
class ExampleProvider(self.singleton_cls): class ExampleProvider(self.singleton_cls):
provided_type = Example provided_type = Example

View File

@ -0,0 +1,27 @@
version: "1"
container:
config:
provider: Configuration
session:
provider: Singleton
provides: boto3.session.Session
kwargs:
aws_access_key_id: container.config.aws_access_key_id
aws_secret_access_key: container.config.aws_secret_access_key
aws_session_token: container.config.aws_session_token
region_name: container.config.aws_region_name
s3_client:
provider: Factory
provides: container.session.provided.client.call()
kwargs:
service_name: s3
sqs_client:
provider: Factory
provides: container.session.provided.client.call()
kwargs:
service_name: sqs

View File

@ -0,0 +1,58 @@
version: "1"
container:
core:
config:
provider: Configuration
gateways:
database_client:
provider: Singleton
provides: sqlite3.connect
args:
- provider: Callable
provides: schemasample.utils.return_
args:
- container.core.config.database.dsn
s3_client:
provider: Singleton
provides: boto3.client
kwargs:
service_name: s3
aws_access_key_id: container.core.config.aws.access_key_id
aws_secret_access_key: container.core.config.aws.secret_access_key
services:
user:
provider: Factory
provides: schemasample.services.UserService
kwargs:
db:
provider: Callable
provides: schemasample.utils.return_
args:
- container.gateways.database_client
auth:
provider: Factory
provides: schemasample.services.AuthService
kwargs:
db:
provider: Callable
provides: schemasample.utils.return_
args:
- container.gateways.database_client
token_ttl: container.core.config.auth.token_ttl.as_int()
photo:
provider: Factory
provides: schemasample.services.PhotoService
kwargs:
db:
provider: Callable
provides: schemasample.utils.return_
args:
- container.gateways.database_client
s3: container.gateways.s3_client

View File

@ -0,0 +1,43 @@
version: "1"
container:
services:
user:
provider: Factory
provides: schemasample.services.UserService
kwargs:
db: container.gateways.database_client
auth:
provider: Factory
provides: schemasample.services.AuthService
kwargs:
db: container.gateways.database_client
token_ttl: container.core.config.auth.token_ttl.as_int()
photo:
provider: Factory
provides: schemasample.services.PhotoService
kwargs:
db: container.gateways.database_client
s3: container.gateways.s3_client
gateways:
database_client:
provider: Singleton
provides: sqlite3.connect
args:
- container.core.config.database.dsn
s3_client:
provider: Singleton
provides: boto3.client
kwargs:
service_name: s3
aws_access_key_id: core.config.aws.access_key_id
aws_secret_access_key: core.config.aws.secret_access_key
core:
config:
provider: Configuration

View File

@ -0,0 +1,43 @@
version: "1"
container:
core:
config:
provider: Configuration
gateways:
database_client:
provider: Singleton
provides: sqlite3.connect
args:
- container.core.config.database.dsn
s3_client:
provider: Singleton
provides: boto3.client
kwargs:
service_name: s3
aws_access_key_id: container.core.config.aws.access_key_id
aws_secret_access_key: container.core.config.aws.secret_access_key
services:
user:
provider: Factory
provides: schemasample.services.UserService
kwargs:
db: container.gateways.database_client
auth:
provider: Factory
provides: schemasample.services.AuthService
kwargs:
db: container.gateways.database_client
token_ttl: container.core.config.auth.token_ttl.as_int()
photo:
provider: Factory
provides: schemasample.services.PhotoService
kwargs:
db: container.gateways.database_client
s3: container.gateways.s3_client

View File

@ -0,0 +1,39 @@
version: "1"
container:
config:
provider: Configuration
database_client:
provider: Singleton
provides: sqlite3.connect
args:
- container.config.database.dsn
s3_client:
provider: Singleton
provides: boto3.client
kwargs:
service_name: s3
aws_access_key_id: container.config.aws.access_key_id
aws_secret_access_key: container.config.aws.secret_access_key
user_service:
provider: Factory
provides: schemasample.services.UserService
kwargs:
db: container.database_client
auth_service:
provider: Factory
provides: schemasample.services.AuthService
kwargs:
db: container.database_client
token_ttl: container.config.auth.token_ttl.as_int()
photo_service:
provider: Factory
provides: schemasample.services.PhotoService
kwargs:
db: container.database_client
s3: container.s3_client

View File

@ -0,0 +1,56 @@
"""Services module."""
import logging
import sqlite3
from typing import Dict
from mypy_boto3_s3 import S3Client
class BaseService:
def __init__(self) -> None:
self.logger = logging.getLogger(
f'{__name__}.{self.__class__.__name__}',
)
class UserService(BaseService):
def __init__(self, db: sqlite3.Connection) -> None:
self.db = db
super().__init__()
def get_user(self, email: str) -> Dict[str, str]:
self.logger.debug('User %s has been found in database', email)
return {'email': email, 'password_hash': '...'}
class AuthService(BaseService):
def __init__(self, db: sqlite3.Connection, token_ttl: int) -> None:
self.db = db
self.token_ttl = token_ttl
super().__init__()
def authenticate(self, user: Dict[str, str], password: str) -> None:
assert password is not None
self.logger.debug(
'User %s has been successfully authenticated',
user['email'],
)
class PhotoService(BaseService):
def __init__(self, db: sqlite3.Connection, s3: S3Client) -> None:
self.db = db
self.s3 = s3
super().__init__()
def upload_photo(self, user: Dict[str, str], photo_path: str) -> None:
self.logger.debug(
'Photo %s has been successfully uploaded by user %s',
photo_path,
user['email'],
)

View File

@ -0,0 +1,2 @@
def return_(instance):
return instance

View File

@ -0,0 +1,13 @@
"""Test module for wiring."""
import sys
if 'pypy' not in sys.version.lower():
import numpy # noqa
from numpy import * # noqa
import scipy # noqa
from scipy import * # noqa
import builtins # noqa
from builtins import * # noqa

View File

@ -1,7 +1,6 @@
"""Test module for wiring.""" """Test module for wiring."""
from decimal import Decimal from decimal import Decimal
import sys
from typing import Callable from typing import Callable
from dependency_injector import providers from dependency_injector import providers
@ -129,16 +128,3 @@ def test_class_decorator(service: Service = Provide[Container.service]):
def test_container(container: Container = Provide[Container]): def test_container(container: Container = Provide[Container]):
return container.service() return container.service()
# Import tests
if 'pypy' not in sys.version.lower():
import numpy # noqa
from numpy import * # noqa
import scipy # noqa
from scipy import * # noqa
import builtins # noqa
from builtins import * # noqa

View File

@ -0,0 +1 @@
"""Schema tests."""

View File

@ -0,0 +1,162 @@
import contextlib
import json
import os.path
import tempfile
import unittest
import yaml
from dependency_injector import containers, providers, errors
class FromSchemaTests(unittest.TestCase):
def test(self):
container = containers.DynamicContainer()
container.from_schema(
{
'version': '1',
'container': {
'provider1': {
'provider': 'Factory',
'provides': 'list',
'args': [1, 2, 3],
},
'provider2': {
'provider': 'Factory',
'provides': 'dict',
'kwargs': {
'one': 'container.provider1',
'two': 2,
},
},
},
},
)
self.assertIsInstance(container.provider1, providers.Factory)
self.assertIs(container.provider1.provides, list)
self.assertEqual(container.provider1.args, (1, 2, 3))
self.assertIsInstance(container.provider2, providers.Factory)
self.assertIs(container.provider2.provides, dict)
self.assertEqual(container.provider2.kwargs, {'one': container.provider1, 'two': 2})
class FromYamlSchemaTests(unittest.TestCase):
def test(self):
container = containers.DynamicContainer()
with tempfile.TemporaryDirectory() as tmp_dir:
schema_path = os.path.join(tmp_dir, 'schema.yml')
with open(schema_path, 'w') as file:
file.write("""
version: "1"
container:
provider1:
provider: Factory
provides: list
args:
- 1
- 2
- 3
provider2:
provider: Factory
provides: dict
kwargs:
one: container.provider1
two: 2
""")
container.from_yaml_schema(schema_path)
self.assertIsInstance(container.provider1, providers.Factory)
self.assertIs(container.provider1.provides, list)
self.assertEqual(container.provider1.args, (1, 2, 3))
self.assertIsInstance(container.provider2, providers.Factory)
self.assertIs(container.provider2.provides, dict)
self.assertEqual(container.provider2.kwargs, {'one': container.provider1, 'two': 2})
def test_with_loader(self):
container = containers.DynamicContainer()
with tempfile.TemporaryDirectory() as tmp_dir:
schema_path = os.path.join(tmp_dir, 'schema.yml')
with open(schema_path, 'w') as file:
file.write("""
version: "1"
container:
provider:
provider: Factory
provides: list
args: [1, 2, 3]
""")
container.from_yaml_schema(schema_path, loader=yaml.Loader)
self.assertIsInstance(container.provider, providers.Factory)
self.assertIs(container.provider.provides, list)
self.assertEqual(container.provider.args, (1, 2, 3))
def test_no_yaml_installed(self):
@contextlib.contextmanager
def no_yaml_module():
containers.yaml = None
yield
containers.yaml = yaml
container = containers.DynamicContainer()
with no_yaml_module():
with self.assertRaises(errors.Error) as error:
container.from_yaml_schema('./no-yaml-installed.yml')
self.assertEqual(
error.exception.args[0],
'Unable to load yaml schema - PyYAML is not installed. '
'Install PyYAML or install Dependency Injector with yaml extras: '
'"pip install dependency-injector[yaml]"',
)
class FromJsonSchemaTests(unittest.TestCase):
def test(self):
container = containers.DynamicContainer()
with tempfile.TemporaryDirectory() as tmp_dir:
schema_path = os.path.join(tmp_dir, 'schema.json')
with open(schema_path, 'w') as file:
file.write(
json.dumps(
{
'version': '1',
'container': {
'provider1': {
'provider': 'Factory',
'provides': 'list',
'args': [1, 2, 3],
},
'provider2': {
'provider': 'Factory',
'provides': 'dict',
'kwargs': {
'one': 'container.provider1',
'two': 2,
},
},
},
},
indent=4,
),
)
container.from_json_schema(schema_path)
self.assertIsInstance(container.provider1, providers.Factory)
self.assertIs(container.provider1.provides, list)
self.assertEqual(container.provider1.args, (1, 2, 3))
self.assertIsInstance(container.provider2, providers.Factory)
self.assertIs(container.provider2.provides, dict)
self.assertEqual(container.provider2.kwargs, {'one': container.provider1, 'two': 2})

View File

@ -0,0 +1,293 @@
import sqlite3
import unittest
from dependency_injector import containers
# Runtime import
import os
_TOP_DIR = os.path.abspath(
os.path.sep.join((
os.path.dirname(__file__),
'../',
)),
)
_SAMPLES_DIR = os.path.abspath(
os.path.sep.join((
os.path.dirname(__file__),
'../samples/',
)),
)
import sys
sys.path.append(_SAMPLES_DIR)
from schemasample.services import UserService, AuthService, PhotoService
class TestSchemaSingleContainer(unittest.TestCase):
def test(self):
container = containers.DynamicContainer()
container.from_yaml_schema(f'{_SAMPLES_DIR}/schemasample/container-single.yml')
container.config.from_dict({
'database': {
'dsn': ':memory:',
},
'aws': {
'access_key_id': 'KEY',
'secret_access_key': 'SECRET',
},
'auth': {
'token_ttl': 3600,
},
})
# User service
user_service1 = container.user_service()
user_service2 = container.user_service()
self.assertIsInstance(user_service1, UserService)
self.assertIsInstance(user_service2, UserService)
self.assertIsNot(user_service1, user_service2)
self.assertIsInstance(user_service1.db, sqlite3.Connection)
self.assertIsInstance(user_service2.db, sqlite3.Connection)
self.assertIs(user_service1.db, user_service2.db)
# Auth service
auth_service1 = container.auth_service()
auth_service2 = container.auth_service()
self.assertIsInstance(auth_service1, AuthService)
self.assertIsInstance(auth_service2, AuthService)
self.assertIsNot(auth_service1, auth_service2)
self.assertIsInstance(auth_service1.db, sqlite3.Connection)
self.assertIsInstance(auth_service2.db, sqlite3.Connection)
self.assertIs(auth_service1.db, auth_service2.db)
self.assertIs(auth_service1.db, container.database_client())
self.assertIs(auth_service2.db, container.database_client())
self.assertEqual(auth_service1.token_ttl, 3600)
self.assertEqual(auth_service2.token_ttl, 3600)
# Photo service
photo_service1 = container.photo_service()
photo_service2 = container.photo_service()
self.assertIsInstance(photo_service1, PhotoService)
self.assertIsInstance(photo_service2, PhotoService)
self.assertIsNot(photo_service1, photo_service2)
self.assertIsInstance(photo_service1.db, sqlite3.Connection)
self.assertIsInstance(photo_service2.db, sqlite3.Connection)
self.assertIs(photo_service1.db, photo_service2.db)
self.assertIs(photo_service1.db, container.database_client())
self.assertIs(photo_service2.db, container.database_client())
self.assertIs(photo_service1.s3, photo_service2.s3)
self.assertIs(photo_service1.s3, container.s3_client())
self.assertIs(photo_service2.s3, container.s3_client())
class TestSchemaMultipleContainers(unittest.TestCase):
def test(self):
container = containers.DynamicContainer()
container.from_yaml_schema(f'{_SAMPLES_DIR}/schemasample/container-multiple.yml')
container.core.config.from_dict({
'database': {
'dsn': ':memory:',
},
'aws': {
'access_key_id': 'KEY',
'secret_access_key': 'SECRET',
},
'auth': {
'token_ttl': 3600,
},
})
# User service
user_service1 = container.services.user()
user_service2 = container.services.user()
self.assertIsInstance(user_service1, UserService)
self.assertIsInstance(user_service2, UserService)
self.assertIsNot(user_service1, user_service2)
self.assertIsInstance(user_service1.db, sqlite3.Connection)
self.assertIsInstance(user_service2.db, sqlite3.Connection)
self.assertIs(user_service1.db, user_service2.db)
# Auth service
auth_service1 = container.services.auth()
auth_service2 = container.services.auth()
self.assertIsInstance(auth_service1, AuthService)
self.assertIsInstance(auth_service2, AuthService)
self.assertIsNot(auth_service1, auth_service2)
self.assertIsInstance(auth_service1.db, sqlite3.Connection)
self.assertIsInstance(auth_service2.db, sqlite3.Connection)
self.assertIs(auth_service1.db, auth_service2.db)
self.assertIs(auth_service1.db, container.gateways.database_client())
self.assertIs(auth_service2.db, container.gateways.database_client())
self.assertEqual(auth_service1.token_ttl, 3600)
self.assertEqual(auth_service2.token_ttl, 3600)
# Photo service
photo_service1 = container.services.photo()
photo_service2 = container.services.photo()
self.assertIsInstance(photo_service1, PhotoService)
self.assertIsInstance(photo_service2, PhotoService)
self.assertIsNot(photo_service1, photo_service2)
self.assertIsInstance(photo_service1.db, sqlite3.Connection)
self.assertIsInstance(photo_service2.db, sqlite3.Connection)
self.assertIs(photo_service1.db, photo_service2.db)
self.assertIs(photo_service1.db, container.gateways.database_client())
self.assertIs(photo_service2.db, container.gateways.database_client())
self.assertIs(photo_service1.s3, photo_service2.s3)
self.assertIs(photo_service1.s3, container.gateways.s3_client())
self.assertIs(photo_service2.s3, container.gateways.s3_client())
class TestSchemaMultipleContainersReordered(unittest.TestCase):
def test(self):
container = containers.DynamicContainer()
container.from_yaml_schema(f'{_SAMPLES_DIR}/schemasample/container-multiple-reordered.yml')
container.core.config.from_dict({
'database': {
'dsn': ':memory:',
},
'aws': {
'access_key_id': 'KEY',
'secret_access_key': 'SECRET',
},
'auth': {
'token_ttl': 3600,
},
})
# User service
user_service1 = container.services.user()
user_service2 = container.services.user()
self.assertIsInstance(user_service1, UserService)
self.assertIsInstance(user_service2, UserService)
self.assertIsNot(user_service1, user_service2)
self.assertIsInstance(user_service1.db, sqlite3.Connection)
self.assertIsInstance(user_service2.db, sqlite3.Connection)
self.assertIs(user_service1.db, user_service2.db)
# Auth service
auth_service1 = container.services.auth()
auth_service2 = container.services.auth()
self.assertIsInstance(auth_service1, AuthService)
self.assertIsInstance(auth_service2, AuthService)
self.assertIsNot(auth_service1, auth_service2)
self.assertIsInstance(auth_service1.db, sqlite3.Connection)
self.assertIsInstance(auth_service2.db, sqlite3.Connection)
self.assertIs(auth_service1.db, auth_service2.db)
self.assertIs(auth_service1.db, container.gateways.database_client())
self.assertIs(auth_service2.db, container.gateways.database_client())
self.assertEqual(auth_service1.token_ttl, 3600)
self.assertEqual(auth_service2.token_ttl, 3600)
# Photo service
photo_service1 = container.services.photo()
photo_service2 = container.services.photo()
self.assertIsInstance(photo_service1, PhotoService)
self.assertIsInstance(photo_service2, PhotoService)
self.assertIsNot(photo_service1, photo_service2)
self.assertIsInstance(photo_service1.db, sqlite3.Connection)
self.assertIsInstance(photo_service2.db, sqlite3.Connection)
self.assertIs(photo_service1.db, photo_service2.db)
self.assertIs(photo_service1.db, container.gateways.database_client())
self.assertIs(photo_service2.db, container.gateways.database_client())
self.assertIs(photo_service1.s3, photo_service2.s3)
self.assertIs(photo_service1.s3, container.gateways.s3_client())
self.assertIs(photo_service2.s3, container.gateways.s3_client())
class TestSchemaMultipleContainersWithInlineProviders(unittest.TestCase):
def test(self):
container = containers.DynamicContainer()
container.from_yaml_schema(f'{_SAMPLES_DIR}/schemasample/container-multiple-inline.yml')
container.core.config.from_dict({
'database': {
'dsn': ':memory:',
},
'aws': {
'access_key_id': 'KEY',
'secret_access_key': 'SECRET',
},
'auth': {
'token_ttl': 3600,
},
})
# User service
user_service1 = container.services.user()
user_service2 = container.services.user()
self.assertIsInstance(user_service1, UserService)
self.assertIsInstance(user_service2, UserService)
self.assertIsNot(user_service1, user_service2)
self.assertIsInstance(user_service1.db, sqlite3.Connection)
self.assertIsInstance(user_service2.db, sqlite3.Connection)
self.assertIs(user_service1.db, user_service2.db)
# Auth service
auth_service1 = container.services.auth()
auth_service2 = container.services.auth()
self.assertIsInstance(auth_service1, AuthService)
self.assertIsInstance(auth_service2, AuthService)
self.assertIsNot(auth_service1, auth_service2)
self.assertIsInstance(auth_service1.db, sqlite3.Connection)
self.assertIsInstance(auth_service2.db, sqlite3.Connection)
self.assertIs(auth_service1.db, auth_service2.db)
self.assertIs(auth_service1.db, container.gateways.database_client())
self.assertIs(auth_service2.db, container.gateways.database_client())
self.assertEqual(auth_service1.token_ttl, 3600)
self.assertEqual(auth_service2.token_ttl, 3600)
# Photo service
photo_service1 = container.services.photo()
photo_service2 = container.services.photo()
self.assertIsInstance(photo_service1, PhotoService)
self.assertIsInstance(photo_service2, PhotoService)
self.assertIsNot(photo_service1, photo_service2)
self.assertIsInstance(photo_service1.db, sqlite3.Connection)
self.assertIsInstance(photo_service2.db, sqlite3.Connection)
self.assertIs(photo_service1.db, photo_service2.db)
self.assertIs(photo_service1.db, container.gateways.database_client())
self.assertIs(photo_service2.db, container.gateways.database_client())
self.assertIs(photo_service1.s3, photo_service2.s3)
self.assertIs(photo_service1.s3, container.gateways.s3_client())
self.assertIs(photo_service2.s3, container.gateways.s3_client())
class TestSchemaBoto3Session(unittest.TestCase):
def test(self):
container = containers.DynamicContainer()
container.from_yaml_schema(f'{_SAMPLES_DIR}/schemasample/container-boto3-session.yml')
container.config.from_dict(
{
'aws_access_key_id': 'key',
'aws_secret_access_key': 'secret',
'aws_session_token': 'token',
'aws_region_name': 'us-east-1',
},
)
self.assertEqual(container.s3_client().__class__.__name__, 'S3')
self.assertEqual(container.sqs_client().__class__.__name__, 'SQS')

View File

@ -437,6 +437,7 @@ class AutoLoaderTest(unittest.TestCase):
def test_register_container(self): def test_register_container(self):
register_loader_containers(self.container) register_loader_containers(self.container)
importlib.reload(module) importlib.reload(module)
importlib.import_module('wiringsamples.imports')
service = module.test_function() service = module.test_function()
self.assertIsInstance(service, Service) self.assertIsInstance(service, Service)

View File

@ -1,4 +1,6 @@
import contextlib
from decimal import Decimal from decimal import Decimal
import importlib
import unittest import unittest
from dependency_injector.wiring import ( from dependency_injector.wiring import (
@ -6,6 +8,8 @@ from dependency_injector.wiring import (
Provide, Provide,
Provider, Provider,
Closing, Closing,
register_loader_containers,
unregister_loader_containers,
) )
from dependency_injector import containers, providers, errors from dependency_injector import containers, providers, errors
@ -409,27 +413,27 @@ class WiringAsyncInjectionsTest(AsyncTestCase):
self.assertEqual(asyncinjections.resource2.shutdown_counter, 2) self.assertEqual(asyncinjections.resource2.shutdown_counter, 2)
# class AutoLoaderTest(unittest.TestCase): class AutoLoaderTest(unittest.TestCase):
#
# container: Container container: Container
#
# def setUp(self) -> None: def setUp(self) -> None:
# self.container = Container(config={'a': {'b': {'c': 10}}}) self.container = Container(config={'a': {'b': {'c': 10}}})
# importlib.reload(module) importlib.reload(module)
#
# def tearDown(self) -> None: def tearDown(self) -> None:
# with contextlib.suppress(ValueError): with contextlib.suppress(ValueError):
# unregister_loader_containers(self.container) unregister_loader_containers(self.container)
#
# self.container.unwire() self.container.unwire()
#
# @classmethod @classmethod
# def tearDownClass(cls) -> None: def tearDownClass(cls) -> None:
# importlib.reload(module) importlib.reload(module)
#
# def test_register_container(self): def test_register_container(self):
# register_loader_containers(self.container) register_loader_containers(self.container)
# importlib.reload(module) importlib.reload(module)
#
# service = module.test_function() service = module.test_function()
# self.assertIsInstance(service, Service) self.assertIsInstance(service, Service)

View File

@ -10,6 +10,8 @@ deps=
fastapi fastapi
numpy numpy
scipy scipy
boto3
mypy_boto3_s3
extras= extras=
yaml yaml
pydantic pydantic
@ -68,6 +70,8 @@ commands=
deps= deps=
httpx httpx
fastapi fastapi
boto3
mypy_boto3_s3
extras= extras=
yaml yaml
flask flask