mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2024-11-22 09:36:48 +03:00
Factory deep init injections (#277)
* Add factory deep context providing * Add example * Add test
This commit is contained in:
parent
2fc3606671
commit
4a8133204c
48
examples/providers/factory_deep_init_injections.py
Normal file
48
examples/providers/factory_deep_init_injections.py
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
"""`Factory` providers deep init injections example."""
|
||||||
|
|
||||||
|
from dependency_injector import providers
|
||||||
|
|
||||||
|
|
||||||
|
class Regularizer:
|
||||||
|
def __init__(self, alpha):
|
||||||
|
self.alpha = alpha
|
||||||
|
|
||||||
|
|
||||||
|
class Loss:
|
||||||
|
def __init__(self, regularizer):
|
||||||
|
self.regularizer = regularizer
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationTask:
|
||||||
|
def __init__(self, loss):
|
||||||
|
self.loss = loss
|
||||||
|
|
||||||
|
|
||||||
|
class Algorithm:
|
||||||
|
def __init__(self, task):
|
||||||
|
self.task = task
|
||||||
|
|
||||||
|
|
||||||
|
algorithm_factory = providers.Factory(
|
||||||
|
Algorithm,
|
||||||
|
task=providers.Factory(
|
||||||
|
ClassificationTask,
|
||||||
|
loss=providers.Factory(
|
||||||
|
Loss,
|
||||||
|
regularizer=providers.Factory(
|
||||||
|
Regularizer,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
algorithm_1 = algorithm_factory(task__loss__regularizer__alpha=0.5)
|
||||||
|
assert algorithm_1.task.loss.regularizer.alpha == 0.5
|
||||||
|
|
||||||
|
algorithm_2 = algorithm_factory(task__loss__regularizer__alpha=0.7)
|
||||||
|
assert algorithm_2.task.loss.regularizer.alpha == 0.7
|
||||||
|
|
||||||
|
algorithm_3 = algorithm_factory(task__loss__regularizer=Regularizer(alpha=0.8))
|
||||||
|
assert algorithm_3.task.loss.regularizer.alpha == 0.8
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -250,11 +250,38 @@ cdef inline object __get_value(Injection self):
|
||||||
return self.__value()
|
return self.__value()
|
||||||
|
|
||||||
|
|
||||||
|
cdef inline object __get_value_kwargs(Injection self, dict kwargs):
|
||||||
|
if self.__call == 0:
|
||||||
|
return self.__value
|
||||||
|
return self.__value(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
cdef inline tuple __separate_prefixed_kwargs(dict kwargs):
|
||||||
|
cdef dict plain_kwargs = {}
|
||||||
|
cdef dict prefixed_kwargs = {}
|
||||||
|
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if '__' not in key:
|
||||||
|
plain_kwargs[key] = value
|
||||||
|
continue
|
||||||
|
|
||||||
|
index = key.index('__')
|
||||||
|
prefix, name = key[:index], key[index+2:]
|
||||||
|
|
||||||
|
if prefix not in prefixed_kwargs:
|
||||||
|
prefixed_kwargs[prefix] = {}
|
||||||
|
prefixed_kwargs[prefix][name] = value
|
||||||
|
|
||||||
|
return plain_kwargs, prefixed_kwargs
|
||||||
|
|
||||||
|
|
||||||
@cython.boundscheck(False)
|
@cython.boundscheck(False)
|
||||||
@cython.wraparound(False)
|
@cython.wraparound(False)
|
||||||
cdef inline tuple __provide_positional_args(tuple args,
|
cdef inline tuple __provide_positional_args(
|
||||||
tuple inj_args,
|
tuple args,
|
||||||
int inj_args_len):
|
tuple inj_args,
|
||||||
|
int inj_args_len,
|
||||||
|
):
|
||||||
cdef int index
|
cdef int index
|
||||||
cdef list positional_args
|
cdef list positional_args
|
||||||
cdef PositionalInjection injection
|
cdef PositionalInjection injection
|
||||||
|
@ -273,11 +300,15 @@ cdef inline tuple __provide_positional_args(tuple args,
|
||||||
|
|
||||||
@cython.boundscheck(False)
|
@cython.boundscheck(False)
|
||||||
@cython.wraparound(False)
|
@cython.wraparound(False)
|
||||||
cdef inline dict __provide_keyword_args(dict kwargs,
|
cdef inline dict __provide_keyword_args(
|
||||||
tuple inj_kwargs,
|
dict kwargs,
|
||||||
int inj_kwargs_len):
|
tuple inj_kwargs,
|
||||||
|
int inj_kwargs_len,
|
||||||
|
):
|
||||||
cdef int index
|
cdef int index
|
||||||
cdef object name
|
cdef object name
|
||||||
|
cdef object value
|
||||||
|
cdef dict prefixed
|
||||||
cdef NamedInjection kw_injection
|
cdef NamedInjection kw_injection
|
||||||
|
|
||||||
if len(kwargs) == 0:
|
if len(kwargs) == 0:
|
||||||
|
@ -286,20 +317,33 @@ cdef inline dict __provide_keyword_args(dict kwargs,
|
||||||
name = __get_name(kw_injection)
|
name = __get_name(kw_injection)
|
||||||
kwargs[name] = __get_value(kw_injection)
|
kwargs[name] = __get_value(kw_injection)
|
||||||
else:
|
else:
|
||||||
|
kwargs, prefixed = __separate_prefixed_kwargs(kwargs)
|
||||||
|
|
||||||
|
|
||||||
for index in range(inj_kwargs_len):
|
for index in range(inj_kwargs_len):
|
||||||
kw_injection = <NamedInjection>inj_kwargs[index]
|
kw_injection = <NamedInjection>inj_kwargs[index]
|
||||||
name = __get_name(kw_injection)
|
name = __get_name(kw_injection)
|
||||||
if name not in kwargs:
|
|
||||||
kwargs[name] = __get_value(kw_injection)
|
if name in kwargs:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if name in prefixed:
|
||||||
|
value = __get_value_kwargs(kw_injection, prefixed[name])
|
||||||
|
else:
|
||||||
|
value = __get_value(kw_injection)
|
||||||
|
|
||||||
|
kwargs[name] = value
|
||||||
|
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
@cython.boundscheck(False)
|
@cython.boundscheck(False)
|
||||||
@cython.wraparound(False)
|
@cython.wraparound(False)
|
||||||
cdef inline object __inject_attributes(object instance,
|
cdef inline object __inject_attributes(
|
||||||
tuple attributes,
|
object instance,
|
||||||
int attributes_len):
|
tuple attributes,
|
||||||
|
int attributes_len,
|
||||||
|
):
|
||||||
cdef NamedInjection attr_injection
|
cdef NamedInjection attr_injection
|
||||||
for index in range(attributes_len):
|
for index in range(attributes_len):
|
||||||
attr_injection = <NamedInjection>attributes[index]
|
attr_injection = <NamedInjection>attributes[index]
|
||||||
|
|
|
@ -166,6 +166,45 @@ class FactoryTests(unittest.TestCase):
|
||||||
self.assertEqual(instance.init_arg3, 33)
|
self.assertEqual(instance.init_arg3, 33)
|
||||||
self.assertEqual(instance.init_arg4, 44)
|
self.assertEqual(instance.init_arg4, 44)
|
||||||
|
|
||||||
|
def test_call_with_deep_context_kwargs(self):
|
||||||
|
"""`Factory` providers deep init injections example."""
|
||||||
|
class Regularizer:
|
||||||
|
def __init__(self, alpha):
|
||||||
|
self.alpha = alpha
|
||||||
|
|
||||||
|
class Loss:
|
||||||
|
def __init__(self, regularizer):
|
||||||
|
self.regularizer = regularizer
|
||||||
|
|
||||||
|
class ClassificationTask:
|
||||||
|
def __init__(self, loss):
|
||||||
|
self.loss = loss
|
||||||
|
|
||||||
|
class Algorithm:
|
||||||
|
def __init__(self, task):
|
||||||
|
self.task = task
|
||||||
|
|
||||||
|
algorithm_factory = providers.Factory(
|
||||||
|
Algorithm,
|
||||||
|
task=providers.Factory(
|
||||||
|
ClassificationTask,
|
||||||
|
loss=providers.Factory(
|
||||||
|
Loss,
|
||||||
|
regularizer=providers.Factory(
|
||||||
|
Regularizer,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
algorithm_1 = algorithm_factory(task__loss__regularizer__alpha=0.5)
|
||||||
|
algorithm_2 = algorithm_factory(task__loss__regularizer__alpha=0.7)
|
||||||
|
algorithm_3 = algorithm_factory(task__loss__regularizer=Regularizer(alpha=0.8))
|
||||||
|
|
||||||
|
self.assertEqual(algorithm_1.task.loss.regularizer.alpha, 0.5)
|
||||||
|
self.assertEqual(algorithm_2.task.loss.regularizer.alpha, 0.7)
|
||||||
|
self.assertEqual(algorithm_3.task.loss.regularizer.alpha, 0.8)
|
||||||
|
|
||||||
def test_fluent_interface(self):
|
def test_fluent_interface(self):
|
||||||
provider = providers.Factory(Example) \
|
provider = providers.Factory(Example) \
|
||||||
.add_args(1, 2) \
|
.add_args(1, 2) \
|
||||||
|
|
Loading…
Reference in New Issue
Block a user