mirror of
				https://github.com/ets-labs/python-dependency-injector.git
				synced 2025-10-31 16:07:51 +03:00 
			
		
		
		
	Add wiring of class methods
This commit is contained in:
		
							parent
							
								
									c4b5494b6b
								
							
						
					
					
						commit
						949a91b657
					
				|  | @ -158,7 +158,8 @@ def wire( | ||||||
|             if inspect.isfunction(member): |             if inspect.isfunction(member): | ||||||
|                 _patch_fn(module, name, member, providers_map) |                 _patch_fn(module, name, member, providers_map) | ||||||
|             elif inspect.isclass(member): |             elif inspect.isclass(member): | ||||||
|                 _patch_cls(member, providers_map) |                 for method_name, method in inspect.getmembers(member, inspect.isfunction): | ||||||
|  |                     _patch_fn(member, method_name, method, providers_map) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def unwire( | def unwire( | ||||||
|  | @ -179,29 +180,8 @@ def unwire( | ||||||
|             if inspect.isfunction(member): |             if inspect.isfunction(member): | ||||||
|                 _unpatch_fn(module, name, member) |                 _unpatch_fn(module, name, member) | ||||||
|             elif inspect.isclass(member): |             elif inspect.isclass(member): | ||||||
|                 _unpatch_cls(member,) |                 for method_name, method in inspect.getmembers(member, inspect.isfunction): | ||||||
| 
 |                     _unpatch_fn(member, method_name, method) | ||||||
| 
 |  | ||||||
| def _patch_cls( |  | ||||||
|         cls: Type[Any], |  | ||||||
|         providers_map: ProvidersMap, |  | ||||||
| ) -> None: |  | ||||||
|     if not hasattr(cls, '__init__'): |  | ||||||
|         return |  | ||||||
|     init_method = getattr(cls, '__init__') |  | ||||||
|     injections = _resolve_injections(init_method, providers_map) |  | ||||||
|     if not injections: |  | ||||||
|         return |  | ||||||
|     setattr(cls, '__init__', _patch_with_injections(init_method, injections)) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def _unpatch_cls(cls: Type[Any]) -> None: |  | ||||||
|     if not hasattr(cls, '__init__'): |  | ||||||
|         return |  | ||||||
|     init_method = getattr(cls, '__init__') |  | ||||||
|     if not _is_patched(init_method): |  | ||||||
|         return |  | ||||||
|     setattr(cls, '__init__', _get_original_from_patched(init_method)) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _patch_fn( | def _patch_fn( | ||||||
|  |  | ||||||
|  | @ -14,6 +14,9 @@ class TestClass: | ||||||
|     def __init__(self, service: Service = Provide[Container.service]): |     def __init__(self, service: Service = Provide[Container.service]): | ||||||
|         self.service = service |         self.service = service | ||||||
| 
 | 
 | ||||||
|  |     def method(self, service: Service = Provide[Container.service]): | ||||||
|  |         return service | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| def test_function(service: Service = Provide[Container.service]): | def test_function(service: Service = Provide[Container.service]): | ||||||
|     return service |     return service | ||||||
|  |  | ||||||
|  | @ -1,7 +1,7 @@ | ||||||
| from decimal import Decimal | from decimal import Decimal | ||||||
| import unittest | import unittest | ||||||
| 
 | 
 | ||||||
| from dependency_injector.wiring import wire | from dependency_injector.wiring import wire, Provide | ||||||
| 
 | 
 | ||||||
| from . import module, package | from . import module, package | ||||||
| from .service import Service | from .service import Service | ||||||
|  | @ -35,6 +35,11 @@ class WiringTest(unittest.TestCase): | ||||||
|         test_class_object = module.TestClass(service=test_service) |         test_class_object = module.TestClass(service=test_service) | ||||||
|         self.assertIs(test_class_object.service, test_service) |         self.assertIs(test_class_object.service, test_service) | ||||||
| 
 | 
 | ||||||
|  |     def test_class_method_wiring(self): | ||||||
|  |         test_class_object = module.TestClass() | ||||||
|  |         service = test_class_object.method() | ||||||
|  |         self.assertIsInstance(service, Service) | ||||||
|  | 
 | ||||||
|     def test_function_wiring(self): |     def test_function_wiring(self): | ||||||
|         service = module.test_function() |         service = module.test_function() | ||||||
|         self.assertIsInstance(service, Service) |         self.assertIsInstance(service, Service) | ||||||
|  | @ -104,3 +109,26 @@ class WiringTest(unittest.TestCase): | ||||||
|                 modules=[module], |                 modules=[module], | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|  |     def test_unwire_function(self): | ||||||
|  |         self.container.unwire() | ||||||
|  |         self.assertIsInstance(module.test_function(), Provide) | ||||||
|  | 
 | ||||||
|  |     def test_unwire_class(self): | ||||||
|  |         self.container.unwire() | ||||||
|  |         test_class_object = module.TestClass() | ||||||
|  |         self.assertIsInstance(test_class_object.service, Provide) | ||||||
|  | 
 | ||||||
|  |     def test_unwire_class_method(self): | ||||||
|  |         self.container.unwire() | ||||||
|  |         test_class_object = module.TestClass() | ||||||
|  |         self.assertIsInstance(test_class_object.method(), Provide) | ||||||
|  | 
 | ||||||
|  |     def test_unwire_package_function(self): | ||||||
|  |         self.container.unwire() | ||||||
|  |         from .package.subpackage.submodule import test_function | ||||||
|  |         self.assertIsInstance(test_function(), Provide) | ||||||
|  | 
 | ||||||
|  |     def test_unwire_package_function_by_reference(self): | ||||||
|  |         from .package.subpackage import submodule | ||||||
|  |         self.container.unwire() | ||||||
|  |         self.assertIsInstance(submodule.test_function(), Provide) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user