Improve debugability of deepcopy errors (#839)

This commit is contained in:
ZipFile 2025-01-01 21:22:29 +02:00 committed by GitHub
parent 3ba4704bc1
commit d82d9fb822
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 155 additions and 12 deletions

View File

@ -10,3 +10,24 @@ class Error(Exception):
class NoSuchProviderError(Error, AttributeError): class NoSuchProviderError(Error, AttributeError):
"""Error that is raised when provider lookup is failed.""" """Error that is raised when provider lookup is failed."""
class NonCopyableArgumentError(Error):
"""Error that is raised when provider argument is not deep-copyable."""
index: int
keyword: str
provider: object
def __init__(self, provider: object, index: int = -1, keyword: str = "") -> None:
self.provider = provider
self.index = index
self.keyword = keyword
def __str__(self) -> str:
s = (
f"keyword argument {self.keyword}"
if self.keyword
else f"argument at index {self.index}"
)
return f"Couldn't copy {s} for provider {self.provider!r}"

View File

@ -530,7 +530,21 @@ def is_delegated(instance: Any) -> bool: ...
def represent_provider(provider: Provider, provides: Any) -> str: ... def represent_provider(provider: Provider, provides: Any) -> str: ...
def deepcopy(instance: Any, memo: Optional[_Dict[Any, Any]] = None): Any: ... def deepcopy(instance: Any, memo: Optional[_Dict[Any, Any]] = None) -> Any: ...
def deepcopy_args(
provider: Provider[Any],
args: Tuple[Any, ...],
memo: Optional[_Dict[int, Any]] = None,
) -> Tuple[Any, ...]: ...
def deepcopy_kwargs(
provider: Provider[Any],
kwargs: _Dict[str, Any],
memo: Optional[_Dict[int, Any]] = None,
) -> Dict[str, Any]: ...
def merge_dicts(dict1: _Dict[Any, Any], dict2: _Dict[Any, Any]) -> _Dict[Any, Any]: ... def merge_dicts(dict1: _Dict[Any, Any], dict2: _Dict[Any, Any]) -> _Dict[Any, Any]: ...

View File

@ -71,6 +71,7 @@ except ImportError:
from .errors import ( from .errors import (
Error, Error,
NoSuchProviderError, NoSuchProviderError,
NonCopyableArgumentError,
) )
cimport cython cimport cython
@ -1252,8 +1253,8 @@ cdef class Callable(Provider):
copied = _memorized_duplicate(self, memo) copied = _memorized_duplicate(self, memo)
copied.set_provides(_copy_if_provider(self.provides, memo)) copied.set_provides(_copy_if_provider(self.provides, memo))
copied.set_args(*deepcopy(self.args, memo)) copied.set_args(*deepcopy_args(self, self.args, memo))
copied.set_kwargs(**deepcopy(self.kwargs, memo)) copied.set_kwargs(**deepcopy_kwargs(self, self.kwargs, memo))
self._copy_overridings(copied, memo) self._copy_overridings(copied, memo)
return copied return copied
@ -2539,8 +2540,8 @@ cdef class Factory(Provider):
copied = _memorized_duplicate(self, memo) copied = _memorized_duplicate(self, memo)
copied.set_provides(_copy_if_provider(self.provides, memo)) copied.set_provides(_copy_if_provider(self.provides, memo))
copied.set_args(*deepcopy(self.args, memo)) copied.set_args(*deepcopy_args(self, self.args, memo))
copied.set_kwargs(**deepcopy(self.kwargs, memo)) copied.set_kwargs(**deepcopy_kwargs(self, self.kwargs, memo))
copied.set_attributes(**deepcopy(self.attributes, memo)) copied.set_attributes(**deepcopy(self.attributes, memo))
self._copy_overridings(copied, memo) self._copy_overridings(copied, memo)
return copied return copied
@ -2838,8 +2839,8 @@ cdef class BaseSingleton(Provider):
copied = _memorized_duplicate(self, memo) copied = _memorized_duplicate(self, memo)
copied.set_provides(_copy_if_provider(self.provides, memo)) copied.set_provides(_copy_if_provider(self.provides, memo))
copied.set_args(*deepcopy(self.args, memo)) copied.set_args(*deepcopy_args(self, self.args, memo))
copied.set_kwargs(**deepcopy(self.kwargs, memo)) copied.set_kwargs(**deepcopy_kwargs(self, self.kwargs, memo))
copied.set_attributes(**deepcopy(self.attributes, memo)) copied.set_attributes(**deepcopy(self.attributes, memo))
self._copy_overridings(copied, memo) self._copy_overridings(copied, memo)
return copied return copied
@ -3451,7 +3452,7 @@ cdef class List(Provider):
return copied return copied
copied = _memorized_duplicate(self, memo) copied = _memorized_duplicate(self, memo)
copied.set_args(*deepcopy(self.args, memo)) copied.set_args(*deepcopy_args(self, self.args, memo))
self._copy_overridings(copied, memo) self._copy_overridings(copied, memo)
return copied return copied
@ -3674,8 +3675,8 @@ cdef class Resource(Provider):
copied = _memorized_duplicate(self, memo) copied = _memorized_duplicate(self, memo)
copied.set_provides(_copy_if_provider(self.provides, memo)) copied.set_provides(_copy_if_provider(self.provides, memo))
copied.set_args(*deepcopy(self.args, memo)) copied.set_args(*deepcopy_args(self, self.args, memo))
copied.set_kwargs(**deepcopy(self.kwargs, memo)) copied.set_kwargs(**deepcopy_kwargs(self, self.kwargs, memo))
self._copy_overridings(copied, memo) self._copy_overridings(copied, memo)
@ -4525,8 +4526,8 @@ cdef class MethodCaller(Provider):
copied = _memorized_duplicate(self, memo) copied = _memorized_duplicate(self, memo)
copied.set_provides(_copy_if_provider(self.provides, memo)) copied.set_provides(_copy_if_provider(self.provides, memo))
copied.set_args(*deepcopy(self.args, memo)) copied.set_args(*deepcopy_args(self, self.args, memo))
copied.set_kwargs(**deepcopy(self.kwargs, memo)) copied.set_kwargs(**deepcopy_kwargs(self, self.kwargs, memo))
self._copy_overridings(copied, memo) self._copy_overridings(copied, memo)
return copied return copied
@ -4927,6 +4928,48 @@ cpdef object deepcopy(object instance, dict memo=None):
return copy.deepcopy(instance, memo) return copy.deepcopy(instance, memo)
cpdef tuple deepcopy_args(
Provider provider,
tuple args,
dict[int, object] memo = None,
):
"""A wrapper for deepcopy for positional arguments.
Used to improve debugability of objects that cannot be deep-copied.
"""
cdef list[object] out = []
for i, arg in enumerate(args):
try:
out.append(copy.deepcopy(arg, memo))
except Exception as e:
raise NonCopyableArgumentError(provider, index=i) from e
return tuple(out)
cpdef dict[str, object] deepcopy_kwargs(
Provider provider,
dict[str, object] kwargs,
dict[int, object] memo = None,
):
"""A wrapper for deepcopy for keyword arguments.
Used to improve debugability of objects that cannot be deep-copied.
"""
cdef dict[str, object] out = {}
for name, arg in kwargs.items():
try:
out[name] = copy.deepcopy(arg, memo)
except Exception as e:
raise NonCopyableArgumentError(provider, keyword=name) from e
return out
def __add_sys_streams(memo): def __add_sys_streams(memo):
"""Add system streams to memo dictionary. """Add system streams to memo dictionary.

View File

@ -0,0 +1,65 @@
import sys
from typing import Any, Dict, NoReturn
from pytest import raises
from dependency_injector.errors import NonCopyableArgumentError
from dependency_injector.providers import (
Provider,
deepcopy,
deepcopy_args,
deepcopy_kwargs,
)
class NonCopiable:
def __deepcopy__(self, memo: Dict[int, Any]) -> NoReturn:
raise NotImplementedError
def test_deepcopy_streams_not_copied() -> None:
l = [sys.stdin, sys.stdout, sys.stderr]
assert deepcopy(l) == l
def test_deepcopy_args() -> None:
provider = Provider[None]()
copiable = NonCopiable()
memo: Dict[int, Any] = {id(copiable): copiable}
assert deepcopy_args(provider, (1, copiable), memo) == (1, copiable)
def test_deepcopy_args_non_copiable() -> None:
provider = Provider[None]()
copiable = NonCopiable()
memo: Dict[int, Any] = {id(copiable): copiable}
with raises(
NonCopyableArgumentError,
match=r"^Couldn't copy argument at index 3 for provider ",
):
deepcopy_args(provider, (1, copiable, object(), NonCopiable()), memo)
def test_deepcopy_kwargs() -> None:
provider = Provider[None]()
copiable = NonCopiable()
memo: Dict[int, Any] = {id(copiable): copiable}
assert deepcopy_kwargs(provider, {"x": 1, "y": copiable}, memo) == {
"x": 1,
"y": copiable,
}
def test_deepcopy_kwargs_non_copiable() -> None:
provider = Provider[None]()
copiable = NonCopiable()
memo: Dict[int, Any] = {id(copiable): copiable}
with raises(
NonCopyableArgumentError,
match=r"^Couldn't copy keyword argument z for provider ",
):
deepcopy_kwargs(provider, {"x": 1, "y": copiable, "z": NonCopiable()}, memo)