diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index 1c671bb1..5afafbcb 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -208,11 +208,15 @@ def _patch_fn( fn: Callable[..., Any], providers_map: ProvidersMap, ) -> None: - injections, closing = _resolve_injections(fn, providers_map) - if not injections: - return - patched = _patch_with_injections(fn, injections, closing) - setattr(module, name, _wrap_patched(patched, fn, injections, closing)) + if not _is_patched(fn): + reference_injections, reference_closing = _fetch_reference_injections(fn) + if not reference_injections: + return + fn = _get_patched(fn, reference_injections, reference_closing) + + _bind_injections(fn, providers_map) + + setattr(module, name, fn) def _patch_method( @@ -221,28 +225,26 @@ def _patch_method( method: Callable[..., Any], providers_map: ProvidersMap, ) -> None: - injections, closing = _resolve_injections(method, providers_map) - if not injections: - return - if hasattr(cls, '__dict__') \ and name in cls.__dict__ \ and isinstance(cls.__dict__[name], (classmethod, staticmethod)): method = cls.__dict__[name] - patched = _patch_with_injections(method.__func__, injections, closing) - patched = type(method)(patched) + fn = method.__func__ else: - patched = _patch_with_injections(method, injections, closing) + fn = method - setattr(cls, name, _wrap_patched(patched, method, injections, closing)) + if not _is_patched(fn): + reference_injections, reference_closing = _fetch_reference_injections(fn) + if not reference_injections: + return + fn = _get_patched(fn, reference_injections, reference_closing) + _bind_injections(fn, providers_map) -def _wrap_patched(patched: Callable[..., Any], original, injections, closing): - patched.__wired__ = True - patched.__original__ = original - patched.__injections__ = injections - patched.__closing__ = closing - return patched + if isinstance(method, (classmethod, staticmethod)): + fn = type(method)(fn) + + setattr(cls, name, fn) def _unpatch( @@ -250,14 +252,20 @@ def _unpatch( name: str, fn: Callable[..., Any], ) -> None: + if hasattr(module, '__dict__') \ + and name in module.__dict__ \ + and isinstance(module.__dict__[name], (classmethod, staticmethod)): + method = module.__dict__[name] + fn = method.__func__ + if not _is_patched(fn): return - setattr(module, name, _get_original_from_patched(fn)) + + _unbind_injections(fn) -def _resolve_injections( +def _fetch_reference_injections( fn: Callable[..., Any], - providers_map: ProvidersMap, ) -> Tuple[Dict[str, Any], Dict[str, Any]]: signature = inspect.signature(fn) @@ -268,24 +276,33 @@ def _resolve_injections( continue marker = parameter.default - closing_modifier = False if isinstance(marker, Closing): - closing_modifier = True marker = marker.provider + closing[parameter_name] = marker + injections[parameter_name] = marker + return injections, closing + + +def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> None: + for injection, marker in fn.__reference_injections__.items(): provider = providers_map.resolve_provider(marker.provider) + if provider is None: continue - if closing_modifier: - closing[parameter_name] = provider - if isinstance(marker, Provide): - injections[parameter_name] = provider + fn.__injections__[injection] = provider elif isinstance(marker, Provider): - injections[parameter_name] = provider.provider + fn.__injections__[injection] = provider.provider - return injections, closing + if injection in fn.__reference_closing__: + fn.__closing__[injection] = provider + + +def _unbind_injections(fn: Callable[..., Any]) -> None: + fn.__injections__ = {} + fn.__closing__ = {} def _fetch_modules(package): @@ -303,26 +320,34 @@ def _is_method(member): return inspect.ismethod(member) or inspect.isfunction(member) -def _patch_with_injections(fn, injections, closing): +def _get_patched(fn, reference_injections, reference_closing): if inspect.iscoroutinefunction(fn): - _patched = _get_async_patched(fn, injections, closing) + patched = _get_async_patched(fn) else: - _patched = _get_patched(fn, injections, closing) - return _patched + patched = _get_sync_patched(fn) + + patched.__wired__ = True + patched.__original__ = fn + patched.__injections__ = {} + patched.__reference_injections__ = reference_injections + patched.__closing__ = {} + patched.__reference_closing__ = reference_closing + + return patched -def _get_patched(fn, injections, closing): +def _get_sync_patched(fn): @functools.wraps(fn) def _patched(*args, **kwargs): to_inject = kwargs.copy() - for injection, provider in injections.items(): + for injection, provider in _patched.__injections__.items(): if injection not in kwargs \ or _is_fastapi_default_arg_injection(injection, kwargs): to_inject[injection] = provider() result = fn(*args, **to_inject) - for injection, provider in closing.items(): + for injection, provider in _patched.__closing__.items(): if injection in kwargs \ and not _is_fastapi_default_arg_injection(injection, kwargs): continue @@ -334,18 +359,18 @@ def _get_patched(fn, injections, closing): return _patched -def _get_async_patched(fn, injections, closing): +def _get_async_patched(fn): @functools.wraps(fn) async def _patched(*args, **kwargs): to_inject = kwargs.copy() - for injection, provider in injections.items(): + for injection, provider in _patched.__injections__.items(): if injection not in kwargs \ or _is_fastapi_default_arg_injection(injection, kwargs): to_inject[injection] = provider() result = await fn(*args, **to_inject) - for injection, provider in closing.items(): + for injection, provider in _patched.__closing__.items(): if injection in kwargs \ and not _is_fastapi_default_arg_injection(injection, kwargs): continue @@ -366,10 +391,6 @@ def _is_patched(fn): return getattr(fn, '__wired__', False) is True -def _get_original_from_patched(fn): - return getattr(fn, '__original__') - - def _is_declarative_container_instance(instance: Any) -> bool: return (not isinstance(instance, type) and getattr(instance, '__IS_CONTAINER__', False) is True