Add typing improvements in wiring module

This commit is contained in:
Roman Mogylatov 2022-07-23 23:09:31 -04:00
parent 99c18cd6c2
commit ac04bc60f2

View File

@ -91,7 +91,7 @@ Container = Any
class PatchedRegistry:
def __init__(self):
def __init__(self) -> None:
self._callables: Dict[Callable[..., Any], "PatchedCallable"] = {}
self._attributes: Set[PatchedAttribute] = set()
@ -110,7 +110,7 @@ class PatchedRegistry:
def has_callable(self, fn: Callable[..., Any]) -> bool:
return fn in self._callables
def register_attribute(self, patched: "PatchedAttribute"):
def register_attribute(self, patched: "PatchedAttribute") -> None:
self._attributes.add(patched)
def get_attributes_from_module(self, module: ModuleType) -> Iterator["PatchedAttribute"]:
@ -119,7 +119,7 @@ class PatchedRegistry:
continue
yield attribute
def clear_module_attributes(self, module: ModuleType):
def clear_module_attributes(self, module: ModuleType) -> None:
for attribute in self._attributes.copy():
if not attribute.is_in_module(module):
continue
@ -143,7 +143,7 @@ class PatchedCallable:
original: Optional[Callable[..., Any]] = None,
reference_injections: Optional[Dict[Any, Any]] = None,
reference_closing: Optional[Dict[Any, Any]] = None,
):
) -> None:
self.patched = patched
self.original = original
@ -175,7 +175,7 @@ class PatchedCallable:
class PatchedAttribute:
def __init__(self, member: Any, name: str, marker: "_Marker"):
def __init__(self, member: Any, name: str, marker: "_Marker") -> None:
self.member = member
self.name = name
self.marker = marker
@ -195,7 +195,7 @@ class ProvidersMap:
CONTAINER_STRING_ID = "<container>"
def __init__(self, container):
def __init__(self, container) -> None:
self._container = container
self._map = self._create_providers_map(
current_container=container,
@ -629,15 +629,19 @@ def _fetch_modules(package):
return modules
def _is_method(member):
def _is_method(member) -> bool:
return inspect.ismethod(member) or inspect.isfunction(member)
def _is_marker(member):
def _is_marker(member) -> bool:
return isinstance(member, _Marker)
def _get_patched(fn, reference_injections, reference_closing):
def _get_patched(
fn: F,
reference_injections: Dict[Any, Any],
reference_closing: Dict[Any, Any],
) -> F:
patched_object = PatchedCallable(
original=fn,
reference_injections=reference_injections,
@ -659,7 +663,7 @@ def _is_fastapi_depends(param: Any) -> bool:
return fastapi and isinstance(param, fastapi.params.Depends)
def _is_patched(fn):
def _is_patched(fn) -> bool:
return _patched_registry.has_callable(fn)
@ -688,7 +692,7 @@ class Modifier:
class TypeModifier(Modifier):
def __init__(self, type_: Type):
def __init__(self, type_: Type) -> None:
self.type_ = type_
def modify(
@ -716,7 +720,7 @@ def as_(type_: Type) -> TypeModifier:
class RequiredModifier(Modifier):
def __init__(self):
def __init__(self) -> None:
self.type_modifier = None
def as_int(self) -> "RequiredModifier":
@ -772,7 +776,7 @@ class ProvidedInstance(Modifier):
TYPE_ITEM = "item"
TYPE_CALL = "call"
def __init__(self):
def __init__(self) -> None:
self.segments = []
def __getattr__(self, item):
@ -857,32 +861,32 @@ class AutoLoader:
Automatically wire containers when modules are imported.
"""
def __init__(self):
def __init__(self) -> None:
self.containers = []
self._path_hook = None
def register_containers(self, *containers):
def register_containers(self, *containers) -> None:
self.containers.extend(containers)
if not self.installed:
self.install()
def unregister_containers(self, *containers):
def unregister_containers(self, *containers) -> None:
for container in containers:
self.containers.remove(container)
if not self.containers:
self.uninstall()
def wire_module(self, module):
def wire_module(self, module) -> None:
for container in self.containers:
container.wire(modules=[module])
@property
def installed(self):
def installed(self) -> bool:
return self._path_hook in sys.path_hooks
def install(self):
def install(self) -> None:
if self.installed:
return
@ -913,7 +917,7 @@ class AutoLoader:
sys.path_importer_cache.clear()
importlib.invalidate_caches()
def uninstall(self):
def uninstall(self) -> None:
if not self.installed:
return
@ -958,7 +962,7 @@ from ._cwiring import _async_inject # noqa
# Wiring uses the following Python wrapper because there is
# no possibility to compile a first-type citizen coroutine in Cython.
def _get_async_patched(fn, patched: PatchedCallable):
def _get_async_patched(fn: F, patched: PatchedCallable) -> F:
@functools.wraps(fn)
async def _patched(*args, **kwargs):
return await _async_inject(