mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2024-11-22 09:36:48 +03:00
Factory deep init injections (#277)
* Add factory deep context providing * Add example * Add test
This commit is contained in:
parent
2fc3606671
commit
4a8133204c
48
examples/providers/factory_deep_init_injections.py
Normal file
48
examples/providers/factory_deep_init_injections.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
"""`Factory` providers deep init injections example."""
|
||||
|
||||
from dependency_injector import providers
|
||||
|
||||
|
||||
class Regularizer:
|
||||
def __init__(self, alpha):
|
||||
self.alpha = alpha
|
||||
|
||||
|
||||
class Loss:
|
||||
def __init__(self, regularizer):
|
||||
self.regularizer = regularizer
|
||||
|
||||
|
||||
class ClassificationTask:
|
||||
def __init__(self, loss):
|
||||
self.loss = loss
|
||||
|
||||
|
||||
class Algorithm:
|
||||
def __init__(self, task):
|
||||
self.task = task
|
||||
|
||||
|
||||
algorithm_factory = providers.Factory(
|
||||
Algorithm,
|
||||
task=providers.Factory(
|
||||
ClassificationTask,
|
||||
loss=providers.Factory(
|
||||
Loss,
|
||||
regularizer=providers.Factory(
|
||||
Regularizer,
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
algorithm_1 = algorithm_factory(task__loss__regularizer__alpha=0.5)
|
||||
assert algorithm_1.task.loss.regularizer.alpha == 0.5
|
||||
|
||||
algorithm_2 = algorithm_factory(task__loss__regularizer__alpha=0.7)
|
||||
assert algorithm_2.task.loss.regularizer.alpha == 0.7
|
||||
|
||||
algorithm_3 = algorithm_factory(task__loss__regularizer=Regularizer(alpha=0.8))
|
||||
assert algorithm_3.task.loss.regularizer.alpha == 0.8
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -250,11 +250,38 @@ cdef inline object __get_value(Injection self):
|
|||
return self.__value()
|
||||
|
||||
|
||||
cdef inline object __get_value_kwargs(Injection self, dict kwargs):
|
||||
if self.__call == 0:
|
||||
return self.__value
|
||||
return self.__value(**kwargs)
|
||||
|
||||
|
||||
cdef inline tuple __separate_prefixed_kwargs(dict kwargs):
|
||||
cdef dict plain_kwargs = {}
|
||||
cdef dict prefixed_kwargs = {}
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if '__' not in key:
|
||||
plain_kwargs[key] = value
|
||||
continue
|
||||
|
||||
index = key.index('__')
|
||||
prefix, name = key[:index], key[index+2:]
|
||||
|
||||
if prefix not in prefixed_kwargs:
|
||||
prefixed_kwargs[prefix] = {}
|
||||
prefixed_kwargs[prefix][name] = value
|
||||
|
||||
return plain_kwargs, prefixed_kwargs
|
||||
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cdef inline tuple __provide_positional_args(tuple args,
|
||||
cdef inline tuple __provide_positional_args(
|
||||
tuple args,
|
||||
tuple inj_args,
|
||||
int inj_args_len):
|
||||
int inj_args_len,
|
||||
):
|
||||
cdef int index
|
||||
cdef list positional_args
|
||||
cdef PositionalInjection injection
|
||||
|
@ -273,11 +300,15 @@ cdef inline tuple __provide_positional_args(tuple args,
|
|||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cdef inline dict __provide_keyword_args(dict kwargs,
|
||||
cdef inline dict __provide_keyword_args(
|
||||
dict kwargs,
|
||||
tuple inj_kwargs,
|
||||
int inj_kwargs_len):
|
||||
int inj_kwargs_len,
|
||||
):
|
||||
cdef int index
|
||||
cdef object name
|
||||
cdef object value
|
||||
cdef dict prefixed
|
||||
cdef NamedInjection kw_injection
|
||||
|
||||
if len(kwargs) == 0:
|
||||
|
@ -286,20 +317,33 @@ cdef inline dict __provide_keyword_args(dict kwargs,
|
|||
name = __get_name(kw_injection)
|
||||
kwargs[name] = __get_value(kw_injection)
|
||||
else:
|
||||
kwargs, prefixed = __separate_prefixed_kwargs(kwargs)
|
||||
|
||||
|
||||
for index in range(inj_kwargs_len):
|
||||
kw_injection = <NamedInjection>inj_kwargs[index]
|
||||
name = __get_name(kw_injection)
|
||||
if name not in kwargs:
|
||||
kwargs[name] = __get_value(kw_injection)
|
||||
|
||||
if name in kwargs:
|
||||
continue
|
||||
|
||||
if name in prefixed:
|
||||
value = __get_value_kwargs(kw_injection, prefixed[name])
|
||||
else:
|
||||
value = __get_value(kw_injection)
|
||||
|
||||
kwargs[name] = value
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cdef inline object __inject_attributes(object instance,
|
||||
cdef inline object __inject_attributes(
|
||||
object instance,
|
||||
tuple attributes,
|
||||
int attributes_len):
|
||||
int attributes_len,
|
||||
):
|
||||
cdef NamedInjection attr_injection
|
||||
for index in range(attributes_len):
|
||||
attr_injection = <NamedInjection>attributes[index]
|
||||
|
|
|
@ -166,6 +166,45 @@ class FactoryTests(unittest.TestCase):
|
|||
self.assertEqual(instance.init_arg3, 33)
|
||||
self.assertEqual(instance.init_arg4, 44)
|
||||
|
||||
def test_call_with_deep_context_kwargs(self):
|
||||
"""`Factory` providers deep init injections example."""
|
||||
class Regularizer:
|
||||
def __init__(self, alpha):
|
||||
self.alpha = alpha
|
||||
|
||||
class Loss:
|
||||
def __init__(self, regularizer):
|
||||
self.regularizer = regularizer
|
||||
|
||||
class ClassificationTask:
|
||||
def __init__(self, loss):
|
||||
self.loss = loss
|
||||
|
||||
class Algorithm:
|
||||
def __init__(self, task):
|
||||
self.task = task
|
||||
|
||||
algorithm_factory = providers.Factory(
|
||||
Algorithm,
|
||||
task=providers.Factory(
|
||||
ClassificationTask,
|
||||
loss=providers.Factory(
|
||||
Loss,
|
||||
regularizer=providers.Factory(
|
||||
Regularizer,
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
algorithm_1 = algorithm_factory(task__loss__regularizer__alpha=0.5)
|
||||
algorithm_2 = algorithm_factory(task__loss__regularizer__alpha=0.7)
|
||||
algorithm_3 = algorithm_factory(task__loss__regularizer=Regularizer(alpha=0.8))
|
||||
|
||||
self.assertEqual(algorithm_1.task.loss.regularizer.alpha, 0.5)
|
||||
self.assertEqual(algorithm_2.task.loss.regularizer.alpha, 0.7)
|
||||
self.assertEqual(algorithm_3.task.loss.regularizer.alpha, 0.8)
|
||||
|
||||
def test_fluent_interface(self):
|
||||
provider = providers.Factory(Example) \
|
||||
.add_args(1, 2) \
|
||||
|
|
Loading…
Reference in New Issue
Block a user