Add warnings for unresolved markers

This commit is contained in:
ZipFile 2025-09-17 22:08:02 +00:00
parent 759d89e9bd
commit 1222827a5e
7 changed files with 165 additions and 17 deletions

View File

@ -251,6 +251,32 @@ To inject a container use special identifier ``<container>``:
def foo(container: Container = Provide["<container>"]) -> None: def foo(container: Container = Provide["<container>"]) -> None:
... ...
Caveats
~~~~~~~
While using string identifiers you may not notice a typo in the identifier until the code is executed.
In order to aid with catching such errors early, you may pass `warn_unresolved=True` to the ``wire`` method and/or :class:`WiringConfiguration`:
.. code-block:: python
:emphasize-lines: 4
class Container(containers.DeclarativeContainer):
wiring_config = containers.WiringConfiguration(
modules=["yourapp.module"],
warn_unresolved=True,
)
Or:
.. code-block:: python
:emphasize-lines: 4
container = Container()
container.wire(
modules=["yourapp.module"],
warn_unresolved=True,
)
Making injections into modules and class attributes Making injections into modules and class attributes
--------------------------------------------------- ---------------------------------------------------

View File

@ -54,7 +54,7 @@ classifiers = [
dynamic = ["version"] dynamic = ["version"]
dependencies = [ dependencies = [
# typing.Annotated since v3.9 # typing.Annotated since v3.9
# typing.Self since v3.11 # typing.Self and typing.assert_never since v3.11
"typing-extensions; python_version<'3.11'", "typing-extensions; python_version<'3.11'",
] ]

View File

@ -72,6 +72,7 @@ class Container:
modules: Optional[Iterable[Any]] = None, modules: Optional[Iterable[Any]] = None,
packages: Optional[Iterable[Any]] = None, packages: Optional[Iterable[Any]] = None,
from_package: Optional[str] = None, from_package: Optional[str] = None,
warn_unresolved: bool = False,
) -> None: ... ) -> None: ...
def unwire(self) -> None: ... def unwire(self) -> None: ...
def init_resources(self, resource_type: Type[Resource[Any]] = Resource) -> Optional[Awaitable[None]]: ... def init_resources(self, resource_type: Type[Resource[Any]] = Resource) -> Optional[Awaitable[None]]: ...

View File

@ -20,15 +20,31 @@ from .wiring import wire, unwire
class WiringConfiguration: class WiringConfiguration:
"""Container wiring configuration.""" """Container wiring configuration."""
def __init__(self, modules=None, packages=None, from_package=None, auto_wire=True, keep_cache=False): def __init__(
self,
modules=None,
packages=None,
from_package=None,
auto_wire=True,
keep_cache=False,
warn_unresolved=False,
):
self.modules = [*modules] if modules else [] self.modules = [*modules] if modules else []
self.packages = [*packages] if packages else [] self.packages = [*packages] if packages else []
self.from_package = from_package self.from_package = from_package
self.auto_wire = auto_wire self.auto_wire = auto_wire
self.keep_cache = keep_cache self.keep_cache = keep_cache
self.warn_unresolved = warn_unresolved
def __deepcopy__(self, memo=None): def __deepcopy__(self, memo=None):
return self.__class__(self.modules, self.packages, self.from_package, self.auto_wire, self.keep_cache) return self.__class__(
self.modules,
self.packages,
self.from_package,
self.auto_wire,
self.keep_cache,
self.warn_unresolved,
)
class Container: class Container:
@ -259,7 +275,14 @@ class DynamicContainer(Container):
"""Check if auto wiring is needed.""" """Check if auto wiring is needed."""
return self.wiring_config.auto_wire is True return self.wiring_config.auto_wire is True
def wire(self, modules=None, packages=None, from_package=None, keep_cache=None): def wire(
self,
modules=None,
packages=None,
from_package=None,
keep_cache=None,
warn_unresolved=False,
):
"""Wire container providers with provided packages and modules. """Wire container providers with provided packages and modules.
:rtype: None :rtype: None
@ -298,6 +321,7 @@ class DynamicContainer(Container):
modules=modules, modules=modules,
packages=packages, packages=packages,
keep_cache=keep_cache, keep_cache=keep_cache,
warn_unresolved=warn_unresolved,
) )
if modules: if modules:

View File

