mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2025-07-18 03:52:19 +03:00
Add wiring registry
This commit is contained in:
parent
06d865c7b7
commit
4057a79cf8
|
@ -6,7 +6,7 @@ import importlib
|
||||||
import pkgutil
|
import pkgutil
|
||||||
import sys
|
import sys
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import Optional, Iterable, Callable, Any, Tuple, Dict, Generic, TypeVar, Type, cast
|
from typing import Optional, Iterable, Iterator, Callable, Any, Tuple, Dict, Generic, TypeVar, Type, cast
|
||||||
|
|
||||||
if sys.version_info < (3, 7):
|
if sys.version_info < (3, 7):
|
||||||
from typing import GenericMeta
|
from typing import GenericMeta
|
||||||
|
@ -32,6 +32,24 @@ F = TypeVar('F', bound=Callable[..., Any])
|
||||||
Container = Any
|
Container = Any
|
||||||
|
|
||||||
|
|
||||||
|
class Registry:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._storage = set()
|
||||||
|
|
||||||
|
def add(self, patched: Callable[..., Any]) -> None:
|
||||||
|
self._storage.add(patched)
|
||||||
|
|
||||||
|
def get_from_module(self, module: ModuleType) -> Iterator[Callable[..., Any]]:
|
||||||
|
for patched in self._storage:
|
||||||
|
if patched.__module__ != module.__name__:
|
||||||
|
continue
|
||||||
|
yield patched
|
||||||
|
|
||||||
|
|
||||||
|
_patched_registry = Registry()
|
||||||
|
|
||||||
|
|
||||||
class ProvidersMap:
|
class ProvidersMap:
|
||||||
|
|
||||||
def __init__(self, container):
|
def __init__(self, container):
|
||||||
|
@ -181,6 +199,9 @@ def wire(
|
||||||
for method_name, method in inspect.getmembers(member, _is_method):
|
for method_name, method in inspect.getmembers(member, _is_method):
|
||||||
_patch_method(member, method_name, method, providers_map)
|
_patch_method(member, method_name, method, providers_map)
|
||||||
|
|
||||||
|
for patched in _patched_registry.get_from_module(module):
|
||||||
|
_bind_injections(patched, providers_map)
|
||||||
|
|
||||||
|
|
||||||
def unwire(
|
def unwire(
|
||||||
*,
|
*,
|
||||||
|
@ -203,11 +224,15 @@ def unwire(
|
||||||
for method_name, method in inspect.getmembers(member, inspect.isfunction):
|
for method_name, method in inspect.getmembers(member, inspect.isfunction):
|
||||||
_unpatch(member, method_name, method)
|
_unpatch(member, method_name, method)
|
||||||
|
|
||||||
|
for patched in _patched_registry.get_from_module(module):
|
||||||
|
_unbind_injections(patched)
|
||||||
|
|
||||||
|
|
||||||
def inject(fn: F) -> F:
|
def inject(fn: F) -> F:
|
||||||
"""Decorate callable with injecting decorator."""
|
"""Decorate callable with injecting decorator."""
|
||||||
reference_injections, reference_closing = _fetch_reference_injections(fn)
|
reference_injections, reference_closing = _fetch_reference_injections(fn)
|
||||||
patched = _get_patched(fn, reference_injections, reference_closing)
|
patched = _get_patched(fn, reference_injections, reference_closing)
|
||||||
|
_patched_registry.add(patched)
|
||||||
return cast(F, patched)
|
return cast(F, patched)
|
||||||
|
|
||||||
|
|
||||||
|
@ -222,6 +247,7 @@ def _patch_fn(
|
||||||
if not reference_injections:
|
if not reference_injections:
|
||||||
return
|
return
|
||||||
fn = _get_patched(fn, reference_injections, reference_closing)
|
fn = _get_patched(fn, reference_injections, reference_closing)
|
||||||
|
_patched_registry.add(fn)
|
||||||
|
|
||||||
_bind_injections(fn, providers_map)
|
_bind_injections(fn, providers_map)
|
||||||
|
|
||||||
|
@ -247,6 +273,7 @@ def _patch_method(
|
||||||
if not reference_injections:
|
if not reference_injections:
|
||||||
return
|
return
|
||||||
fn = _get_patched(fn, reference_injections, reference_closing)
|
fn = _get_patched(fn, reference_injections, reference_closing)
|
||||||
|
_patched_registry.add(fn)
|
||||||
|
|
||||||
_bind_injections(fn, providers_map)
|
_bind_injections(fn, providers_map)
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
from dependency_injector.wiring import Provide, Provider
|
from dependency_injector.wiring import inject, Provide, Provider
|
||||||
|
|
||||||
from .container import Container, SubContainer
|
from .container import Container, SubContainer
|
||||||
from .service import Service
|
from .service import Service
|
||||||
|
@ -65,3 +65,17 @@ def test_provide_from_different_containers(
|
||||||
some_value: int = Provide[SubContainer.int_object],
|
some_value: int = Provide[SubContainer.int_object],
|
||||||
):
|
):
|
||||||
return service, some_value
|
return service, some_value
|
||||||
|
|
||||||
|
|
||||||
|
class ClassDecorator:
|
||||||
|
def __init__(self, fn):
|
||||||
|
self._fn = fn
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self._fn(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@ClassDecorator
|
||||||
|
@inject
|
||||||
|
def test_class_decorator(service: Service = Provide[Container.service]):
|
||||||
|
return service
|
||||||
|
|
|
@ -226,6 +226,10 @@ class WiringTest(unittest.TestCase):
|
||||||
self.assertEqual(result_2.init_counter, 0)
|
self.assertEqual(result_2.init_counter, 0)
|
||||||
self.assertEqual(result_2.shutdown_counter, 0)
|
self.assertEqual(result_2.shutdown_counter, 0)
|
||||||
|
|
||||||
|
def test_class_decorator(self):
|
||||||
|
service = module.test_class_decorator()
|
||||||
|
self.assertIsInstance(service, Service)
|
||||||
|
|
||||||
|
|
||||||
class WiringAndFastAPITest(unittest.TestCase):
|
class WiringAndFastAPITest(unittest.TestCase):
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user