mirror of
				https://github.com/ets-labs/python-dependency-injector.git
				synced 2025-11-04 09:57:37 +03:00 
			
		
		
		
	Add wiring registry
This commit is contained in:
		
							parent
							
								
									06d865c7b7
								
							
						
					
					
						commit
						4057a79cf8
					
				| 
						 | 
					@ -6,7 +6,7 @@ import importlib
 | 
				
			||||||
import pkgutil
 | 
					import pkgutil
 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
from types import ModuleType
 | 
					from types import ModuleType
 | 
				
			||||||
from typing import Optional, Iterable, Callable, Any, Tuple, Dict, Generic, TypeVar, Type, cast
 | 
					from typing import Optional, Iterable, Iterator, Callable, Any, Tuple, Dict, Generic, TypeVar, Type, cast
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if sys.version_info < (3, 7):
 | 
					if sys.version_info < (3, 7):
 | 
				
			||||||
    from typing import GenericMeta
 | 
					    from typing import GenericMeta
 | 
				
			||||||
| 
						 | 
					@ -32,6 +32,24 @@ F = TypeVar('F', bound=Callable[..., Any])
 | 
				
			||||||
Container = Any
 | 
					Container = Any
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Registry:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self):
 | 
				
			||||||
 | 
					        self._storage = set()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add(self, patched: Callable[..., Any]) -> None:
 | 
				
			||||||
 | 
					        self._storage.add(patched)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_from_module(self, module: ModuleType) -> Iterator[Callable[..., Any]]:
 | 
				
			||||||
 | 
					        for patched in self._storage:
 | 
				
			||||||
 | 
					            if patched.__module__ != module.__name__:
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					            yield patched
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_patched_registry = Registry()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ProvidersMap:
 | 
					class ProvidersMap:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, container):
 | 
					    def __init__(self, container):
 | 
				
			||||||
| 
						 | 
					@ -181,6 +199,9 @@ def wire(
 | 
				
			||||||
                for method_name, method in inspect.getmembers(member, _is_method):
 | 
					                for method_name, method in inspect.getmembers(member, _is_method):
 | 
				
			||||||
                    _patch_method(member, method_name, method, providers_map)
 | 
					                    _patch_method(member, method_name, method, providers_map)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for patched in _patched_registry.get_from_module(module):
 | 
				
			||||||
 | 
					            _bind_injections(patched, providers_map)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def unwire(
 | 
					def unwire(
 | 
				
			||||||
        *,
 | 
					        *,
 | 
				
			||||||
| 
						 | 
					@ -203,11 +224,15 @@ def unwire(
 | 
				
			||||||
                for method_name, method in inspect.getmembers(member, inspect.isfunction):
 | 
					                for method_name, method in inspect.getmembers(member, inspect.isfunction):
 | 
				
			||||||
                    _unpatch(member, method_name, method)
 | 
					                    _unpatch(member, method_name, method)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for patched in _patched_registry.get_from_module(module):
 | 
				
			||||||
 | 
					            _unbind_injections(patched)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def inject(fn: F) -> F:
 | 
					def inject(fn: F) -> F:
 | 
				
			||||||
    """Decorate callable with injecting decorator."""
 | 
					    """Decorate callable with injecting decorator."""
 | 
				
			||||||
    reference_injections, reference_closing = _fetch_reference_injections(fn)
 | 
					    reference_injections, reference_closing = _fetch_reference_injections(fn)
 | 
				
			||||||
    patched = _get_patched(fn, reference_injections, reference_closing)
 | 
					    patched = _get_patched(fn, reference_injections, reference_closing)
 | 
				
			||||||
 | 
					    _patched_registry.add(patched)
 | 
				
			||||||
    return cast(F, patched)
 | 
					    return cast(F, patched)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -222,6 +247,7 @@ def _patch_fn(
 | 
				
			||||||
        if not reference_injections:
 | 
					        if not reference_injections:
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
        fn = _get_patched(fn, reference_injections, reference_closing)
 | 
					        fn = _get_patched(fn, reference_injections, reference_closing)
 | 
				
			||||||
 | 
					        _patched_registry.add(fn)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    _bind_injections(fn, providers_map)
 | 
					    _bind_injections(fn, providers_map)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -247,6 +273,7 @@ def _patch_method(
 | 
				
			||||||
        if not reference_injections:
 | 
					        if not reference_injections:
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
        fn = _get_patched(fn, reference_injections, reference_closing)
 | 
					        fn = _get_patched(fn, reference_injections, reference_closing)
 | 
				
			||||||
 | 
					        _patched_registry.add(fn)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    _bind_injections(fn, providers_map)
 | 
					    _bind_injections(fn, providers_map)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -3,7 +3,7 @@
 | 
				
			||||||
from decimal import Decimal
 | 
					from decimal import Decimal
 | 
				
			||||||
from typing import Callable
 | 
					from typing import Callable
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from dependency_injector.wiring import Provide, Provider
 | 
					from dependency_injector.wiring import inject, Provide, Provider
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .container import Container, SubContainer
 | 
					from .container import Container, SubContainer
 | 
				
			||||||
from .service import Service
 | 
					from .service import Service
 | 
				
			||||||
| 
						 | 
					@ -65,3 +65,17 @@ def test_provide_from_different_containers(
 | 
				
			||||||
        some_value: int = Provide[SubContainer.int_object],
 | 
					        some_value: int = Provide[SubContainer.int_object],
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    return service, some_value
 | 
					    return service, some_value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ClassDecorator:
 | 
				
			||||||
 | 
					    def __init__(self, fn):
 | 
				
			||||||
 | 
					        self._fn = fn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, *args, **kwargs):
 | 
				
			||||||
 | 
					        return self._fn(*args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@ClassDecorator
 | 
				
			||||||
 | 
					@inject
 | 
				
			||||||
 | 
					def test_class_decorator(service: Service = Provide[Container.service]):
 | 
				
			||||||
 | 
					    return service
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -226,6 +226,10 @@ class WiringTest(unittest.TestCase):
 | 
				
			||||||
        self.assertEqual(result_2.init_counter, 0)
 | 
					        self.assertEqual(result_2.init_counter, 0)
 | 
				
			||||||
        self.assertEqual(result_2.shutdown_counter, 0)
 | 
					        self.assertEqual(result_2.shutdown_counter, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_class_decorator(self):
 | 
				
			||||||
 | 
					        service = module.test_class_decorator()
 | 
				
			||||||
 | 
					        self.assertIsInstance(service, Service)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class WiringAndFastAPITest(unittest.TestCase):
 | 
					class WiringAndFastAPITest(unittest.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user