Implement string imports for wiring

This commit is contained in:
Roman Mogylatov 2021-09-28 09:58:40 -04:00
parent 44343421dc
commit f436a3c42e
6 changed files with 4227 additions and 3072 deletions

File diff suppressed because it is too large Load Diff

View File

@ -43,7 +43,7 @@ class Container:
def override_providers(self, **overriding_providers: Union[Provider, Any]) -> None: ... def override_providers(self, **overriding_providers: Union[Provider, Any]) -> None: ...
def reset_last_overriding(self) -> None: ... def reset_last_overriding(self) -> None: ...
def reset_override(self) -> None: ... def reset_override(self) -> None: ...
def wire(self, modules: Optional[Iterable[Any]] = None, packages: Optional[Iterable[Any]] = None) -> None: ... def wire(self, modules: Optional[Iterable[Any]] = None, packages: Optional[Iterable[Any]] = None, from_package: Optional[str] = None) -> None: ...
def unwire(self) -> None: ... def unwire(self) -> None: ...
def init_resources(self) -> Optional[Awaitable]: ... def init_resources(self) -> Optional[Awaitable]: ...
def shutdown_resources(self) -> Optional[Awaitable]: ... def shutdown_resources(self) -> Optional[Awaitable]: ...

View File

@ -1,7 +1,9 @@
"""Containers module.""" """Containers module."""
import contextlib
import json import json
import sys import sys
import importlib
import inspect
try: try:
import asyncio import asyncio
@ -248,11 +250,22 @@ class DynamicContainer(Container):
for provider in six.itervalues(self.providers): for provider in six.itervalues(self.providers):
provider.reset_override() provider.reset_override()
def wire(self, modules=None, packages=None): def wire(self, modules=None, packages=None, from_package=None):
"""Wire container providers with provided packages and modules. """Wire container providers with provided packages and modules.
:rtype: None :rtype: None
""" """
modules = [*modules] if modules else []
packages = [*packages] if packages else []
if from_package is None and \
(_has_any_relative_string_imports(modules) or _has_any_relative_string_imports(packages)):
with contextlib.suppress(Exception):
from_package = _resolve_calling_package_name()
modules = _resolve_string_imports(modules, from_package)
packages = _resolve_string_imports(packages, from_package)
wire( wire(
container=self, container=self,
modules=modules, modules=modules,
@ -261,7 +274,6 @@ class DynamicContainer(Container):
if modules: if modules:
self.wired_to_modules.extend(modules) self.wired_to_modules.extend(modules)
if packages: if packages:
self.wired_to_packages.extend(packages) self.wired_to_packages.extend(packages)
@ -789,3 +801,27 @@ cpdef object _check_provider_type(object container, object provider):
if not isinstance(provider, container.provider_type): if not isinstance(provider, container.provider_type):
raise errors.Error('{0} can contain only {1} ' raise errors.Error('{0} can contain only {1} '
'instances'.format(container, container.provider_type)) 'instances'.format(container, container.provider_type))
cpdef bint _has_any_relative_string_imports(object modules):
for module in modules:
if not isinstance(module, str):
continue
if module.startswith("."):
return True
else:
return False
cpdef list _resolve_string_imports(object modules, object from_package):
return [
importlib.import_module(module, from_package) if isinstance(module, str) else module
for module in modules
]
cpdef object _resolve_calling_package_name():
stack = inspect.stack()
pre_last_frame = stack[0]
module = inspect.getmodule(pre_last_frame[0])
return module.__package__

View File

@ -322,20 +322,15 @@ class InspectFilter:
def wire( # noqa: C901 def wire( # noqa: C901
container: Container, container: Container,
*, *,
modules: Optional[Iterable[Union[ModuleType, str]]] = None, modules: Optional[Iterable[ModuleType]] = None,
packages: Optional[Iterable[Union[ModuleType, str]]] = None, packages: Optional[Iterable[ModuleType]] = None,
from_package: Optional[str] = None,
) -> None: ) -> None:
"""Wire container providers with provided packages and modules.""" """Wire container providers with provided packages and modules."""
if not modules: modules = [*modules] if modules else []
modules = []
modules = _resolve_string_imports(modules, from_package)
if not packages: if packages:
packages = [] for package in packages:
packages = _resolve_string_imports(packages, from_package) modules.extend(_fetch_modules(package))
for package in packages:
modules.extend(_fetch_modules(package))
providers_map = ProvidersMap(container) providers_map = ProvidersMap(container)
@ -372,8 +367,7 @@ def unwire( # noqa: C901
packages: Optional[Iterable[ModuleType]] = None, packages: Optional[Iterable[ModuleType]] = None,
) -> None: ) -> None:
"""Wire provided packages and modules with previous wired providers.""" """Wire provided packages and modules with previous wired providers."""
if not modules: modules = [*modules] if modules else []
modules = []
if packages: if packages:
for package in packages: for package in packages:
@ -677,16 +671,6 @@ def _is_declarative_container(instance: Any) -> bool:
and getattr(instance, 'declarative_parent', None) is None) and getattr(instance, 'declarative_parent', None) is None)
def _resolve_string_imports(
modules: Optional[Iterable[Union[ModuleType, str]]],
from_package: Optional[str],
) -> List[ModuleType]:
return [
importlib.import_module(module, from_package) if isinstance(module, str) else module
for module in modules
]
class Modifier: class Modifier:
def modify( def modify(

View File

@ -0,0 +1,8 @@
"""Wiring sample package."""
def wire_with_relative_string_names(container):
container.wire(
modules=[".module"],
packages=[".package"],
)

View File

@ -36,6 +36,7 @@ from asyncutils import AsyncTestCase
from wiringsamples import module, package from wiringsamples import module, package
from wiringsamples.service import Service from wiringsamples.service import Service
from wiringsamples.container import Container, SubContainer from wiringsamples.container import Container, SubContainer
from wiringsamples.wire_relative_string_names import wire_with_relative_string_names
class WiringTest(unittest.TestCase): class WiringTest(unittest.TestCase):
@ -314,6 +315,52 @@ class WiringTest(unittest.TestCase):
self.assertIsInstance(service, Service) self.assertIsInstance(service, Service)
class WiringWithStringModuleAndPackageNamesTest(unittest.TestCase):
container: Container
def setUp(self) -> None:
self.container = Container()
self.addCleanup(self.container.unwire)
def test_absolute_names(self):
self.container.wire(
modules=["wiringsamples.module"],
packages=["wiringsamples.package"],
)
service = module.test_function()
self.assertIsInstance(service, Service)
from wiringsamples.package.subpackage.submodule import test_function
service = test_function()
self.assertIsInstance(service, Service)
def test_relative_names_with_explicit_package(self):
self.container.wire(
modules=[".module"],
packages=[".package"],
from_package="wiringsamples",
)
service = module.test_function()
self.assertIsInstance(service, Service)
from wiringsamples.package.subpackage.submodule import test_function
service = test_function()
self.assertIsInstance(service, Service)
def test_relative_names_with_auto_package(self):
wire_with_relative_string_names(self.container)
service = module.test_function()
self.assertIsInstance(service, Service)
from wiringsamples.package.subpackage.submodule import test_function
service = test_function()
self.assertIsInstance(service, Service)
class ModuleAsPackageTest(unittest.TestCase): class ModuleAsPackageTest(unittest.TestCase):
def setUp(self): def setUp(self):