Factory deep init injections (#277)

* Add factory deep context providing

* Add example

* Add test
This commit is contained in:
Roman Mogylatov 2020-08-06 16:33:06 -04:00 committed by GitHub
parent 2fc3606671
commit 4a8133204c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 2157 additions and 406 deletions

View File

@ -0,0 +1,48 @@
"""`Factory` providers deep init injections example."""
from dependency_injector import providers
class Regularizer:
def __init__(self, alpha):
self.alpha = alpha
class Loss:
def __init__(self, regularizer):
self.regularizer = regularizer
class ClassificationTask:
def __init__(self, loss):
self.loss = loss
class Algorithm:
def __init__(self, task):
self.task = task
algorithm_factory = providers.Factory(
Algorithm,
task=providers.Factory(
ClassificationTask,
loss=providers.Factory(
Loss,
regularizer=providers.Factory(
Regularizer,
),
),
),
)
if __name__ == '__main__':
algorithm_1 = algorithm_factory(task__loss__regularizer__alpha=0.5)
assert algorithm_1.task.loss.regularizer.alpha == 0.5
algorithm_2 = algorithm_factory(task__loss__regularizer__alpha=0.7)
assert algorithm_2.task.loss.regularizer.alpha == 0.7
algorithm_3 = algorithm_factory(task__loss__regularizer=Regularizer(alpha=0.8))
assert algorithm_3.task.loss.regularizer.alpha == 0.8

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -250,11 +250,38 @@ cdef inline object __get_value(Injection self):
return self.__value()
cdef inline object __get_value_kwargs(Injection self, dict kwargs):
if self.__call == 0:
return self.__value
return self.__value(**kwargs)
cdef inline tuple __separate_prefixed_kwargs(dict kwargs):
cdef dict plain_kwargs = {}
cdef dict prefixed_kwargs = {}
for key, value in kwargs.items():
if '__' not in key:
plain_kwargs[key] = value
continue
index = key.index('__')
prefix, name = key[:index], key[index+2:]
if prefix not in prefixed_kwargs:
prefixed_kwargs[prefix] = {}
prefixed_kwargs[prefix][name] = value
return plain_kwargs, prefixed_kwargs
@cython.boundscheck(False)
@cython.wraparound(False)
cdef inline tuple __provide_positional_args(tuple args,
cdef inline tuple __provide_positional_args(
tuple args,
tuple inj_args,
int inj_args_len):
int inj_args_len,
):
cdef int index
cdef list positional_args
cdef PositionalInjection injection
@ -273,11 +300,15 @@ cdef inline tuple __provide_positional_args(tuple args,
@cython.boundscheck(False)
@cython.wraparound(False)
cdef inline dict __provide_keyword_args(dict kwargs,
cdef inline dict __provide_keyword_args(
dict kwargs,
tuple inj_kwargs,
int inj_kwargs_len):
int inj_kwargs_len,
):
cdef int index
cdef object name
cdef object value
cdef dict prefixed
cdef NamedInjection kw_injection
if len(kwargs) == 0:
@ -286,20 +317,33 @@ cdef inline dict __provide_keyword_args(dict kwargs,
name = __get_name(kw_injection)
kwargs[name] = __get_value(kw_injection)
else:
kwargs, prefixed = __separate_prefixed_kwargs(kwargs)
for index in range(inj_kwargs_len):
kw_injection = <NamedInjection>inj_kwargs[index]
name = __get_name(kw_injection)
if name not in kwargs:
kwargs[name] = __get_value(kw_injection)
if name in kwargs:
continue
if name in prefixed:
value = __get_value_kwargs(kw_injection, prefixed[name])
else:
value = __get_value(kw_injection)
kwargs[name] = value
return kwargs
@cython.boundscheck(False)
@cython.wraparound(False)
cdef inline object __inject_attributes(object instance,
cdef inline object __inject_attributes(
object instance,
tuple attributes,
int attributes_len):
int attributes_len,
):
cdef NamedInjection attr_injection
for index in range(attributes_len):
attr_injection = <NamedInjection>attributes[index]

View File

@ -166,6 +166,45 @@ class FactoryTests(unittest.TestCase):
self.assertEqual(instance.init_arg3, 33)
self.assertEqual(instance.init_arg4, 44)
def test_call_with_deep_context_kwargs(self):
"""`Factory` providers deep init injections example."""
class Regularizer:
def __init__(self, alpha):
self.alpha = alpha
class Loss:
def __init__(self, regularizer):
self.regularizer = regularizer
class ClassificationTask:
def __init__(self, loss):
self.loss = loss
class Algorithm:
def __init__(self, task):
self.task = task
algorithm_factory = providers.Factory(
Algorithm,
task=providers.Factory(
ClassificationTask,
loss=providers.Factory(
Loss,
regularizer=providers.Factory(
Regularizer,
),
),
),
)
algorithm_1 = algorithm_factory(task__loss__regularizer__alpha=0.5)
algorithm_2 = algorithm_factory(task__loss__regularizer__alpha=0.7)
algorithm_3 = algorithm_factory(task__loss__regularizer=Regularizer(alpha=0.8))
self.assertEqual(algorithm_1.task.loss.regularizer.alpha, 0.5)
self.assertEqual(algorithm_2.task.loss.regularizer.alpha, 0.7)
self.assertEqual(algorithm_3.task.loss.regularizer.alpha, 0.8)
def test_fluent_interface(self):
provider = providers.Factory(Example) \
.add_args(1, 2) \