From 91c90cf9e9db0b029c5ccb6376d849d72e1e7809 Mon Sep 17 00:00:00 2001 From: Taein Min Date: Mon, 20 Jan 2025 23:45:48 +0900 Subject: [PATCH] Run black --- examples/wiring/example.py | 12 +- src/dependency_injector/wiring.py | 250 +++++++++++++----------- tests/unit/samples/wiringfastapi/web.py | 14 +- tests/unit/samples/wiringflask/web.py | 2 +- tests/unit/wiring/test_fastapi_py36.py | 13 +- tests/unit/wiring/test_flask_py36.py | 24 ++- 6 files changed, 175 insertions(+), 140 deletions(-) diff --git a/examples/wiring/example.py b/examples/wiring/example.py index 5361fa95..0e32b192 100644 --- a/examples/wiring/example.py +++ b/examples/wiring/example.py @@ -5,24 +5,24 @@ from dependency_injector.wiring import Provide, inject from typing import Annotated -class Service: - ... +class Service: ... class Container(containers.DeclarativeContainer): service = providers.Factory(Service) + # You can place marker on parameter default value @inject -def main(service: Service = Provide[Container.service]) -> None: - ... +def main(service: Service = Provide[Container.service]) -> None: ... # Also, you can place marker with typing.Annotated @inject -def main_with_annotated(service: Annotated[Service, Provide[Container.service]]) -> None: - ... +def main_with_annotated( + service: Annotated[Service, Provide[Container.service]] +) -> None: ... if __name__ == "__main__": diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index d88bac9f..67d56f29 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -27,8 +27,9 @@ from typing import ( if sys.version_info < (3, 7): from typing import GenericMeta else: - class GenericMeta(type): - ... + + class GenericMeta(type): ... + # Hotfix, see: https://github.com/ets-labs/python-dependency-injector/issues/362 if sys.version_info >= (3, 9): @@ -51,6 +52,7 @@ else: def get_origin(tp): return None + try: import fastapi.params except ImportError: @@ -113,7 +115,9 @@ class PatchedRegistry: def register_callable(self, patched: "PatchedCallable") -> None: self._callables[patched.patched] = patched - def get_callables_from_module(self, module: ModuleType) -> Iterator[Callable[..., Any]]: + def get_callables_from_module( + self, module: ModuleType + ) -> Iterator[Callable[..., Any]]: for patched_callable in self._callables.values(): if not patched_callable.is_in_module(module): continue @@ -128,7 +132,9 @@ class PatchedRegistry: def register_attribute(self, patched: "PatchedAttribute") -> None: self._attributes.add(patched) - def get_attributes_from_module(self, module: ModuleType) -> Iterator["PatchedAttribute"]: + def get_attributes_from_module( + self, module: ModuleType + ) -> Iterator["PatchedAttribute"]: for attribute in self._attributes: if not attribute.is_in_module(module): continue @@ -153,11 +159,11 @@ class PatchedCallable: ) def __init__( - self, - patched: Optional[Callable[..., Any]] = None, - original: Optional[Callable[..., Any]] = None, - reference_injections: Optional[Dict[Any, Any]] = None, - reference_closing: Optional[Dict[Any, Any]] = None, + self, + patched: Optional[Callable[..., Any]] = None, + original: Optional[Callable[..., Any]] = None, + reference_injections: Optional[Dict[Any, Any]] = None, + reference_closing: Optional[Dict[Any, Any]] = None, ) -> None: self.patched = patched self.original = original @@ -228,18 +234,21 @@ class ProvidersMap: ) def resolve_provider( - self, - provider: Union[providers.Provider, str], - modifier: Optional["Modifier"] = None, + self, + provider: Union[providers.Provider, str], + modifier: Optional["Modifier"] = None, ) -> Optional[providers.Provider]: if isinstance(provider, providers.Delegate): return self._resolve_delegate(provider) - elif isinstance(provider, ( - providers.ProvidedInstance, - providers.AttributeGetter, - providers.ItemGetter, - providers.MethodCaller, - )): + elif isinstance( + provider, + ( + providers.ProvidedInstance, + providers.AttributeGetter, + providers.ItemGetter, + providers.MethodCaller, + ), + ): return self._resolve_provided_instance(provider) elif isinstance(provider, providers.ConfigurationOption): return self._resolve_config_option(provider) @@ -251,9 +260,9 @@ class ProvidersMap: return self._resolve_provider(provider) def _resolve_string_id( - self, - id: str, - modifier: Optional["Modifier"] = None, + self, + id: str, + modifier: Optional["Modifier"] = None, ) -> Optional[providers.Provider]: if id == self.CONTAINER_STRING_ID: return self._container.__self__ @@ -270,16 +279,19 @@ class ProvidersMap: return provider def _resolve_provided_instance( - self, - original: providers.Provider, + self, + original: providers.Provider, ) -> Optional[providers.Provider]: modifiers = [] - while isinstance(original, ( + while isinstance( + original, + ( providers.ProvidedInstance, providers.AttributeGetter, providers.ItemGetter, providers.MethodCaller, - )): + ), + ): modifiers.insert(0, original) original = original.provides @@ -303,8 +315,8 @@ class ProvidersMap: return new def _resolve_delegate( - self, - original: providers.Delegate, + self, + original: providers.Delegate, ) -> Optional[providers.Provider]: provider = self._resolve_provider(original.provides) if provider: @@ -312,9 +324,9 @@ class ProvidersMap: return provider def _resolve_config_option( - self, - original: providers.ConfigurationOption, - as_: Any = None, + self, + original: providers.ConfigurationOption, + as_: Any = None, ) -> Optional[providers.Provider]: original_root = original.root new = self._resolve_provider(original_root) @@ -338,8 +350,8 @@ class ProvidersMap: return new def _resolve_provider( - self, - original: providers.Provider, + self, + original: providers.Provider, ) -> Optional[providers.Provider]: try: return self._map[original] @@ -348,9 +360,9 @@ class ProvidersMap: @classmethod def _create_providers_map( - cls, - current_container: Container, - original_container: Container, + cls, + current_container: Container, + original_container: Container, ) -> Dict[providers.Provider, providers.Provider]: current_providers = current_container.providers current_providers["__self__"] = current_container.__self__ @@ -363,8 +375,9 @@ class ProvidersMap: original_provider = original_providers[provider_name] providers_map[original_provider] = current_provider - if isinstance(current_provider, providers.Container) \ - and isinstance(original_provider, providers.Container): + if isinstance(current_provider, providers.Container) and isinstance( + original_provider, providers.Container + ): subcontainer_map = cls._create_providers_map( current_container=current_provider.container, original_container=original_provider.container, @@ -390,19 +403,21 @@ class InspectFilter: return werkzeug and isinstance(instance, werkzeug.local.LocalProxy) def _is_starlette_request_cls(self, instance: object) -> bool: - return starlette \ - and isinstance(instance, type) \ - and _safe_is_subclass(instance, starlette.requests.Request) + return ( + starlette + and isinstance(instance, type) + and _safe_is_subclass(instance, starlette.requests.Request) + ) def _is_builtin(self, instance: object) -> bool: return inspect.isbuiltin(instance) def wire( # noqa: C901 - container: Container, - *, - modules: Optional[Iterable[ModuleType]] = None, - packages: Optional[Iterable[ModuleType]] = None, + container: Container, + *, + modules: Optional[Iterable[ModuleType]] = None, + packages: Optional[Iterable[ModuleType]] = None, ) -> None: """Wire container providers with provided packages and modules.""" modules = [*modules] if modules else [] @@ -432,18 +447,22 @@ def wire( # noqa: C901 else: for cls_member_name, cls_member in cls_members: if _is_marker(cls_member): - _patch_attribute(cls, cls_member_name, cls_member, providers_map) + _patch_attribute( + cls, cls_member_name, cls_member, providers_map + ) elif _is_method(cls_member): - _patch_method(cls, cls_member_name, cls_member, providers_map) + _patch_method( + cls, cls_member_name, cls_member, providers_map + ) for patched in _patched_registry.get_callables_from_module(module): _bind_injections(patched, providers_map) def unwire( # noqa: C901 - *, - modules: Optional[Iterable[ModuleType]] = None, - packages: Optional[Iterable[ModuleType]] = None, + *, + modules: Optional[Iterable[ModuleType]] = None, + packages: Optional[Iterable[ModuleType]] = None, ) -> None: """Wire provided packages and modules with previous wired providers.""" modules = [*modules] if modules else [] @@ -457,7 +476,9 @@ def unwire( # noqa: C901 if inspect.isfunction(member): _unpatch(module, name, member) elif inspect.isclass(member): - for method_name, method in inspect.getmembers(member, inspect.isfunction): + for method_name, method in inspect.getmembers( + member, inspect.isfunction + ): _unpatch(member, method_name, method) for patched in _patched_registry.get_callables_from_module(module): @@ -476,10 +497,10 @@ def inject(fn: F) -> F: def _patch_fn( - module: ModuleType, - name: str, - fn: Callable[..., Any], - providers_map: ProvidersMap, + module: ModuleType, + name: str, + fn: Callable[..., Any], + providers_map: ProvidersMap, ) -> None: if not _is_patched(fn): reference_injections, reference_closing = _fetch_reference_injections(fn) @@ -493,14 +514,16 @@ def _patch_fn( def _patch_method( - cls: Type, - name: str, - method: Callable[..., Any], - providers_map: ProvidersMap, + cls: Type, + name: str, + method: Callable[..., Any], + providers_map: ProvidersMap, ) -> None: - if hasattr(cls, "__dict__") \ - and name in cls.__dict__ \ - and isinstance(cls.__dict__[name], (classmethod, staticmethod)): + if ( + hasattr(cls, "__dict__") + and name in cls.__dict__ + and isinstance(cls.__dict__[name], (classmethod, staticmethod)) + ): method = cls.__dict__[name] fn = method.__func__ else: @@ -521,13 +544,15 @@ def _patch_method( def _unpatch( - module: ModuleType, - name: str, - fn: Callable[..., Any], + module: ModuleType, + name: str, + fn: Callable[..., Any], ) -> None: - if hasattr(module, "__dict__") \ - and name in module.__dict__ \ - and isinstance(module.__dict__[name], (classmethod, staticmethod)): + if ( + hasattr(module, "__dict__") + and name in module.__dict__ + and isinstance(module.__dict__[name], (classmethod, staticmethod)) + ): method = module.__dict__[name] fn = method.__func__ @@ -538,10 +563,10 @@ def _unpatch( def _patch_attribute( - member: Any, - name: str, - marker: "_Marker", - providers_map: ProvidersMap, + member: Any, + name: str, + marker: "_Marker", + providers_map: ProvidersMap, ) -> None: provider = providers_map.resolve_provider(marker.provider, marker.modifier) if provider is None: @@ -581,15 +606,14 @@ def _extract_marker(parameter: inspect.Parameter) -> Optional["_Marker"]: def _fetch_reference_injections( # noqa: C901 - fn: Callable[..., Any], + fn: Callable[..., Any], ) -> Tuple[Dict[str, Any], Dict[str, Any]]: # Hotfix, see: # - https://github.com/ets-labs/python-dependency-injector/issues/362 # - https://github.com/ets-labs/python-dependency-injector/issues/398 - if GenericAlias and any(( - fn is GenericAlias, - getattr(fn, "__func__", None) is GenericAlias - )): + if GenericAlias and any( + (fn is GenericAlias, getattr(fn, "__func__", None) is GenericAlias) + ): fn = fn.__init__ try: @@ -618,7 +642,9 @@ def _fetch_reference_injections( # noqa: C901 return injections, closing -def _locate_dependent_closing_args(provider: providers.Provider) -> Dict[str, providers.Provider]: +def _locate_dependent_closing_args( + provider: providers.Provider, +) -> Dict[str, providers.Provider]: if not hasattr(provider, "args"): return {} @@ -672,8 +698,8 @@ def _fetch_modules(package): if not hasattr(package, "__path__") or not hasattr(package, "__name__"): return modules for module_info in pkgutil.walk_packages( - path=package.__path__, - prefix=package.__name__ + ".", + path=package.__path__, + prefix=package.__name__ + ".", ): module = importlib.import_module(module_info.name) modules.append(module) @@ -689,9 +715,9 @@ def _is_marker(member) -> bool: def _get_patched( - fn: F, - reference_injections: Dict[Any, Any], - reference_closing: Dict[Any, Any], + fn: F, + reference_injections: Dict[Any, Any], + reference_closing: Dict[Any, Any], ) -> F: patched_object = PatchedCallable( original=fn, @@ -719,9 +745,11 @@ def _is_patched(fn) -> bool: def _is_declarative_container(instance: Any) -> bool: - return (isinstance(instance, type) - and getattr(instance, "__IS_CONTAINER__", False) is True - and getattr(instance, "declarative_parent", None) is None) + return ( + isinstance(instance, type) + and getattr(instance, "__IS_CONTAINER__", False) is True + and getattr(instance, "declarative_parent", None) is None + ) def _safe_is_subclass(instance: Any, cls: Type) -> bool: @@ -734,11 +762,10 @@ def _safe_is_subclass(instance: Any, cls: Type) -> bool: class Modifier: def modify( - self, - provider: providers.ConfigurationOption, - providers_map: ProvidersMap, - ) -> providers.Provider: - ... + self, + provider: providers.ConfigurationOption, + providers_map: ProvidersMap, + ) -> providers.Provider: ... class TypeModifier(Modifier): @@ -747,9 +774,9 @@ class TypeModifier(Modifier): self.type_ = type_ def modify( - self, - provider: providers.ConfigurationOption, - providers_map: ProvidersMap, + self, + provider: providers.ConfigurationOption, + providers_map: ProvidersMap, ) -> providers.Provider: return provider.as_(self.type_) @@ -787,9 +814,9 @@ class RequiredModifier(Modifier): return self def modify( - self, - provider: providers.ConfigurationOption, - providers_map: ProvidersMap, + self, + provider: providers.ConfigurationOption, + providers_map: ProvidersMap, ) -> providers.Provider: provider = provider.required() if self.type_modifier: @@ -808,9 +835,9 @@ class InvariantModifier(Modifier): self.id = id def modify( - self, - provider: providers.ConfigurationOption, - providers_map: ProvidersMap, + self, + provider: providers.ConfigurationOption, + providers_map: ProvidersMap, ) -> providers.Provider: invariant_segment = providers_map.resolve_provider(self.id) return provider[invariant_segment] @@ -843,9 +870,9 @@ class ProvidedInstance(Modifier): return self def modify( - self, - provider: providers.Provider, - providers_map: ProvidersMap, + self, + provider: providers.Provider, + providers_map: ProvidersMap, ) -> providers.Provider: provider = provider.provided for type_, value in self.segments: @@ -876,9 +903,9 @@ class _Marker(Generic[T], metaclass=ClassGetItemMeta): __IS_MARKER__ = True def __init__( - self, - provider: Union[providers.Provider, Container, str], - modifier: Optional[Modifier] = None, + self, + provider: Union[providers.Provider, Container, str], + modifier: Optional[Modifier] = None, ) -> None: if _is_declarative_container(provider): provider = provider.__self__ @@ -894,16 +921,13 @@ class _Marker(Generic[T], metaclass=ClassGetItemMeta): return self -class Provide(_Marker): - ... +class Provide(_Marker): ... -class Provider(_Marker): - ... +class Provider(_Marker): ... -class Closing(_Marker): - ... +class Closing(_Marker): ... class AutoLoader: @@ -953,8 +977,7 @@ class AutoLoader: super().exec_module(module) loader.wire_module(module) - class ExtensionFileLoader(importlib.machinery.ExtensionFileLoader): - ... + class ExtensionFileLoader(importlib.machinery.ExtensionFileLoader): ... loader_details = [ (SourcelessFileLoader, importlib.machinery.BYTECODE_SUFFIXES), @@ -1023,4 +1046,5 @@ def _get_async_patched(fn: F, patched: PatchedCallable) -> F: patched.injections, patched.closing, ) + return _patched diff --git a/tests/unit/samples/wiringfastapi/web.py b/tests/unit/samples/wiringfastapi/web.py index f5c22d58..c1ed5102 100644 --- a/tests/unit/samples/wiringfastapi/web.py +++ b/tests/unit/samples/wiringfastapi/web.py @@ -3,7 +3,9 @@ import sys from typing_extensions import Annotated from fastapi import FastAPI, Depends -from fastapi import Request # See: https://github.com/ets-labs/python-dependency-injector/issues/398 +from fastapi import ( + Request, +) # See: https://github.com/ets-labs/python-dependency-injector/issues/398 from fastapi.security import HTTPBasic, HTTPBasicCredentials from dependency_injector import containers, providers from dependency_injector.wiring import inject, Provide @@ -29,17 +31,17 @@ async def index(service: Service = Depends(Provide[Container.service])): result = await service.process() return {"result": result} -@app.api_route('/annotated') + +@app.api_route("/annotated") @inject async def annotated(service: Annotated[Service, Depends(Provide[Container.service])]): result = await service.process() - return {'result': result} + return {"result": result} + @app.get("/auth") @inject -def read_current_user( - credentials: HTTPBasicCredentials = Depends(security) -): +def read_current_user(credentials: HTTPBasicCredentials = Depends(security)): return {"username": credentials.username, "password": credentials.password} diff --git a/tests/unit/samples/wiringflask/web.py b/tests/unit/samples/wiringflask/web.py index 0d21e3f9..890b9c26 100644 --- a/tests/unit/samples/wiringflask/web.py +++ b/tests/unit/samples/wiringflask/web.py @@ -34,7 +34,7 @@ def index(service: Service = Provide[Container.service]): @inject def annotated(service: Annotated[Service, Provide[Container.service]]): result = service.process() - return jsonify({'result': result}) + return jsonify({"result": result}) container = Container() diff --git a/tests/unit/wiring/test_fastapi_py36.py b/tests/unit/wiring/test_fastapi_py36.py index 74655775..f1e3930d 100644 --- a/tests/unit/wiring/test_fastapi_py36.py +++ b/tests/unit/wiring/test_fastapi_py36.py @@ -3,13 +3,17 @@ from pytest import fixture, mark # Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir import os + _SAMPLES_DIR = os.path.abspath( - os.path.sep.join(( - os.path.dirname(__file__), - "../samples/", - )), + os.path.sep.join( + ( + os.path.dirname(__file__), + "../samples/", + ) + ), ) import sys + sys.path.append(_SAMPLES_DIR) @@ -49,7 +53,6 @@ async def test_depends_with_annotated(async_client: AsyncClient): assert response.json() == {"result": "Foo"} - @mark.asyncio async def test_depends_injection(async_client: AsyncClient): response = await async_client.get("/auth", auth=("john_smith", "secret")) diff --git a/tests/unit/wiring/test_flask_py36.py b/tests/unit/wiring/test_flask_py36.py index beb23fdc..97420275 100644 --- a/tests/unit/wiring/test_flask_py36.py +++ b/tests/unit/wiring/test_flask_py36.py @@ -2,19 +2,25 @@ import json # Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir import os + _TOP_DIR = os.path.abspath( - os.path.sep.join(( - os.path.dirname(__file__), - "../", - )), + os.path.sep.join( + ( + os.path.dirname(__file__), + "../", + ) + ), ) _SAMPLES_DIR = os.path.abspath( - os.path.sep.join(( - os.path.dirname(__file__), - "../samples/", - )), + os.path.sep.join( + ( + os.path.dirname(__file__), + "../samples/", + ) + ), ) import sys + sys.path.append(_TOP_DIR) sys.path.append(_SAMPLES_DIR) @@ -36,6 +42,6 @@ def test_wiring_with_annotated(): with web.app.app_context(): response = client.get("/annotated") - + assert response.status_code == 200 assert json.loads(response.data) == {"result": "OK"}