Fix Closing dependency resolution (#852)

Co-authored-by: federinik <federico.tomasi@outlook.com>
Co-authored-by: jazzthief <mynameisyegor@gmail.com>
This commit is contained in:
ZipFile 2025-02-23 18:31:34 +02:00 committed by GitHub
parent 8b625d81ad
commit 09efbffab1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 137 additions and 61 deletions

View File

@ -1,26 +1,26 @@
"""Wiring module.""" """Wiring module."""
import functools import functools
import inspect
import importlib import importlib
import importlib.machinery import importlib.machinery
import inspect
import pkgutil import pkgutil
import warnings
import sys import sys
import warnings
from types import ModuleType from types import ModuleType
from typing import ( from typing import (
Optional,
Iterable,
Iterator,
Callable,
Any, Any,
Tuple, Callable,
Dict, Dict,
Generic, Generic,
TypeVar, Iterable,
Type, Iterator,
Union, Optional,
Set, Set,
Tuple,
Type,
TypeVar,
Union,
cast, cast,
) )
@ -643,21 +643,18 @@ def _fetch_reference_injections( # noqa: C901
def _locate_dependent_closing_args( def _locate_dependent_closing_args(
provider: providers.Provider, provider: providers.Provider, closing_deps: Dict[str, providers.Provider]
) -> Dict[str, providers.Provider]: ) -> Dict[str, providers.Provider]:
if not hasattr(provider, "args"): for arg in [
return {} *getattr(provider, "args", []),
*getattr(provider, "kwargs", {}).values(),
closing_deps = {} ]:
for arg in provider.args: if not isinstance(arg, providers.Provider):
if not isinstance(arg, providers.Provider) or not hasattr(arg, "args"):
continue continue
if isinstance(arg, providers.Resource):
closing_deps[str(id(arg))] = arg
if not arg.args and isinstance(arg, providers.Resource): _locate_dependent_closing_args(arg, closing_deps)
return {str(id(arg)): arg}
else:
closing_deps += _locate_dependent_closing_args(arg)
return closing_deps
def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> None: def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> None:
@ -681,7 +678,8 @@ def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Non
if injection in patched_callable.reference_closing: if injection in patched_callable.reference_closing:
patched_callable.add_closing(injection, provider) patched_callable.add_closing(injection, provider)
deps = _locate_dependent_closing_args(provider) deps = {}
_locate_dependent_closing_args(provider, deps)
for key, dep in deps.items(): for key, dep in deps.items():
patched_callable.add_closing(key, dep) patched_callable.add_closing(key, dep)

View File

@ -1,41 +1,80 @@
from typing import Any, Dict, List, Optional
from dependency_injector import containers, providers from dependency_injector import containers, providers
from dependency_injector.wiring import inject, Provide, Closing from dependency_injector.wiring import Closing, Provide, inject
class Counter:
def __init__(self) -> None:
self._init = 0
self._shutdown = 0
def init(self) -> None:
self._init += 1
def shutdown(self) -> None:
self._shutdown += 1
def reset(self) -> None:
self._init = 0
self._shutdown = 0
class Service: class Service:
init_counter: int = 0 def __init__(self, counter: Optional[Counter] = None, **dependencies: Any) -> None:
shutdown_counter: int = 0 self.counter = counter or Counter()
self.dependencies = dependencies
@classmethod def init(self) -> None:
def reset_counter(cls): self.counter.init()
cls.init_counter = 0
cls.shutdown_counter = 0
@classmethod def shutdown(self) -> None:
def init(cls): self.counter.shutdown()
cls.init_counter += 1
@classmethod @property
def shutdown(cls): def init_counter(self) -> int:
cls.shutdown_counter += 1 return self.counter._init
@property
def shutdown_counter(self) -> int:
return self.counter._shutdown
class FactoryService: class FactoryService:
def __init__(self, service: Service): def __init__(self, service: Service, service2: Service):
self.service = service self.service = service
self.service2 = service2
def init_service(): class NestedService:
service = Service() def __init__(self, factory_service: FactoryService):
self.factory_service = factory_service
def init_service(counter: Counter, _list: List[int], _dict: Dict[str, int]):
service = Service(counter, _list=_list, _dict=_dict)
service.init() service.init()
yield service yield service
service.shutdown() service.shutdown()
class Container(containers.DeclarativeContainer): class Container(containers.DeclarativeContainer):
counter = providers.Singleton(Counter)
service = providers.Resource(init_service) _list = providers.List(
factory_service = providers.Factory(FactoryService, service) providers.Callable(lambda a: a, a=1), providers.Callable(lambda b: b, 2)
)
_dict = providers.Dict(
a=providers.Callable(lambda a: a, a=3), b=providers.Callable(lambda b: b, 4)
)
service = providers.Resource(init_service, counter, _list, _dict=_dict)
service2 = providers.Resource(init_service, counter, _list, _dict=_dict)
factory_service = providers.Factory(FactoryService, service, service2)
factory_service_kwargs = providers.Factory(
FactoryService,
service=service,
service2=service2,
)
nested_service = providers.Factory(NestedService, factory_service)
@inject @inject
@ -44,5 +83,21 @@ def test_function(service: Service = Closing[Provide["service"]]):
@inject @inject
def test_function_dependency(factory: FactoryService = Closing[Provide["factory_service"]]): def test_function_dependency(
factory: FactoryService = Closing[Provide["factory_service"]],
):
return factory return factory
@inject
def test_function_dependency_kwargs(
factory: FactoryService = Closing[Provide["factory_service_kwargs"]],
):
return factory
@inject
def test_function_nested_dependency(
nested: NestedService = Closing[Provide["nested_service"]],
):
return nested

View File

@ -2,13 +2,13 @@
from decimal import Decimal from decimal import Decimal
from pytest import fixture, mark, raises
from samples.wiringstringids import module, package, resourceclosing
from samples.wiringstringids.container import Container, SubContainer
from samples.wiringstringids.service import Service
from dependency_injector import errors from dependency_injector import errors
from dependency_injector.wiring import Closing, Provide, Provider, wire from dependency_injector.wiring import Closing, Provide, Provider, wire
from pytest import fixture, mark, raises
from samples.wiringstringids import module, package, resourceclosing
from samples.wiringstringids.service import Service
from samples.wiringstringids.container import Container, SubContainer
@fixture(autouse=True) @fixture(autouse=True)
@ -34,10 +34,11 @@ def subcontainer():
@fixture @fixture
def resourceclosing_container(): def resourceclosing_container(request):
container = resourceclosing.Container() container = resourceclosing.Container()
container.wire(modules=[resourceclosing]) container.wire(modules=[resourceclosing])
yield container with container.reset_singletons():
yield container
container.unwire() container.unwire()
@ -274,42 +275,65 @@ def test_wire_multiple_containers():
@mark.usefixtures("resourceclosing_container") @mark.usefixtures("resourceclosing_container")
def test_closing_resource(): def test_closing_resource():
resourceclosing.Service.reset_counter()
result_1 = resourceclosing.test_function() result_1 = resourceclosing.test_function()
assert isinstance(result_1, resourceclosing.Service) assert isinstance(result_1, resourceclosing.Service)
assert result_1.init_counter == 1 assert result_1.init_counter == 1
assert result_1.shutdown_counter == 1 assert result_1.shutdown_counter == 1
assert result_1.dependencies == {"_list": [1, 2], "_dict": {"a": 3, "b": 4}}
result_2 = resourceclosing.test_function() result_2 = resourceclosing.test_function()
assert isinstance(result_2, resourceclosing.Service) assert isinstance(result_2, resourceclosing.Service)
assert result_2.init_counter == 2 assert result_2.init_counter == 2
assert result_2.shutdown_counter == 2 assert result_2.shutdown_counter == 2
assert result_1.dependencies == {"_list": [1, 2], "_dict": {"a": 3, "b": 4}}
assert result_1 is not result_2 assert result_1 is not result_2
@mark.usefixtures("resourceclosing_container") @mark.usefixtures("resourceclosing_container")
def test_closing_dependency_resource(): def test_closing_dependency_resource():
resourceclosing.Service.reset_counter()
result_1 = resourceclosing.test_function_dependency() result_1 = resourceclosing.test_function_dependency()
assert isinstance(result_1, resourceclosing.FactoryService) assert isinstance(result_1, resourceclosing.FactoryService)
assert result_1.service.init_counter == 1 assert result_1.service.init_counter == 2
assert result_1.service.shutdown_counter == 1 assert result_1.service.shutdown_counter == 2
result_2 = resourceclosing.test_function_dependency() result_2 = resourceclosing.test_function_dependency()
assert isinstance(result_2, resourceclosing.FactoryService) assert isinstance(result_2, resourceclosing.FactoryService)
assert result_2.service.init_counter == 2 assert result_2.service.init_counter == 4
assert result_2.service.shutdown_counter == 2 assert result_2.service.shutdown_counter == 4
@mark.usefixtures("resourceclosing_container")
def test_closing_dependency_resource_kwargs():
result_1 = resourceclosing.test_function_dependency_kwargs()
assert isinstance(result_1, resourceclosing.FactoryService)
assert result_1.service.init_counter == 2
assert result_1.service.shutdown_counter == 2
result_2 = resourceclosing.test_function_dependency_kwargs()
assert isinstance(result_2, resourceclosing.FactoryService)
assert result_2.service.init_counter == 4
assert result_2.service.shutdown_counter == 4
@mark.usefixtures("resourceclosing_container")
def test_closing_nested_dependency_resource():
result_1 = resourceclosing.test_function_nested_dependency()
assert isinstance(result_1, resourceclosing.NestedService)
assert result_1.factory_service.service.init_counter == 2
assert result_1.factory_service.service.shutdown_counter == 2
result_2 = resourceclosing.test_function_nested_dependency()
assert isinstance(result_2, resourceclosing.NestedService)
assert result_2.factory_service.service.init_counter == 4
assert result_2.factory_service.service.shutdown_counter == 4
assert result_1 is not result_2 assert result_1 is not result_2
@mark.usefixtures("resourceclosing_container") @mark.usefixtures("resourceclosing_container")
def test_closing_resource_bypass_marker_injection(): def test_closing_resource_bypass_marker_injection():
resourceclosing.Service.reset_counter()
result_1 = resourceclosing.test_function(service=Closing[Provide["service"]]) result_1 = resourceclosing.test_function(service=Closing[Provide["service"]])
assert isinstance(result_1, resourceclosing.Service) assert isinstance(result_1, resourceclosing.Service)
assert result_1.init_counter == 1 assert result_1.init_counter == 1
@ -325,7 +349,6 @@ def test_closing_resource_bypass_marker_injection():
@mark.usefixtures("resourceclosing_container") @mark.usefixtures("resourceclosing_container")
def test_closing_resource_context(): def test_closing_resource_context():
resourceclosing.Service.reset_counter()
service = resourceclosing.Service() service = resourceclosing.Service()
result_1 = resourceclosing.test_function(service=service) result_1 = resourceclosing.test_function(service=service)