Add warnings for unresolved markers

This commit is contained in:
ZipFile 2025-09-17 22:08:02 +00:00
parent 759d89e9bd
commit 1222827a5e
7 changed files with 165 additions and 17 deletions

View File

@ -251,6 +251,32 @@ To inject a container use special identifier ``<container>``:
def foo(container: Container = Provide["<container>"]) -> None:
...
Caveats
~~~~~~~
While using string identifiers you may not notice a typo in the identifier until the code is executed.
In order to aid with catching such errors early, you may pass `warn_unresolved=True` to the ``wire`` method and/or :class:`WiringConfiguration`:
.. code-block:: python
:emphasize-lines: 4
class Container(containers.DeclarativeContainer):
wiring_config = containers.WiringConfiguration(
modules=["yourapp.module"],
warn_unresolved=True,
)
Or:
.. code-block:: python
:emphasize-lines: 4
container = Container()
container.wire(
modules=["yourapp.module"],
warn_unresolved=True,
)
Making injections into modules and class attributes
---------------------------------------------------

View File

@ -54,7 +54,7 @@ classifiers = [
dynamic = ["version"]
dependencies = [
# typing.Annotated since v3.9
# typing.Self since v3.11
# typing.Self and typing.assert_never since v3.11
"typing-extensions; python_version<'3.11'",
]

View File

@ -72,6 +72,7 @@ class Container:
modules: Optional[Iterable[Any]] = None,
packages: Optional[Iterable[Any]] = None,
from_package: Optional[str] = None,
warn_unresolved: bool = False,
) -> None: ...
def unwire(self) -> None: ...
def init_resources(self, resource_type: Type[Resource[Any]] = Resource) -> Optional[Awaitable[None]]: ...

View File

