Add wiring registry

This commit is contained in:
Roman Mogylatov 2020-11-15 15:17:13 -05:00
parent 06d865c7b7
commit 4057a79cf8
3 changed files with 47 additions and 2 deletions

View File

@ -6,7 +6,7 @@ import importlib
import pkgutil
import sys
from types import ModuleType
from typing import Optional, Iterable, Callable, Any, Tuple, Dict, Generic, TypeVar, Type, cast
from typing import Optional, Iterable, Iterator, Callable, Any, Tuple, Dict, Generic, TypeVar, Type, cast
if sys.version_info < (3, 7):
from typing import GenericMeta
@ -32,6 +32,24 @@ F = TypeVar('F', bound=Callable[..., Any])
Container = Any
class Registry:
def __init__(self):
self._storage = set()
def add(self, patched: Callable[..., Any]) -> None:
self._storage.add(patched)
def get_from_module(self, module: ModuleType) -> Iterator[Callable[..., Any]]:
for patched in self._storage:
if patched.__module__ != module.__name__:
continue
yield patched
_patched_registry = Registry()
class ProvidersMap:
def __init__(self, container):
@ -181,6 +199,9 @@ def wire(
for method_name, method in inspect.getmembers(member, _is_method):
_patch_method(member, method_name, method, providers_map)
for patched in _patched_registry.get_from_module(module):
_bind_injections(patched, providers_map)
def unwire(
*,
@ -203,11 +224,15 @@ def unwire(
for method_name, method in inspect.getmembers(member, inspect.isfunction):
_unpatch(member, method_name, method)
for patched in _patched_registry.get_from_module(module):
_unbind_injections(patched)
def inject(fn: F) -> F:
"""Decorate callable with injecting decorator."""
reference_injections, reference_closing = _fetch_reference_injections(fn)
patched = _get_patched(fn, reference_injections, reference_closing)
_patched_registry.add(patched)
return cast(F, patched)
@ -222,6 +247,7 @@ def _patch_fn(
if not reference_injections:
return
fn = _get_patched(fn, reference_injections, reference_closing)
_patched_registry.add(fn)
_bind_injections(fn, providers_map)
@ -247,6 +273,7 @@ def _patch_method(
if not reference_injections:
return
fn = _get_patched(fn, reference_injections, reference_closing)
_patched_registry.add(fn)
_bind_injections(fn, providers_map)

View File

@ -3,7 +3,7 @@
from decimal import Decimal
from typing import Callable
from dependency_injector.wiring import Provide, Provider
from dependency_injector.wiring import inject, Provide, Provider
from .container import Container, SubContainer
from .service import Service
@ -65,3 +65,17 @@ def test_provide_from_different_containers(
some_value: int = Provide[SubContainer.int_object],
):
return service, some_value
class ClassDecorator:
def __init__(self, fn):
self._fn = fn
def __call__(self, *args, **kwargs):
return self._fn(*args, **kwargs)
@ClassDecorator
@inject
def test_class_decorator(service: Service = Provide[Container.service]):
return service

View File

@ -226,6 +226,10 @@ class WiringTest(unittest.TestCase):
self.assertEqual(result_2.init_counter, 0)
self.assertEqual(result_2.shutdown_counter, 0)
def test_class_decorator(self):
service = module.test_class_decorator()
self.assertIsInstance(service, Service)
class WiringAndFastAPITest(unittest.TestCase):