@ -25,15 +25,14 @@ from typing import (
Type, Type,
TypeVar, TypeVar,
Union, Union,
assert_never,
cast, cast,
) )
from warnings import warn from warnings import warn
try: try:
from typing import Self from typing import Self, assert_never
except ImportError: except ImportError:
from typing_extensions import Self from typing_extensions import Self, assert_never
try: try:
from functools import cache from functools import cache
@ -140,6 +139,10 @@ class DIWiringWarning(RuntimeWarning):
"""Base class for all warnings raised by the wiring module.""" """Base class for all warnings raised by the wiring module."""
class UnresolvedMarkerWarning(DIWiringWarning):
"""Warning raised when a marker with string identifier cannot be resolved against container."""
class PatchedRegistry: class PatchedRegistry:
def __init__(self) -> None: def __init__(self) -> None:
@ -434,6 +437,7 @@ def wire( # noqa: C901
modules: Optional[Iterable[ModuleType]] = None, modules: Optional[Iterable[ModuleType]] = None,
packages: Optional[Iterable[ModuleType]] = None, packages: Optional[Iterable[ModuleType]] = None,
keep_cache: bool = False, keep_cache: bool = False,
warn_unresolved: bool = False,
) -> None: ) -> None:
"""Wire container providers with provided packages and modules.""" """Wire container providers with provided packages and modules."""
modules = [*modules] if modules else [] modules = [*modules] if modules else []
@ -450,9 +454,23 @@ def wire( # noqa: C901
continue continue
if _is_marker(member): if _is_marker(member):
_patch_attribute(module, member_name, member, providers_map) _patch_attribute(
module,
member_name,
member,
providers_map,
warn_unresolved=warn_unresolved,
warn_unresolved_stacklevel=1,
)
elif inspect.isfunction(member): elif inspect.isfunction(member):
_patch_fn(module, member_name, member, providers_map) _patch_fn(
module,
member_name,
member,
providers_map,
warn_unresolved=warn_unresolved,
warn_unresolved_stacklevel=1,
)
elif inspect.isclass(member): elif inspect.isclass(member):
cls = member cls = member
try: try:
@ -464,15 +482,30 @@ def wire( # noqa: C901
for cls_member_name, cls_member in cls_members: for cls_member_name, cls_member in cls_members:
if _is_marker(cls_member): if _is_marker(cls_member):
_patch_attribute( _patch_attribute(
cls, cls_member_name, cls_member, providers_map cls,
cls_member_name,
cls_member,
providers_map,
warn_unresolved=warn_unresolved,
warn_unresolved_stacklevel=1,
) )
elif _is_method(cls_member): elif _is_method(cls_member):
_patch_method( _patch_method(
cls, cls_member_name, cls_member, providers_map cls,
cls_member_name,
cls_member,
providers_map,
warn_unresolved=warn_unresolved,
warn_unresolved_stacklevel=1,
) )
for patched in _patched_registry.get_callables_from_module(module): for patched in _patched_registry.get_callables_from_module(module):
_bind_injections(patched, providers_map) _bind_injections(
patched,
providers_map,
warn_unresolved=warn_unresolved,
warn_unresolved_stacklevel=1,
)
if not keep_cache: if not keep_cache:
clear_cache() clear_cache()
@ -525,6 +558,8 @@ def _patch_fn(
name: str, name: str,
fn: Callable[..., Any], fn: Callable[..., Any],
providers_map: ProvidersMap, providers_map: ProvidersMap,
warn_unresolved: bool = False,
warn_unresolved_stacklevel: int = 0,
) -> None: ) -> None:
if not _is_patched(fn): if not _is_patched(fn):
reference_injections, reference_closing = _fetch_reference_injections(fn) reference_injections, reference_closing = _fetch_reference_injections(fn)
@ -532,7 +567,12 @@ def _patch_fn(
return return
fn = _get_patched(fn, reference_injections, reference_closing) fn = _get_patched(fn, reference_injections, reference_closing)
_bind_injections(fn, providers_map) _bind_injections(
fn,
providers_map,
warn_unresolved=warn_unresolved,
warn_unresolved_stacklevel=warn_unresolved_stacklevel + 1,
)
setattr(module, name, fn) setattr(module, name, fn)
@ -542,6 +582,8 @@ def _patch_method(
name: str, name: str,
method: Callable[..., Any], method: Callable[..., Any],
providers_map: ProvidersMap, providers_map: ProvidersMap,
warn_unresolved: bool = False,
warn_unresolved_stacklevel: int = 0,
) -> None: ) -> None:
if ( if (
hasattr(cls, "__dict__") hasattr(cls, "__dict__")
@ -559,7 +601,12 @@ def _patch_method(
return return
fn = _get_patched(fn, reference_injections, reference_closing) fn = _get_patched(fn, reference_injections, reference_closing)
_bind_injections(fn, providers_map) _bind_injections(
fn,
providers_map,
warn_unresolved=warn_unresolved,
warn_unresolved_stacklevel=warn_unresolved_stacklevel + 1,
)
if fn is method: if fn is method:
# Hotfix, see: https://github.com/ets-labs/python-dependency-injector/issues/884 # Hotfix, see: https://github.com/ets-labs/python-dependency-injector/issues/884
@ -595,9 +642,17 @@ def _patch_attribute(
name: str, name: str,
marker: "_Marker", marker: "_Marker",
providers_map: ProvidersMap, providers_map: ProvidersMap,
warn_unresolved: bool = False,
warn_unresolved_stacklevel: int = 0,
) -> None: ) -> None:
provider = providers_map.resolve_provider(marker.provider, marker.modifier) provider = providers_map.resolve_provider(marker.provider, marker.modifier)
if provider is None: if provider is None:
if warn_unresolved:
warn(
f"Unresolved marker {name} in {member!r}",
UnresolvedMarkerWarning,
stacklevel=warn_unresolved_stacklevel + 2,
)
return return
_patched_registry.register_attribute(PatchedAttribute(member, name, marker)) _patched_registry.register_attribute(PatchedAttribute(member, name, marker))
@ -674,7 +729,12 @@ def _fetch_reference_injections( # noqa: C901
return injections, closing return injections, closing
def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> None: def _bind_injections(
fn: Callable[..., Any],
providers_map: ProvidersMap,
warn_unresolved: bool = False,
warn_unresolved_stacklevel: int = 0,
) -> None:
patched_callable = _patched_registry.get_callable(fn) patched_callable = _patched_registry.get_callable(fn)
if patched_callable is None: if patched_callable is None:
return return
@ -683,6 +743,12 @@ def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Non
provider = providers_map.resolve_provider(marker.provider, marker.modifier) provider = providers_map.resolve_provider(marker.provider, marker.modifier)
if provider is None: if provider is None:
if warn_unresolved:
warn(
f"Unresolved marker {injection} in {fn.__qualname__}",
UnresolvedMarkerWarning,
stacklevel=warn_unresolved_stacklevel + 2,
)
continue continue
if isinstance(marker, Provide): if isinstance(marker, Provide):

View File

@ -0,0 +1,15 @@
from dependency_injector.wiring import Provide, inject
missing_obj: object = Provide["missing"]
class TestMissingClass:
obj: object = Provide["missing"]
def method(self, obj: object = Provide["missing"]) -> object:
return obj
@inject
def test_missing_function(obj: object = Provide["missing"]):
return obj

View File

@ -3,13 +3,19 @@
import re import re
from decimal import Decimal from decimal import Decimal
from pytest import fixture, mark, raises from pytest import fixture, mark, raises, warns
from samples.wiringstringids import module, package, resourceclosing from samples.wiringstringids import module, package, resourceclosing
from samples.wiringstringids.container import Container, SubContainer from samples.wiringstringids.container import Container, SubContainer
from samples.wiringstringids.service import Service from samples.wiringstringids.service import Service
from dependency_injector import errors from dependency_injector import errors
from dependency_injector.wiring import Closing, Provide, Provider, wire from dependency_injector.wiring import (
Closing,
Provide,
Provider,
UnresolvedMarkerWarning,
wire,
)
@fixture(autouse=True) @fixture(autouse=True)
@ -73,6 +79,16 @@ def test_module_attribute_wiring_with_invalid_marker(container: Container):
container.wire(modules=[module_invalid_attr_injection]) container.wire(modules=[module_invalid_attr_injection])
def test_warn_unresolved_marker(container: Container):
from samples.wiringstringids import missing
with warns(
UnresolvedMarkerWarning,
match=r"^Unresolved marker .+ in .+$",
):
container.wire(modules=[missing], warn_unresolved=True)
def test_class_wiring(): def test_class_wiring():
test_class_object = module.TestClass() test_class_object = module.TestClass()
assert isinstance(test_class_object.service, Service) assert isinstance(test_class_object.service, Service)