@ -20,15 +20,31 @@ from .wiring import wire, unwire
class WiringConfiguration:
"""Container wiring configuration."""
def __init__(self, modules=None, packages=None, from_package=None, auto_wire=True, keep_cache=False):
def __init__(
self,
modules=None,
packages=None,
from_package=None,
auto_wire=True,
keep_cache=False,
warn_unresolved=False,
):
self.modules = [*modules] if modules else []
self.packages = [*packages] if packages else []
self.from_package = from_package
self.auto_wire = auto_wire
self.keep_cache = keep_cache
self.warn_unresolved = warn_unresolved
def __deepcopy__(self, memo=None):
return self.__class__(self.modules, self.packages, self.from_package, self.auto_wire, self.keep_cache)
return self.__class__(
self.modules,
self.packages,
self.from_package,
self.auto_wire,
self.keep_cache,
self.warn_unresolved,
)
class Container:
@ -259,7 +275,14 @@ class DynamicContainer(Container):
"""Check if auto wiring is needed."""
return self.wiring_config.auto_wire is True
def wire(self, modules=None, packages=None, from_package=None, keep_cache=None):
def wire(
self,
modules=None,
packages=None,
from_package=None,
keep_cache=None,
warn_unresolved=False,
):
"""Wire container providers with provided packages and modules.
:rtype: None
@ -298,6 +321,7 @@ class DynamicContainer(Container):
modules=modules,
packages=packages,
keep_cache=keep_cache,
warn_unresolved=warn_unresolved,
)
if modules:

View File

@ -25,15 +25,14 @@ from typing import (
Type,
TypeVar,
Union,
assert_never,
cast,
)
from warnings import warn
try:
from typing import Self
from typing import Self, assert_never
except ImportError:
from typing_extensions import Self
from typing_extensions import Self, assert_never
try:
from functools import cache
@ -140,6 +139,10 @@ class DIWiringWarning(RuntimeWarning):
"""Base class for all warnings raised by the wiring module."""
class UnresolvedMarkerWarning(DIWiringWarning):
"""Warning raised when a marker with string identifier cannot be resolved against container."""
class PatchedRegistry:
def __init__(self) -> None:
@ -434,6 +437,7 @@ def wire( # noqa: C901
modules: Optional[Iterable[ModuleType]] = None,
packages: Optional[Iterable[ModuleType]] = None,
keep_cache: bool = False,
warn_unresolved: bool = False,
) -> None:
"""Wire container providers with provided packages and modules."""
modules = [*modules] if modules else []
@ -450,9 +454,23 @@ def wire( # noqa: C901
continue
if _is_marker(member):
_patch_attribute(module, member_name, member, providers_map)
_patch_attribute(
module,
member_name,
member,
providers_map,
warn_unresolved=warn_unresolved,
warn_unresolved_stacklevel=1,
)
elif inspect.isfunction(member):
_patch_fn(module, member_name, member, providers_map)
_patch_fn(
module,
member_name,
member,
providers_map,
warn_unresolved=warn_unresolved,
warn_unresolved_stacklevel=1,
)
elif inspect.isclass(member):
cls = member
try:
@ -464,15 +482,30 @@ def wire( # noqa: C901
for cls_member_name, cls_member in cls_members:
if _is_marker(cls_member):
_patch_attribute(
cls, cls_member_name, cls_member, providers_map
cls,
cls_member_name,
cls_member,
providers_map,
warn_unresolved=warn_unresolved,
warn_unresolved_stacklevel=1,
)
elif _is_method(cls_member):
_patch_method(
cls, cls_member_name, cls_member, providers_map
cls,
cls_member_name,
cls_member,
providers_map,
warn_unresolved=warn_unresolved,
warn_unresolved_stacklevel=1,
)
for patched in _patched_registry.get_callables_from_module(module):
_bind_injections(patched, providers_map)
_bind_injections(
patched,
providers_map,
warn_unresolved=warn_unresolved,
warn_unresolved_stacklevel=1,
)
if not keep_cache:
clear_cache()
@ -525,6 +558,8 @@ def _patch_fn(
name: str,
fn: Callable[..., Any],
providers_map: ProvidersMap,
warn_unresolved: bool = False,
warn_unresolved_stacklevel: int = 0,
) -> None:
if not _is_patched(fn):
reference_injections, reference_closing = _fetch_reference_injections(fn)
@ -532,7 +567,12 @@ def _patch_fn(
return
fn = _get_patched(fn, reference_injections, reference_closing)
_bind_injections(fn, providers_map)
_bind_injections(
fn,
providers_map,
warn_unresolved=warn_unresolved,
warn_unresolved_stacklevel=warn_unresolved_stacklevel + 1,
)
setattr(module, name, fn)
@ -542,6 +582,8 @@ def _patch_method(
name: str,
method: Callable[..., Any],
providers_map: ProvidersMap,
warn_unresolved: bool = False,
warn_unresolved_stacklevel: int = 0,
) -> None:
if (
hasattr(cls, "__dict__")
@ -559,7 +601,12 @@ def _patch_method(
return
fn = _get_patched(fn, reference_injections, reference_closing)
_bind_injections(fn, providers_map)
_bind_injections(
fn,
providers_map,
warn_unresolved=warn_unresolved,
warn_unresolved_stacklevel=warn_unresolved_stacklevel + 1,
)
if fn is method:
# Hotfix, see: https://github.com/ets-labs/python-dependency-injector/issues/884
@ -595,9 +642,17 @@ def _patch_attribute(
name: str,
marker: "_Marker",
providers_map: ProvidersMap,
warn_unresolved: bool = False,
warn_unresolved_stacklevel: int = 0,
) -> None:
provider = providers_map.resolve_provider(marker.provider, marker.modifier)
if provider is None:
if warn_unresolved:
warn(
f"Unresolved marker {name} in {member!r}",
UnresolvedMarkerWarning,
stacklevel=warn_unresolved_stacklevel + 2,
)
return
_patched_registry.register_attribute(PatchedAttribute(member, name, marker))
@ -674,7 +729,12 @@ def _fetch_reference_injections( # noqa: C901
return injections, closing
def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> None:
def _bind_injections(
fn: Callable[..., Any],
providers_map: ProvidersMap,
warn_unresolved: bool = False,
warn_unresolved_stacklevel: int = 0,
) -> None:
patched_callable = _patched_registry.get_callable(fn)
if patched_callable is None:
return
@ -683,6 +743,12 @@ def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Non
provider = providers_map.resolve_provider(marker.provider, marker.modifier)
if provider is None:
if warn_unresolved:
warn(
f"Unresolved marker {injection} in {fn.__qualname__}",
UnresolvedMarkerWarning,
stacklevel=warn_unresolved_stacklevel + 2,
)
continue
if isinstance(marker, Provide):

View File

@ -0,0 +1,15 @@
from dependency_injector.wiring import Provide, inject
missing_obj: object = Provide["missing"]
class TestMissingClass:
obj: object = Provide["missing"]
def method(self, obj: object = Provide["missing"]) -> object:
return obj
@inject
def test_missing_function(obj: object = Provide["missing"]):
return obj

View File

@ -3,13 +3,19 @@
import re
from decimal import Decimal
from pytest import fixture, mark, raises
from pytest import fixture, mark, raises, warns
from samples.wiringstringids import module, package, resourceclosing
from samples.wiringstringids.container import Container, SubContainer
from samples.wiringstringids.service import Service
from dependency_injector import errors
from dependency_injector.wiring import Closing, Provide, Provider, wire
from dependency_injector.wiring import (
Closing,
Provide,
Provider,
UnresolvedMarkerWarning,
wire,
)
@fixture(autouse=True)
@ -73,6 +79,16 @@ def test_module_attribute_wiring_with_invalid_marker(container: Container):
container.wire(modules=[module_invalid_attr_injection])
def test_warn_unresolved_marker(container: Container):
from samples.wiringstringids import missing
with warns(
UnresolvedMarkerWarning,
match=r"^Unresolved marker .+ in .+$",
):
container.wire(modules=[missing], warn_unresolved=True)
def test_class_wiring():
test_class_object = module.TestClass()
assert isinstance(test_class_object.service, Service)