mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2025-07-02 03:13:15 +03:00
Run black
This commit is contained in:
parent
633f60ac19
commit
91c90cf9e9
|
@ -5,24 +5,24 @@ from dependency_injector.wiring import Provide, inject
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
|
|
||||||
class Service:
|
class Service: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class Container(containers.DeclarativeContainer):
|
class Container(containers.DeclarativeContainer):
|
||||||
|
|
||||||
service = providers.Factory(Service)
|
service = providers.Factory(Service)
|
||||||
|
|
||||||
|
|
||||||
# You can place marker on parameter default value
|
# You can place marker on parameter default value
|
||||||
@inject
|
@inject
|
||||||
def main(service: Service = Provide[Container.service]) -> None:
|
def main(service: Service = Provide[Container.service]) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
# Also, you can place marker with typing.Annotated
|
# Also, you can place marker with typing.Annotated
|
||||||
@inject
|
@inject
|
||||||
def main_with_annotated(service: Annotated[Service, Provide[Container.service]]) -> None:
|
def main_with_annotated(
|
||||||
...
|
service: Annotated[Service, Provide[Container.service]]
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -27,8 +27,9 @@ from typing import (
|
||||||
if sys.version_info < (3, 7):
|
if sys.version_info < (3, 7):
|
||||||
from typing import GenericMeta
|
from typing import GenericMeta
|
||||||
else:
|
else:
|
||||||
class GenericMeta(type):
|
|
||||||
...
|
class GenericMeta(type): ...
|
||||||
|
|
||||||
|
|
||||||
# Hotfix, see: https://github.com/ets-labs/python-dependency-injector/issues/362
|
# Hotfix, see: https://github.com/ets-labs/python-dependency-injector/issues/362
|
||||||
if sys.version_info >= (3, 9):
|
if sys.version_info >= (3, 9):
|
||||||
|
@ -51,6 +52,7 @@ else:
|
||||||
def get_origin(tp):
|
def get_origin(tp):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import fastapi.params
|
import fastapi.params
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -113,7 +115,9 @@ class PatchedRegistry:
|
||||||
def register_callable(self, patched: "PatchedCallable") -> None:
|
def register_callable(self, patched: "PatchedCallable") -> None:
|
||||||
self._callables[patched.patched] = patched
|
self._callables[patched.patched] = patched
|
||||||
|
|
||||||
def get_callables_from_module(self, module: ModuleType) -> Iterator[Callable[..., Any]]:
|
def get_callables_from_module(
|
||||||
|
self, module: ModuleType
|
||||||
|
) -> Iterator[Callable[..., Any]]:
|
||||||
for patched_callable in self._callables.values():
|
for patched_callable in self._callables.values():
|
||||||
if not patched_callable.is_in_module(module):
|
if not patched_callable.is_in_module(module):
|
||||||
continue
|
continue
|
||||||
|
@ -128,7 +132,9 @@ class PatchedRegistry:
|
||||||
def register_attribute(self, patched: "PatchedAttribute") -> None:
|
def register_attribute(self, patched: "PatchedAttribute") -> None:
|
||||||
self._attributes.add(patched)
|
self._attributes.add(patched)
|
||||||
|
|
||||||
def get_attributes_from_module(self, module: ModuleType) -> Iterator["PatchedAttribute"]:
|
def get_attributes_from_module(
|
||||||
|
self, module: ModuleType
|
||||||
|
) -> Iterator["PatchedAttribute"]:
|
||||||
for attribute in self._attributes:
|
for attribute in self._attributes:
|
||||||
if not attribute.is_in_module(module):
|
if not attribute.is_in_module(module):
|
||||||
continue
|
continue
|
||||||
|
@ -153,11 +159,11 @@ class PatchedCallable:
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
patched: Optional[Callable[..., Any]] = None,
|
patched: Optional[Callable[..., Any]] = None,
|
||||||
original: Optional[Callable[..., Any]] = None,
|
original: Optional[Callable[..., Any]] = None,
|
||||||
reference_injections: Optional[Dict[Any, Any]] = None,
|
reference_injections: Optional[Dict[Any, Any]] = None,
|
||||||
reference_closing: Optional[Dict[Any, Any]] = None,
|
reference_closing: Optional[Dict[Any, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.patched = patched
|
self.patched = patched
|
||||||
self.original = original
|
self.original = original
|
||||||
|
@ -228,18 +234,21 @@ class ProvidersMap:
|
||||||
)
|
)
|
||||||
|
|
||||||
def resolve_provider(
|
def resolve_provider(
|
||||||
self,
|
self,
|
||||||
provider: Union[providers.Provider, str],
|
provider: Union[providers.Provider, str],
|
||||||
modifier: Optional["Modifier"] = None,
|
modifier: Optional["Modifier"] = None,
|
||||||
) -> Optional[providers.Provider]:
|
) -> Optional[providers.Provider]:
|
||||||
if isinstance(provider, providers.Delegate):
|
if isinstance(provider, providers.Delegate):
|
||||||
return self._resolve_delegate(provider)
|
return self._resolve_delegate(provider)
|
||||||
elif isinstance(provider, (
|
elif isinstance(
|
||||||
providers.ProvidedInstance,
|
provider,
|
||||||
providers.AttributeGetter,
|
(
|
||||||
providers.ItemGetter,
|
providers.ProvidedInstance,
|
||||||
providers.MethodCaller,
|
providers.AttributeGetter,
|
||||||
)):
|
providers.ItemGetter,
|
||||||
|
providers.MethodCaller,
|
||||||
|
),
|
||||||
|
):
|
||||||
return self._resolve_provided_instance(provider)
|
return self._resolve_provided_instance(provider)
|
||||||
elif isinstance(provider, providers.ConfigurationOption):
|
elif isinstance(provider, providers.ConfigurationOption):
|
||||||
return self._resolve_config_option(provider)
|
return self._resolve_config_option(provider)
|
||||||
|
@ -251,9 +260,9 @@ class ProvidersMap:
|
||||||
return self._resolve_provider(provider)
|
return self._resolve_provider(provider)
|
||||||
|
|
||||||
def _resolve_string_id(
|
def _resolve_string_id(
|
||||||
self,
|
self,
|
||||||
id: str,
|
id: str,
|
||||||
modifier: Optional["Modifier"] = None,
|
modifier: Optional["Modifier"] = None,
|
||||||
) -> Optional[providers.Provider]:
|
) -> Optional[providers.Provider]:
|
||||||
if id == self.CONTAINER_STRING_ID:
|
if id == self.CONTAINER_STRING_ID:
|
||||||
return self._container.__self__
|
return self._container.__self__
|
||||||
|
@ -270,16 +279,19 @@ class ProvidersMap:
|
||||||
return provider
|
return provider
|
||||||
|
|
||||||
def _resolve_provided_instance(
|
def _resolve_provided_instance(
|
||||||
self,
|
self,
|
||||||
original: providers.Provider,
|
original: providers.Provider,
|
||||||
) -> Optional[providers.Provider]:
|
) -> Optional[providers.Provider]:
|
||||||
modifiers = []
|
modifiers = []
|
||||||
while isinstance(original, (
|
while isinstance(
|
||||||
|
original,
|
||||||
|
(
|
||||||
providers.ProvidedInstance,
|
providers.ProvidedInstance,
|
||||||
providers.AttributeGetter,
|
providers.AttributeGetter,
|
||||||
providers.ItemGetter,
|
providers.ItemGetter,
|
||||||
providers.MethodCaller,
|
providers.MethodCaller,
|
||||||
)):
|
),
|
||||||
|
):
|
||||||
modifiers.insert(0, original)
|
modifiers.insert(0, original)
|
||||||
original = original.provides
|
original = original.provides
|
||||||
|
|
||||||
|
@ -303,8 +315,8 @@ class ProvidersMap:
|
||||||
return new
|
return new
|
||||||
|
|
||||||
def _resolve_delegate(
|
def _resolve_delegate(
|
||||||
self,
|
self,
|
||||||
original: providers.Delegate,
|
original: providers.Delegate,
|
||||||
) -> Optional[providers.Provider]:
|
) -> Optional[providers.Provider]:
|
||||||
provider = self._resolve_provider(original.provides)
|
provider = self._resolve_provider(original.provides)
|
||||||
if provider:
|
if provider:
|
||||||
|
@ -312,9 +324,9 @@ class ProvidersMap:
|
||||||
return provider
|
return provider
|
||||||
|
|
||||||
def _resolve_config_option(
|
def _resolve_config_option(
|
||||||
self,
|
self,
|
||||||
original: providers.ConfigurationOption,
|
original: providers.ConfigurationOption,
|
||||||
as_: Any = None,
|
as_: Any = None,
|
||||||
) -> Optional[providers.Provider]:
|
) -> Optional[providers.Provider]:
|
||||||
original_root = original.root
|
original_root = original.root
|
||||||
new = self._resolve_provider(original_root)
|
new = self._resolve_provider(original_root)
|
||||||
|
@ -338,8 +350,8 @@ class ProvidersMap:
|
||||||
return new
|
return new
|
||||||
|
|
||||||
def _resolve_provider(
|
def _resolve_provider(
|
||||||
self,
|
self,
|
||||||
original: providers.Provider,
|
original: providers.Provider,
|
||||||
) -> Optional[providers.Provider]:
|
) -> Optional[providers.Provider]:
|
||||||
try:
|
try:
|
||||||
return self._map[original]
|
return self._map[original]
|
||||||
|
@ -348,9 +360,9 @@ class ProvidersMap:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _create_providers_map(
|
def _create_providers_map(
|
||||||
cls,
|
cls,
|
||||||
current_container: Container,
|
current_container: Container,
|
||||||
original_container: Container,
|
original_container: Container,
|
||||||
) -> Dict[providers.Provider, providers.Provider]:
|
) -> Dict[providers.Provider, providers.Provider]:
|
||||||
current_providers = current_container.providers
|
current_providers = current_container.providers
|
||||||
current_providers["__self__"] = current_container.__self__
|
current_providers["__self__"] = current_container.__self__
|
||||||
|
@ -363,8 +375,9 @@ class ProvidersMap:
|
||||||
original_provider = original_providers[provider_name]
|
original_provider = original_providers[provider_name]
|
||||||
providers_map[original_provider] = current_provider
|
providers_map[original_provider] = current_provider
|
||||||
|
|
||||||
if isinstance(current_provider, providers.Container) \
|
if isinstance(current_provider, providers.Container) and isinstance(
|
||||||
and isinstance(original_provider, providers.Container):
|
original_provider, providers.Container
|
||||||
|
):
|
||||||
subcontainer_map = cls._create_providers_map(
|
subcontainer_map = cls._create_providers_map(
|
||||||
current_container=current_provider.container,
|
current_container=current_provider.container,
|
||||||
original_container=original_provider.container,
|
original_container=original_provider.container,
|
||||||
|
@ -390,19 +403,21 @@ class InspectFilter:
|
||||||
return werkzeug and isinstance(instance, werkzeug.local.LocalProxy)
|
return werkzeug and isinstance(instance, werkzeug.local.LocalProxy)
|
||||||
|
|
||||||
def _is_starlette_request_cls(self, instance: object) -> bool:
|
def _is_starlette_request_cls(self, instance: object) -> bool:
|
||||||
return starlette \
|
return (
|
||||||
and isinstance(instance, type) \
|
starlette
|
||||||
and _safe_is_subclass(instance, starlette.requests.Request)
|
and isinstance(instance, type)
|
||||||
|
and _safe_is_subclass(instance, starlette.requests.Request)
|
||||||
|
)
|
||||||
|
|
||||||
def _is_builtin(self, instance: object) -> bool:
|
def _is_builtin(self, instance: object) -> bool:
|
||||||
return inspect.isbuiltin(instance)
|
return inspect.isbuiltin(instance)
|
||||||
|
|
||||||
|
|
||||||
def wire( # noqa: C901
|
def wire( # noqa: C901
|
||||||
container: Container,
|
container: Container,
|
||||||
*,
|
*,
|
||||||
modules: Optional[Iterable[ModuleType]] = None,
|
modules: Optional[Iterable[ModuleType]] = None,
|
||||||
packages: Optional[Iterable[ModuleType]] = None,
|
packages: Optional[Iterable[ModuleType]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Wire container providers with provided packages and modules."""
|
"""Wire container providers with provided packages and modules."""
|
||||||
modules = [*modules] if modules else []
|
modules = [*modules] if modules else []
|
||||||
|
@ -432,18 +447,22 @@ def wire( # noqa: C901
|
||||||
else:
|
else:
|
||||||
for cls_member_name, cls_member in cls_members:
|
for cls_member_name, cls_member in cls_members:
|
||||||
if _is_marker(cls_member):
|
if _is_marker(cls_member):
|
||||||
_patch_attribute(cls, cls_member_name, cls_member, providers_map)
|
_patch_attribute(
|
||||||
|
cls, cls_member_name, cls_member, providers_map
|
||||||
|
)
|
||||||
elif _is_method(cls_member):
|
elif _is_method(cls_member):
|
||||||
_patch_method(cls, cls_member_name, cls_member, providers_map)
|
_patch_method(
|
||||||
|
cls, cls_member_name, cls_member, providers_map
|
||||||
|
)
|
||||||
|
|
||||||
for patched in _patched_registry.get_callables_from_module(module):
|
for patched in _patched_registry.get_callables_from_module(module):
|
||||||
_bind_injections(patched, providers_map)
|
_bind_injections(patched, providers_map)
|
||||||
|
|
||||||
|
|
||||||
def unwire( # noqa: C901
|
def unwire( # noqa: C901
|
||||||
*,
|
*,
|
||||||
modules: Optional[Iterable[ModuleType]] = None,
|
modules: Optional[Iterable[ModuleType]] = None,
|
||||||
packages: Optional[Iterable[ModuleType]] = None,
|
packages: Optional[Iterable[ModuleType]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Wire provided packages and modules with previous wired providers."""
|
"""Wire provided packages and modules with previous wired providers."""
|
||||||
modules = [*modules] if modules else []
|
modules = [*modules] if modules else []
|
||||||
|
@ -457,7 +476,9 @@ def unwire( # noqa: C901
|
||||||
if inspect.isfunction(member):
|
if inspect.isfunction(member):
|
||||||
_unpatch(module, name, member)
|
_unpatch(module, name, member)
|
||||||
elif inspect.isclass(member):
|
elif inspect.isclass(member):
|
||||||
for method_name, method in inspect.getmembers(member, inspect.isfunction):
|
for method_name, method in inspect.getmembers(
|
||||||
|
member, inspect.isfunction
|
||||||
|
):
|
||||||
_unpatch(member, method_name, method)
|
_unpatch(member, method_name, method)
|
||||||
|
|
||||||
for patched in _patched_registry.get_callables_from_module(module):
|
for patched in _patched_registry.get_callables_from_module(module):
|
||||||
|
@ -476,10 +497,10 @@ def inject(fn: F) -> F:
|
||||||
|
|
||||||
|
|
||||||
def _patch_fn(
|
def _patch_fn(
|
||||||
module: ModuleType,
|
module: ModuleType,
|
||||||
name: str,
|
name: str,
|
||||||
fn: Callable[..., Any],
|
fn: Callable[..., Any],
|
||||||
providers_map: ProvidersMap,
|
providers_map: ProvidersMap,
|
||||||
) -> None:
|
) -> None:
|
||||||
if not _is_patched(fn):
|
if not _is_patched(fn):
|
||||||
reference_injections, reference_closing = _fetch_reference_injections(fn)
|
reference_injections, reference_closing = _fetch_reference_injections(fn)
|
||||||
|
@ -493,14 +514,16 @@ def _patch_fn(
|
||||||
|
|
||||||
|
|
||||||
def _patch_method(
|
def _patch_method(
|
||||||
cls: Type,
|
cls: Type,
|
||||||
name: str,
|
name: str,
|
||||||
method: Callable[..., Any],
|
method: Callable[..., Any],
|
||||||
providers_map: ProvidersMap,
|
providers_map: ProvidersMap,
|
||||||
) -> None:
|
) -> None:
|
||||||
if hasattr(cls, "__dict__") \
|
if (
|
||||||
and name in cls.__dict__ \
|
hasattr(cls, "__dict__")
|
||||||
and isinstance(cls.__dict__[name], (classmethod, staticmethod)):
|
and name in cls.__dict__
|
||||||
|
and isinstance(cls.__dict__[name], (classmethod, staticmethod))
|
||||||
|
):
|
||||||
method = cls.__dict__[name]
|
method = cls.__dict__[name]
|
||||||
fn = method.__func__
|
fn = method.__func__
|
||||||
else:
|
else:
|
||||||
|
@ -521,13 +544,15 @@ def _patch_method(
|
||||||
|
|
||||||
|
|
||||||
def _unpatch(
|
def _unpatch(
|
||||||
module: ModuleType,
|
module: ModuleType,
|
||||||
name: str,
|
name: str,
|
||||||
fn: Callable[..., Any],
|
fn: Callable[..., Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
if hasattr(module, "__dict__") \
|
if (
|
||||||
and name in module.__dict__ \
|
hasattr(module, "__dict__")
|
||||||
and isinstance(module.__dict__[name], (classmethod, staticmethod)):
|
and name in module.__dict__
|
||||||
|
and isinstance(module.__dict__[name], (classmethod, staticmethod))
|
||||||
|
):
|
||||||
method = module.__dict__[name]
|
method = module.__dict__[name]
|
||||||
fn = method.__func__
|
fn = method.__func__
|
||||||
|
|
||||||
|
@ -538,10 +563,10 @@ def _unpatch(
|
||||||
|
|
||||||
|
|
||||||
def _patch_attribute(
|
def _patch_attribute(
|
||||||
member: Any,
|
member: Any,
|
||||||
name: str,
|
name: str,
|
||||||
marker: "_Marker",
|
marker: "_Marker",
|
||||||
providers_map: ProvidersMap,
|
providers_map: ProvidersMap,
|
||||||
) -> None:
|
) -> None:
|
||||||
provider = providers_map.resolve_provider(marker.provider, marker.modifier)
|
provider = providers_map.resolve_provider(marker.provider, marker.modifier)
|
||||||
if provider is None:
|
if provider is None:
|
||||||
|
@ -581,15 +606,14 @@ def _extract_marker(parameter: inspect.Parameter) -> Optional["_Marker"]:
|
||||||
|
|
||||||
|
|
||||||
def _fetch_reference_injections( # noqa: C901
|
def _fetch_reference_injections( # noqa: C901
|
||||||
fn: Callable[..., Any],
|
fn: Callable[..., Any],
|
||||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||||
# Hotfix, see:
|
# Hotfix, see:
|
||||||
# - https://github.com/ets-labs/python-dependency-injector/issues/362
|
# - https://github.com/ets-labs/python-dependency-injector/issues/362
|
||||||
# - https://github.com/ets-labs/python-dependency-injector/issues/398
|
# - https://github.com/ets-labs/python-dependency-injector/issues/398
|
||||||
if GenericAlias and any((
|
if GenericAlias and any(
|
||||||
fn is GenericAlias,
|
(fn is GenericAlias, getattr(fn, "__func__", None) is GenericAlias)
|
||||||
getattr(fn, "__func__", None) is GenericAlias
|
):
|
||||||
)):
|
|
||||||
fn = fn.__init__
|
fn = fn.__init__
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -618,7 +642,9 @@ def _fetch_reference_injections( # noqa: C901
|
||||||
return injections, closing
|
return injections, closing
|
||||||
|
|
||||||
|
|
||||||
def _locate_dependent_closing_args(provider: providers.Provider) -> Dict[str, providers.Provider]:
|
def _locate_dependent_closing_args(
|
||||||
|
provider: providers.Provider,
|
||||||
|
) -> Dict[str, providers.Provider]:
|
||||||
if not hasattr(provider, "args"):
|
if not hasattr(provider, "args"):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@ -672,8 +698,8 @@ def _fetch_modules(package):
|
||||||
if not hasattr(package, "__path__") or not hasattr(package, "__name__"):
|
if not hasattr(package, "__path__") or not hasattr(package, "__name__"):
|
||||||
return modules
|
return modules
|
||||||
for module_info in pkgutil.walk_packages(
|
for module_info in pkgutil.walk_packages(
|
||||||
path=package.__path__,
|
path=package.__path__,
|
||||||
prefix=package.__name__ + ".",
|
prefix=package.__name__ + ".",
|
||||||
):
|
):
|
||||||
module = importlib.import_module(module_info.name)
|
module = importlib.import_module(module_info.name)
|
||||||
modules.append(module)
|
modules.append(module)
|
||||||
|
@ -689,9 +715,9 @@ def _is_marker(member) -> bool:
|
||||||
|
|
||||||
|
|
||||||
def _get_patched(
|
def _get_patched(
|
||||||
fn: F,
|
fn: F,
|
||||||
reference_injections: Dict[Any, Any],
|
reference_injections: Dict[Any, Any],
|
||||||
reference_closing: Dict[Any, Any],
|
reference_closing: Dict[Any, Any],
|
||||||
) -> F:
|
) -> F:
|
||||||
patched_object = PatchedCallable(
|
patched_object = PatchedCallable(
|
||||||
original=fn,
|
original=fn,
|
||||||
|
@ -719,9 +745,11 @@ def _is_patched(fn) -> bool:
|
||||||
|
|
||||||
|
|
||||||
def _is_declarative_container(instance: Any) -> bool:
|
def _is_declarative_container(instance: Any) -> bool:
|
||||||
return (isinstance(instance, type)
|
return (
|
||||||
and getattr(instance, "__IS_CONTAINER__", False) is True
|
isinstance(instance, type)
|
||||||
and getattr(instance, "declarative_parent", None) is None)
|
and getattr(instance, "__IS_CONTAINER__", False) is True
|
||||||
|
and getattr(instance, "declarative_parent", None) is None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _safe_is_subclass(instance: Any, cls: Type) -> bool:
|
def _safe_is_subclass(instance: Any, cls: Type) -> bool:
|
||||||
|
@ -734,11 +762,10 @@ def _safe_is_subclass(instance: Any, cls: Type) -> bool:
|
||||||
class Modifier:
|
class Modifier:
|
||||||
|
|
||||||
def modify(
|
def modify(
|
||||||
self,
|
self,
|
||||||
provider: providers.ConfigurationOption,
|
provider: providers.ConfigurationOption,
|
||||||
providers_map: ProvidersMap,
|
providers_map: ProvidersMap,
|
||||||
) -> providers.Provider:
|
) -> providers.Provider: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class TypeModifier(Modifier):
|
class TypeModifier(Modifier):
|
||||||
|
@ -747,9 +774,9 @@ class TypeModifier(Modifier):
|
||||||
self.type_ = type_
|
self.type_ = type_
|
||||||
|
|
||||||
def modify(
|
def modify(
|
||||||
self,
|
self,
|
||||||
provider: providers.ConfigurationOption,
|
provider: providers.ConfigurationOption,
|
||||||
providers_map: ProvidersMap,
|
providers_map: ProvidersMap,
|
||||||
) -> providers.Provider:
|
) -> providers.Provider:
|
||||||
return provider.as_(self.type_)
|
return provider.as_(self.type_)
|
||||||
|
|
||||||
|
@ -787,9 +814,9 @@ class RequiredModifier(Modifier):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def modify(
|
def modify(
|
||||||
self,
|
self,
|
||||||
provider: providers.ConfigurationOption,
|
provider: providers.ConfigurationOption,
|
||||||
providers_map: ProvidersMap,
|
providers_map: ProvidersMap,
|
||||||
) -> providers.Provider:
|
) -> providers.Provider:
|
||||||
provider = provider.required()
|
provider = provider.required()
|
||||||
if self.type_modifier:
|
if self.type_modifier:
|
||||||
|
@ -808,9 +835,9 @@ class InvariantModifier(Modifier):
|
||||||
self.id = id
|
self.id = id
|
||||||
|
|
||||||
def modify(
|
def modify(
|
||||||
self,
|
self,
|
||||||
provider: providers.ConfigurationOption,
|
provider: providers.ConfigurationOption,
|
||||||
providers_map: ProvidersMap,
|
providers_map: ProvidersMap,
|
||||||
) -> providers.Provider:
|
) -> providers.Provider:
|
||||||
invariant_segment = providers_map.resolve_provider(self.id)
|
invariant_segment = providers_map.resolve_provider(self.id)
|
||||||
return provider[invariant_segment]
|
return provider[invariant_segment]
|
||||||
|
@ -843,9 +870,9 @@ class ProvidedInstance(Modifier):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def modify(
|
def modify(
|
||||||
self,
|
self,
|
||||||
provider: providers.Provider,
|
provider: providers.Provider,
|
||||||
providers_map: ProvidersMap,
|
providers_map: ProvidersMap,
|
||||||
) -> providers.Provider:
|
) -> providers.Provider:
|
||||||
provider = provider.provided
|
provider = provider.provided
|
||||||
for type_, value in self.segments:
|
for type_, value in self.segments:
|
||||||
|
@ -876,9 +903,9 @@ class _Marker(Generic[T], metaclass=ClassGetItemMeta):
|
||||||
__IS_MARKER__ = True
|
__IS_MARKER__ = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
provider: Union[providers.Provider, Container, str],
|
provider: Union[providers.Provider, Container, str],
|
||||||
modifier: Optional[Modifier] = None,
|
modifier: Optional[Modifier] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if _is_declarative_container(provider):
|
if _is_declarative_container(provider):
|
||||||
provider = provider.__self__
|
provider = provider.__self__
|
||||||
|
@ -894,16 +921,13 @@ class _Marker(Generic[T], metaclass=ClassGetItemMeta):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
class Provide(_Marker):
|
class Provide(_Marker): ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class Provider(_Marker):
|
class Provider(_Marker): ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class Closing(_Marker):
|
class Closing(_Marker): ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class AutoLoader:
|
class AutoLoader:
|
||||||
|
@ -953,8 +977,7 @@ class AutoLoader:
|
||||||
super().exec_module(module)
|
super().exec_module(module)
|
||||||
loader.wire_module(module)
|
loader.wire_module(module)
|
||||||
|
|
||||||
class ExtensionFileLoader(importlib.machinery.ExtensionFileLoader):
|
class ExtensionFileLoader(importlib.machinery.ExtensionFileLoader): ...
|
||||||
...
|
|
||||||
|
|
||||||
loader_details = [
|
loader_details = [
|
||||||
(SourcelessFileLoader, importlib.machinery.BYTECODE_SUFFIXES),
|
(SourcelessFileLoader, importlib.machinery.BYTECODE_SUFFIXES),
|
||||||
|
@ -1023,4 +1046,5 @@ def _get_async_patched(fn: F, patched: PatchedCallable) -> F:
|
||||||
patched.injections,
|
patched.injections,
|
||||||
patched.closing,
|
patched.closing,
|
||||||
)
|
)
|
||||||
|
|
||||||
return _patched
|
return _patched
|
||||||
|
|
|
@ -3,7 +3,9 @@ import sys
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from fastapi import FastAPI, Depends
|
from fastapi import FastAPI, Depends
|
||||||
from fastapi import Request # See: https://github.com/ets-labs/python-dependency-injector/issues/398
|
from fastapi import (
|
||||||
|
Request,
|
||||||
|
) # See: https://github.com/ets-labs/python-dependency-injector/issues/398
|
||||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||||
from dependency_injector import containers, providers
|
from dependency_injector import containers, providers
|
||||||
from dependency_injector.wiring import inject, Provide
|
from dependency_injector.wiring import inject, Provide
|
||||||
|
@ -29,17 +31,17 @@ async def index(service: Service = Depends(Provide[Container.service])):
|
||||||
result = await service.process()
|
result = await service.process()
|
||||||
return {"result": result}
|
return {"result": result}
|
||||||
|
|
||||||
@app.api_route('/annotated')
|
|
||||||
|
@app.api_route("/annotated")
|
||||||
@inject
|
@inject
|
||||||
async def annotated(service: Annotated[Service, Depends(Provide[Container.service])]):
|
async def annotated(service: Annotated[Service, Depends(Provide[Container.service])]):
|
||||||
result = await service.process()
|
result = await service.process()
|
||||||
return {'result': result}
|
return {"result": result}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/auth")
|
@app.get("/auth")
|
||||||
@inject
|
@inject
|
||||||
def read_current_user(
|
def read_current_user(credentials: HTTPBasicCredentials = Depends(security)):
|
||||||
credentials: HTTPBasicCredentials = Depends(security)
|
|
||||||
):
|
|
||||||
return {"username": credentials.username, "password": credentials.password}
|
return {"username": credentials.username, "password": credentials.password}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,7 @@ def index(service: Service = Provide[Container.service]):
|
||||||
@inject
|
@inject
|
||||||
def annotated(service: Annotated[Service, Provide[Container.service]]):
|
def annotated(service: Annotated[Service, Provide[Container.service]]):
|
||||||
result = service.process()
|
result = service.process()
|
||||||
return jsonify({'result': result})
|
return jsonify({"result": result})
|
||||||
|
|
||||||
|
|
||||||
container = Container()
|
container = Container()
|
||||||
|
|
|
@ -3,13 +3,17 @@ from pytest import fixture, mark
|
||||||
|
|
||||||
# Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir
|
# Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir
|
||||||
import os
|
import os
|
||||||
|
|
||||||
_SAMPLES_DIR = os.path.abspath(
|
_SAMPLES_DIR = os.path.abspath(
|
||||||
os.path.sep.join((
|
os.path.sep.join(
|
||||||
os.path.dirname(__file__),
|
(
|
||||||
"../samples/",
|
os.path.dirname(__file__),
|
||||||
)),
|
"../samples/",
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.path.append(_SAMPLES_DIR)
|
sys.path.append(_SAMPLES_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
@ -49,7 +53,6 @@ async def test_depends_with_annotated(async_client: AsyncClient):
|
||||||
assert response.json() == {"result": "Foo"}
|
assert response.json() == {"result": "Foo"}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@mark.asyncio
|
@mark.asyncio
|
||||||
async def test_depends_injection(async_client: AsyncClient):
|
async def test_depends_injection(async_client: AsyncClient):
|
||||||
response = await async_client.get("/auth", auth=("john_smith", "secret"))
|
response = await async_client.get("/auth", auth=("john_smith", "secret"))
|
||||||
|
|
|
@ -2,19 +2,25 @@ import json
|
||||||
|
|
||||||
# Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir
|
# Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir
|
||||||
import os
|
import os
|
||||||
|
|
||||||
_TOP_DIR = os.path.abspath(
|
_TOP_DIR = os.path.abspath(
|
||||||
os.path.sep.join((
|
os.path.sep.join(
|
||||||
os.path.dirname(__file__),
|
(
|
||||||
"../",
|
os.path.dirname(__file__),
|
||||||
)),
|
"../",
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
_SAMPLES_DIR = os.path.abspath(
|
_SAMPLES_DIR = os.path.abspath(
|
||||||
os.path.sep.join((
|
os.path.sep.join(
|
||||||
os.path.dirname(__file__),
|
(
|
||||||
"../samples/",
|
os.path.dirname(__file__),
|
||||||
)),
|
"../samples/",
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.path.append(_TOP_DIR)
|
sys.path.append(_TOP_DIR)
|
||||||
sys.path.append(_SAMPLES_DIR)
|
sys.path.append(_SAMPLES_DIR)
|
||||||
|
|
||||||
|
@ -36,6 +42,6 @@ def test_wiring_with_annotated():
|
||||||
|
|
||||||
with web.app.app_context():
|
with web.app.app_context():
|
||||||
response = client.get("/annotated")
|
response = client.get("/annotated")
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert json.loads(response.data) == {"result": "OK"}
|
assert json.loads(response.data) == {"result": "OK"}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user