diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index 362493d9..1c671bb1 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -316,13 +316,15 @@ def _get_patched(fn, injections, closing): def _patched(*args, **kwargs): to_inject = kwargs.copy() for injection, provider in injections.items(): - if injection not in kwargs: + if injection not in kwargs \ + or _is_fastapi_default_arg_injection(injection, kwargs): to_inject[injection] = provider() result = fn(*args, **to_inject) for injection, provider in closing.items(): - if injection in kwargs: + if injection in kwargs \ + and not _is_fastapi_default_arg_injection(injection, kwargs): continue if not isinstance(provider, providers.Resource): continue @@ -337,13 +339,15 @@ def _get_async_patched(fn, injections, closing): async def _patched(*args, **kwargs): to_inject = kwargs.copy() for injection, provider in injections.items(): - if injection not in kwargs: + if injection not in kwargs \ + or _is_fastapi_default_arg_injection(injection, kwargs): to_inject[injection] = provider() result = await fn(*args, **to_inject) for injection, provider in closing.items(): - if injection in kwargs: + if injection in kwargs \ + and not _is_fastapi_default_arg_injection(injection, kwargs): continue if not isinstance(provider, providers.Resource): continue @@ -353,6 +357,11 @@ def _get_async_patched(fn, injections, closing): return _patched +def _is_fastapi_default_arg_injection(injection, kwargs): + """Check if injection is FastAPI injection of the default argument.""" + return injection in kwargs and isinstance(kwargs[injection], _Marker) + + def _is_patched(fn): return getattr(fn, '__wired__', False) is True diff --git a/tests/unit/wiring/test_wiring_py36.py b/tests/unit/wiring/test_wiring_py36.py index 753c88b4..b1886ea7 100644 --- a/tests/unit/wiring/test_wiring_py36.py +++ b/tests/unit/wiring/test_wiring_py36.py @@ -1,7 +1,7 @@ from decimal import Decimal import unittest -from dependency_injector.wiring import wire, Provide +from dependency_injector.wiring import wire, Provide, Closing # Runtime import to avoid syntax errors in samples on Python < 3.5 import os @@ -225,3 +225,41 @@ class WiringTest(unittest.TestCase): self.assertIs(result_2, service) self.assertEqual(result_2.init_counter, 0) self.assertEqual(result_2.shutdown_counter, 0) + + +class WiringAndFastAPITest(unittest.TestCase): + + container: Container + + def test_bypass_marker_injection(self): + container = Container() + container.wire(modules=[module]) + self.addCleanup(container.unwire) + + service = module.test_function(service=Provide[Container.service]) + self.assertIsInstance(service, Service) + + def test_closing_resource_bypass_marker_injection(self): + from wiringsamples import resourceclosing + + resourceclosing.Service.reset_counter() + + container = resourceclosing.Container() + container.wire(modules=[resourceclosing]) + self.addCleanup(container.unwire) + + result_1 = resourceclosing.test_function( + service=Closing[Provide[resourceclosing.Container.service]], + ) + self.assertIsInstance(result_1, resourceclosing.Service) + self.assertEqual(result_1.init_counter, 1) + self.assertEqual(result_1.shutdown_counter, 1) + + result_2 = resourceclosing.test_function( + service=Closing[Provide[resourceclosing.Container.service]], + ) + self.assertIsInstance(result_2, resourceclosing.Service) + self.assertEqual(result_2.init_counter, 2) + self.assertEqual(result_2.shutdown_counter, 2) + + self.assertIsNot(result_1, result_2)