mirror of
				https://github.com/ets-labs/python-dependency-injector.git
				synced 2025-11-01 00:17:55 +03:00 
			
		
		
		
	Factory deep init injections (#277)
* Add factory deep context providing * Add example * Add test
This commit is contained in:
		
							parent
							
								
									2fc3606671
								
							
						
					
					
						commit
						4a8133204c
					
				
							
								
								
									
										48
									
								
								examples/providers/factory_deep_init_injections.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								examples/providers/factory_deep_init_injections.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,48 @@ | |||
| """`Factory` providers deep init injections example.""" | ||||
| 
 | ||||
| from dependency_injector import providers | ||||
| 
 | ||||
| 
 | ||||
| class Regularizer: | ||||
|     def __init__(self, alpha): | ||||
|         self.alpha = alpha | ||||
| 
 | ||||
| 
 | ||||
| class Loss: | ||||
|     def __init__(self, regularizer): | ||||
|         self.regularizer = regularizer | ||||
| 
 | ||||
| 
 | ||||
| class ClassificationTask: | ||||
|     def __init__(self, loss): | ||||
|         self.loss = loss | ||||
| 
 | ||||
| 
 | ||||
| class Algorithm: | ||||
|     def __init__(self, task): | ||||
|         self.task = task | ||||
| 
 | ||||
| 
 | ||||
| algorithm_factory = providers.Factory( | ||||
|     Algorithm, | ||||
|     task=providers.Factory( | ||||
|         ClassificationTask, | ||||
|         loss=providers.Factory( | ||||
|             Loss, | ||||
|             regularizer=providers.Factory( | ||||
|                 Regularizer, | ||||
|             ), | ||||
|         ), | ||||
|     ), | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     algorithm_1 = algorithm_factory(task__loss__regularizer__alpha=0.5) | ||||
|     assert algorithm_1.task.loss.regularizer.alpha == 0.5 | ||||
| 
 | ||||
|     algorithm_2 = algorithm_factory(task__loss__regularizer__alpha=0.7) | ||||
|     assert algorithm_2.task.loss.regularizer.alpha == 0.7 | ||||
| 
 | ||||
|     algorithm_3 = algorithm_factory(task__loss__regularizer=Regularizer(alpha=0.8)) | ||||
|     assert algorithm_3.task.loss.regularizer.alpha == 0.8 | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							|  | @ -250,11 +250,38 @@ cdef inline object __get_value(Injection self): | |||
|     return self.__value() | ||||
| 
 | ||||
| 
 | ||||
| cdef inline object __get_value_kwargs(Injection self, dict kwargs): | ||||
|     if self.__call == 0: | ||||
|         return self.__value | ||||
|     return self.__value(**kwargs) | ||||
| 
 | ||||
| 
 | ||||
| cdef inline tuple __separate_prefixed_kwargs(dict kwargs): | ||||
|     cdef dict plain_kwargs = {} | ||||
|     cdef dict prefixed_kwargs = {} | ||||
| 
 | ||||
|     for key, value in kwargs.items(): | ||||
|         if '__' not in key: | ||||
|             plain_kwargs[key] = value | ||||
|             continue | ||||
| 
 | ||||
|         index = key.index('__') | ||||
|         prefix, name = key[:index], key[index+2:] | ||||
| 
 | ||||
|         if prefix not in prefixed_kwargs: | ||||
|             prefixed_kwargs[prefix] = {} | ||||
|         prefixed_kwargs[prefix][name] = value | ||||
| 
 | ||||
|     return plain_kwargs, prefixed_kwargs | ||||
| 
 | ||||
| 
 | ||||
| @cython.boundscheck(False) | ||||
| @cython.wraparound(False) | ||||
| cdef inline tuple __provide_positional_args(tuple args, | ||||
| cdef inline tuple __provide_positional_args( | ||||
|         tuple args, | ||||
|         tuple inj_args, | ||||
|                                             int inj_args_len): | ||||
|         int inj_args_len, | ||||
| ): | ||||
|     cdef int index | ||||
|     cdef list positional_args | ||||
|     cdef PositionalInjection injection | ||||
|  | @ -273,11 +300,15 @@ cdef inline tuple __provide_positional_args(tuple args, | |||
| 
 | ||||
| @cython.boundscheck(False) | ||||
| @cython.wraparound(False) | ||||
| cdef inline dict __provide_keyword_args(dict kwargs, | ||||
| cdef inline dict __provide_keyword_args( | ||||
|         dict kwargs, | ||||
|         tuple inj_kwargs, | ||||
|                                         int inj_kwargs_len): | ||||
|         int inj_kwargs_len, | ||||
| ): | ||||
|     cdef int index | ||||
|     cdef object name | ||||
|     cdef object value | ||||
|     cdef dict prefixed | ||||
|     cdef NamedInjection kw_injection | ||||
| 
 | ||||
|     if len(kwargs) == 0: | ||||
|  | @ -286,20 +317,33 @@ cdef inline dict __provide_keyword_args(dict kwargs, | |||
|             name = __get_name(kw_injection) | ||||
|             kwargs[name] = __get_value(kw_injection) | ||||
|     else: | ||||
|         kwargs, prefixed = __separate_prefixed_kwargs(kwargs) | ||||
| 
 | ||||
| 
 | ||||
|         for index in range(inj_kwargs_len): | ||||
|             kw_injection = <NamedInjection>inj_kwargs[index] | ||||
|             name = __get_name(kw_injection) | ||||
|             if name not in kwargs: | ||||
|                 kwargs[name] = __get_value(kw_injection) | ||||
| 
 | ||||
|             if name in kwargs: | ||||
|                 continue | ||||
| 
 | ||||
|             if name in prefixed: | ||||
|                 value = __get_value_kwargs(kw_injection, prefixed[name]) | ||||
|             else: | ||||
|                 value = __get_value(kw_injection) | ||||
| 
 | ||||
|             kwargs[name] = value | ||||
| 
 | ||||
|     return kwargs | ||||
| 
 | ||||
| 
 | ||||
| @cython.boundscheck(False) | ||||
| @cython.wraparound(False) | ||||
| cdef inline object __inject_attributes(object instance, | ||||
| cdef inline object __inject_attributes( | ||||
|         object instance, | ||||
|         tuple attributes, | ||||
|                                        int attributes_len): | ||||
|         int attributes_len, | ||||
| ): | ||||
|     cdef NamedInjection attr_injection | ||||
|     for index in range(attributes_len): | ||||
|         attr_injection = <NamedInjection>attributes[index] | ||||
|  |  | |||
|  | @ -166,6 +166,45 @@ class FactoryTests(unittest.TestCase): | |||
|         self.assertEqual(instance.init_arg3, 33) | ||||
|         self.assertEqual(instance.init_arg4, 44) | ||||
| 
 | ||||
|     def test_call_with_deep_context_kwargs(self): | ||||
|         """`Factory` providers deep init injections example.""" | ||||
|         class Regularizer: | ||||
|             def __init__(self, alpha): | ||||
|                 self.alpha = alpha | ||||
| 
 | ||||
|         class Loss: | ||||
|             def __init__(self, regularizer): | ||||
|                 self.regularizer = regularizer | ||||
| 
 | ||||
|         class ClassificationTask: | ||||
|             def __init__(self, loss): | ||||
|                 self.loss = loss | ||||
| 
 | ||||
|         class Algorithm: | ||||
|             def __init__(self, task): | ||||
|                 self.task = task | ||||
| 
 | ||||
|         algorithm_factory = providers.Factory( | ||||
|             Algorithm, | ||||
|             task=providers.Factory( | ||||
|                 ClassificationTask, | ||||
|                 loss=providers.Factory( | ||||
|                     Loss, | ||||
|                     regularizer=providers.Factory( | ||||
|                         Regularizer, | ||||
|                     ), | ||||
|                 ), | ||||
|             ), | ||||
|         ) | ||||
| 
 | ||||
|         algorithm_1 = algorithm_factory(task__loss__regularizer__alpha=0.5) | ||||
|         algorithm_2 = algorithm_factory(task__loss__regularizer__alpha=0.7) | ||||
|         algorithm_3 = algorithm_factory(task__loss__regularizer=Regularizer(alpha=0.8)) | ||||
| 
 | ||||
|         self.assertEqual(algorithm_1.task.loss.regularizer.alpha, 0.5) | ||||
|         self.assertEqual(algorithm_2.task.loss.regularizer.alpha, 0.7) | ||||
|         self.assertEqual(algorithm_3.task.loss.regularizer.alpha, 0.8) | ||||
| 
 | ||||
|     def test_fluent_interface(self): | ||||
|         provider = providers.Factory(Example) \ | ||||
|             .add_args(1, 2) \ | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user