Run black

This commit is contained in:
Taein Min 2025-01-20 23:45:48 +09:00
parent 633f60ac19
commit 91c90cf9e9
6 changed files with 175 additions and 140 deletions

View File

@ -5,24 +5,24 @@ from dependency_injector.wiring import Provide, inject
from typing import Annotated from typing import Annotated
class Service: class Service: ...
...
class Container(containers.DeclarativeContainer): class Container(containers.DeclarativeContainer):
service = providers.Factory(Service) service = providers.Factory(Service)
# You can place marker on parameter default value # You can place marker on parameter default value
@inject @inject
def main(service: Service = Provide[Container.service]) -> None: def main(service: Service = Provide[Container.service]) -> None: ...
...
# Also, you can place marker with typing.Annotated # Also, you can place marker with typing.Annotated
@inject @inject
def main_with_annotated(service: Annotated[Service, Provide[Container.service]]) -> None: def main_with_annotated(
... service: Annotated[Service, Provide[Container.service]]
) -> None: ...
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -27,8 +27,9 @@ from typing import (
if sys.version_info < (3, 7): if sys.version_info < (3, 7):
from typing import GenericMeta from typing import GenericMeta
else: else:
class GenericMeta(type):
... class GenericMeta(type): ...
# Hotfix, see: https://github.com/ets-labs/python-dependency-injector/issues/362 # Hotfix, see: https://github.com/ets-labs/python-dependency-injector/issues/362
if sys.version_info >= (3, 9): if sys.version_info >= (3, 9):
@ -51,6 +52,7 @@ else:
def get_origin(tp): def get_origin(tp):
return None return None
try: try:
import fastapi.params import fastapi.params
except ImportError: except ImportError:
@ -113,7 +115,9 @@ class PatchedRegistry:
def register_callable(self, patched: "PatchedCallable") -> None: def register_callable(self, patched: "PatchedCallable") -> None:
self._callables[patched.patched] = patched self._callables[patched.patched] = patched
def get_callables_from_module(self, module: ModuleType) -> Iterator[Callable[..., Any]]: def get_callables_from_module(
self, module: ModuleType
) -> Iterator[Callable[..., Any]]:
for patched_callable in self._callables.values(): for patched_callable in self._callables.values():
if not patched_callable.is_in_module(module): if not patched_callable.is_in_module(module):
continue continue
@ -128,7 +132,9 @@ class PatchedRegistry:
def register_attribute(self, patched: "PatchedAttribute") -> None: def register_attribute(self, patched: "PatchedAttribute") -> None:
self._attributes.add(patched) self._attributes.add(patched)
def get_attributes_from_module(self, module: ModuleType) -> Iterator["PatchedAttribute"]: def get_attributes_from_module(
self, module: ModuleType
) -> Iterator["PatchedAttribute"]:
for attribute in self._attributes: for attribute in self._attributes:
if not attribute.is_in_module(module): if not attribute.is_in_module(module):
continue continue
@ -153,11 +159,11 @@ class PatchedCallable:
) )
def __init__( def __init__(
self, self,
patched: Optional[Callable[..., Any]] = None, patched: Optional[Callable[..., Any]] = None,
original: Optional[Callable[..., Any]] = None, original: Optional[Callable[..., Any]] = None,
reference_injections: Optional[Dict[Any, Any]] = None, reference_injections: Optional[Dict[Any, Any]] = None,
reference_closing: Optional[Dict[Any, Any]] = None, reference_closing: Optional[Dict[Any, Any]] = None,
) -> None: ) -> None:
self.patched = patched self.patched = patched
self.original = original self.original = original
@ -228,18 +234,21 @@ class ProvidersMap:
) )
def resolve_provider( def resolve_provider(
self, self,
provider: Union[providers.Provider, str], provider: Union[providers.Provider, str],
modifier: Optional["Modifier"] = None, modifier: Optional["Modifier"] = None,
) -> Optional[providers.Provider]: ) -> Optional[providers.Provider]:
if isinstance(provider, providers.Delegate): if isinstance(provider, providers.Delegate):
return self._resolve_delegate(provider) return self._resolve_delegate(provider)
elif isinstance(provider, ( elif isinstance(
providers.ProvidedInstance, provider,
providers.AttributeGetter, (
providers.ItemGetter, providers.ProvidedInstance,
providers.MethodCaller, providers.AttributeGetter,
)): providers.ItemGetter,
providers.MethodCaller,
),
):
return self._resolve_provided_instance(provider) return self._resolve_provided_instance(provider)
elif isinstance(provider, providers.ConfigurationOption): elif isinstance(provider, providers.ConfigurationOption):
return self._resolve_config_option(provider) return self._resolve_config_option(provider)
@ -251,9 +260,9 @@ class ProvidersMap:
return self._resolve_provider(provider) return self._resolve_provider(provider)
def _resolve_string_id( def _resolve_string_id(
self, self,
id: str, id: str,
modifier: Optional["Modifier"] = None, modifier: Optional["Modifier"] = None,
) -> Optional[providers.Provider]: ) -> Optional[providers.Provider]:
if id == self.CONTAINER_STRING_ID: if id == self.CONTAINER_STRING_ID:
return self._container.__self__ return self._container.__self__
@ -270,16 +279,19 @@ class ProvidersMap:
return provider return provider
def _resolve_provided_instance( def _resolve_provided_instance(
self, self,
original: providers.Provider, original: providers.Provider,
) -> Optional[providers.Provider]: ) -> Optional[providers.Provider]:
modifiers = [] modifiers = []
while isinstance(original, ( while isinstance(
original,
(
providers.ProvidedInstance, providers.ProvidedInstance,
providers.AttributeGetter, providers.AttributeGetter,
providers.ItemGetter, providers.ItemGetter,
providers.MethodCaller, providers.MethodCaller,
)): ),
):
modifiers.insert(0, original) modifiers.insert(0, original)
original = original.provides original = original.provides
@ -303,8 +315,8 @@ class ProvidersMap:
return new return new
def _resolve_delegate( def _resolve_delegate(
self, self,
original: providers.Delegate, original: providers.Delegate,
) -> Optional[providers.Provider]: ) -> Optional[providers.Provider]:
provider = self._resolve_provider(original.provides) provider = self._resolve_provider(original.provides)
if provider: if provider:
@ -312,9 +324,9 @@ class ProvidersMap:
return provider return provider
def _resolve_config_option( def _resolve_config_option(
self, self,
original: providers.ConfigurationOption, original: providers.ConfigurationOption,
as_: Any = None, as_: Any = None,
) -> Optional[providers.Provider]: ) -> Optional[providers.Provider]:
original_root = original.root original_root = original.root
new = self._resolve_provider(original_root) new = self._resolve_provider(original_root)
@ -338,8 +350,8 @@ class ProvidersMap:
return new return new
def _resolve_provider( def _resolve_provider(
self, self,
original: providers.Provider, original: providers.Provider,
) -> Optional[providers.Provider]: ) -> Optional[providers.Provider]:
try: try:
return self._map[original] return self._map[original]
@ -348,9 +360,9 @@ class ProvidersMap:
@classmethod @classmethod
def _create_providers_map( def _create_providers_map(
cls, cls,
current_container: Container, current_container: Container,
original_container: Container, original_container: Container,
) -> Dict[providers.Provider, providers.Provider]: ) -> Dict[providers.Provider, providers.Provider]:
current_providers = current_container.providers current_providers = current_container.providers
current_providers["__self__"] = current_container.__self__ current_providers["__self__"] = current_container.__self__
@ -363,8 +375,9 @@ class ProvidersMap:
original_provider = original_providers[provider_name] original_provider = original_providers[provider_name]
providers_map[original_provider] = current_provider providers_map[original_provider] = current_provider
if isinstance(current_provider, providers.Container) \ if isinstance(current_provider, providers.Container) and isinstance(
and isinstance(original_provider, providers.Container): original_provider, providers.Container
):
subcontainer_map = cls._create_providers_map( subcontainer_map = cls._create_providers_map(
current_container=current_provider.container, current_container=current_provider.container,
original_container=original_provider.container, original_container=original_provider.container,
@ -390,19 +403,21 @@ class InspectFilter:
return werkzeug and isinstance(instance, werkzeug.local.LocalProxy) return werkzeug and isinstance(instance, werkzeug.local.LocalProxy)
def _is_starlette_request_cls(self, instance: object) -> bool: def _is_starlette_request_cls(self, instance: object) -> bool:
return starlette \ return (
and isinstance(instance, type) \ starlette
and _safe_is_subclass(instance, starlette.requests.Request) and isinstance(instance, type)
and _safe_is_subclass(instance, starlette.requests.Request)
)
def _is_builtin(self, instance: object) -> bool: def _is_builtin(self, instance: object) -> bool:
return inspect.isbuiltin(instance) return inspect.isbuiltin(instance)
def wire( # noqa: C901 def wire( # noqa: C901
container: Container, container: Container,
*, *,
modules: Optional[Iterable[ModuleType]] = None, modules: Optional[Iterable[ModuleType]] = None,
packages: Optional[Iterable[ModuleType]] = None, packages: Optional[Iterable[ModuleType]] = None,
) -> None: ) -> None:
"""Wire container providers with provided packages and modules.""" """Wire container providers with provided packages and modules."""
modules = [*modules] if modules else [] modules = [*modules] if modules else []
@ -432,18 +447,22 @@ def wire( # noqa: C901
else: else:
for cls_member_name, cls_member in cls_members: for cls_member_name, cls_member in cls_members:
if _is_marker(cls_member): if _is_marker(cls_member):
_patch_attribute(cls, cls_member_name, cls_member, providers_map) _patch_attribute(
cls, cls_member_name, cls_member, providers_map
)
elif _is_method(cls_member): elif _is_method(cls_member):
_patch_method(cls, cls_member_name, cls_member, providers_map) _patch_method(
cls, cls_member_name, cls_member, providers_map
)
for patched in _patched_registry.get_callables_from_module(module): for patched in _patched_registry.get_callables_from_module(module):
_bind_injections(patched, providers_map) _bind_injections(patched, providers_map)
def unwire( # noqa: C901 def unwire( # noqa: C901
*, *,
modules: Optional[Iterable[ModuleType]] = None, modules: Optional[Iterable[ModuleType]] = None,
packages: Optional[Iterable[ModuleType]] = None, packages: Optional[Iterable[ModuleType]] = None,
) -> None: ) -> None:
"""Wire provided packages and modules with previous wired providers.""" """Wire provided packages and modules with previous wired providers."""
modules = [*modules] if modules else [] modules = [*modules] if modules else []
@ -457,7 +476,9 @@ def unwire( # noqa: C901
if inspect.isfunction(member): if inspect.isfunction(member):
_unpatch(module, name, member) _unpatch(module, name, member)
elif inspect.isclass(member): elif inspect.isclass(member):
for method_name, method in inspect.getmembers(member, inspect.isfunction): for method_name, method in inspect.getmembers(
member, inspect.isfunction
):
_unpatch(member, method_name, method) _unpatch(member, method_name, method)
for patched in _patched_registry.get_callables_from_module(module): for patched in _patched_registry.get_callables_from_module(module):
@ -476,10 +497,10 @@ def inject(fn: F) -> F:
def _patch_fn( def _patch_fn(
module: ModuleType, module: ModuleType,
name: str, name: str,
fn: Callable[..., Any], fn: Callable[..., Any],
providers_map: ProvidersMap, providers_map: ProvidersMap,
) -> None: ) -> None:
if not _is_patched(fn): if not _is_patched(fn):
reference_injections, reference_closing = _fetch_reference_injections(fn) reference_injections, reference_closing = _fetch_reference_injections(fn)
@ -493,14 +514,16 @@ def _patch_fn(
def _patch_method( def _patch_method(
cls: Type, cls: Type,
name: str, name: str,
method: Callable[..., Any], method: Callable[..., Any],
providers_map: ProvidersMap, providers_map: ProvidersMap,
) -> None: ) -> None:
if hasattr(cls, "__dict__") \ if (
and name in cls.__dict__ \ hasattr(cls, "__dict__")
and isinstance(cls.__dict__[name], (classmethod, staticmethod)): and name in cls.__dict__
and isinstance(cls.__dict__[name], (classmethod, staticmethod))
):
method = cls.__dict__[name] method = cls.__dict__[name]
fn = method.__func__ fn = method.__func__
else: else:
@ -521,13 +544,15 @@ def _patch_method(
def _unpatch( def _unpatch(
module: ModuleType, module: ModuleType,
name: str, name: str,
fn: Callable[..., Any], fn: Callable[..., Any],
) -> None: ) -> None:
if hasattr(module, "__dict__") \ if (
and name in module.__dict__ \ hasattr(module, "__dict__")
and isinstance(module.__dict__[name], (classmethod, staticmethod)): and name in module.__dict__
and isinstance(module.__dict__[name], (classmethod, staticmethod))
):
method = module.__dict__[name] method = module.__dict__[name]
fn = method.__func__ fn = method.__func__
@ -538,10 +563,10 @@ def _unpatch(
def _patch_attribute( def _patch_attribute(
member: Any, member: Any,
name: str, name: str,
marker: "_Marker", marker: "_Marker",
providers_map: ProvidersMap, providers_map: ProvidersMap,
) -> None: ) -> None:
provider = providers_map.resolve_provider(marker.provider, marker.modifier) provider = providers_map.resolve_provider(marker.provider, marker.modifier)
if provider is None: if provider is None:
@ -581,15 +606,14 @@ def _extract_marker(parameter: inspect.Parameter) -> Optional["_Marker"]:
def _fetch_reference_injections( # noqa: C901 def _fetch_reference_injections( # noqa: C901
fn: Callable[..., Any], fn: Callable[..., Any],
) -> Tuple[Dict[str, Any], Dict[str, Any]]: ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# Hotfix, see: # Hotfix, see:
# - https://github.com/ets-labs/python-dependency-injector/issues/362 # - https://github.com/ets-labs/python-dependency-injector/issues/362
# - https://github.com/ets-labs/python-dependency-injector/issues/398 # - https://github.com/ets-labs/python-dependency-injector/issues/398
if GenericAlias and any(( if GenericAlias and any(
fn is GenericAlias, (fn is GenericAlias, getattr(fn, "__func__", None) is GenericAlias)
getattr(fn, "__func__", None) is GenericAlias ):
)):
fn = fn.__init__ fn = fn.__init__
try: try:
@ -618,7 +642,9 @@ def _fetch_reference_injections( # noqa: C901
return injections, closing return injections, closing
def _locate_dependent_closing_args(provider: providers.Provider) -> Dict[str, providers.Provider]: def _locate_dependent_closing_args(
provider: providers.Provider,
) -> Dict[str, providers.Provider]:
if not hasattr(provider, "args"): if not hasattr(provider, "args"):
return {} return {}
@ -672,8 +698,8 @@ def _fetch_modules(package):
if not hasattr(package, "__path__") or not hasattr(package, "__name__"): if not hasattr(package, "__path__") or not hasattr(package, "__name__"):
return modules return modules
for module_info in pkgutil.walk_packages( for module_info in pkgutil.walk_packages(
path=package.__path__, path=package.__path__,
prefix=package.__name__ + ".", prefix=package.__name__ + ".",
): ):
module = importlib.import_module(module_info.name) module = importlib.import_module(module_info.name)
modules.append(module) modules.append(module)
@ -689,9 +715,9 @@ def _is_marker(member) -> bool:
def _get_patched( def _get_patched(
fn: F, fn: F,
reference_injections: Dict[Any, Any], reference_injections: Dict[Any, Any],
reference_closing: Dict[Any, Any], reference_closing: Dict[Any, Any],
) -> F: ) -> F:
patched_object = PatchedCallable( patched_object = PatchedCallable(
original=fn, original=fn,
@ -719,9 +745,11 @@ def _is_patched(fn) -> bool:
def _is_declarative_container(instance: Any) -> bool: def _is_declarative_container(instance: Any) -> bool:
return (isinstance(instance, type) return (
and getattr(instance, "__IS_CONTAINER__", False) is True isinstance(instance, type)
and getattr(instance, "declarative_parent", None) is None) and getattr(instance, "__IS_CONTAINER__", False) is True
and getattr(instance, "declarative_parent", None) is None
)
def _safe_is_subclass(instance: Any, cls: Type) -> bool: def _safe_is_subclass(instance: Any, cls: Type) -> bool:
@ -734,11 +762,10 @@ def _safe_is_subclass(instance: Any, cls: Type) -> bool:
class Modifier: class Modifier:
def modify( def modify(
self, self,
provider: providers.ConfigurationOption, provider: providers.ConfigurationOption,
providers_map: ProvidersMap, providers_map: ProvidersMap,
) -> providers.Provider: ) -> providers.Provider: ...
...
class TypeModifier(Modifier): class TypeModifier(Modifier):
@ -747,9 +774,9 @@ class TypeModifier(Modifier):
self.type_ = type_ self.type_ = type_
def modify( def modify(
self, self,
provider: providers.ConfigurationOption, provider: providers.ConfigurationOption,
providers_map: ProvidersMap, providers_map: ProvidersMap,
) -> providers.Provider: ) -> providers.Provider:
return provider.as_(self.type_) return provider.as_(self.type_)
@ -787,9 +814,9 @@ class RequiredModifier(Modifier):
return self return self
def modify( def modify(
self, self,
provider: providers.ConfigurationOption, provider: providers.ConfigurationOption,
providers_map: ProvidersMap, providers_map: ProvidersMap,
) -> providers.Provider: ) -> providers.Provider:
provider = provider.required() provider = provider.required()
if self.type_modifier: if self.type_modifier:
@ -808,9 +835,9 @@ class InvariantModifier(Modifier):
self.id = id self.id = id
def modify( def modify(
self, self,
provider: providers.ConfigurationOption, provider: providers.ConfigurationOption,
providers_map: ProvidersMap, providers_map: ProvidersMap,
) -> providers.Provider: ) -> providers.Provider:
invariant_segment = providers_map.resolve_provider(self.id) invariant_segment = providers_map.resolve_provider(self.id)
return provider[invariant_segment] return provider[invariant_segment]
@ -843,9 +870,9 @@ class ProvidedInstance(Modifier):
return self return self
def modify( def modify(
self, self,
provider: providers.Provider, provider: providers.Provider,
providers_map: ProvidersMap, providers_map: ProvidersMap,
) -> providers.Provider: ) -> providers.Provider:
provider = provider.provided provider = provider.provided
for type_, value in self.segments: for type_, value in self.segments:
@ -876,9 +903,9 @@ class _Marker(Generic[T], metaclass=ClassGetItemMeta):
__IS_MARKER__ = True __IS_MARKER__ = True
def __init__( def __init__(
self, self,
provider: Union[providers.Provider, Container, str], provider: Union[providers.Provider, Container, str],
modifier: Optional[Modifier] = None, modifier: Optional[Modifier] = None,
) -> None: ) -> None:
if _is_declarative_container(provider): if _is_declarative_container(provider):
provider = provider.__self__ provider = provider.__self__
@ -894,16 +921,13 @@ class _Marker(Generic[T], metaclass=ClassGetItemMeta):
return self return self
class Provide(_Marker): class Provide(_Marker): ...
...
class Provider(_Marker): class Provider(_Marker): ...
...
class Closing(_Marker): class Closing(_Marker): ...
...
class AutoLoader: class AutoLoader:
@ -953,8 +977,7 @@ class AutoLoader:
super().exec_module(module) super().exec_module(module)
loader.wire_module(module) loader.wire_module(module)
class ExtensionFileLoader(importlib.machinery.ExtensionFileLoader): class ExtensionFileLoader(importlib.machinery.ExtensionFileLoader): ...
...
loader_details = [ loader_details = [
(SourcelessFileLoader, importlib.machinery.BYTECODE_SUFFIXES), (SourcelessFileLoader, importlib.machinery.BYTECODE_SUFFIXES),
@ -1023,4 +1046,5 @@ def _get_async_patched(fn: F, patched: PatchedCallable) -> F:
patched.injections, patched.injections,
patched.closing, patched.closing,
) )
return _patched return _patched

View File

@ -3,7 +3,9 @@ import sys
from typing_extensions import Annotated from typing_extensions import Annotated
from fastapi import FastAPI, Depends from fastapi import FastAPI, Depends
from fastapi import Request # See: https://github.com/ets-labs/python-dependency-injector/issues/398 from fastapi import (
Request,
) # See: https://github.com/ets-labs/python-dependency-injector/issues/398
from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.security import HTTPBasic, HTTPBasicCredentials
from dependency_injector import containers, providers from dependency_injector import containers, providers
from dependency_injector.wiring import inject, Provide from dependency_injector.wiring import inject, Provide
@ -29,17 +31,17 @@ async def index(service: Service = Depends(Provide[Container.service])):
result = await service.process() result = await service.process()
return {"result": result} return {"result": result}
@app.api_route('/annotated')
@app.api_route("/annotated")
@inject @inject
async def annotated(service: Annotated[Service, Depends(Provide[Container.service])]): async def annotated(service: Annotated[Service, Depends(Provide[Container.service])]):
result = await service.process() result = await service.process()
return {'result': result} return {"result": result}
@app.get("/auth") @app.get("/auth")
@inject @inject
def read_current_user( def read_current_user(credentials: HTTPBasicCredentials = Depends(security)):
credentials: HTTPBasicCredentials = Depends(security)
):
return {"username": credentials.username, "password": credentials.password} return {"username": credentials.username, "password": credentials.password}

View File

@ -34,7 +34,7 @@ def index(service: Service = Provide[Container.service]):
@inject @inject
def annotated(service: Annotated[Service, Provide[Container.service]]): def annotated(service: Annotated[Service, Provide[Container.service]]):
result = service.process() result = service.process()
return jsonify({'result': result}) return jsonify({"result": result})
container = Container() container = Container()

View File

@ -3,13 +3,17 @@ from pytest import fixture, mark
# Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir # Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir
import os import os
_SAMPLES_DIR = os.path.abspath( _SAMPLES_DIR = os.path.abspath(
os.path.sep.join(( os.path.sep.join(
os.path.dirname(__file__), (
"../samples/", os.path.dirname(__file__),
)), "../samples/",
)
),
) )
import sys import sys
sys.path.append(_SAMPLES_DIR) sys.path.append(_SAMPLES_DIR)
@ -49,7 +53,6 @@ async def test_depends_with_annotated(async_client: AsyncClient):
assert response.json() == {"result": "Foo"} assert response.json() == {"result": "Foo"}
@mark.asyncio @mark.asyncio
async def test_depends_injection(async_client: AsyncClient): async def test_depends_injection(async_client: AsyncClient):
response = await async_client.get("/auth", auth=("john_smith", "secret")) response = await async_client.get("/auth", auth=("john_smith", "secret"))

View File

@ -2,19 +2,25 @@ import json
# Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir # Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir
import os import os
_TOP_DIR = os.path.abspath( _TOP_DIR = os.path.abspath(
os.path.sep.join(( os.path.sep.join(
os.path.dirname(__file__), (
"../", os.path.dirname(__file__),
)), "../",
)
),
) )
_SAMPLES_DIR = os.path.abspath( _SAMPLES_DIR = os.path.abspath(
os.path.sep.join(( os.path.sep.join(
os.path.dirname(__file__), (
"../samples/", os.path.dirname(__file__),
)), "../samples/",
)
),
) )
import sys import sys
sys.path.append(_TOP_DIR) sys.path.append(_TOP_DIR)
sys.path.append(_SAMPLES_DIR) sys.path.append(_SAMPLES_DIR)
@ -36,6 +42,6 @@ def test_wiring_with_annotated():
with web.app.app_context(): with web.app.app_context():
response = client.get("/annotated") response = client.get("/annotated")
assert response.status_code == 200 assert response.status_code == 200
assert json.loads(response.data) == {"result": "OK"} assert json.loads(response.data) == {"result": "OK"}