From d82d9fb8222acc91960947d01ef310f4d5aa2b63 Mon Sep 17 00:00:00 2001 From: ZipFile Date: Wed, 1 Jan 2025 21:22:29 +0200 Subject: [PATCH] Improve debugability of deepcopy errors (#839) --- src/dependency_injector/errors.py | 21 ++++++ src/dependency_injector/providers.pyi | 16 ++++- src/dependency_injector/providers.pyx | 65 +++++++++++++++---- .../unit/providers/utils/test_deepcopy_py3.py | 65 +++++++++++++++++++ 4 files changed, 155 insertions(+), 12 deletions(-) create mode 100644 tests/unit/providers/utils/test_deepcopy_py3.py diff --git a/src/dependency_injector/errors.py b/src/dependency_injector/errors.py index 7b11862e..407313ce 100644 --- a/src/dependency_injector/errors.py +++ b/src/dependency_injector/errors.py @@ -10,3 +10,24 @@ class Error(Exception): class NoSuchProviderError(Error, AttributeError): """Error that is raised when provider lookup is failed.""" + + +class NonCopyableArgumentError(Error): + """Error that is raised when provider argument is not deep-copyable.""" + + index: int + keyword: str + provider: object + + def __init__(self, provider: object, index: int = -1, keyword: str = "") -> None: + self.provider = provider + self.index = index + self.keyword = keyword + + def __str__(self) -> str: + s = ( + f"keyword argument {self.keyword}" + if self.keyword + else f"argument at index {self.index}" + ) + return f"Couldn't copy {s} for provider {self.provider!r}" diff --git a/src/dependency_injector/providers.pyi b/src/dependency_injector/providers.pyi index 83d6ca88..b7fbf211 100644 --- a/src/dependency_injector/providers.pyi +++ b/src/dependency_injector/providers.pyi @@ -530,7 +530,21 @@ def is_delegated(instance: Any) -> bool: ... def represent_provider(provider: Provider, provides: Any) -> str: ... -def deepcopy(instance: Any, memo: Optional[_Dict[Any, Any]] = None): Any: ... +def deepcopy(instance: Any, memo: Optional[_Dict[Any, Any]] = None) -> Any: ... + + +def deepcopy_args( + provider: Provider[Any], + args: Tuple[Any, ...], + memo: Optional[_Dict[int, Any]] = None, +) -> Tuple[Any, ...]: ... + + +def deepcopy_kwargs( + provider: Provider[Any], + kwargs: _Dict[str, Any], + memo: Optional[_Dict[int, Any]] = None, +) -> Dict[str, Any]: ... def merge_dicts(dict1: _Dict[Any, Any], dict2: _Dict[Any, Any]) -> _Dict[Any, Any]: ... diff --git a/src/dependency_injector/providers.pyx b/src/dependency_injector/providers.pyx index 2db9fa2f..84c1fad7 100644 --- a/src/dependency_injector/providers.pyx +++ b/src/dependency_injector/providers.pyx @@ -71,6 +71,7 @@ except ImportError: from .errors import ( Error, NoSuchProviderError, + NonCopyableArgumentError, ) cimport cython @@ -1252,8 +1253,8 @@ cdef class Callable(Provider): copied = _memorized_duplicate(self, memo) copied.set_provides(_copy_if_provider(self.provides, memo)) - copied.set_args(*deepcopy(self.args, memo)) - copied.set_kwargs(**deepcopy(self.kwargs, memo)) + copied.set_args(*deepcopy_args(self, self.args, memo)) + copied.set_kwargs(**deepcopy_kwargs(self, self.kwargs, memo)) self._copy_overridings(copied, memo) return copied @@ -2539,8 +2540,8 @@ cdef class Factory(Provider): copied = _memorized_duplicate(self, memo) copied.set_provides(_copy_if_provider(self.provides, memo)) - copied.set_args(*deepcopy(self.args, memo)) - copied.set_kwargs(**deepcopy(self.kwargs, memo)) + copied.set_args(*deepcopy_args(self, self.args, memo)) + copied.set_kwargs(**deepcopy_kwargs(self, self.kwargs, memo)) copied.set_attributes(**deepcopy(self.attributes, memo)) self._copy_overridings(copied, memo) return copied @@ -2838,8 +2839,8 @@ cdef class BaseSingleton(Provider): copied = _memorized_duplicate(self, memo) copied.set_provides(_copy_if_provider(self.provides, memo)) - copied.set_args(*deepcopy(self.args, memo)) - copied.set_kwargs(**deepcopy(self.kwargs, memo)) + copied.set_args(*deepcopy_args(self, self.args, memo)) + copied.set_kwargs(**deepcopy_kwargs(self, self.kwargs, memo)) copied.set_attributes(**deepcopy(self.attributes, memo)) self._copy_overridings(copied, memo) return copied @@ -3451,7 +3452,7 @@ cdef class List(Provider): return copied copied = _memorized_duplicate(self, memo) - copied.set_args(*deepcopy(self.args, memo)) + copied.set_args(*deepcopy_args(self, self.args, memo)) self._copy_overridings(copied, memo) return copied @@ -3674,8 +3675,8 @@ cdef class Resource(Provider): copied = _memorized_duplicate(self, memo) copied.set_provides(_copy_if_provider(self.provides, memo)) - copied.set_args(*deepcopy(self.args, memo)) - copied.set_kwargs(**deepcopy(self.kwargs, memo)) + copied.set_args(*deepcopy_args(self, self.args, memo)) + copied.set_kwargs(**deepcopy_kwargs(self, self.kwargs, memo)) self._copy_overridings(copied, memo) @@ -4525,8 +4526,8 @@ cdef class MethodCaller(Provider): copied = _memorized_duplicate(self, memo) copied.set_provides(_copy_if_provider(self.provides, memo)) - copied.set_args(*deepcopy(self.args, memo)) - copied.set_kwargs(**deepcopy(self.kwargs, memo)) + copied.set_args(*deepcopy_args(self, self.args, memo)) + copied.set_kwargs(**deepcopy_kwargs(self, self.kwargs, memo)) self._copy_overridings(copied, memo) return copied @@ -4927,6 +4928,48 @@ cpdef object deepcopy(object instance, dict memo=None): return copy.deepcopy(instance, memo) +cpdef tuple deepcopy_args( + Provider provider, + tuple args, + dict[int, object] memo = None, +): + """A wrapper for deepcopy for positional arguments. + + Used to improve debugability of objects that cannot be deep-copied. + """ + + cdef list[object] out = [] + + for i, arg in enumerate(args): + try: + out.append(copy.deepcopy(arg, memo)) + except Exception as e: + raise NonCopyableArgumentError(provider, index=i) from e + + return tuple(out) + + +cpdef dict[str, object] deepcopy_kwargs( + Provider provider, + dict[str, object] kwargs, + dict[int, object] memo = None, +): + """A wrapper for deepcopy for keyword arguments. + + Used to improve debugability of objects that cannot be deep-copied. + """ + + cdef dict[str, object] out = {} + + for name, arg in kwargs.items(): + try: + out[name] = copy.deepcopy(arg, memo) + except Exception as e: + raise NonCopyableArgumentError(provider, keyword=name) from e + + return out + + def __add_sys_streams(memo): """Add system streams to memo dictionary. diff --git a/tests/unit/providers/utils/test_deepcopy_py3.py b/tests/unit/providers/utils/test_deepcopy_py3.py new file mode 100644 index 00000000..57f7a7da --- /dev/null +++ b/tests/unit/providers/utils/test_deepcopy_py3.py @@ -0,0 +1,65 @@ +import sys +from typing import Any, Dict, NoReturn + +from pytest import raises + +from dependency_injector.errors import NonCopyableArgumentError +from dependency_injector.providers import ( + Provider, + deepcopy, + deepcopy_args, + deepcopy_kwargs, +) + + +class NonCopiable: + def __deepcopy__(self, memo: Dict[int, Any]) -> NoReturn: + raise NotImplementedError + + +def test_deepcopy_streams_not_copied() -> None: + l = [sys.stdin, sys.stdout, sys.stderr] + assert deepcopy(l) == l + + +def test_deepcopy_args() -> None: + provider = Provider[None]() + copiable = NonCopiable() + memo: Dict[int, Any] = {id(copiable): copiable} + + assert deepcopy_args(provider, (1, copiable), memo) == (1, copiable) + + +def test_deepcopy_args_non_copiable() -> None: + provider = Provider[None]() + copiable = NonCopiable() + memo: Dict[int, Any] = {id(copiable): copiable} + + with raises( + NonCopyableArgumentError, + match=r"^Couldn't copy argument at index 3 for provider ", + ): + deepcopy_args(provider, (1, copiable, object(), NonCopiable()), memo) + + +def test_deepcopy_kwargs() -> None: + provider = Provider[None]() + copiable = NonCopiable() + memo: Dict[int, Any] = {id(copiable): copiable} + + assert deepcopy_kwargs(provider, {"x": 1, "y": copiable}, memo) == { + "x": 1, + "y": copiable, + } + + +def test_deepcopy_kwargs_non_copiable() -> None: + provider = Provider[None]() + copiable = NonCopiable() + memo: Dict[int, Any] = {id(copiable): copiable} + + with raises( + NonCopyableArgumentError, + match=r"^Couldn't copy keyword argument z for provider ", + ): + deepcopy_kwargs(provider, {"x": 1, "y": copiable, "z": NonCopiable()}, memo)