Implement wiring autoloader

This commit is contained in:
Roman Mogylatov 2021-01-28 19:18:47 -05:00
parent 9225f9dcd6
commit 41e18d2c89
2 changed files with 136 additions and 1 deletions

View File

@ -4,6 +4,7 @@ import asyncio
import functools import functools
import inspect import inspect
import importlib import importlib
import importlib.machinery
import pkgutil import pkgutil
import sys import sys
from types import ModuleType from types import ModuleType
@ -52,6 +53,11 @@ __all__ = (
'Provide', 'Provide',
'Provider', 'Provider',
'Closing', 'Closing',
'register_loader_containers',
'unregister_loader_containers',
'install_loader',
'uninstall_loader',
'is_loader_installed',
) )
T = TypeVar('T') T = TypeVar('T')
@ -535,3 +541,98 @@ class Provider(_Marker):
class Closing(_Marker): class Closing(_Marker):
... ...
class AutoLoader:
"""Auto-wiring module loader.
Automatically wire containers when modules are imported.
"""
def __init__(self):
self.containers = []
self._path_hook = None
def register_containers(self, *containers):
self.containers.extend(containers)
if not self.installed:
self.install()
def unregister_containers(self, *containers):
for container in containers:
self.containers.remove(container)
if not self.containers:
self.uninstall()
def wire_module(self, module):
for container in self.containers:
container.wire(modules=[module])
@property
def installed(self):
return self._path_hook is not None
def install(self):
if self.installed:
return
loader = self
class SourcelessFileLoader(importlib.machinery.SourcelessFileLoader):
def exec_module(self, module):
super().exec_module(module)
loader.wire_module(module)
class SourceFileLoader(importlib.machinery.SourceFileLoader):
def exec_module(self, module):
super().exec_module(module)
loader.wire_module(module)
loader_details = [
(SourcelessFileLoader, importlib.machinery.BYTECODE_SUFFIXES),
(SourceFileLoader, importlib.machinery.SOURCE_SUFFIXES),
]
self._path_hook = importlib.machinery.FileFinder.path_hook(*loader_details)
sys.path_hooks.insert(0, self._path_hook)
sys.path_importer_cache.clear()
importlib.invalidate_caches()
def uninstall(self):
if not self.installed:
return
sys.path_hooks.remove(self._path_hook)
sys.path_importer_cache.clear()
importlib.invalidate_caches()
_loader = AutoLoader()
def register_loader_containers(*containers: Container) -> None:
"""Register containers in auto-wiring module loader."""
_loader.register_containers(*containers)
def unregister_loader_containers(*containers: Container) -> None:
"""Unregister containers from auto-wiring module loader."""
_loader.unregister_containers(*containers)
def install_loader() -> None:
"""Install auto-wiring module loader hook."""
_loader.install()
def uninstall_loader() -> None:
"""Uninstall auto-wiring module loader hook."""
_loader.uninstall()
def is_loader_installed() -> bool:
"""Check if auto-wiring module loader hook is installed."""
return _loader.installed

View File

@ -1,7 +1,15 @@
import contextlib
from decimal import Decimal from decimal import Decimal
import importlib
import unittest import unittest
from dependency_injector.wiring import wire, Provide, Closing from dependency_injector.wiring import (
wire,
Provide,
Closing,
register_loader_containers,
unregister_loader_containers,
)
from dependency_injector import errors from dependency_injector import errors
# Runtime import to avoid syntax errors in samples on Python < 3.5 # Runtime import to avoid syntax errors in samples on Python < 3.5
@ -367,3 +375,29 @@ class WiringAsyncInjectionsTest(AsyncTestCase):
self.assertIs(resource2, asyncinjections.resource2) self.assertIs(resource2, asyncinjections.resource2)
self.assertEqual(asyncinjections.resource2.init_counter, 2) self.assertEqual(asyncinjections.resource2.init_counter, 2)
self.assertEqual(asyncinjections.resource2.shutdown_counter, 2) self.assertEqual(asyncinjections.resource2.shutdown_counter, 2)
class AutoLoaderTest(unittest.TestCase):
container: Container
def setUp(self) -> None:
self.container = Container(config={'a': {'b': {'c': 10}}})
importlib.reload(module)
def tearDown(self) -> None:
with contextlib.suppress(ValueError):
unregister_loader_containers(self.container)
self.container.unwire()
@classmethod
def tearDownClass(cls) -> None:
importlib.reload(module)
def test_register_container(self):
register_loader_containers(self.container)
importlib.reload(module)
service = module.test_function()
self.assertIsInstance(service, Service)