Add support of positional argument injections for Callable provider

Also current commit contains:
- Some refactoring of internals
- Additional unit tests for Factory and Singleton providers
This commit is contained in:
Roman Mogilatov 2015-10-21 11:44:25 +03:00
parent 402539ed7f
commit 59b98959bc
5 changed files with 147 additions and 90 deletions

View File

@ -1,11 +1,14 @@
"""Injections module.""" """Injections module."""
import sys import sys
import itertools
import six import six
from .utils import is_provider from .utils import is_provider
from .utils import ensure_is_injection from .utils import is_injection
from .utils import get_injectable_kwargs from .utils import is_arg_injection
from .utils import is_kwarg_injection
from .errors import Error from .errors import Error
@ -77,11 +80,7 @@ def inject(*args, **kwargs):
:type injection: Injection :type injection: Injection
:return: (callable) -> (callable) :return: (callable) -> (callable)
""" """
injections = tuple(KwArg(name, value) injections = _parse_kwargs_injections(args, kwargs)
for name, value in six.iteritems(kwargs))
if args:
injections += tuple(ensure_is_injection(injection)
for injection in args)
def decorator(callback_or_cls): def decorator(callback_or_cls):
"""Dependency injection decorator.""" """Dependency injection decorator."""
@ -107,10 +106,41 @@ def inject(*args, **kwargs):
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
"""Decorated with dependency injection callback.""" """Decorated with dependency injection callback."""
return callback(*args, return callback(*args,
**get_injectable_kwargs(kwargs, **_get_injectable_kwargs(kwargs,
decorated.injections)) decorated.injections))
decorated.injections = injections decorated.injections = injections
return decorated return decorated
return decorator return decorator
def _parse_args_injections(args):
"""Parse positional argument injections according to current syntax."""
return tuple(Arg(arg) if not is_injection(arg) else arg
for arg in args
if not is_injection(arg) or is_arg_injection(arg))
def _parse_kwargs_injections(args, kwargs):
"""Parse keyword argument injections according to current syntax."""
kwarg_injections = tuple(injection
for injection in args
if is_kwarg_injection(injection))
if kwargs:
kwarg_injections += tuple(KwArg(name, value)
for name, value in six.iteritems(kwargs))
return kwarg_injections
def _get_injectable_args(context_args, arg_injections):
"""Return tuple of positional arguments, patched with injections."""
return itertools.chain((arg.value for arg in arg_injections), context_args)
def _get_injectable_kwargs(context_kwargs, kwarg_injections):
"""Return dictionary of keyword arguments, patched with injections."""
injectable_kwargs = dict((kwarg.name, kwarg.value)
for kwarg in kwarg_injections)
injectable_kwargs.update(context_kwargs)
return injectable_kwargs

View File

