mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2025-07-04 12:23:14 +03:00
Add implementation
This commit is contained in:
parent
c787ac2f63
commit
1d28e62a93
|
@ -20,6 +20,7 @@ from typing import (
|
|||
TypeVar,
|
||||
Type,
|
||||
Union,
|
||||
Set,
|
||||
cast,
|
||||
)
|
||||
|
||||
|
@ -82,22 +83,53 @@ F = TypeVar('F', bound=Callable[..., Any])
|
|||
Container = Any
|
||||
|
||||
|
||||
class Registry:
|
||||
class PatchedRegistry:
|
||||
|
||||
def __init__(self):
|
||||
self._storage = set()
|
||||
self._callables: Set[Callable[..., Any]] = set()
|
||||
self._attributes: Set[PatchedAttribute] = set()
|
||||
|
||||
def add(self, patched: Callable[..., Any]) -> None:
|
||||
self._storage.add(patched)
|
||||
def add_callable(self, patched: Callable[..., Any]) -> None:
|
||||
self._callables.add(patched)
|
||||
|
||||
def get_from_module(self, module: ModuleType) -> Iterator[Callable[..., Any]]:
|
||||
for patched in self._storage:
|
||||
def get_callables_from_module(self, module: ModuleType) -> Iterator[Callable[..., Any]]:
|
||||
for patched in self._callables:
|
||||
if patched.__module__ != module.__name__:
|
||||
continue
|
||||
yield patched
|
||||
|
||||
def add_attribute(self, patched: 'PatchedAttribute'):
|
||||
self._attributes.add(patched)
|
||||
|
||||
_patched_registry = Registry()
|
||||
def get_attributes_from_module(self, module: ModuleType) -> Iterator['PatchedAttribute']:
|
||||
for attribute in self._attributes:
|
||||
if not attribute.is_in_module(module):
|
||||
continue
|
||||
yield attribute
|
||||
|
||||
def clear_module_attributes(self, module: ModuleType):
|
||||
for attribute in self._attributes.copy():
|
||||
if not attribute.is_in_module(module):
|
||||
continue
|
||||
self._attributes.remove(attribute)
|
||||
|
||||
|
||||
class PatchedAttribute:
|
||||
|
||||
def __init__(self, member: Any, name: str, marker: '_Marker'):
|
||||
self.member = member
|
||||
self.name = name
|
||||
self.marker = marker
|
||||
|
||||
@property
|
||||
def module_name(self) -> str:
|
||||
if isinstance(self.member, ModuleType):
|
||||
return self.member.__name__
|
||||
else:
|
||||
return self.member.__module__
|
||||
|
||||
def is_in_module(self, module: ModuleType) -> bool:
|
||||
return self.module_name == module.__name__
|
||||
|
||||
|
||||
class ProvidersMap:
|
||||
|
@ -278,9 +310,6 @@ class InspectFilter:
|
|||
and issubclass(instance, starlette.requests.Request)
|
||||
|
||||
|
||||
inspect_filter = InspectFilter()
|
||||
|
||||
|
||||
def wire( # noqa: C901
|
||||
container: Container,
|
||||
*,
|
||||
|
@ -301,16 +330,23 @@ def wire( # noqa: C901
|
|||
providers_map = ProvidersMap(container)
|
||||
|
||||
for module in modules:
|
||||
for name, member in inspect.getmembers(module):
|
||||
if inspect_filter.is_excluded(member):
|
||||
for member_name, member in inspect.getmembers(module):
|
||||
if _inspect_filter.is_excluded(member):
|
||||
continue
|
||||
if inspect.isfunction(member):
|
||||
_patch_fn(module, name, member, providers_map)
|
||||
elif inspect.isclass(member):
|
||||
for method_name, method in inspect.getmembers(member, _is_method):
|
||||
_patch_method(member, method_name, method, providers_map)
|
||||
|
||||
for patched in _patched_registry.get_from_module(module):
|
||||
if _is_marker(member):
|
||||
_patch_attribute(module, member_name, member, providers_map)
|
||||
elif inspect.isfunction(member):
|
||||
_patch_fn(module, member_name, member, providers_map)
|
||||
elif inspect.isclass(member):
|
||||
cls = member
|
||||
for cls_member_name, cls_member in inspect.getmembers(cls):
|
||||
if _is_marker(cls_member):
|
||||
_patch_attribute(cls, cls_member_name, cls_member, providers_map)
|
||||
elif _is_method(cls_member):
|
||||
_patch_method(cls, cls_member_name, cls_member, providers_map)
|
||||
|
||||
for patched in _patched_registry.get_callables_from_module(module):
|
||||
_bind_injections(patched, providers_map)
|
||||
|
||||
|
||||
|
@ -335,15 +371,19 @@ def unwire(
|
|||
for method_name, method in inspect.getmembers(member, inspect.isfunction):
|
||||
_unpatch(member, method_name, method)
|
||||
|
||||
for patched in _patched_registry.get_from_module(module):
|
||||
for patched in _patched_registry.get_callables_from_module(module):
|
||||
_unbind_injections(patched)
|
||||
|
||||
for patched_attribute in _patched_registry.get_attributes_from_module(module):
|
||||
_unpatch_attribute(patched_attribute)
|
||||
_patched_registry.clear_module_attributes(module)
|
||||
|
||||
|
||||
def inject(fn: F) -> F:
|
||||
"""Decorate callable with injecting decorator."""
|
||||
reference_injections, reference_closing = _fetch_reference_injections(fn)
|
||||
patched = _get_patched(fn, reference_injections, reference_closing)
|
||||
_patched_registry.add(patched)
|
||||
_patched_registry.add_callable(patched)
|
||||
return cast(F, patched)
|
||||
|
||||
|
||||
|
@ -358,7 +398,7 @@ def _patch_fn(
|
|||
if not reference_injections:
|
||||
return
|
||||
fn = _get_patched(fn, reference_injections, reference_closing)
|
||||
_patched_registry.add(fn)
|
||||
_patched_registry.add_callable(fn)
|
||||
|
||||
_bind_injections(fn, providers_map)
|
||||
|
||||
|
@ -384,7 +424,7 @@ def _patch_method(
|
|||
if not reference_injections:
|
||||
return
|
||||
fn = _get_patched(fn, reference_injections, reference_closing)
|
||||
_patched_registry.add(fn)
|
||||
_patched_registry.add_callable(fn)
|
||||
|
||||
_bind_injections(fn, providers_map)
|
||||
|
||||
|
@ -411,6 +451,31 @@ def _unpatch(
|
|||
_unbind_injections(fn)
|
||||
|
||||
|
||||
def _patch_attribute(
|
||||
member: Any,
|
||||
name: str,
|
||||
marker: '_Marker',
|
||||
providers_map: ProvidersMap,
|
||||
) -> None:
|
||||
provider = providers_map.resolve_provider(marker.provider, marker.modifier)
|
||||
if provider is None:
|
||||
return
|
||||
|
||||
_patched_registry.add_attribute(PatchedAttribute(member, name, marker))
|
||||
|
||||
if isinstance(marker, Provide):
|
||||
instance = provider()
|
||||
setattr(member, name, instance)
|
||||
elif isinstance(marker, Provider):
|
||||
setattr(member, name, provider)
|
||||
else:
|
||||
raise Exception(f'Unknown type of marker {marker}')
|
||||
|
||||
|
||||
def _unpatch_attribute(patched: PatchedAttribute) -> None:
|
||||
setattr(patched.member, patched.name, patched.marker)
|
||||
|
||||
|
||||
def _fetch_reference_injections(
|
||||
fn: Callable[..., Any],
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
|
@ -484,6 +549,10 @@ def _is_method(member):
|
|||
return inspect.ismethod(member) or inspect.isfunction(member)
|
||||
|
||||
|
||||
def _is_marker(member):
|
||||
return isinstance(member, Provide) or isinstance(member, Provider)
|
||||
|
||||
|
||||
def _get_patched(fn, reference_injections, reference_closing):
|
||||
if inspect.iscoroutinefunction(fn):
|
||||
patched = _get_async_patched(fn)
|
||||
|
@ -825,9 +894,6 @@ class AutoLoader:
|
|||
importlib.invalidate_caches()
|
||||
|
||||
|
||||
_loader = AutoLoader()
|
||||
|
||||
|
||||
def register_loader_containers(*containers: Container) -> None:
|
||||
"""Register containers in auto-wiring module loader."""
|
||||
_loader.register_containers(*containers)
|
||||
|
@ -851,3 +917,8 @@ def uninstall_loader() -> None:
|
|||
def is_loader_installed() -> bool:
|
||||
"""Check if auto-wiring module loader hook is installed."""
|
||||
return _loader.installed
|
||||
|
||||
|
||||
_patched_registry = PatchedRegistry()
|
||||
_inspect_filter = InspectFilter()
|
||||
_loader = AutoLoader()
|
||||
|
|
Loading…
Reference in New Issue
Block a user