diff --git a/src/dependency_injector/providers.pxd b/src/dependency_injector/providers.pxd index 21ed7f22..50c16a27 100644 --- a/src/dependency_injector/providers.pxd +++ b/src/dependency_injector/providers.pxd @@ -239,6 +239,13 @@ cdef class Resource(Provider): cpdef object _provide(self, tuple args, dict kwargs) +cdef class ContextLocalResource(Resource): + cdef object _resource_context_var + cdef object _shutdowner_context_var + + cpdef object _provide(self, tuple args, dict kwargs) + + cdef class Container(Provider): cdef object _container_cls cdef dict _overriding_providers diff --git a/src/dependency_injector/providers.pyi b/src/dependency_injector/providers.pyi index 8f9b525a..d6168d64 100644 --- a/src/dependency_injector/providers.pyi +++ b/src/dependency_injector/providers.pyi @@ -525,6 +525,8 @@ class Resource(Provider[T]): def init(self) -> Optional[Awaitable[T]]: ... def shutdown(self) -> Optional[Awaitable]: ... +class ContextLocalResource(Resource[T]):... + class Container(Provider[T]): def __init__( self, diff --git a/src/dependency_injector/providers.pyx b/src/dependency_injector/providers.pyx index d8a8ab35..045b8dc7 100644 --- a/src/dependency_injector/providers.pyx +++ b/src/dependency_injector/providers.pyx @@ -3186,7 +3186,7 @@ cdef class ThreadLocalSingleton(BaseSingleton): return future_result self._storage.instance = instance - + return instance def _async_init_instance(self, future_result, result): @@ -3867,6 +3867,133 @@ cdef class Resource(Provider): return self._resource +cdef class ContextLocalResource(Resource): + _none = object() + + def __init__(self, provides=None, *args, **kwargs): + self._resource_context_var = ContextVar("_resource_context_var", default=self._none) + self._shutdowner_context_var = ContextVar("_shutdowner_context_var", default=self._none) + super().__init__(provides, *args, **kwargs) + + def __deepcopy__(self, memo): + """Create and return full copy of provider.""" + copied = memo.get(id(self)) + if copied is not None: + return copied + + if self._resource_context_var.get() != self._none: + raise Error("Can not copy initialized resource") + copied = _memorized_duplicate(self, memo) + copied.set_provides(_copy_if_provider(self.provides, memo)) + copied.set_args(*deepcopy_args(self, self.args, memo)) + copied.set_kwargs(**deepcopy_kwargs(self, self.kwargs, memo)) + + self._copy_overridings(copied, memo) + + return copied + + @property + def initialized(self): + """Check if resource is initialized.""" + return self._resource_context_var.get() != self._none + + + def shutdown(self): + """Shutdown resource.""" + if self._resource_context_var.get() == self._none : + self._reset_all_contex_vars() + if self._async_mode == ASYNC_MODE_ENABLED: + return NULL_AWAITABLE + return + if self._shutdowner_context_var.get(): + future = self._shutdowner_context_var.get()(None, None, None) + if __is_future_or_coroutine(future): + self._reset_all_contex_vars() + return ensure_future(self._shutdown_async(future)) + + + self._reset_all_contex_vars() + if self._async_mode == ASYNC_MODE_ENABLED: + return NULL_AWAITABLE + + def _reset_all_contex_vars(self): + self._resource_context_var.set(self._none) + self._shutdowner_context_var.set(self._none) + + + async def _shutdown_async(self, future) -> None: + await future + + + async def _handle_async_cm(self, obj) -> None: + resource = await obj.__aenter__() + return resource + + async def _provide_async(self, future): + try: + obj = await future + + if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'): + resource = await obj.__aenter__() + shutdowner = obj.__aexit__ + elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'): + resource = obj.__enter__() + shutdowner = obj.__exit__ + else: + resource = obj + shutdowner = None + + return resource, shutdowner + except: + raise + + cpdef object _provide(self, tuple args, dict kwargs): + if self._resource_context_var.get() != self._none: + return self._resource_context_var.get() + obj = __call( + self._provides, + args, + self._args, + self._args_len, + kwargs, + self._kwargs, + self._kwargs_len, + self._async_mode, + ) + + if __is_future_or_coroutine(obj): + future_result = asyncio.Future() + future = ensure_future(self._provide_async(obj)) + future.add_done_callback(functools.partial(self._async_init_instance, future_result)) + return future_result + elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'): + resource = obj.__enter__() + self._resource_context_var.set(resource) + self._shutdowner_context_var.set(obj.__exit__) + elif hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'): + resource = ensure_future(self._handle_async_cm(obj)) + self._resource_context_var.set(resource) + self._shutdowner_context_var.set(obj.__aexit__) + return resource + else: + self._resource_context_var.set(obj) + self._shutdowner_context_var.set(None) + + return self._resource_context_var.get() + + def _async_init_instance(self, future_result, result): + try: + resource, shutdowner = result.result() + except Exception as exception: + self._resource_context_var.set(self._none) + self._shutdowner_context_var.set(self._none) + future_result.set_exception(exception) + else: + self._resource_context_var.set(resource) + self._shutdowner_context_var.set(shutdowner) + future_result.set_result(resource) + + cdef class Container(Provider): """Container provider provides an instance of declarative container. diff --git a/tests/unit/providers/resource/test_context_local_resource_py38.py b/tests/unit/providers/resource/test_context_local_resource_py38.py new file mode 100644 index 00000000..63f3c9b6 --- /dev/null +++ b/tests/unit/providers/resource/test_context_local_resource_py38.py @@ -0,0 +1,478 @@ +"""Resource provider tests.""" + +import asyncio +import decimal +import sys +from contextlib import contextmanager +from typing import Any + +from pytest import mark, raises + +from dependency_injector import containers, errors, providers, resources + +def init_fn(*args, **kwargs): + return args, kwargs + + +def test_is_provider(): + assert providers.is_provider(providers.ContextLocalResource(init_fn)) is True + + +def test_init_optional_provides(): + provider = providers.ContextLocalResource() + provider.set_provides(init_fn) + assert provider.provides is init_fn + assert provider() == (tuple(), dict()) + + +def test_set_provides_returns_(): + provider = providers.ContextLocalResource() + assert provider.set_provides(init_fn) is provider + + +@mark.parametrize( + "str_name,cls", + [ + ("dependency_injector.providers.Factory", providers.Factory), + ("decimal.Decimal", decimal.Decimal), + ("list", list), + (".test_context_local_resource_py38.test_is_provider", test_is_provider), + ("test_is_provider", test_is_provider), + ], +) +def test_set_provides_string_imports(str_name, cls): + print( providers.ContextLocalResource(str_name).provides) + print(cls) + assert providers.ContextLocalResource(str_name).provides is cls + + +def test_provided_instance_provider(): + provider = providers.ContextLocalResource(init_fn) + assert isinstance(provider.provided, providers.ProvidedInstance) + + +def test_injection(): + resource = object() + + def _init(): + _init.counter += 1 + return resource + + _init.counter = 0 + + class Container(containers.DeclarativeContainer): + context_local_resource = providers.ContextLocalResource(_init) + dependency1 = providers.List(context_local_resource) + dependency2 = providers.List(context_local_resource) + + container = Container() + list1 = container.dependency1() + list2 = container.dependency2() + + assert list1 == [resource] + assert list1[0] is resource + + assert list2 == [resource] + assert list2[0] is resource + + assert _init.counter == 1 + + +def test_injection_in_different_context(): + def _init(): + return object() + + async def _async_init(): + return object() + + + class Container(containers.DeclarativeContainer): + context_local_resource = providers.ContextLocalResource(_init) + async_context_local_resource = providers.ContextLocalResource(_async_init) + + loop = asyncio.get_event_loop() + container = Container() + obj1 = loop.run_until_complete(container.async_context_local_resource()) + obj2 = loop.run_until_complete(container.async_context_local_resource()) + assert obj1!=obj2 + + obj3 = container.context_local_resource() + obj4 = container.context_local_resource() + + assert obj3==obj4 + + + + +def test_init_function(): + def _init(): + _init.counter += 1 + + _init.counter = 0 + + provider = providers.ContextLocalResource(_init) + + result1 = provider() + assert result1 is None + assert _init.counter == 1 + + result2 = provider() + assert result2 is None + assert _init.counter == 1 + + provider.shutdown() + + +def test_init_generator(): + def _init(): + _init.init_counter += 1 + yield + _init.shutdown_counter += 1 + + _init.init_counter = 0 + _init.shutdown_counter = 0 + + provider = providers.ContextLocalResource(_init) + + result1 = provider() + assert result1 is None + assert _init.init_counter == 1 + assert _init.shutdown_counter == 0 + + provider.shutdown() + assert _init.init_counter == 1 + assert _init.shutdown_counter == 1 + + result2 = provider() + assert result2 is None + assert _init.init_counter == 2 + assert _init.shutdown_counter == 1 + + provider.shutdown() + assert _init.init_counter == 2 + assert _init.shutdown_counter == 2 + + +def test_init_context_manager() -> None: + init_counter, shutdown_counter = 0, 0 + + @contextmanager + def _init(): + nonlocal init_counter, shutdown_counter + + init_counter += 1 + yield + shutdown_counter += 1 + + init_counter = 0 + shutdown_counter = 0 + + provider = providers.ContextLocalResource(_init) + + result1 = provider() + assert result1 is None + assert init_counter == 1 + assert shutdown_counter == 0 + + provider.shutdown() + assert init_counter == 1 + assert shutdown_counter == 1 + + result2 = provider() + assert result2 is None + assert init_counter == 2 + assert shutdown_counter == 1 + + provider.shutdown() + assert init_counter == 2 + assert shutdown_counter == 2 + + +def test_init_class(): + class TestResource(resources.Resource): + init_counter = 0 + shutdown_counter = 0 + + def init(self): + self.__class__.init_counter += 1 + + def shutdown(self, _): + self.__class__.shutdown_counter += 1 + + provider = providers.ContextLocalResource(TestResource) + + result1 = provider() + assert result1 is None + assert TestResource.init_counter == 1 + assert TestResource.shutdown_counter == 0 + + provider.shutdown() + assert TestResource.init_counter == 1 + assert TestResource.shutdown_counter == 1 + + result2 = provider() + assert result2 is None + assert TestResource.init_counter == 2 + assert TestResource.shutdown_counter == 1 + + provider.shutdown() + assert TestResource.init_counter == 2 + assert TestResource.shutdown_counter == 2 + + +def test_init_class_generic_typing(): + # See issue: https://github.com/ets-labs/python-dependency-injector/issues/488 + class TestDependency: + ... + + class TestResource(resources.Resource[TestDependency]): + def init(self, *args: Any, **kwargs: Any) -> TestDependency: + return TestDependency() + + def shutdown(self, resource: TestDependency) -> None: ... + + assert issubclass(TestResource, resources.Resource) is True + + +def test_init_class_abc_init_definition_is_required(): + class TestResource(resources.Resource): + ... + + with raises(TypeError) as context: + TestResource() + + assert "Can't instantiate abstract class TestResource" in str(context.value) + assert "init" in str(context.value) + + +def test_init_class_abc_shutdown_definition_is_not_required(): + class TestResource(resources.Resource): + def init(self): + ... + + assert hasattr(TestResource(), "shutdown") is True + + +def test_init_not_callable(): + provider = providers.ContextLocalResource(1) + with raises(TypeError, match=r"object is not callable"): + provider.init() + + +def test_init_and_shutdown(): + def _init(): + _init.init_counter += 1 + yield + _init.shutdown_counter += 1 + + _init.init_counter = 0 + _init.shutdown_counter = 0 + + provider = providers.ContextLocalResource(_init) + + result1 = provider.init() + assert result1 is None + assert _init.init_counter == 1 + assert _init.shutdown_counter == 0 + + provider.shutdown() + assert _init.init_counter == 1 + assert _init.shutdown_counter == 1 + + result2 = provider.init() + assert result2 is None + assert _init.init_counter == 2 + assert _init.shutdown_counter == 1 + + provider.shutdown() + assert _init.init_counter == 2 + assert _init.shutdown_counter == 2 + + +def test_shutdown_of_not_initialized(): + def _init(): + yield + + provider = providers.ContextLocalResource(_init) + + result = provider.shutdown() + assert result is None + + +def test_initialized(): + provider = providers.ContextLocalResource(init_fn) + assert provider.initialized is False + + provider.init() + assert provider.initialized is True + + provider.shutdown() + assert provider.initialized is False + + +def test_call_with_context_args(): + provider = providers.ContextLocalResource(init_fn, "i1", "i2") + assert provider("i3", i4=4) == (("i1", "i2", "i3"), {"i4": 4}) + + +def test_fluent_interface(): + provider = providers.ContextLocalResource(init_fn) \ + .add_args(1, 2) \ + .add_kwargs(a3=3, a4=4) + assert provider() == ((1, 2), {"a3": 3, "a4": 4}) + + +def test_set_args(): + provider = providers.ContextLocalResource(init_fn) \ + .add_args(1, 2) \ + .set_args(3, 4) + assert provider.args == (3, 4) + + +def test_clear_args(): + provider = providers.ContextLocalResource(init_fn) \ + .add_args(1, 2) \ + .clear_args() + assert provider.args == tuple() + + +def test_set_kwargs(): + provider = providers.ContextLocalResource(init_fn) \ + .add_kwargs(a1="i1", a2="i2") \ + .set_kwargs(a3="i3", a4="i4") + assert provider.kwargs == {"a3": "i3", "a4": "i4"} + + +def test_clear_kwargs(): + provider = providers.ContextLocalResource(init_fn) \ + .add_kwargs(a1="i1", a2="i2") \ + .clear_kwargs() + assert provider.kwargs == {} + + +def test_call_overridden(): + provider = providers.ContextLocalResource(init_fn, 1) + overriding_provider1 = providers.ContextLocalResource(init_fn, 2) + overriding_provider2 = providers.ContextLocalResource(init_fn, 3) + + provider.override(overriding_provider1) + provider.override(overriding_provider2) + + instance1 = provider() + instance2 = provider() + + assert instance1 is instance2 + assert instance1 == ((3,), {}) + assert instance2 == ((3,), {}) + + +def test_deepcopy(): + provider = providers.ContextLocalResource(init_fn, 1, 2, a3=3, a4=4) + + provider_copy = providers.deepcopy(provider) + + assert provider is not provider_copy + assert provider.args == provider_copy.args + assert provider.kwargs == provider_copy.kwargs + assert isinstance(provider, providers.ContextLocalResource) + + +def test_deepcopy_initialized(): + provider = providers.ContextLocalResource(init_fn) + provider.init() + + with raises(errors.Error): + providers.deepcopy(provider) + + +def test_deepcopy_from_memo(): + provider = providers.ContextLocalResource(init_fn) + provider_copy_memo = providers.ContextLocalResource(init_fn) + + provider_copy = providers.deepcopy( + provider, + memo={id(provider): provider_copy_memo}, + ) + + assert provider_copy is provider_copy_memo + + +def test_deepcopy_args(): + provider = providers.ContextLocalResource(init_fn) + dependent_provider1 = providers.Factory(list) + dependent_provider2 = providers.Factory(dict) + + provider.add_args(dependent_provider1, dependent_provider2) + + provider_copy = providers.deepcopy(provider) + dependent_provider_copy1 = provider_copy.args[0] + dependent_provider_copy2 = provider_copy.args[1] + + assert provider.args != provider_copy.args + + assert dependent_provider1.cls is dependent_provider_copy1.cls + assert dependent_provider1 is not dependent_provider_copy1 + + assert dependent_provider2.cls is dependent_provider_copy2.cls + assert dependent_provider2 is not dependent_provider_copy2 + + +def test_deepcopy_kwargs(): + provider = providers.ContextLocalResource(init_fn) + dependent_provider1 = providers.Factory(list) + dependent_provider2 = providers.Factory(dict) + + provider.add_kwargs(d1=dependent_provider1, d2=dependent_provider2) + + provider_copy = providers.deepcopy(provider) + dependent_provider_copy1 = provider_copy.kwargs["d1"] + dependent_provider_copy2 = provider_copy.kwargs["d2"] + + assert provider.kwargs != provider_copy.kwargs + + assert dependent_provider1.cls is dependent_provider_copy1.cls + assert dependent_provider1 is not dependent_provider_copy1 + + assert dependent_provider2.cls is dependent_provider_copy2.cls + assert dependent_provider2 is not dependent_provider_copy2 + + +def test_deepcopy_overridden(): + provider = providers.ContextLocalResource(init_fn) + object_provider = providers.Object(object()) + + provider.override(object_provider) + + provider_copy = providers.deepcopy(provider) + object_provider_copy = provider_copy.overridden[0] + + assert provider is not provider_copy + assert provider.args == provider_copy.args + assert isinstance(provider, providers.ContextLocalResource) + + assert object_provider is not object_provider_copy + assert isinstance(object_provider_copy, providers.Object) + + +def test_deepcopy_with_sys_streams(): + provider = providers.ContextLocalResource(init_fn) + provider.add_args(sys.stdin, sys.stdout, sys.stderr) + + provider_copy = providers.deepcopy(provider) + + assert provider is not provider_copy + assert isinstance(provider_copy, providers.ContextLocalResource) + assert provider.args[0] is sys.stdin + assert provider.args[1] is sys.stdout + assert provider.args[2] is sys.stderr + + +def test_repr(): + provider = providers.ContextLocalResource(init_fn) + + assert repr(provider) == ( + "".format( + repr(init_fn), + hex(id(provider)), + ) + )