Add support for List and Dict providers to _locate_dependent_closing_args

This commit is contained in:
ZipFile 2025-02-01 13:11:51 +00:00
parent 72a316c2b1
commit f8ab127856
3 changed files with 95 additions and 107 deletions

View File

@ -1,26 +1,26 @@
"""Wiring module.""" """Wiring module."""
import functools import functools
import inspect
import importlib import importlib
import importlib.machinery import importlib.machinery
import inspect
import pkgutil import pkgutil
import warnings
import sys import sys
import warnings
from types import ModuleType from types import ModuleType
from typing import ( from typing import (
Optional,
Iterable,
Iterator,
Callable,
Any, Any,
Tuple, Callable,
Dict, Dict,
Generic, Generic,
TypeVar, Iterable,
Type, Iterator,
Union, Optional,
Set, Set,
Tuple,
Type,
TypeVar,
Union,
cast, cast,
) )
@ -643,21 +643,18 @@ def _fetch_reference_injections( # noqa: C901
def _locate_dependent_closing_args( def _locate_dependent_closing_args(
provider: providers.Provider, provider: providers.Provider, closing_deps: Dict[str, providers.Provider]
) -> Dict[str, providers.Provider]: ) -> Dict[str, providers.Provider]:
if not hasattr(provider, "args"): for arg in [
return {} *getattr(provider, "args", []),
*getattr(provider, "kwargs", {}).values(),
closing_deps = {} ]:
for arg in [*provider.args, *provider.kwargs.values()]: if not isinstance(arg, providers.Provider):
if not isinstance(arg, providers.Provider) or not hasattr(arg, "args"):
continue continue
if isinstance(arg, providers.Resource): if isinstance(arg, providers.Resource):
return {str(id(arg)): arg} closing_deps[str(id(arg))] = arg
if arg.args or arg.kwargs:
closing_deps |= _locate_dependent_closing_args(arg)
return closing_deps _locate_dependent_closing_args(arg, closing_deps)
def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> None: def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> None:
@ -681,7 +678,8 @@ def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Non
if injection in patched_callable.reference_closing: if injection in patched_callable.reference_closing:
patched_callable.add_closing(injection, provider) patched_callable.add_closing(injection, provider)
deps = _locate_dependent_closing_args(provider) deps = {}
_locate_dependent_closing_args(provider, deps)
for key, dep in deps.items(): for key, dep in deps.items():
patched_callable.add_closing(key, dep) patched_callable.add_closing(key, dep)
@ -1030,8 +1028,8 @@ _inspect_filter = InspectFilter()
_loader = AutoLoader() _loader = AutoLoader()
# Optimizations # Optimizations
from ._cwiring import _get_sync_patched # noqa
from ._cwiring import _async_inject # noqa from ._cwiring import _async_inject # noqa
from ._cwiring import _get_sync_patched # noqa
# Wiring uses the following Python wrapper because there is # Wiring uses the following Python wrapper because there is

View File

@ -1,35 +1,49 @@
from typing import Any, Dict, List, Optional
from dependency_injector import containers, providers from dependency_injector import containers, providers
from dependency_injector.wiring import inject, Provide, Closing from dependency_injector.wiring import Closing, Provide, inject
class Singleton: class Counter:
pass def __init__(self) -> None:
self._init = 0
self._shutdown = 0
def init(self) -> None:
self._init += 1
def shutdown(self) -> None:
self._shutdown += 1
def reset(self) -> None:
self._init = 0
self._shutdown = 0
class Service: class Service:
init_counter: int = 0 def __init__(self, counter: Optional[Counter] = None, **dependencies: Any) -> None:
shutdown_counter: int = 0 self.counter = counter or Counter()
dependency: Singleton = None self.dependencies = dependencies
@classmethod def init(self) -> None:
def reset_counter(cls): self.counter.init()
cls.init_counter = 0
cls.shutdown_counter = 0
@classmethod def shutdown(self) -> None:
def init(cls, dependency: Singleton = None): self.counter.shutdown()
if dependency:
cls.dependency = dependency
cls.init_counter += 1
@classmethod @property
def shutdown(cls): def init_counter(self) -> int:
cls.shutdown_counter += 1 return self.counter._init
@property
def shutdown_counter(self) -> int:
return self.counter._shutdown
class FactoryService: class FactoryService:
def __init__(self, service: Service): def __init__(self, service: Service, service2: Service):
self.service = service self.service = service
self.service2 = service2
class NestedService: class NestedService:
@ -37,42 +51,28 @@ class NestedService:
self.factory_service = factory_service self.factory_service = factory_service
def init_service(): def init_service(counter: Counter, _list: List[int], _dict: Dict[str, int]):
service = Service() service = Service(counter, _list=_list, _dict=_dict)
service.init() service.init()
yield service yield service
service.shutdown() service.shutdown()
def init_service_with_singleton(singleton: Singleton):
service = Service()
service.init(singleton)
yield service
service.shutdown()
class Container(containers.DeclarativeContainer): class Container(containers.DeclarativeContainer):
counter = providers.Singleton(Counter)
service = providers.Resource(init_service) _list = providers.List(
factory_service = providers.Factory(FactoryService, service) providers.Callable(lambda a: a, a=1), providers.Callable(lambda b: b, 2)
)
_dict = providers.Dict(
a=providers.Callable(lambda a: a, a=3), b=providers.Callable(lambda b: b, 4)
)
service = providers.Resource(init_service, counter, _list, _dict=_dict)
service2 = providers.Resource(init_service, counter, _list, _dict=_dict)
factory_service = providers.Factory(FactoryService, service, service2)
factory_service_kwargs = providers.Factory( factory_service_kwargs = providers.Factory(
FactoryService, FactoryService,
service=service service=service,
) service2=service2,
nested_service = providers.Factory(NestedService, factory_service)
class ContainerSingleton(containers.DeclarativeContainer):
singleton = providers.Singleton(Singleton)
service = providers.Resource(
init_service_with_singleton,
singleton
)
factory_service = providers.Factory(FactoryService, service)
factory_service_kwargs = providers.Factory(
FactoryService,
service=service
) )
nested_service = providers.Factory(NestedService, factory_service) nested_service = providers.Factory(NestedService, factory_service)
@ -84,20 +84,20 @@ def test_function(service: Service = Closing[Provide["service"]]):
@inject @inject
def test_function_dependency( def test_function_dependency(
factory: FactoryService = Closing[Provide["factory_service"]] factory: FactoryService = Closing[Provide["factory_service"]],
): ):
return factory return factory
@inject @inject
def test_function_dependency_kwargs( def test_function_dependency_kwargs(
factory: FactoryService = Closing[Provide["factory_service_kwargs"]] factory: FactoryService = Closing[Provide["factory_service_kwargs"]],
): ):
return factory return factory
@inject @inject
def test_function_nested_dependency( def test_function_nested_dependency(
nested: NestedService = Closing[Provide["nested_service"]] nested: NestedService = Closing[Provide["nested_service"]],
): ):
return nested return nested

View File

@ -2,13 +2,13 @@
from decimal import Decimal from decimal import Decimal
from pytest import fixture, mark, raises
from samples.wiringstringids import module, package, resourceclosing
from samples.wiringstringids.container import Container, SubContainer
from samples.wiringstringids.service import Service
from dependency_injector import errors from dependency_injector import errors
from dependency_injector.wiring import Closing, Provide, Provider, wire from dependency_injector.wiring import Closing, Provide, Provider, wire
from pytest import fixture, mark, raises
from samples.wiringstringids import module, package, resourceclosing
from samples.wiringstringids.service import Service
from samples.wiringstringids.container import Container, SubContainer
@fixture(autouse=True) @fixture(autouse=True)
@ -33,14 +33,12 @@ def subcontainer():
container.unwire() container.unwire()
@fixture(params=[ @fixture
resourceclosing.Container,
resourceclosing.ContainerSingleton,
])
def resourceclosing_container(request): def resourceclosing_container(request):
container = request.param() container = resourceclosing.Container()
container.wire(modules=[resourceclosing]) container.wire(modules=[resourceclosing])
yield container with container.reset_singletons():
yield container
container.unwire() container.unwire()
@ -277,72 +275,65 @@ def test_wire_multiple_containers():
@mark.usefixtures("resourceclosing_container") @mark.usefixtures("resourceclosing_container")
def test_closing_resource(): def test_closing_resource():
resourceclosing.Service.reset_counter()
result_1 = resourceclosing.test_function() result_1 = resourceclosing.test_function()
assert isinstance(result_1, resourceclosing.Service) assert isinstance(result_1, resourceclosing.Service)
assert result_1.init_counter == 1 assert result_1.init_counter == 1
assert result_1.shutdown_counter == 1 assert result_1.shutdown_counter == 1
assert result_1.dependencies == {"_list": [1, 2], "_dict": {"a": 3, "b": 4}}
result_2 = resourceclosing.test_function() result_2 = resourceclosing.test_function()
assert isinstance(result_2, resourceclosing.Service) assert isinstance(result_2, resourceclosing.Service)
assert result_2.init_counter == 2 assert result_2.init_counter == 2
assert result_2.shutdown_counter == 2 assert result_2.shutdown_counter == 2
assert result_1.dependencies == {"_list": [1, 2], "_dict": {"a": 3, "b": 4}}
assert result_1 is not result_2 assert result_1 is not result_2
@mark.usefixtures("resourceclosing_container") @mark.usefixtures("resourceclosing_container")
def test_closing_dependency_resource(): def test_closing_dependency_resource():
resourceclosing.Service.reset_counter()
result_1 = resourceclosing.test_function_dependency() result_1 = resourceclosing.test_function_dependency()
assert isinstance(result_1, resourceclosing.FactoryService) assert isinstance(result_1, resourceclosing.FactoryService)
assert result_1.service.init_counter == 1 assert result_1.service.init_counter == 2
assert result_1.service.shutdown_counter == 1 assert result_1.service.shutdown_counter == 2
result_2 = resourceclosing.test_function_dependency() result_2 = resourceclosing.test_function_dependency()
assert isinstance(result_2, resourceclosing.FactoryService) assert isinstance(result_2, resourceclosing.FactoryService)
assert result_2.service.init_counter == 2 assert result_2.service.init_counter == 4
assert result_2.service.shutdown_counter == 2 assert result_2.service.shutdown_counter == 4
@mark.usefixtures("resourceclosing_container") @mark.usefixtures("resourceclosing_container")
def test_closing_dependency_resource_kwargs(): def test_closing_dependency_resource_kwargs():
resourceclosing.Service.reset_counter()
result_1 = resourceclosing.test_function_dependency_kwargs() result_1 = resourceclosing.test_function_dependency_kwargs()
assert isinstance(result_1, resourceclosing.FactoryService) assert isinstance(result_1, resourceclosing.FactoryService)
assert result_1.service.init_counter == 1 assert result_1.service.init_counter == 2
assert result_1.service.shutdown_counter == 1 assert result_1.service.shutdown_counter == 2
result_2 = resourceclosing.test_function_dependency_kwargs() result_2 = resourceclosing.test_function_dependency_kwargs()
assert isinstance(result_2, resourceclosing.FactoryService) assert isinstance(result_2, resourceclosing.FactoryService)
assert result_2.service.init_counter == 2 assert result_2.service.init_counter == 4
assert result_2.service.shutdown_counter == 2 assert result_2.service.shutdown_counter == 4
@mark.usefixtures("resourceclosing_container") @mark.usefixtures("resourceclosing_container")
def test_closing_nested_dependency_resource(): def test_closing_nested_dependency_resource():
resourceclosing.Service.reset_counter()
result_1 = resourceclosing.test_function_nested_dependency() result_1 = resourceclosing.test_function_nested_dependency()
assert isinstance(result_1, resourceclosing.NestedService) assert isinstance(result_1, resourceclosing.NestedService)
assert result_1.factory_service.service.init_counter == 1 assert result_1.factory_service.service.init_counter == 2
assert result_1.factory_service.service.shutdown_counter == 1 assert result_1.factory_service.service.shutdown_counter == 2
result_2 = resourceclosing.test_function_nested_dependency() result_2 = resourceclosing.test_function_nested_dependency()
assert isinstance(result_2, resourceclosing.NestedService) assert isinstance(result_2, resourceclosing.NestedService)
assert result_2.factory_service.service.init_counter == 2 assert result_2.factory_service.service.init_counter == 4
assert result_2.factory_service.service.shutdown_counter == 2 assert result_2.factory_service.service.shutdown_counter == 4
assert result_1 is not result_2 assert result_1 is not result_2
@mark.usefixtures("resourceclosing_container") @mark.usefixtures("resourceclosing_container")
def test_closing_resource_bypass_marker_injection(): def test_closing_resource_bypass_marker_injection():
resourceclosing.Service.reset_counter()
result_1 = resourceclosing.test_function(service=Closing[Provide["service"]]) result_1 = resourceclosing.test_function(service=Closing[Provide["service"]])
assert isinstance(result_1, resourceclosing.Service) assert isinstance(result_1, resourceclosing.Service)
assert result_1.init_counter == 1 assert result_1.init_counter == 1
@ -358,7 +349,6 @@ def test_closing_resource_bypass_marker_injection():
@mark.usefixtures("resourceclosing_container") @mark.usefixtures("resourceclosing_container")
def test_closing_resource_context(): def test_closing_resource_context():
resourceclosing.Service.reset_counter()
service = resourceclosing.Service() service = resourceclosing.Service()
result_1 = resourceclosing.test_function(service=service) result_1 = resourceclosing.test_function(service=service)