mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2024-11-25 11:04:01 +03:00
Add wiring hotfix for fastapi
This commit is contained in:
parent
26a89d664b
commit
97c33442e0
|
@ -316,13 +316,15 @@ def _get_patched(fn, injections, closing):
|
|||
def _patched(*args, **kwargs):
|
||||
to_inject = kwargs.copy()
|
||||
for injection, provider in injections.items():
|
||||
if injection not in kwargs:
|
||||
if injection not in kwargs \
|
||||
or _is_fastapi_default_arg_injection(injection, kwargs):
|
||||
to_inject[injection] = provider()
|
||||
|
||||
result = fn(*args, **to_inject)
|
||||
|
||||
for injection, provider in closing.items():
|
||||
if injection in kwargs:
|
||||
if injection in kwargs \
|
||||
and not _is_fastapi_default_arg_injection(injection, kwargs):
|
||||
continue
|
||||
if not isinstance(provider, providers.Resource):
|
||||
continue
|
||||
|
@ -337,13 +339,15 @@ def _get_async_patched(fn, injections, closing):
|
|||
async def _patched(*args, **kwargs):
|
||||
to_inject = kwargs.copy()
|
||||
for injection, provider in injections.items():
|
||||
if injection not in kwargs:
|
||||
if injection not in kwargs \
|
||||
or _is_fastapi_default_arg_injection(injection, kwargs):
|
||||
to_inject[injection] = provider()
|
||||
|
||||
result = await fn(*args, **to_inject)
|
||||
|
||||
for injection, provider in closing.items():
|
||||
if injection in kwargs:
|
||||
if injection in kwargs \
|
||||
and not _is_fastapi_default_arg_injection(injection, kwargs):
|
||||
continue
|
||||
if not isinstance(provider, providers.Resource):
|
||||
continue
|
||||
|
@ -353,6 +357,11 @@ def _get_async_patched(fn, injections, closing):
|
|||
return _patched
|
||||
|
||||
|
||||
def _is_fastapi_default_arg_injection(injection, kwargs):
|
||||
"""Check if injection is FastAPI injection of the default argument."""
|
||||
return injection in kwargs and isinstance(kwargs[injection], _Marker)
|
||||
|
||||
|
||||
def _is_patched(fn):
|
||||
return getattr(fn, '__wired__', False) is True
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from decimal import Decimal
|
||||
import unittest
|
||||
|
||||
from dependency_injector.wiring import wire, Provide
|
||||
from dependency_injector.wiring import wire, Provide, Closing
|
||||
|
||||
# Runtime import to avoid syntax errors in samples on Python < 3.5
|
||||
import os
|
||||
|
@ -225,3 +225,41 @@ class WiringTest(unittest.TestCase):
|
|||
self.assertIs(result_2, service)
|
||||
self.assertEqual(result_2.init_counter, 0)
|
||||
self.assertEqual(result_2.shutdown_counter, 0)
|
||||
|
||||
|
||||
class WiringAndFastAPITest(unittest.TestCase):
|
||||
|
||||
container: Container
|
||||
|
||||
def test_bypass_marker_injection(self):
|
||||
container = Container()
|
||||
container.wire(modules=[module])
|
||||
self.addCleanup(container.unwire)
|
||||
|
||||
service = module.test_function(service=Provide[Container.service])
|
||||
self.assertIsInstance(service, Service)
|
||||
|
||||
def test_closing_resource_bypass_marker_injection(self):
|
||||
from wiringsamples import resourceclosing
|
||||
|
||||
resourceclosing.Service.reset_counter()
|
||||
|
||||
container = resourceclosing.Container()
|
||||
container.wire(modules=[resourceclosing])
|
||||
self.addCleanup(container.unwire)
|
||||
|
||||
result_1 = resourceclosing.test_function(
|
||||
service=Closing[Provide[resourceclosing.Container.service]],
|
||||
)
|
||||
self.assertIsInstance(result_1, resourceclosing.Service)
|
||||
self.assertEqual(result_1.init_counter, 1)
|
||||
self.assertEqual(result_1.shutdown_counter, 1)
|
||||
|
||||
result_2 = resourceclosing.test_function(
|
||||
service=Closing[Provide[resourceclosing.Container.service]],
|
||||
)
|
||||
self.assertIsInstance(result_2, resourceclosing.Service)
|
||||
self.assertEqual(result_2.init_counter, 2)
|
||||
self.assertEqual(result_2.shutdown_counter, 2)
|
||||
|
||||
self.assertIsNot(result_1, result_2)
|
||||
|
|
Loading…
Reference in New Issue
Block a user