Run black

This commit is contained in:
Taein Min 2025-01-20 23:45:48 +09:00
parent 633f60ac19
commit 91c90cf9e9
6 changed files with 175 additions and 140 deletions

View File

@ -5,24 +5,24 @@ from dependency_injector.wiring import Provide, inject
from typing import Annotated
class Service:
...
class Service: ...
class Container(containers.DeclarativeContainer):
service = providers.Factory(Service)
# You can place marker on parameter default value
@inject
def main(service: Service = Provide[Container.service]) -> None:
...
def main(service: Service = Provide[Container.service]) -> None: ...
# Also, you can place marker with typing.Annotated
@inject
def main_with_annotated(service: Annotated[Service, Provide[Container.service]]) -> None:
...
def main_with_annotated(
service: Annotated[Service, Provide[Container.service]]
) -> None: ...
if __name__ == "__main__":

View File

@ -27,8 +27,9 @@ from typing import (
if sys.version_info < (3, 7):
from typing import GenericMeta
else:
class GenericMeta(type):
...
class GenericMeta(type): ...
# Hotfix, see: https://github.com/ets-labs/python-dependency-injector/issues/362
if sys.version_info >= (3, 9):
@ -51,6 +52,7 @@ else:
def get_origin(tp):
return None
try:
import fastapi.params
except ImportError:
@ -113,7 +115,9 @@ class PatchedRegistry:
def register_callable(self, patched: "PatchedCallable") -> None:
self._callables[patched.patched] = patched
def get_callables_from_module(self, module: ModuleType) -> Iterator[Callable[..., Any]]:
def get_callables_from_module(
self, module: ModuleType
) -> Iterator[Callable[..., Any]]:
for patched_callable in self._callables.values():
if not patched_callable.is_in_module(module):
continue
@ -128,7 +132,9 @@ class PatchedRegistry:
def register_attribute(self, patched: "PatchedAttribute") -> None:
self._attributes.add(patched)
def get_attributes_from_module(self, module: ModuleType) -> Iterator["PatchedAttribute"]:
def get_attributes_from_module(
self, module: ModuleType
) -> Iterator["PatchedAttribute"]:
for attribute in self._attributes:
if not attribute.is_in_module(module):
continue
@ -234,12 +240,15 @@ class ProvidersMap:
) -> Optional[providers.Provider]:
if isinstance(provider, providers.Delegate):
return self._resolve_delegate(provider)
elif isinstance(provider, (
elif isinstance(
provider,
(
providers.ProvidedInstance,
providers.AttributeGetter,
providers.ItemGetter,
providers.MethodCaller,
)):
),
):
return self._resolve_provided_instance(provider)
elif isinstance(provider, providers.ConfigurationOption):
return self._resolve_config_option(provider)
@ -274,12 +283,15 @@ class ProvidersMap:
original: providers.Provider,
) -> Optional[providers.Provider]:
modifiers = []
while isinstance(original, (
while isinstance(
original,
(
providers.ProvidedInstance,
providers.AttributeGetter,
providers.ItemGetter,
providers.MethodCaller,
)):
),
):
modifiers.insert(0, original)
original = original.provides
@ -363,8 +375,9 @@ class ProvidersMap:
original_provider = original_providers[provider_name]
providers_map[original_provider] = current_provider
if isinstance(current_provider, providers.Container) \
and isinstance(original_provider, providers.Container):
if isinstance(current_provider, providers.Container) and isinstance(
original_provider, providers.Container
):
subcontainer_map = cls._create_providers_map(
current_container=current_provider.container,
original_container=original_provider.container,
@ -390,9 +403,11 @@ class InspectFilter:
return werkzeug and isinstance(instance, werkzeug.local.LocalProxy)
def _is_starlette_request_cls(self, instance: object) -> bool:
return starlette \
and isinstance(instance, type) \
return (
starlette
and isinstance(instance, type)
and _safe_is_subclass(instance, starlette.requests.Request)
)
def _is_builtin(self, instance: object) -> bool:
return inspect.isbuiltin(instance)
@ -432,9 +447,13 @@ def wire( # noqa: C901
else:
for cls_member_name, cls_member in cls_members:
if _is_marker(cls_member):
_patch_attribute(cls, cls_member_name, cls_member, providers_map)
_patch_attribute(
cls, cls_member_name, cls_member, providers_map
)
elif _is_method(cls_member):
_patch_method(cls, cls_member_name, cls_member, providers_map)
_patch_method(
cls, cls_member_name, cls_member, providers_map
)
for patched in _patched_registry.get_callables_from_module(module):
_bind_injections(patched, providers_map)
@ -457,7 +476,9 @@ def unwire( # noqa: C901
if inspect.isfunction(member):
_unpatch(module, name, member)
elif inspect.isclass(member):
for method_name, method in inspect.getmembers(member, inspect.isfunction):
for method_name, method in inspect.getmembers(
member, inspect.isfunction
):
_unpatch(member, method_name, method)
for patched in _patched_registry.get_callables_from_module(module):
@ -498,9 +519,11 @@ def _patch_method(
method: Callable[..., Any],
providers_map: ProvidersMap,
) -> None:
if hasattr(cls, "__dict__") \
and name in cls.__dict__ \
and isinstance(cls.__dict__[name], (classmethod, staticmethod)):
if (
hasattr(cls, "__dict__")
and name in cls.__dict__
and isinstance(cls.__dict__[name], (classmethod, staticmethod))
):
method = cls.__dict__[name]
fn = method.__func__
else:
@ -525,9 +548,11 @@ def _unpatch(
name: str,
fn: Callable[..., Any],
) -> None:
if hasattr(module, "__dict__") \
and name in module.__dict__ \
and isinstance(module.__dict__[name], (classmethod, staticmethod)):
if (
hasattr(module, "__dict__")
and name in module.__dict__
and isinstance(module.__dict__[name], (classmethod, staticmethod))
):
method = module.__dict__[name]
fn = method.__func__
@ -586,10 +611,9 @@ def _fetch_reference_injections( # noqa: C901
# Hotfix, see:
# - https://github.com/ets-labs/python-dependency-injector/issues/362
# - https://github.com/ets-labs/python-dependency-injector/issues/398
if GenericAlias and any((
fn is GenericAlias,
getattr(fn, "__func__", None) is GenericAlias
)):
if GenericAlias and any(
(fn is GenericAlias, getattr(fn, "__func__", None) is GenericAlias)
):
fn = fn.__init__
try:
@ -618,7 +642,9 @@ def _fetch_reference_injections( # noqa: C901
return injections, closing
def _locate_dependent_closing_args(provider: providers.Provider) -> Dict[str, providers.Provider]:
def _locate_dependent_closing_args(
provider: providers.Provider,
) -> Dict[str, providers.Provider]:
if not hasattr(provider, "args"):
return {}
@ -719,9 +745,11 @@ def _is_patched(fn) -> bool:
def _is_declarative_container(instance: Any) -> bool:
return (isinstance(instance, type)
return (
isinstance(instance, type)
and getattr(instance, "__IS_CONTAINER__", False) is True
and getattr(instance, "declarative_parent", None) is None)
and getattr(instance, "declarative_parent", None) is None
)
def _safe_is_subclass(instance: Any, cls: Type) -> bool:
@ -737,8 +765,7 @@ class Modifier:
self,
provider: providers.ConfigurationOption,
providers_map: ProvidersMap,
) -> providers.Provider:
...
) -> providers.Provider: ...
class TypeModifier(Modifier):
@ -894,16 +921,13 @@ class _Marker(Generic[T], metaclass=ClassGetItemMeta):
return self
class Provide(_Marker):
...
class Provide(_Marker): ...
class Provider(_Marker):
...
class Provider(_Marker): ...
class Closing(_Marker):
...
class Closing(_Marker): ...
class AutoLoader:
@ -953,8 +977,7 @@ class AutoLoader:
super().exec_module(module)
loader.wire_module(module)
class ExtensionFileLoader(importlib.machinery.ExtensionFileLoader):
...
class ExtensionFileLoader(importlib.machinery.ExtensionFileLoader): ...
loader_details = [
(SourcelessFileLoader, importlib.machinery.BYTECODE_SUFFIXES),
@ -1023,4 +1046,5 @@ def _get_async_patched(fn: F, patched: PatchedCallable) -> F:
patched.injections,
patched.closing,
)
return _patched

View File

@ -3,7 +3,9 @@ import sys
from typing_extensions import Annotated
from fastapi import FastAPI, Depends
from fastapi import Request # See: https://github.com/ets-labs/python-dependency-injector/issues/398
from fastapi import (
Request,
) # See: https://github.com/ets-labs/python-dependency-injector/issues/398
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from dependency_injector import containers, providers
from dependency_injector.wiring import inject, Provide
@ -29,17 +31,17 @@ async def index(service: Service = Depends(Provide[Container.service])):
result = await service.process()
return {"result": result}
@app.api_route('/annotated')
@app.api_route("/annotated")
@inject
async def annotated(service: Annotated[Service, Depends(Provide[Container.service])]):
result = await service.process()
return {'result': result}
return {"result": result}
@app.get("/auth")
@inject
def read_current_user(
credentials: HTTPBasicCredentials = Depends(security)
):
def read_current_user(credentials: HTTPBasicCredentials = Depends(security)):
return {"username": credentials.username, "password": credentials.password}

View File

@ -34,7 +34,7 @@ def index(service: Service = Provide[Container.service]):
@inject
def annotated(service: Annotated[Service, Provide[Container.service]]):
result = service.process()
return jsonify({'result': result})
return jsonify({"result": result})
container = Container()

View File

@ -3,13 +3,17 @@ from pytest import fixture, mark
# Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir
import os
_SAMPLES_DIR = os.path.abspath(
os.path.sep.join((
os.path.sep.join(
(
os.path.dirname(__file__),
"../samples/",
)),
)
),
)
import sys
sys.path.append(_SAMPLES_DIR)
@ -49,7 +53,6 @@ async def test_depends_with_annotated(async_client: AsyncClient):
assert response.json() == {"result": "Foo"}
@mark.asyncio
async def test_depends_injection(async_client: AsyncClient):
response = await async_client.get("/auth", auth=("john_smith", "secret"))

View File

@ -2,19 +2,25 @@ import json
# Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir
import os
_TOP_DIR = os.path.abspath(
os.path.sep.join((
os.path.sep.join(
(
os.path.dirname(__file__),
"../",
)),
)
),
)
_SAMPLES_DIR = os.path.abspath(
os.path.sep.join((
os.path.sep.join(
(
os.path.dirname(__file__),
"../samples/",
)),
)
),
)
import sys
sys.path.append(_TOP_DIR)
sys.path.append(_SAMPLES_DIR)