Add unwire functionality

This commit is contained in:
Roman Mogylatov 2020-09-19 15:34:39 -04:00
parent cae99da84f
commit f3619d696f
4 changed files with 1650 additions and 1140 deletions

File diff suppressed because it is too large Load Diff

View File

@ -12,11 +12,14 @@ from .providers cimport (
if sys.version_info[:2] >= (3, 6):
from .wiring import wire
from .wiring import wire, unwire
else:
def wire(*args, **kwargs):
raise NotADirectoryError('Wiring requires Python 3.6 or above')
def unwire(*args, **kwargs):
raise NotADirectoryError('Wiring requires Python 3.6 or above')
class DynamicContainer(object):
"""Dynamic inversion of control container.
@ -55,9 +58,11 @@ class DynamicContainer(object):
:rtype: None
"""
self.provider_type = Provider
self.providers = dict()
self.providers = {}
self.overridden = tuple()
self.declarative_parent = None
self.wired_to_modules = []
self.wired_to_packages = []
super(DynamicContainer, self).__init__()
def __deepcopy__(self, memo):
@ -196,7 +201,7 @@ class DynamicContainer(object):
def wire(self, modules=None, packages=None):
"""Wire container providers with provided packages and modules by name.
"""Wire container providers with provided packages and modules.
:rtype: None
"""
@ -206,6 +211,23 @@ class DynamicContainer(object):
packages=packages,
)
if modules:
self.wired_to_modules.extend(modules)
if packages:
self.wired_to_packages.extend(packages)
def unwire(self):
"""Unwire container providers from previously wired packages and modules."""
unwire(
modules=self.wired_to_modules,
packages=self.wired_to_packages,
)
self.wired_to_modules.clear()
self.wired_to_packages.clear()
class DeclarativeContainerMetaClass(type):
"""Declarative inversion of control container meta class."""

View File

@ -27,7 +27,7 @@ def wire(
modules: Optional[Iterable[ModuleType]] = None,
packages: Optional[Iterable[ModuleType]] = None,
) -> None:
"""Wire container providers with provided packages and modules by name."""
"""Wire container providers with provided packages and modules."""
if not modules:
modules = []
@ -43,6 +43,27 @@ def wire(
_patch_cls(member, container)
def unwire(
*,
modules: Optional[Iterable[ModuleType]] = None,
packages: Optional[Iterable[ModuleType]] = None,
) -> None:
"""Wire provided packages and modules with previous wired providers."""
if not modules:
modules = []
if packages:
for package in packages:
modules.extend(_fetch_modules(package))
for module in modules:
for name, member in inspect.getmembers(module):
if inspect.isfunction(member):
_unpatch_fn(module, name, member)
elif inspect.isclass(member):
_unpatch_cls(member,)
def _patch_cls(
cls: Type[Any],
container: AnyContainer,
@ -50,14 +71,21 @@ def _patch_cls(
if not hasattr(cls, '__init__'):
return
init_method = getattr(cls, '__init__')
injections = _resolve_injections(init_method, container)
if not injections:
return
setattr(cls, '__init__', _patch_with_injections(init_method, injections))
def _unpatch_cls(cls: Type[Any]) -> None:
if not hasattr(cls, '__init__'):
return
init_method = getattr(cls, '__init__')
if not _is_patched(init_method):
return
setattr(cls, '__init__', _get_original_from_patched(init_method))
def _patch_fn(
module: ModuleType,
name: str,
@ -67,10 +95,19 @@ def _patch_fn(
injections = _resolve_injections(fn, container)
if not injections:
return
setattr(module, name, _patch_with_injections(fn, injections))
def _unpatch_fn(
module: ModuleType,
name: str,
fn: Callable[..., Any],
) -> None:
if not _is_patched(fn):
return
setattr(module, name, _get_original_from_patched(fn))
def _resolve_injections(fn: Callable[..., Any], container: AnyContainer) -> Dict[str, Any]:
config = _resolve_container_config(container)
@ -148,9 +185,22 @@ def _patch_with_injections(fn, injections):
to_inject.update(kwargs)
return fn(*args, **to_inject)
_patched.__wired__ = True
_patched.__original__ = fn
_patched.__injections__ = injections
return _patched
def _is_patched(fn):
return getattr(fn, '__wired__', False) is True
def _get_original_from_patched(fn):
return getattr(fn, '__original__')
class ClassGetItemMeta(GenericMeta):
def __getitem__(cls, item):
# Spike for Python 3.6

View File

@ -10,13 +10,13 @@ class WiringTest(unittest.TestCase):
container: Container
@classmethod
def setUpClass(cls) -> None:
cls.container = Container(config={'a': {'b': {'c': 10}}})
cls.container.wire(
def setUp(self) -> None:
self.container = Container(config={'a': {'b': {'c': 10}}})
self.container.wire(
modules=[module],
packages=[package],
)
self.addCleanup(self.container.unwire)
def test_package_lookup(self):
from .package.subpackage.submodule import test_function