diff --git a/docs/wiring.rst b/docs/wiring.rst index bb6ba156..3d5778c4 100644 --- a/docs/wiring.rst +++ b/docs/wiring.rst @@ -251,6 +251,32 @@ To inject a container use special identifier ````: def foo(container: Container = Provide[""]) -> None: ... +Caveats +~~~~~~~ + +While using string identifiers you may not notice a typo in the identifier until the code is executed. +In order to aid with catching such errors early, you may pass `warn_unresolved=True` to the ``wire`` method and/or :class:`WiringConfiguration`: + +.. code-block:: python + :emphasize-lines: 4 + + class Container(containers.DeclarativeContainer): + wiring_config = containers.WiringConfiguration( + modules=["yourapp.module"], + warn_unresolved=True, + ) + +Or: + +.. code-block:: python + :emphasize-lines: 4 + + container = Container() + container.wire( + modules=["yourapp.module"], + warn_unresolved=True, + ) + Making injections into modules and class attributes --------------------------------------------------- diff --git a/pyproject.toml b/pyproject.toml index ef0b946d..5fef5a15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ classifiers = [ dynamic = ["version"] dependencies = [ # typing.Annotated since v3.9 - # typing.Self since v3.11 + # typing.Self and typing.assert_never since v3.11 "typing-extensions; python_version<'3.11'", ] diff --git a/src/dependency_injector/containers.pyi b/src/dependency_injector/containers.pyi index f21a8791..95eef00a 100644 --- a/src/dependency_injector/containers.pyi +++ b/src/dependency_injector/containers.pyi @@ -72,6 +72,7 @@ class Container: modules: Optional[Iterable[Any]] = None, packages: Optional[Iterable[Any]] = None, from_package: Optional[str] = None, + warn_unresolved: bool = False, ) -> None: ... def unwire(self) -> None: ... def init_resources(self, resource_type: Type[Resource[Any]] = Resource) -> Optional[Awaitable[None]]: ... diff --git a/src/dependency_injector/containers.pyx b/src/dependency_injector/containers.pyx index 99762da2..f9e8ea91 100644 --- a/src/dependency_injector/containers.pyx +++ b/src/dependency_injector/containers.pyx @@ -20,15 +20,31 @@ from .wiring import wire, unwire class WiringConfiguration: """Container wiring configuration.""" - def __init__(self, modules=None, packages=None, from_package=None, auto_wire=True, keep_cache=False): + def __init__( + self, + modules=None, + packages=None, + from_package=None, + auto_wire=True, + keep_cache=False, + warn_unresolved=False, + ): self.modules = [*modules] if modules else [] self.packages = [*packages] if packages else [] self.from_package = from_package self.auto_wire = auto_wire self.keep_cache = keep_cache + self.warn_unresolved = warn_unresolved def __deepcopy__(self, memo=None): - return self.__class__(self.modules, self.packages, self.from_package, self.auto_wire, self.keep_cache) + return self.__class__( + self.modules, + self.packages, + self.from_package, + self.auto_wire, + self.keep_cache, + self.warn_unresolved, + ) class Container: @@ -259,7 +275,14 @@ class DynamicContainer(Container): """Check if auto wiring is needed.""" return self.wiring_config.auto_wire is True - def wire(self, modules=None, packages=None, from_package=None, keep_cache=None): + def wire( + self, + modules=None, + packages=None, + from_package=None, + keep_cache=None, + warn_unresolved=False, + ): """Wire container providers with provided packages and modules. :rtype: None @@ -298,6 +321,7 @@ class DynamicContainer(Container): modules=modules, packages=packages, keep_cache=keep_cache, + warn_unresolved=warn_unresolved, ) if modules: diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index 4b8dbf4d..211fdcde 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -25,15 +25,14 @@ from typing import ( Type, TypeVar, Union, - assert_never, cast, ) from warnings import warn try: - from typing import Self + from typing import Self, assert_never except ImportError: - from typing_extensions import Self + from typing_extensions import Self, assert_never try: from functools import cache @@ -140,6 +139,10 @@ class DIWiringWarning(RuntimeWarning): """Base class for all warnings raised by the wiring module.""" +class UnresolvedMarkerWarning(DIWiringWarning): + """Warning raised when a marker with string identifier cannot be resolved against container.""" + + class PatchedRegistry: def __init__(self) -> None: @@ -434,6 +437,7 @@ def wire( # noqa: C901 modules: Optional[Iterable[ModuleType]] = None, packages: Optional[Iterable[ModuleType]] = None, keep_cache: bool = False, + warn_unresolved: bool = False, ) -> None: """Wire container providers with provided packages and modules.""" modules = [*modules] if modules else [] @@ -450,9 +454,23 @@ def wire( # noqa: C901 continue if _is_marker(member): - _patch_attribute(module, member_name, member, providers_map) + _patch_attribute( + module, + member_name, + member, + providers_map, + warn_unresolved=warn_unresolved, + warn_unresolved_stacklevel=1, + ) elif inspect.isfunction(member): - _patch_fn(module, member_name, member, providers_map) + _patch_fn( + module, + member_name, + member, + providers_map, + warn_unresolved=warn_unresolved, + warn_unresolved_stacklevel=1, + ) elif inspect.isclass(member): cls = member try: @@ -464,15 +482,30 @@ def wire( # noqa: C901 for cls_member_name, cls_member in cls_members: if _is_marker(cls_member): _patch_attribute( - cls, cls_member_name, cls_member, providers_map + cls, + cls_member_name, + cls_member, + providers_map, + warn_unresolved=warn_unresolved, + warn_unresolved_stacklevel=1, ) elif _is_method(cls_member): _patch_method( - cls, cls_member_name, cls_member, providers_map + cls, + cls_member_name, + cls_member, + providers_map, + warn_unresolved=warn_unresolved, + warn_unresolved_stacklevel=1, ) for patched in _patched_registry.get_callables_from_module(module): - _bind_injections(patched, providers_map) + _bind_injections( + patched, + providers_map, + warn_unresolved=warn_unresolved, + warn_unresolved_stacklevel=1, + ) if not keep_cache: clear_cache() @@ -525,6 +558,8 @@ def _patch_fn( name: str, fn: Callable[..., Any], providers_map: ProvidersMap, + warn_unresolved: bool = False, + warn_unresolved_stacklevel: int = 0, ) -> None: if not _is_patched(fn): reference_injections, reference_closing = _fetch_reference_injections(fn) @@ -532,7 +567,12 @@ def _patch_fn( return fn = _get_patched(fn, reference_injections, reference_closing) - _bind_injections(fn, providers_map) + _bind_injections( + fn, + providers_map, + warn_unresolved=warn_unresolved, + warn_unresolved_stacklevel=warn_unresolved_stacklevel + 1, + ) setattr(module, name, fn) @@ -542,6 +582,8 @@ def _patch_method( name: str, method: Callable[..., Any], providers_map: ProvidersMap, + warn_unresolved: bool = False, + warn_unresolved_stacklevel: int = 0, ) -> None: if ( hasattr(cls, "__dict__") @@ -559,7 +601,12 @@ def _patch_method( return fn = _get_patched(fn, reference_injections, reference_closing) - _bind_injections(fn, providers_map) + _bind_injections( + fn, + providers_map, + warn_unresolved=warn_unresolved, + warn_unresolved_stacklevel=warn_unresolved_stacklevel + 1, + ) if fn is method: # Hotfix, see: https://github.com/ets-labs/python-dependency-injector/issues/884 @@ -595,9 +642,17 @@ def _patch_attribute( name: str, marker: "_Marker", providers_map: ProvidersMap, + warn_unresolved: bool = False, + warn_unresolved_stacklevel: int = 0, ) -> None: provider = providers_map.resolve_provider(marker.provider, marker.modifier) if provider is None: + if warn_unresolved: + warn( + f"Unresolved marker {name} in {member!r}", + UnresolvedMarkerWarning, + stacklevel=warn_unresolved_stacklevel + 2, + ) return _patched_registry.register_attribute(PatchedAttribute(member, name, marker)) @@ -674,7 +729,12 @@ def _fetch_reference_injections( # noqa: C901 return injections, closing -def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> None: +def _bind_injections( + fn: Callable[..., Any], + providers_map: ProvidersMap, + warn_unresolved: bool = False, + warn_unresolved_stacklevel: int = 0, +) -> None: patched_callable = _patched_registry.get_callable(fn) if patched_callable is None: return @@ -683,6 +743,12 @@ def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Non provider = providers_map.resolve_provider(marker.provider, marker.modifier) if provider is None: + if warn_unresolved: + warn( + f"Unresolved marker {injection} in {fn.__qualname__}", + UnresolvedMarkerWarning, + stacklevel=warn_unresolved_stacklevel + 2, + ) continue if isinstance(marker, Provide): diff --git a/tests/unit/samples/wiringstringids/missing.py b/tests/unit/samples/wiringstringids/missing.py new file mode 100644 index 00000000..b8bafae5 --- /dev/null +++ b/tests/unit/samples/wiringstringids/missing.py @@ -0,0 +1,15 @@ +from dependency_injector.wiring import Provide, inject + +missing_obj: object = Provide["missing"] + + +class TestMissingClass: + obj: object = Provide["missing"] + + def method(self, obj: object = Provide["missing"]) -> object: + return obj + + +@inject +def test_missing_function(obj: object = Provide["missing"]): + return obj diff --git a/tests/unit/wiring/string_ids/test_main_py36.py b/tests/unit/wiring/string_ids/test_main_py36.py index 904c4e3e..3a4e344e 100644 --- a/tests/unit/wiring/string_ids/test_main_py36.py +++ b/tests/unit/wiring/string_ids/test_main_py36.py @@ -3,13 +3,19 @@ import re from decimal import Decimal -from pytest import fixture, mark, raises +from pytest import fixture, mark, raises, warns from samples.wiringstringids import module, package, resourceclosing from samples.wiringstringids.container import Container, SubContainer from samples.wiringstringids.service import Service from dependency_injector import errors -from dependency_injector.wiring import Closing, Provide, Provider, wire +from dependency_injector.wiring import ( + Closing, + Provide, + Provider, + UnresolvedMarkerWarning, + wire, +) @fixture(autouse=True) @@ -73,6 +79,16 @@ def test_module_attribute_wiring_with_invalid_marker(container: Container): container.wire(modules=[module_invalid_attr_injection]) +def test_warn_unresolved_marker(container: Container): + from samples.wiringstringids import missing + + with warns( + UnresolvedMarkerWarning, + match=r"^Unresolved marker .+ in .+$", + ): + container.wire(modules=[missing], warn_unresolved=True) + + def test_class_wiring(): test_class_object = module.TestClass() assert isinstance(test_class_object.service, Service)