@ -2,17 +2,14 @@
import six import six
from .injections import Arg from .injections import _parse_args_injections
from .injections import KwArg from .injections import _parse_kwargs_injections
from .injections import _get_injectable_args
from .injections import _get_injectable_kwargs
from .utils import ensure_is_provider from .utils import ensure_is_provider
from .utils import is_injection
from .utils import is_arg_injection
from .utils import is_kwarg_injection
from .utils import is_attribute_injection from .utils import is_attribute_injection
from .utils import is_method_injection from .utils import is_method_injection
from .utils import get_injectable_args
from .utils import get_injectable_kwargs
from .utils import GLOBAL_LOCK from .utils import GLOBAL_LOCK
from .errors import Error from .errors import Error
@ -116,15 +113,8 @@ class Factory(Provider):
raise Error('Factory provider expects to get callable, ' + raise Error('Factory provider expects to get callable, ' +
'got {0} instead'.format(str(provides))) 'got {0} instead'.format(str(provides)))
self.provides = provides self.provides = provides
self.args = tuple(Arg(arg) if not is_injection(arg) else arg self.args = _parse_args_injections(args)
for arg in args self.kwargs = _parse_kwargs_injections(args, kwargs)
if not is_injection(arg) or is_arg_injection(arg))
self.kwargs = tuple(injection
for injection in args
if is_kwarg_injection(injection))
if kwargs:
self.kwargs += tuple(KwArg(name, value)
for name, value in six.iteritems(kwargs))
self.attributes = tuple(injection self.attributes = tuple(injection
for injection in args for injection in args
if is_attribute_injection(injection)) if is_attribute_injection(injection))
@ -135,8 +125,8 @@ class Factory(Provider):
def _provide(self, *args, **kwargs): def _provide(self, *args, **kwargs):
"""Return provided instance.""" """Return provided instance."""
instance = self.provides(*get_injectable_args(args, self.args), instance = self.provides(*_get_injectable_args(args, self.args),
**get_injectable_kwargs(kwargs, self.kwargs)) **_get_injectable_kwargs(kwargs, self.kwargs))
for attribute in self.attributes: for attribute in self.attributes:
setattr(instance, attribute.name, attribute.value) setattr(instance, attribute.name, attribute.value)
for method in self.methods: for method in self.methods:
@ -258,21 +248,21 @@ class Callable(Provider):
with some predefined dependency injections. with some predefined dependency injections.
""" """
__slots__ = ('callback', 'kwargs') __slots__ = ('callback', 'args', 'kwargs')
def __init__(self, callback, **kwargs): def __init__(self, callback, *args, **kwargs):
"""Initializer.""" """Initializer."""
if not callable(callback): if not callable(callback):
raise Error('Callable expected, got {0}'.format(str(callback))) raise Error('Callable expected, got {0}'.format(str(callback)))
self.callback = callback self.callback = callback
self.kwargs = tuple(KwArg(name, value) self.args = _parse_args_injections(args)
for name, value in six.iteritems(kwargs)) self.kwargs = _parse_kwargs_injections(args, kwargs)
super(Callable, self).__init__() super(Callable, self).__init__()
def _provide(self, *args, **kwargs): def _provide(self, *args, **kwargs):
"""Return provided instance.""" """Return provided instance."""
return self.callback(*args, **get_injectable_kwargs(kwargs, return self.callback(*_get_injectable_args(args, self.args),
self.kwargs)) **_get_injectable_kwargs(kwargs, self.kwargs))
class Config(Provider): class Config(Provider):

View File

@ -1,7 +1,6 @@
"""Utils module.""" """Utils module."""
import threading import threading
import itertools
import six import six
@ -87,16 +86,3 @@ def ensure_is_catalog_bundle(instance):
raise Error('Expected catalog bundle instance, ' raise Error('Expected catalog bundle instance, '
'got {0}'.format(str(instance))) 'got {0}'.format(str(instance)))
return instance return instance
def get_injectable_args(context_args, arg_injections):
"""Return tuple of positional args, patched with injections."""
return itertools.chain((arg.value for arg in arg_injections), context_args)
def get_injectable_kwargs(context_kwargs, kwarg_injections):
"""Return dictionary of keyword args, patched with injections."""
kwargs = dict((kwarg.name, kwarg.value)
for kwarg in kwarg_injections)
kwargs.update(context_kwargs)
return kwargs

View File

@ -161,10 +161,6 @@ class InjectTests(unittest.TestCase):
self.assertIsInstance(b2, list) self.assertIsInstance(b2, list)
self.assertIsNot(b1, b2) self.assertIsNot(b1, b2)
def test_decorate_with_not_injection(self):
"""Test `inject()` decorator with not an injection instance."""
self.assertRaises(di.Error, di.inject, object)
def test_decorate_class_method(self): def test_decorate_class_method(self):
"""Test `inject()` decorator with class method.""" """Test `inject()` decorator with class method."""
class Test(object): class Test(object):

View File

