Add wiring hotfix for fastapi

This commit is contained in:
Roman Mogylatov 2020-11-12 15:54:49 -05:00
parent 26a89d664b
commit 97c33442e0
2 changed files with 52 additions and 5 deletions

View File

@ -316,13 +316,15 @@ def _get_patched(fn, injections, closing):
def _patched(*args, **kwargs):
to_inject = kwargs.copy()
for injection, provider in injections.items():
if injection not in kwargs:
if injection not in kwargs \
or _is_fastapi_default_arg_injection(injection, kwargs):
to_inject[injection] = provider()
result = fn(*args, **to_inject)
for injection, provider in closing.items():
if injection in kwargs:
if injection in kwargs \
and not _is_fastapi_default_arg_injection(injection, kwargs):
continue
if not isinstance(provider, providers.Resource):
continue
@ -337,13 +339,15 @@ def _get_async_patched(fn, injections, closing):
async def _patched(*args, **kwargs):
to_inject = kwargs.copy()
for injection, provider in injections.items():
if injection not in kwargs:
if injection not in kwargs \
or _is_fastapi_default_arg_injection(injection, kwargs):
to_inject[injection] = provider()
result = await fn(*args, **to_inject)
for injection, provider in closing.items():
if injection in kwargs:
if injection in kwargs \
and not _is_fastapi_default_arg_injection(injection, kwargs):
continue
if not isinstance(provider, providers.Resource):
continue
@ -353,6 +357,11 @@ def _get_async_patched(fn, injections, closing):
return _patched
def _is_fastapi_default_arg_injection(injection, kwargs):
"""Check if injection is FastAPI injection of the default argument."""
return injection in kwargs and isinstance(kwargs[injection], _Marker)
def _is_patched(fn):
return getattr(fn, '__wired__', False) is True

View File

@ -1,7 +1,7 @@
from decimal import Decimal
import unittest
from dependency_injector.wiring import wire, Provide
from dependency_injector.wiring import wire, Provide, Closing
# Runtime import to avoid syntax errors in samples on Python < 3.5
import os
@ -225,3 +225,41 @@ class WiringTest(unittest.TestCase):
self.assertIs(result_2, service)
self.assertEqual(result_2.init_counter, 0)
self.assertEqual(result_2.shutdown_counter, 0)
class WiringAndFastAPITest(unittest.TestCase):
container: Container
def test_bypass_marker_injection(self):
container = Container()
container.wire(modules=[module])
self.addCleanup(container.unwire)
service = module.test_function(service=Provide[Container.service])
self.assertIsInstance(service, Service)
def test_closing_resource_bypass_marker_injection(self):
from wiringsamples import resourceclosing
resourceclosing.Service.reset_counter()
container = resourceclosing.Container()
container.wire(modules=[resourceclosing])
self.addCleanup(container.unwire)
result_1 = resourceclosing.test_function(
service=Closing[Provide[resourceclosing.Container.service]],
)
self.assertIsInstance(result_1, resourceclosing.Service)
self.assertEqual(result_1.init_counter, 1)
self.assertEqual(result_1.shutdown_counter, 1)
result_2 = resourceclosing.test_function(
service=Closing[Provide[resourceclosing.Container.service]],
)
self.assertIsInstance(result_2, resourceclosing.Service)
self.assertEqual(result_2.init_counter, 2)
self.assertEqual(result_2.shutdown_counter, 2)
self.assertIsNot(result_1, result_2)