This commit is contained in:
Roman Mogylatov 2020-08-06 16:30:53 -04:00
parent ffca0dac0f
commit 04a16ecf0b

View File

@ -166,6 +166,45 @@ class FactoryTests(unittest.TestCase):
self.assertEqual(instance.init_arg3, 33)
self.assertEqual(instance.init_arg4, 44)
def test_call_with_deep_context_kwargs(self):
"""`Factory` providers deep init injections example."""
class Regularizer:
def __init__(self, alpha):
self.alpha = alpha
class Loss:
def __init__(self, regularizer):
self.regularizer = regularizer
class ClassificationTask:
def __init__(self, loss):
self.loss = loss
class Algorithm:
def __init__(self, task):
self.task = task
algorithm_factory = providers.Factory(
Algorithm,
task=providers.Factory(
ClassificationTask,
loss=providers.Factory(
Loss,
regularizer=providers.Factory(
Regularizer,
),
),
),
)
algorithm_1 = algorithm_factory(task__loss__regularizer__alpha=0.5)
algorithm_2 = algorithm_factory(task__loss__regularizer__alpha=0.7)
algorithm_3 = algorithm_factory(task__loss__regularizer=Regularizer(alpha=0.8))
self.assertEqual(algorithm_1.task.loss.regularizer.alpha, 0.5)
self.assertEqual(algorithm_2.task.loss.regularizer.alpha, 0.7)
self.assertEqual(algorithm_3.task.loss.regularizer.alpha, 0.8)
def test_fluent_interface(self):
provider = providers.Factory(Example) \
.add_args(1, 2) \