@ -7,7 +7,8 @@ import dependency_injector as di
class Example(object): class Example(object):
"""Example class for Factory provider tests.""" """Example class for Factory provider tests."""
def __init__(self, init_arg1=None, init_arg2=None): def __init__(self, init_arg1=None, init_arg2=None, init_arg3=None,
init_arg4=None):
"""Initializer. """Initializer.
:param init_arg1: :param init_arg1:
@ -16,6 +17,8 @@ class Example(object):
""" """
self.init_arg1 = init_arg1 self.init_arg1 = init_arg1
self.init_arg2 = init_arg2 self.init_arg2 = init_arg2
self.init_arg3 = init_arg3
self.init_arg4 = init_arg4
self.attribute1 = None self.attribute1 = None
self.attribute2 = None self.attribute2 = None
@ -304,11 +307,13 @@ class FactoryTests(unittest.TestCase):
def test_call_with_context_args(self): def test_call_with_context_args(self):
"""Test creation of new instances with context args.""" """Test creation of new instances with context args."""
provider = di.Factory(Example) provider = di.Factory(Example, 11, 22)
instance = provider(11, 22) instance = provider(33, 44)
self.assertEqual(instance.init_arg1, 11) self.assertEqual(instance.init_arg1, 11)
self.assertEqual(instance.init_arg2, 22) self.assertEqual(instance.init_arg2, 22)
self.assertEqual(instance.init_arg3, 33)
self.assertEqual(instance.init_arg4, 44)
def test_call_with_context_kwargs(self): def test_call_with_context_kwargs(self):
"""Test creation of new instances with context kwargs.""" """Test creation of new instances with context kwargs."""
@ -319,9 +324,19 @@ class FactoryTests(unittest.TestCase):
self.assertEqual(instance1.init_arg1, 1) self.assertEqual(instance1.init_arg1, 1)
self.assertEqual(instance1.init_arg2, 22) self.assertEqual(instance1.init_arg2, 22)
instance1 = provider(init_arg1=11, init_arg2=22) instance2 = provider(init_arg1=11, init_arg2=22)
self.assertEqual(instance1.init_arg1, 11) self.assertEqual(instance2.init_arg1, 11)
self.assertEqual(instance1.init_arg2, 22) self.assertEqual(instance2.init_arg2, 22)
def test_call_with_context_args_and_kwargs(self):
"""Test creation of new instances with context args and kwargs."""
provider = di.Factory(Example, 11)
instance = provider(22, init_arg3=33, init_arg4=44)
self.assertEqual(instance.init_arg1, 11)
self.assertEqual(instance.init_arg2, 22)
self.assertEqual(instance.init_arg3, 33)
self.assertEqual(instance.init_arg4, 44)
def test_call_overridden(self): def test_call_overridden(self):
"""Test creation of new instances on overridden provider.""" """Test creation of new instances on overridden provider."""
@ -521,6 +536,16 @@ class SingletonTests(unittest.TestCase):
self.assertEqual(instance1.init_arg1, 1) self.assertEqual(instance1.init_arg1, 1)
self.assertEqual(instance1.init_arg2, 22) self.assertEqual(instance1.init_arg2, 22)
def test_call_with_context_args_and_kwargs(self):
"""Test getting of instances with context args and kwargs."""
provider = di.Singleton(Example, 11)
instance = provider(22, init_arg3=33, init_arg4=44)
self.assertEqual(instance.init_arg1, 11)
self.assertEqual(instance.init_arg2, 22)
self.assertEqual(instance.init_arg3, 33)
self.assertEqual(instance.init_arg4, 44)
def test_call_overridden(self): def test_call_overridden(self):
"""Test getting of instances on overridden provider.""" """Test getting of instances on overridden provider."""
provider = di.Singleton(Example) provider = di.Singleton(Example)
@ -653,52 +678,82 @@ class StaticProvidersTests(unittest.TestCase):
class CallableTests(unittest.TestCase): class CallableTests(unittest.TestCase):
"""Callable test cases.""" """Callable test cases."""
def example(self, arg1, arg2, arg3): def example(self, arg1, arg2, arg3, arg4):
"""Example callback.""" """Example callback."""
return arg1, arg2, arg3 return arg1, arg2, arg3, arg4
def setUp(self): def test_init_with_callable(self):
"""Set test cases environment up.""" """Test creation of provider with a callable."""
self.provider = di.Callable(self.example, self.assertTrue(di.Callable(self.example))
arg1='a1',
arg2='a2',
arg3='a3')
def test_init_with_not_callable(self): def test_init_with_not_callable(self):
"""Test creation of provider with not callable.""" """Test creation of provider with not a callable."""
self.assertRaises(di.Error, di.Callable, 123) self.assertRaises(di.Error, di.Callable, 123)
def test_is_provider(self):
"""Test `is_provider` check."""
self.assertTrue(di.is_provider(self.provider))
def test_call(self): def test_call(self):
"""Test provider call.""" """Test call."""
self.assertEqual(self.provider(), ('a1', 'a2', 'a3')) provider = di.Callable(lambda: True)
self.assertTrue(provider())
def test_call_with_args(self): def test_call_with_positional_args(self):
"""Test provider call with kwargs priority.""" """Test call with positional args.
New simplified syntax.
"""
provider = di.Callable(self.example, 1, 2, 3, 4)
self.assertTupleEqual(provider(), (1, 2, 3, 4))
def test_call_with_keyword_args(self):
"""Test call with keyword args.
New simplified syntax.
"""
provider = di.Callable(self.example, arg1=1, arg2=2, arg3=3, arg4=4)
self.assertTupleEqual(provider(), (1, 2, 3, 4))
def test_call_with_positional_and_keyword_args(self):
"""Test call with positional and keyword args.
Simplified syntax of positional and keyword arg injections.
"""
provider = di.Callable(self.example, 1, 2, arg3=3, arg4=4)
self.assertTupleEqual(provider(), (1, 2, 3, 4))
def test_call_with_positional_and_keyword_args_extended_syntax(self):
"""Test call with positional and keyword args.
Extended syntax of positional and keyword arg injections.
"""
provider = di.Callable(self.example, provider = di.Callable(self.example,
arg3='a3') di.Arg(1),
self.assertEqual(provider(1, 2), (1, 2, 'a3')) di.Arg(2),
di.KwArg('arg3', 3),
di.KwArg('arg4', 4))
self.assertTupleEqual(provider(), (1, 2, 3, 4))
def test_call_with_kwargs_priority(self): def test_call_with_context_args(self):
"""Test provider call with kwargs priority.""" """Test call with context args."""
self.assertEqual(self.provider(arg1=1, arg3=3), (1, 'a2', 3)) provider = di.Callable(self.example, 1, 2)
self.assertTupleEqual(provider(3, 4), (1, 2, 3, 4))
def test_call_with_context_kwargs(self):
"""Test call with context kwargs."""
provider = di.Callable(self.example,
di.KwArg('arg1', 1))
self.assertTupleEqual(provider(arg2=2, arg3=3, arg4=4), (1, 2, 3, 4))
def test_call_with_context_args_and_kwargs(self):
"""Test call with context args and kwargs."""
provider = di.Callable(self.example, 1)
self.assertTupleEqual(provider(2, arg3=3, arg4=4), (1, 2, 3, 4))
def test_call_overridden(self): def test_call_overridden(self):
"""Test overridden provider call.""" """Test creation of new instances on overridden provider."""
overriding_provider1 = di.Value((1, 2, 3)) provider = di.Callable(self.example)
overriding_provider2 = di.Value((3, 2, 1)) provider.override(di.Value((4, 3, 2, 1)))
provider.override(di.Value((1, 2, 3, 4)))
self.provider.override(overriding_provider1) self.assertTupleEqual(provider(), (1, 2, 3, 4))
self.provider.override(overriding_provider2)
result1 = self.provider()
result2 = self.provider()
self.assertEqual(result1, (3, 2, 1))
self.assertEqual(result2, (3, 2, 1))
class ConfigTests(unittest.TestCase): class ConfigTests(unittest.TestCase):