mirror of
				https://github.com/ets-labs/python-dependency-injector.git
				synced 2025-10-31 16:07:51 +03:00 
			
		
		
		
	Add wiring of class methods
This commit is contained in:
		
							parent
							
								
									c4b5494b6b
								
							
						
					
					
						commit
						949a91b657
					
				|  | @ -158,7 +158,8 @@ def wire( | |||
|             if inspect.isfunction(member): | ||||
|                 _patch_fn(module, name, member, providers_map) | ||||
|             elif inspect.isclass(member): | ||||
|                 _patch_cls(member, providers_map) | ||||
|                 for method_name, method in inspect.getmembers(member, inspect.isfunction): | ||||
|                     _patch_fn(member, method_name, method, providers_map) | ||||
| 
 | ||||
| 
 | ||||
| def unwire( | ||||
|  | @ -179,29 +180,8 @@ def unwire( | |||
|             if inspect.isfunction(member): | ||||
|                 _unpatch_fn(module, name, member) | ||||
|             elif inspect.isclass(member): | ||||
|                 _unpatch_cls(member,) | ||||
| 
 | ||||
| 
 | ||||
| def _patch_cls( | ||||
|         cls: Type[Any], | ||||
|         providers_map: ProvidersMap, | ||||
| ) -> None: | ||||
|     if not hasattr(cls, '__init__'): | ||||
|         return | ||||
|     init_method = getattr(cls, '__init__') | ||||
|     injections = _resolve_injections(init_method, providers_map) | ||||
|     if not injections: | ||||
|         return | ||||
|     setattr(cls, '__init__', _patch_with_injections(init_method, injections)) | ||||
| 
 | ||||
| 
 | ||||
| def _unpatch_cls(cls: Type[Any]) -> None: | ||||
|     if not hasattr(cls, '__init__'): | ||||
|         return | ||||
|     init_method = getattr(cls, '__init__') | ||||
|     if not _is_patched(init_method): | ||||
|         return | ||||
|     setattr(cls, '__init__', _get_original_from_patched(init_method)) | ||||
|                 for method_name, method in inspect.getmembers(member, inspect.isfunction): | ||||
|                     _unpatch_fn(member, method_name, method) | ||||
| 
 | ||||
| 
 | ||||
| def _patch_fn( | ||||
|  |  | |||
|  | @ -14,6 +14,9 @@ class TestClass: | |||
|     def __init__(self, service: Service = Provide[Container.service]): | ||||
|         self.service = service | ||||
| 
 | ||||
|     def method(self, service: Service = Provide[Container.service]): | ||||
|         return service | ||||
| 
 | ||||
| 
 | ||||
| def test_function(service: Service = Provide[Container.service]): | ||||
|     return service | ||||
|  |  | |||
|  | @ -1,7 +1,7 @@ | |||
| from decimal import Decimal | ||||
| import unittest | ||||
| 
 | ||||
| from dependency_injector.wiring import wire | ||||
| from dependency_injector.wiring import wire, Provide | ||||
| 
 | ||||
| from . import module, package | ||||
| from .service import Service | ||||
|  | @ -35,6 +35,11 @@ class WiringTest(unittest.TestCase): | |||
|         test_class_object = module.TestClass(service=test_service) | ||||
|         self.assertIs(test_class_object.service, test_service) | ||||
| 
 | ||||
|     def test_class_method_wiring(self): | ||||
|         test_class_object = module.TestClass() | ||||
|         service = test_class_object.method() | ||||
|         self.assertIsInstance(service, Service) | ||||
| 
 | ||||
|     def test_function_wiring(self): | ||||
|         service = module.test_function() | ||||
|         self.assertIsInstance(service, Service) | ||||
|  | @ -104,3 +109,26 @@ class WiringTest(unittest.TestCase): | |||
|                 modules=[module], | ||||
|             ) | ||||
| 
 | ||||
|     def test_unwire_function(self): | ||||
|         self.container.unwire() | ||||
|         self.assertIsInstance(module.test_function(), Provide) | ||||
| 
 | ||||
|     def test_unwire_class(self): | ||||
|         self.container.unwire() | ||||
|         test_class_object = module.TestClass() | ||||
|         self.assertIsInstance(test_class_object.service, Provide) | ||||
| 
 | ||||
|     def test_unwire_class_method(self): | ||||
|         self.container.unwire() | ||||
|         test_class_object = module.TestClass() | ||||
|         self.assertIsInstance(test_class_object.method(), Provide) | ||||
| 
 | ||||
|     def test_unwire_package_function(self): | ||||
|         self.container.unwire() | ||||
|         from .package.subpackage.submodule import test_function | ||||
|         self.assertIsInstance(test_function(), Provide) | ||||
| 
 | ||||
|     def test_unwire_package_function_by_reference(self): | ||||
|         from .package.subpackage import submodule | ||||
|         self.container.unwire() | ||||
|         self.assertIsInstance(submodule.test_function(), Provide) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user