Add wiring registry

This commit is contained in:
Roman Mogylatov 2020-11-15 15:17:13 -05:00
parent 06d865c7b7
commit 4057a79cf8
3 changed files with 47 additions and 2 deletions

View File

@ -6,7 +6,7 @@ import importlib
import pkgutil import pkgutil
import sys import sys
from types import ModuleType from types import ModuleType
from typing import Optional, Iterable, Callable, Any, Tuple, Dict, Generic, TypeVar, Type, cast from typing import Optional, Iterable, Iterator, Callable, Any, Tuple, Dict, Generic, TypeVar, Type, cast
if sys.version_info < (3, 7): if sys.version_info < (3, 7):
from typing import GenericMeta from typing import GenericMeta
@ -32,6 +32,24 @@ F = TypeVar('F', bound=Callable[..., Any])
Container = Any Container = Any
class Registry:
def __init__(self):
self._storage = set()
def add(self, patched: Callable[..., Any]) -> None:
self._storage.add(patched)
def get_from_module(self, module: ModuleType) -> Iterator[Callable[..., Any]]:
for patched in self._storage:
if patched.__module__ != module.__name__:
continue
yield patched
_patched_registry = Registry()
class ProvidersMap: class ProvidersMap:
def __init__(self, container): def __init__(self, container):
@ -181,6 +199,9 @@ def wire(
for method_name, method in inspect.getmembers(member, _is_method): for method_name, method in inspect.getmembers(member, _is_method):
_patch_method(member, method_name, method, providers_map) _patch_method(member, method_name, method, providers_map)
for patched in _patched_registry.get_from_module(module):
_bind_injections(patched, providers_map)
def unwire( def unwire(
*, *,
@ -203,11 +224,15 @@ def unwire(
for method_name, method in inspect.getmembers(member, inspect.isfunction): for method_name, method in inspect.getmembers(member, inspect.isfunction):
_unpatch(member, method_name, method) _unpatch(member, method_name, method)
for patched in _patched_registry.get_from_module(module):
_unbind_injections(patched)
def inject(fn: F) -> F: def inject(fn: F) -> F:
"""Decorate callable with injecting decorator.""" """Decorate callable with injecting decorator."""
reference_injections, reference_closing = _fetch_reference_injections(fn) reference_injections, reference_closing = _fetch_reference_injections(fn)
patched = _get_patched(fn, reference_injections, reference_closing) patched = _get_patched(fn, reference_injections, reference_closing)
_patched_registry.add(patched)
return cast(F, patched) return cast(F, patched)
@ -222,6 +247,7 @@ def _patch_fn(
if not reference_injections: if not reference_injections:
return return
fn = _get_patched(fn, reference_injections, reference_closing) fn = _get_patched(fn, reference_injections, reference_closing)
_patched_registry.add(fn)
_bind_injections(fn, providers_map) _bind_injections(fn, providers_map)
@ -247,6 +273,7 @@ def _patch_method(
if not reference_injections: if not reference_injections:
return return
fn = _get_patched(fn, reference_injections, reference_closing) fn = _get_patched(fn, reference_injections, reference_closing)
_patched_registry.add(fn)
_bind_injections(fn, providers_map) _bind_injections(fn, providers_map)

View File

@ -3,7 +3,7 @@
from decimal import Decimal from decimal import Decimal
from typing import Callable from typing import Callable
from dependency_injector.wiring import Provide, Provider from dependency_injector.wiring import inject, Provide, Provider
from .container import Container, SubContainer from .container import Container, SubContainer
from .service import Service from .service import Service
@ -65,3 +65,17 @@ def test_provide_from_different_containers(
some_value: int = Provide[SubContainer.int_object], some_value: int = Provide[SubContainer.int_object],
): ):
return service, some_value return service, some_value
class ClassDecorator:
def __init__(self, fn):
self._fn = fn
def __call__(self, *args, **kwargs):
return self._fn(*args, **kwargs)
@ClassDecorator
@inject
def test_class_decorator(service: Service = Provide[Container.service]):
return service

View File

@ -226,6 +226,10 @@ class WiringTest(unittest.TestCase):
self.assertEqual(result_2.init_counter, 0) self.assertEqual(result_2.init_counter, 0)
self.assertEqual(result_2.shutdown_counter, 0) self.assertEqual(result_2.shutdown_counter, 0)
def test_class_decorator(self):
service = module.test_class_decorator()
self.assertIsInstance(service, Service)
class WiringAndFastAPITest(unittest.TestCase): class WiringAndFastAPITest(unittest.TestCase):