diff --git a/src/dependency_injector/containers.pyi b/src/dependency_injector/containers.pyi index ec41ea8e..c637f19c 100644 --- a/src/dependency_injector/containers.pyi +++ b/src/dependency_injector/containers.pyi @@ -30,12 +30,14 @@ class WiringConfiguration: packages: List[Any] from_package: Optional[str] auto_wire: bool + keep_cache: bool def __init__( self, modules: Optional[Iterable[Any]] = None, packages: Optional[Iterable[Any]] = None, from_package: Optional[str] = None, auto_wire: bool = True, + keep_cache: bool = False, ) -> None: ... class Container: diff --git a/src/dependency_injector/containers.pyx b/src/dependency_injector/containers.pyx index 2f4c4af5..bd0a4821 100644 --- a/src/dependency_injector/containers.pyx +++ b/src/dependency_injector/containers.pyx @@ -20,14 +20,15 @@ from .wiring import wire, unwire class WiringConfiguration: """Container wiring configuration.""" - def __init__(self, modules=None, packages=None, from_package=None, auto_wire=True): + def __init__(self, modules=None, packages=None, from_package=None, auto_wire=True, keep_cache=False): self.modules = [*modules] if modules else [] self.packages = [*packages] if packages else [] self.from_package = from_package self.auto_wire = auto_wire + self.keep_cache = keep_cache def __deepcopy__(self, memo=None): - return self.__class__(self.modules, self.packages, self.from_package, self.auto_wire) + return self.__class__(self.modules, self.packages, self.from_package, self.auto_wire, self.keep_cache) class Container: @@ -258,7 +259,7 @@ class DynamicContainer(Container): """Check if auto wiring is needed.""" return self.wiring_config.auto_wire is True - def wire(self, modules=None, packages=None, from_package=None): + def wire(self, modules=None, packages=None, from_package=None, keep_cache=None): """Wire container providers with provided packages and modules. :rtype: None @@ -289,10 +290,14 @@ class DynamicContainer(Container): if not modules and not packages: return + if keep_cache is None: + keep_cache = self.wiring_config.keep_cache + wire( container=self, modules=modules, packages=packages, + keep_cache=keep_cache, ) if modules: diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index b8534ee5..9c976c2d 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -26,6 +26,13 @@ from typing import ( from typing_extensions import Self +try: + from functools import cache +except ImportError: + from functools import lru_cache + + cache = lru_cache(maxsize=None) + # Hotfix, see: https://github.com/ets-labs/python-dependency-injector/issues/362 if sys.version_info >= (3, 9): from types import GenericAlias @@ -409,6 +416,7 @@ def wire( # noqa: C901 *, modules: Optional[Iterable[ModuleType]] = None, packages: Optional[Iterable[ModuleType]] = None, + keep_cache: bool = False, ) -> None: """Wire container providers with provided packages and modules.""" modules = [*modules] if modules else [] @@ -449,6 +457,9 @@ def wire( # noqa: C901 for patched in _patched_registry.get_callables_from_module(module): _bind_injections(patched, providers_map) + if not keep_cache: + clear_cache() + def unwire( # noqa: C901 *, @@ -604,6 +615,7 @@ def _extract_marker(parameter: inspect.Parameter) -> Optional["_Marker"]: return marker +@cache def _fetch_reference_injections( # noqa: C901 fn: Callable[..., Any], ) -> Tuple[Dict[str, Any], Dict[str, Any]]: @@ -1078,3 +1090,8 @@ def _get_members_and_annotated(obj: Any) -> Iterable[Tuple[str, Any]]: member = args[1] members.append((annotation_name, member)) return members + + +def clear_cache() -> None: + """Clear all caches used by :func:`wire`.""" + _fetch_reference_injections.cache_clear() diff --git a/tests/unit/wiring/test_cache.py b/tests/unit/wiring/test_cache.py new file mode 100644 index 00000000..d6c1f45f --- /dev/null +++ b/tests/unit/wiring/test_cache.py @@ -0,0 +1,46 @@ +"""Tests for string module and package names.""" + +from typing import Iterator, Optional + +from pytest import fixture, mark +from samples.wiring.container import Container + +from dependency_injector.wiring import _fetch_reference_injections + + +@fixture +def container() -> Iterator[Container]: + container = Container() + yield container + container.unwire() + + +@mark.parametrize( + ["arg_value", "wc_value", "empty_cache"], + [ + (None, False, True), + (False, True, True), + (True, False, False), + (None, True, False), + ], +) +def test_fetch_reference_injections_cache( + container: Container, + arg_value: Optional[bool], + wc_value: bool, + empty_cache: bool, +) -> None: + container.wiring_config.keep_cache = wc_value + container.wire( + modules=["samples.wiring.module"], + packages=["samples.wiring.package"], + keep_cache=arg_value, + ) + cache_info = _fetch_reference_injections.cache_info() + + if empty_cache: + assert cache_info == (0, 0, None, 0) + else: + assert cache_info.hits > 0 + assert cache_info.misses > 0 + assert cache_info.currsize > 0