diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index c5a591da..94c97ccb 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -20,6 +20,7 @@ from typing import ( TypeVar, Type, Union, + Set, cast, ) @@ -82,22 +83,53 @@ F = TypeVar('F', bound=Callable[..., Any]) Container = Any -class Registry: +class PatchedRegistry: def __init__(self): - self._storage = set() + self._callables: Set[Callable[..., Any]] = set() + self._attributes: Set[PatchedAttribute] = set() - def add(self, patched: Callable[..., Any]) -> None: - self._storage.add(patched) + def add_callable(self, patched: Callable[..., Any]) -> None: + self._callables.add(patched) - def get_from_module(self, module: ModuleType) -> Iterator[Callable[..., Any]]: - for patched in self._storage: + def get_callables_from_module(self, module: ModuleType) -> Iterator[Callable[..., Any]]: + for patched in self._callables: if patched.__module__ != module.__name__: continue yield patched + def add_attribute(self, patched: 'PatchedAttribute'): + self._attributes.add(patched) -_patched_registry = Registry() + def get_attributes_from_module(self, module: ModuleType) -> Iterator['PatchedAttribute']: + for attribute in self._attributes: + if not attribute.is_in_module(module): + continue + yield attribute + + def clear_module_attributes(self, module: ModuleType): + for attribute in self._attributes.copy(): + if not attribute.is_in_module(module): + continue + self._attributes.remove(attribute) + + +class PatchedAttribute: + + def __init__(self, member: Any, name: str, marker: '_Marker'): + self.member = member + self.name = name + self.marker = marker + + @property + def module_name(self) -> str: + if isinstance(self.member, ModuleType): + return self.member.__name__ + else: + return self.member.__module__ + + def is_in_module(self, module: ModuleType) -> bool: + return self.module_name == module.__name__ class ProvidersMap: @@ -278,9 +310,6 @@ class InspectFilter: and issubclass(instance, starlette.requests.Request) -inspect_filter = InspectFilter() - - def wire( # noqa: C901 container: Container, *, @@ -301,16 +330,23 @@ def wire( # noqa: C901 providers_map = ProvidersMap(container) for module in modules: - for name, member in inspect.getmembers(module): - if inspect_filter.is_excluded(member): + for member_name, member in inspect.getmembers(module): + if _inspect_filter.is_excluded(member): continue - if inspect.isfunction(member): - _patch_fn(module, name, member, providers_map) - elif inspect.isclass(member): - for method_name, method in inspect.getmembers(member, _is_method): - _patch_method(member, method_name, method, providers_map) - for patched in _patched_registry.get_from_module(module): + if _is_marker(member): + _patch_attribute(module, member_name, member, providers_map) + elif inspect.isfunction(member): + _patch_fn(module, member_name, member, providers_map) + elif inspect.isclass(member): + cls = member + for cls_member_name, cls_member in inspect.getmembers(cls): + if _is_marker(cls_member): + _patch_attribute(cls, cls_member_name, cls_member, providers_map) + elif _is_method(cls_member): + _patch_method(cls, cls_member_name, cls_member, providers_map) + + for patched in _patched_registry.get_callables_from_module(module): _bind_injections(patched, providers_map) @@ -335,15 +371,19 @@ def unwire( for method_name, method in inspect.getmembers(member, inspect.isfunction): _unpatch(member, method_name, method) - for patched in _patched_registry.get_from_module(module): + for patched in _patched_registry.get_callables_from_module(module): _unbind_injections(patched) + for patched_attribute in _patched_registry.get_attributes_from_module(module): + _unpatch_attribute(patched_attribute) + _patched_registry.clear_module_attributes(module) + def inject(fn: F) -> F: """Decorate callable with injecting decorator.""" reference_injections, reference_closing = _fetch_reference_injections(fn) patched = _get_patched(fn, reference_injections, reference_closing) - _patched_registry.add(patched) + _patched_registry.add_callable(patched) return cast(F, patched) @@ -358,7 +398,7 @@ def _patch_fn( if not reference_injections: return fn = _get_patched(fn, reference_injections, reference_closing) - _patched_registry.add(fn) + _patched_registry.add_callable(fn) _bind_injections(fn, providers_map) @@ -384,7 +424,7 @@ def _patch_method( if not reference_injections: return fn = _get_patched(fn, reference_injections, reference_closing) - _patched_registry.add(fn) + _patched_registry.add_callable(fn) _bind_injections(fn, providers_map) @@ -411,6 +451,31 @@ def _unpatch( _unbind_injections(fn) +def _patch_attribute( + member: Any, + name: str, + marker: '_Marker', + providers_map: ProvidersMap, +) -> None: + provider = providers_map.resolve_provider(marker.provider, marker.modifier) + if provider is None: + return + + _patched_registry.add_attribute(PatchedAttribute(member, name, marker)) + + if isinstance(marker, Provide): + instance = provider() + setattr(member, name, instance) + elif isinstance(marker, Provider): + setattr(member, name, provider) + else: + raise Exception(f'Unknown type of marker {marker}') + + +def _unpatch_attribute(patched: PatchedAttribute) -> None: + setattr(patched.member, patched.name, patched.marker) + + def _fetch_reference_injections( fn: Callable[..., Any], ) -> Tuple[Dict[str, Any], Dict[str, Any]]: @@ -484,6 +549,10 @@ def _is_method(member): return inspect.ismethod(member) or inspect.isfunction(member) +def _is_marker(member): + return isinstance(member, Provide) or isinstance(member, Provider) + + def _get_patched(fn, reference_injections, reference_closing): if inspect.iscoroutinefunction(fn): patched = _get_async_patched(fn) @@ -825,9 +894,6 @@ class AutoLoader: importlib.invalidate_caches() -_loader = AutoLoader() - - def register_loader_containers(*containers: Container) -> None: """Register containers in auto-wiring module loader.""" _loader.register_containers(*containers) @@ -851,3 +917,8 @@ def uninstall_loader() -> None: def is_loader_installed() -> bool: """Check if auto-wiring module loader hook is installed.""" return _loader.installed + + +_patched_registry = PatchedRegistry() +_inspect_filter = InspectFilter() +_loader = AutoLoader()