diff --git a/tests/unit/providers/test_factories_py2_py3.py b/tests/unit/providers/test_factories_py2_py3.py index d8be19fb..af2342dc 100644 --- a/tests/unit/providers/test_factories_py2_py3.py +++ b/tests/unit/providers/test_factories_py2_py3.py @@ -166,6 +166,45 @@ class FactoryTests(unittest.TestCase): self.assertEqual(instance.init_arg3, 33) self.assertEqual(instance.init_arg4, 44) + def test_call_with_deep_context_kwargs(self): + """`Factory` providers deep init injections example.""" + class Regularizer: + def __init__(self, alpha): + self.alpha = alpha + + class Loss: + def __init__(self, regularizer): + self.regularizer = regularizer + + class ClassificationTask: + def __init__(self, loss): + self.loss = loss + + class Algorithm: + def __init__(self, task): + self.task = task + + algorithm_factory = providers.Factory( + Algorithm, + task=providers.Factory( + ClassificationTask, + loss=providers.Factory( + Loss, + regularizer=providers.Factory( + Regularizer, + ), + ), + ), + ) + + algorithm_1 = algorithm_factory(task__loss__regularizer__alpha=0.5) + algorithm_2 = algorithm_factory(task__loss__regularizer__alpha=0.7) + algorithm_3 = algorithm_factory(task__loss__regularizer=Regularizer(alpha=0.8)) + + self.assertEqual(algorithm_1.task.loss.regularizer.alpha, 0.5) + self.assertEqual(algorithm_2.task.loss.regularizer.alpha, 0.7) + self.assertEqual(algorithm_3.task.loss.regularizer.alpha, 0.8) + def test_fluent_interface(self): provider = providers.Factory(Example) \ .add_args(1, 2) \