mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2024-11-22 09:36:48 +03:00
Add wiring hotfix for fastapi
This commit is contained in:
parent
26a89d664b
commit
97c33442e0
|
@ -316,13 +316,15 @@ def _get_patched(fn, injections, closing):
|
||||||
def _patched(*args, **kwargs):
|
def _patched(*args, **kwargs):
|
||||||
to_inject = kwargs.copy()
|
to_inject = kwargs.copy()
|
||||||
for injection, provider in injections.items():
|
for injection, provider in injections.items():
|
||||||
if injection not in kwargs:
|
if injection not in kwargs \
|
||||||
|
or _is_fastapi_default_arg_injection(injection, kwargs):
|
||||||
to_inject[injection] = provider()
|
to_inject[injection] = provider()
|
||||||
|
|
||||||
result = fn(*args, **to_inject)
|
result = fn(*args, **to_inject)
|
||||||
|
|
||||||
for injection, provider in closing.items():
|
for injection, provider in closing.items():
|
||||||
if injection in kwargs:
|
if injection in kwargs \
|
||||||
|
and not _is_fastapi_default_arg_injection(injection, kwargs):
|
||||||
continue
|
continue
|
||||||
if not isinstance(provider, providers.Resource):
|
if not isinstance(provider, providers.Resource):
|
||||||
continue
|
continue
|
||||||
|
@ -337,13 +339,15 @@ def _get_async_patched(fn, injections, closing):
|
||||||
async def _patched(*args, **kwargs):
|
async def _patched(*args, **kwargs):
|
||||||
to_inject = kwargs.copy()
|
to_inject = kwargs.copy()
|
||||||
for injection, provider in injections.items():
|
for injection, provider in injections.items():
|
||||||
if injection not in kwargs:
|
if injection not in kwargs \
|
||||||
|
or _is_fastapi_default_arg_injection(injection, kwargs):
|
||||||
to_inject[injection] = provider()
|
to_inject[injection] = provider()
|
||||||
|
|
||||||
result = await fn(*args, **to_inject)
|
result = await fn(*args, **to_inject)
|
||||||
|
|
||||||
for injection, provider in closing.items():
|
for injection, provider in closing.items():
|
||||||
if injection in kwargs:
|
if injection in kwargs \
|
||||||
|
and not _is_fastapi_default_arg_injection(injection, kwargs):
|
||||||
continue
|
continue
|
||||||
if not isinstance(provider, providers.Resource):
|
if not isinstance(provider, providers.Resource):
|
||||||
continue
|
continue
|
||||||
|
@ -353,6 +357,11 @@ def _get_async_patched(fn, injections, closing):
|
||||||
return _patched
|
return _patched
|
||||||
|
|
||||||
|
|
||||||
|
def _is_fastapi_default_arg_injection(injection, kwargs):
|
||||||
|
"""Check if injection is FastAPI injection of the default argument."""
|
||||||
|
return injection in kwargs and isinstance(kwargs[injection], _Marker)
|
||||||
|
|
||||||
|
|
||||||
def _is_patched(fn):
|
def _is_patched(fn):
|
||||||
return getattr(fn, '__wired__', False) is True
|
return getattr(fn, '__wired__', False) is True
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from dependency_injector.wiring import wire, Provide
|
from dependency_injector.wiring import wire, Provide, Closing
|
||||||
|
|
||||||
# Runtime import to avoid syntax errors in samples on Python < 3.5
|
# Runtime import to avoid syntax errors in samples on Python < 3.5
|
||||||
import os
|
import os
|
||||||
|
@ -225,3 +225,41 @@ class WiringTest(unittest.TestCase):
|
||||||
self.assertIs(result_2, service)
|
self.assertIs(result_2, service)
|
||||||
self.assertEqual(result_2.init_counter, 0)
|
self.assertEqual(result_2.init_counter, 0)
|
||||||
self.assertEqual(result_2.shutdown_counter, 0)
|
self.assertEqual(result_2.shutdown_counter, 0)
|
||||||
|
|
||||||
|
|
||||||
|
class WiringAndFastAPITest(unittest.TestCase):
|
||||||
|
|
||||||
|
container: Container
|
||||||
|
|
||||||
|
def test_bypass_marker_injection(self):
|
||||||
|
container = Container()
|
||||||
|
container.wire(modules=[module])
|
||||||
|
self.addCleanup(container.unwire)
|
||||||
|
|
||||||
|
service = module.test_function(service=Provide[Container.service])
|
||||||
|
self.assertIsInstance(service, Service)
|
||||||
|
|
||||||
|
def test_closing_resource_bypass_marker_injection(self):
|
||||||
|
from wiringsamples import resourceclosing
|
||||||
|
|
||||||
|
resourceclosing.Service.reset_counter()
|
||||||
|
|
||||||
|
container = resourceclosing.Container()
|
||||||
|
container.wire(modules=[resourceclosing])
|
||||||
|
self.addCleanup(container.unwire)
|
||||||
|
|
||||||
|
result_1 = resourceclosing.test_function(
|
||||||
|
service=Closing[Provide[resourceclosing.Container.service]],
|
||||||
|
)
|
||||||
|
self.assertIsInstance(result_1, resourceclosing.Service)
|
||||||
|
self.assertEqual(result_1.init_counter, 1)
|
||||||
|
self.assertEqual(result_1.shutdown_counter, 1)
|
||||||
|
|
||||||
|
result_2 = resourceclosing.test_function(
|
||||||
|
service=Closing[Provide[resourceclosing.Container.service]],
|
||||||
|
)
|
||||||
|
self.assertIsInstance(result_2, resourceclosing.Service)
|
||||||
|
self.assertEqual(result_2.init_counter, 2)
|
||||||
|
self.assertEqual(result_2.shutdown_counter, 2)
|
||||||
|
|
||||||
|
self.assertIsNot(result_1, result_2)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user