From 59b98959bc35f29c1573e08bb7b8b19125479b2c Mon Sep 17 00:00:00 2001 From: Roman Mogilatov Date: Wed, 21 Oct 2015 11:44:25 +0300 Subject: [PATCH] Add support of positional argument injections for Callable provider Also current commit contains: - Some refactoring of internals - Additional unit tests for Factory and Singleton providers --- dependency_injector/injections.py | 48 +++++++++-- dependency_injector/providers.py | 38 ++++----- dependency_injector/utils.py | 14 ---- tests/test_injections.py | 4 - tests/test_providers.py | 133 +++++++++++++++++++++--------- 5 files changed, 147 insertions(+), 90 deletions(-) diff --git a/dependency_injector/injections.py b/dependency_injector/injections.py index 2806b7c0..6833c41c 100644 --- a/dependency_injector/injections.py +++ b/dependency_injector/injections.py @@ -1,11 +1,14 @@ """Injections module.""" import sys +import itertools + import six from .utils import is_provider -from .utils import ensure_is_injection -from .utils import get_injectable_kwargs +from .utils import is_injection +from .utils import is_arg_injection +from .utils import is_kwarg_injection from .errors import Error @@ -77,11 +80,7 @@ def inject(*args, **kwargs): :type injection: Injection :return: (callable) -> (callable) """ - injections = tuple(KwArg(name, value) - for name, value in six.iteritems(kwargs)) - if args: - injections += tuple(ensure_is_injection(injection) - for injection in args) + injections = _parse_kwargs_injections(args, kwargs) def decorator(callback_or_cls): """Dependency injection decorator.""" @@ -107,10 +106,41 @@ def inject(*args, **kwargs): def decorated(*args, **kwargs): """Decorated with dependency injection callback.""" return callback(*args, - **get_injectable_kwargs(kwargs, - decorated.injections)) + **_get_injectable_kwargs(kwargs, + decorated.injections)) decorated.injections = injections return decorated return decorator + + +def _parse_args_injections(args): + """Parse positional argument injections according to current syntax.""" + return tuple(Arg(arg) if not is_injection(arg) else arg + for arg in args + if not is_injection(arg) or is_arg_injection(arg)) + + +def _parse_kwargs_injections(args, kwargs): + """Parse keyword argument injections according to current syntax.""" + kwarg_injections = tuple(injection + for injection in args + if is_kwarg_injection(injection)) + if kwargs: + kwarg_injections += tuple(KwArg(name, value) + for name, value in six.iteritems(kwargs)) + return kwarg_injections + + +def _get_injectable_args(context_args, arg_injections): + """Return tuple of positional arguments, patched with injections.""" + return itertools.chain((arg.value for arg in arg_injections), context_args) + + +def _get_injectable_kwargs(context_kwargs, kwarg_injections): + """Return dictionary of keyword arguments, patched with injections.""" + injectable_kwargs = dict((kwarg.name, kwarg.value) + for kwarg in kwarg_injections) + injectable_kwargs.update(context_kwargs) + return injectable_kwargs diff --git a/dependency_injector/providers.py b/dependency_injector/providers.py index ed7d2145..b9f49696 100644 --- a/dependency_injector/providers.py +++ b/dependency_injector/providers.py @@ -2,17 +2,14 @@ import six -from .injections import Arg -from .injections import KwArg +from .injections import _parse_args_injections +from .injections import _parse_kwargs_injections +from .injections import _get_injectable_args +from .injections import _get_injectable_kwargs from .utils import ensure_is_provider -from .utils import is_injection -from .utils import is_arg_injection -from .utils import is_kwarg_injection from .utils import is_attribute_injection from .utils import is_method_injection -from .utils import get_injectable_args -from .utils import get_injectable_kwargs from .utils import GLOBAL_LOCK from .errors import Error @@ -116,15 +113,8 @@ class Factory(Provider): raise Error('Factory provider expects to get callable, ' + 'got {0} instead'.format(str(provides))) self.provides = provides - self.args = tuple(Arg(arg) if not is_injection(arg) else arg - for arg in args - if not is_injection(arg) or is_arg_injection(arg)) - self.kwargs = tuple(injection - for injection in args - if is_kwarg_injection(injection)) - if kwargs: - self.kwargs += tuple(KwArg(name, value) - for name, value in six.iteritems(kwargs)) + self.args = _parse_args_injections(args) + self.kwargs = _parse_kwargs_injections(args, kwargs) self.attributes = tuple(injection for injection in args if is_attribute_injection(injection)) @@ -135,8 +125,8 @@ class Factory(Provider): def _provide(self, *args, **kwargs): """Return provided instance.""" - instance = self.provides(*get_injectable_args(args, self.args), - **get_injectable_kwargs(kwargs, self.kwargs)) + instance = self.provides(*_get_injectable_args(args, self.args), + **_get_injectable_kwargs(kwargs, self.kwargs)) for attribute in self.attributes: setattr(instance, attribute.name, attribute.value) for method in self.methods: @@ -258,21 +248,21 @@ class Callable(Provider): with some predefined dependency injections. """ - __slots__ = ('callback', 'kwargs') + __slots__ = ('callback', 'args', 'kwargs') - def __init__(self, callback, **kwargs): + def __init__(self, callback, *args, **kwargs): """Initializer.""" if not callable(callback): raise Error('Callable expected, got {0}'.format(str(callback))) self.callback = callback - self.kwargs = tuple(KwArg(name, value) - for name, value in six.iteritems(kwargs)) + self.args = _parse_args_injections(args) + self.kwargs = _parse_kwargs_injections(args, kwargs) super(Callable, self).__init__() def _provide(self, *args, **kwargs): """Return provided instance.""" - return self.callback(*args, **get_injectable_kwargs(kwargs, - self.kwargs)) + return self.callback(*_get_injectable_args(args, self.args), + **_get_injectable_kwargs(kwargs, self.kwargs)) class Config(Provider): diff --git a/dependency_injector/utils.py b/dependency_injector/utils.py index 007b70b0..903afc8b 100644 --- a/dependency_injector/utils.py +++ b/dependency_injector/utils.py @@ -1,7 +1,6 @@ """Utils module.""" import threading -import itertools import six @@ -87,16 +86,3 @@ def ensure_is_catalog_bundle(instance): raise Error('Expected catalog bundle instance, ' 'got {0}'.format(str(instance))) return instance - - -def get_injectable_args(context_args, arg_injections): - """Return tuple of positional args, patched with injections.""" - return itertools.chain((arg.value for arg in arg_injections), context_args) - - -def get_injectable_kwargs(context_kwargs, kwarg_injections): - """Return dictionary of keyword args, patched with injections.""" - kwargs = dict((kwarg.name, kwarg.value) - for kwarg in kwarg_injections) - kwargs.update(context_kwargs) - return kwargs diff --git a/tests/test_injections.py b/tests/test_injections.py index dd8573bd..d52ebf0b 100644 --- a/tests/test_injections.py +++ b/tests/test_injections.py @@ -161,10 +161,6 @@ class InjectTests(unittest.TestCase): self.assertIsInstance(b2, list) self.assertIsNot(b1, b2) - def test_decorate_with_not_injection(self): - """Test `inject()` decorator with not an injection instance.""" - self.assertRaises(di.Error, di.inject, object) - def test_decorate_class_method(self): """Test `inject()` decorator with class method.""" class Test(object): diff --git a/tests/test_providers.py b/tests/test_providers.py index 9887f952..7161acb9 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -7,7 +7,8 @@ import dependency_injector as di class Example(object): """Example class for Factory provider tests.""" - def __init__(self, init_arg1=None, init_arg2=None): + def __init__(self, init_arg1=None, init_arg2=None, init_arg3=None, + init_arg4=None): """Initializer. :param init_arg1: @@ -16,6 +17,8 @@ class Example(object): """ self.init_arg1 = init_arg1 self.init_arg2 = init_arg2 + self.init_arg3 = init_arg3 + self.init_arg4 = init_arg4 self.attribute1 = None self.attribute2 = None @@ -304,11 +307,13 @@ class FactoryTests(unittest.TestCase): def test_call_with_context_args(self): """Test creation of new instances with context args.""" - provider = di.Factory(Example) - instance = provider(11, 22) + provider = di.Factory(Example, 11, 22) + instance = provider(33, 44) self.assertEqual(instance.init_arg1, 11) self.assertEqual(instance.init_arg2, 22) + self.assertEqual(instance.init_arg3, 33) + self.assertEqual(instance.init_arg4, 44) def test_call_with_context_kwargs(self): """Test creation of new instances with context kwargs.""" @@ -319,9 +324,19 @@ class FactoryTests(unittest.TestCase): self.assertEqual(instance1.init_arg1, 1) self.assertEqual(instance1.init_arg2, 22) - instance1 = provider(init_arg1=11, init_arg2=22) - self.assertEqual(instance1.init_arg1, 11) - self.assertEqual(instance1.init_arg2, 22) + instance2 = provider(init_arg1=11, init_arg2=22) + self.assertEqual(instance2.init_arg1, 11) + self.assertEqual(instance2.init_arg2, 22) + + def test_call_with_context_args_and_kwargs(self): + """Test creation of new instances with context args and kwargs.""" + provider = di.Factory(Example, 11) + instance = provider(22, init_arg3=33, init_arg4=44) + + self.assertEqual(instance.init_arg1, 11) + self.assertEqual(instance.init_arg2, 22) + self.assertEqual(instance.init_arg3, 33) + self.assertEqual(instance.init_arg4, 44) def test_call_overridden(self): """Test creation of new instances on overridden provider.""" @@ -521,6 +536,16 @@ class SingletonTests(unittest.TestCase): self.assertEqual(instance1.init_arg1, 1) self.assertEqual(instance1.init_arg2, 22) + def test_call_with_context_args_and_kwargs(self): + """Test getting of instances with context args and kwargs.""" + provider = di.Singleton(Example, 11) + instance = provider(22, init_arg3=33, init_arg4=44) + + self.assertEqual(instance.init_arg1, 11) + self.assertEqual(instance.init_arg2, 22) + self.assertEqual(instance.init_arg3, 33) + self.assertEqual(instance.init_arg4, 44) + def test_call_overridden(self): """Test getting of instances on overridden provider.""" provider = di.Singleton(Example) @@ -653,52 +678,82 @@ class StaticProvidersTests(unittest.TestCase): class CallableTests(unittest.TestCase): """Callable test cases.""" - def example(self, arg1, arg2, arg3): + def example(self, arg1, arg2, arg3, arg4): """Example callback.""" - return arg1, arg2, arg3 + return arg1, arg2, arg3, arg4 - def setUp(self): - """Set test cases environment up.""" - self.provider = di.Callable(self.example, - arg1='a1', - arg2='a2', - arg3='a3') + def test_init_with_callable(self): + """Test creation of provider with a callable.""" + self.assertTrue(di.Callable(self.example)) def test_init_with_not_callable(self): - """Test creation of provider with not callable.""" + """Test creation of provider with not a callable.""" self.assertRaises(di.Error, di.Callable, 123) - def test_is_provider(self): - """Test `is_provider` check.""" - self.assertTrue(di.is_provider(self.provider)) - def test_call(self): - """Test provider call.""" - self.assertEqual(self.provider(), ('a1', 'a2', 'a3')) + """Test call.""" + provider = di.Callable(lambda: True) + self.assertTrue(provider()) - def test_call_with_args(self): - """Test provider call with kwargs priority.""" + def test_call_with_positional_args(self): + """Test call with positional args. + + New simplified syntax. + """ + provider = di.Callable(self.example, 1, 2, 3, 4) + self.assertTupleEqual(provider(), (1, 2, 3, 4)) + + def test_call_with_keyword_args(self): + """Test call with keyword args. + + New simplified syntax. + """ + provider = di.Callable(self.example, arg1=1, arg2=2, arg3=3, arg4=4) + self.assertTupleEqual(provider(), (1, 2, 3, 4)) + + def test_call_with_positional_and_keyword_args(self): + """Test call with positional and keyword args. + + Simplified syntax of positional and keyword arg injections. + """ + provider = di.Callable(self.example, 1, 2, arg3=3, arg4=4) + self.assertTupleEqual(provider(), (1, 2, 3, 4)) + + def test_call_with_positional_and_keyword_args_extended_syntax(self): + """Test call with positional and keyword args. + + Extended syntax of positional and keyword arg injections. + """ provider = di.Callable(self.example, - arg3='a3') - self.assertEqual(provider(1, 2), (1, 2, 'a3')) + di.Arg(1), + di.Arg(2), + di.KwArg('arg3', 3), + di.KwArg('arg4', 4)) + self.assertTupleEqual(provider(), (1, 2, 3, 4)) - def test_call_with_kwargs_priority(self): - """Test provider call with kwargs priority.""" - self.assertEqual(self.provider(arg1=1, arg3=3), (1, 'a2', 3)) + def test_call_with_context_args(self): + """Test call with context args.""" + provider = di.Callable(self.example, 1, 2) + self.assertTupleEqual(provider(3, 4), (1, 2, 3, 4)) + + def test_call_with_context_kwargs(self): + """Test call with context kwargs.""" + provider = di.Callable(self.example, + di.KwArg('arg1', 1)) + self.assertTupleEqual(provider(arg2=2, arg3=3, arg4=4), (1, 2, 3, 4)) + + def test_call_with_context_args_and_kwargs(self): + """Test call with context args and kwargs.""" + provider = di.Callable(self.example, 1) + self.assertTupleEqual(provider(2, arg3=3, arg4=4), (1, 2, 3, 4)) def test_call_overridden(self): - """Test overridden provider call.""" - overriding_provider1 = di.Value((1, 2, 3)) - overriding_provider2 = di.Value((3, 2, 1)) + """Test creation of new instances on overridden provider.""" + provider = di.Callable(self.example) + provider.override(di.Value((4, 3, 2, 1))) + provider.override(di.Value((1, 2, 3, 4))) - self.provider.override(overriding_provider1) - self.provider.override(overriding_provider2) - - result1 = self.provider() - result2 = self.provider() - - self.assertEqual(result1, (3, 2, 1)) - self.assertEqual(result2, (3, 2, 1)) + self.assertTupleEqual(provider(), (1, 2, 3, 4)) class ConfigTests(unittest.TestCase):