mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2024-11-22 09:36:48 +03:00
Add support of positional argument injections for Callable provider
Also current commit contains: - Some refactoring of internals - Additional unit tests for Factory and Singleton providers
This commit is contained in:
parent
402539ed7f
commit
59b98959bc
|
@ -1,11 +1,14 @@
|
|||
"""Injections module."""
|
||||
|
||||
import sys
|
||||
import itertools
|
||||
|
||||
import six
|
||||
|
||||
from .utils import is_provider
|
||||
from .utils import ensure_is_injection
|
||||
from .utils import get_injectable_kwargs
|
||||
from .utils import is_injection
|
||||
from .utils import is_arg_injection
|
||||
from .utils import is_kwarg_injection
|
||||
|
||||
from .errors import Error
|
||||
|
||||
|
@ -77,11 +80,7 @@ def inject(*args, **kwargs):
|
|||
:type injection: Injection
|
||||
:return: (callable) -> (callable)
|
||||
"""
|
||||
injections = tuple(KwArg(name, value)
|
||||
for name, value in six.iteritems(kwargs))
|
||||
if args:
|
||||
injections += tuple(ensure_is_injection(injection)
|
||||
for injection in args)
|
||||
injections = _parse_kwargs_injections(args, kwargs)
|
||||
|
||||
def decorator(callback_or_cls):
|
||||
"""Dependency injection decorator."""
|
||||
|
@ -107,10 +106,41 @@ def inject(*args, **kwargs):
|
|||
def decorated(*args, **kwargs):
|
||||
"""Decorated with dependency injection callback."""
|
||||
return callback(*args,
|
||||
**get_injectable_kwargs(kwargs,
|
||||
**_get_injectable_kwargs(kwargs,
|
||||
decorated.injections))
|
||||
|
||||
decorated.injections = injections
|
||||
|
||||
return decorated
|
||||
return decorator
|
||||
|
||||
|
||||
def _parse_args_injections(args):
|
||||
"""Parse positional argument injections according to current syntax."""
|
||||
return tuple(Arg(arg) if not is_injection(arg) else arg
|
||||
for arg in args
|
||||
if not is_injection(arg) or is_arg_injection(arg))
|
||||
|
||||
|
||||
def _parse_kwargs_injections(args, kwargs):
|
||||
"""Parse keyword argument injections according to current syntax."""
|
||||
kwarg_injections = tuple(injection
|
||||
for injection in args
|
||||
if is_kwarg_injection(injection))
|
||||
if kwargs:
|
||||
kwarg_injections += tuple(KwArg(name, value)
|
||||
for name, value in six.iteritems(kwargs))
|
||||
return kwarg_injections
|
||||
|
||||
|
||||
def _get_injectable_args(context_args, arg_injections):
|
||||
"""Return tuple of positional arguments, patched with injections."""
|
||||
return itertools.chain((arg.value for arg in arg_injections), context_args)
|
||||
|
||||
|
||||
def _get_injectable_kwargs(context_kwargs, kwarg_injections):
|
||||
"""Return dictionary of keyword arguments, patched with injections."""
|
||||
injectable_kwargs = dict((kwarg.name, kwarg.value)
|
||||
for kwarg in kwarg_injections)
|
||||
injectable_kwargs.update(context_kwargs)
|
||||
return injectable_kwargs
|
||||
|
|
|
@ -2,17 +2,14 @@
|
|||
|
||||
import six
|
||||
|
||||
from .injections import Arg
|
||||
from .injections import KwArg
|
||||
from .injections import _parse_args_injections
|
||||
from .injections import _parse_kwargs_injections
|
||||
from .injections import _get_injectable_args
|
||||
from .injections import _get_injectable_kwargs
|
||||
|
||||
from .utils import ensure_is_provider
|
||||
from .utils import is_injection
|
||||
from .utils import is_arg_injection
|
||||
from .utils import is_kwarg_injection
|
||||
from .utils import is_attribute_injection
|
||||
from .utils import is_method_injection
|
||||
from .utils import get_injectable_args
|
||||
from .utils import get_injectable_kwargs
|
||||
from .utils import GLOBAL_LOCK
|
||||
|
||||
from .errors import Error
|
||||
|
@ -116,15 +113,8 @@ class Factory(Provider):
|
|||
raise Error('Factory provider expects to get callable, ' +
|
||||
'got {0} instead'.format(str(provides)))
|
||||
self.provides = provides
|
||||
self.args = tuple(Arg(arg) if not is_injection(arg) else arg
|
||||
for arg in args
|
||||
if not is_injection(arg) or is_arg_injection(arg))
|
||||
self.kwargs = tuple(injection
|
||||
for injection in args
|
||||
if is_kwarg_injection(injection))
|
||||
if kwargs:
|
||||
self.kwargs += tuple(KwArg(name, value)
|
||||
for name, value in six.iteritems(kwargs))
|
||||
self.args = _parse_args_injections(args)
|
||||
self.kwargs = _parse_kwargs_injections(args, kwargs)
|
||||
self.attributes = tuple(injection
|
||||
for injection in args
|
||||
if is_attribute_injection(injection))
|
||||
|
@ -135,8 +125,8 @@ class Factory(Provider):
|
|||
|
||||
def _provide(self, *args, **kwargs):
|
||||
"""Return provided instance."""
|
||||
instance = self.provides(*get_injectable_args(args, self.args),
|
||||
**get_injectable_kwargs(kwargs, self.kwargs))
|
||||
instance = self.provides(*_get_injectable_args(args, self.args),
|
||||
**_get_injectable_kwargs(kwargs, self.kwargs))
|
||||
for attribute in self.attributes:
|
||||
setattr(instance, attribute.name, attribute.value)
|
||||
for method in self.methods:
|
||||
|
@ -258,21 +248,21 @@ class Callable(Provider):
|
|||
with some predefined dependency injections.
|
||||
"""
|
||||
|
||||
__slots__ = ('callback', 'kwargs')
|
||||
__slots__ = ('callback', 'args', 'kwargs')
|
||||
|
||||
def __init__(self, callback, **kwargs):
|
||||
def __init__(self, callback, *args, **kwargs):
|
||||
"""Initializer."""
|
||||
if not callable(callback):
|
||||
raise Error('Callable expected, got {0}'.format(str(callback)))
|
||||
self.callback = callback
|
||||
self.kwargs = tuple(KwArg(name, value)
|
||||
for name, value in six.iteritems(kwargs))
|
||||
self.args = _parse_args_injections(args)
|
||||
self.kwargs = _parse_kwargs_injections(args, kwargs)
|
||||
super(Callable, self).__init__()
|
||||
|
||||
def _provide(self, *args, **kwargs):
|
||||
"""Return provided instance."""
|
||||
return self.callback(*args, **get_injectable_kwargs(kwargs,
|
||||
self.kwargs))
|
||||
return self.callback(*_get_injectable_args(args, self.args),
|
||||
**_get_injectable_kwargs(kwargs, self.kwargs))
|
||||
|
||||
|
||||
class Config(Provider):
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
"""Utils module."""
|
||||
|
||||
import threading
|
||||
import itertools
|
||||
|
||||
import six
|
||||
|
||||
|
@ -87,16 +86,3 @@ def ensure_is_catalog_bundle(instance):
|
|||
raise Error('Expected catalog bundle instance, '
|
||||
'got {0}'.format(str(instance)))
|
||||
return instance
|
||||
|
||||
|
||||
def get_injectable_args(context_args, arg_injections):
|
||||
"""Return tuple of positional args, patched with injections."""
|
||||
return itertools.chain((arg.value for arg in arg_injections), context_args)
|
||||
|
||||
|
||||
def get_injectable_kwargs(context_kwargs, kwarg_injections):
|
||||
"""Return dictionary of keyword args, patched with injections."""
|
||||
kwargs = dict((kwarg.name, kwarg.value)
|
||||
for kwarg in kwarg_injections)
|
||||
kwargs.update(context_kwargs)
|
||||
return kwargs
|
||||
|
|
|
@ -161,10 +161,6 @@ class InjectTests(unittest.TestCase):
|
|||
self.assertIsInstance(b2, list)
|
||||
self.assertIsNot(b1, b2)
|
||||
|
||||
def test_decorate_with_not_injection(self):
|
||||
"""Test `inject()` decorator with not an injection instance."""
|
||||
self.assertRaises(di.Error, di.inject, object)
|
||||
|
||||
def test_decorate_class_method(self):
|
||||
"""Test `inject()` decorator with class method."""
|
||||
class Test(object):
|
||||
|
|
|
@ -7,7 +7,8 @@ import dependency_injector as di
|
|||
class Example(object):
|
||||
"""Example class for Factory provider tests."""
|
||||
|
||||
def __init__(self, init_arg1=None, init_arg2=None):
|
||||
def __init__(self, init_arg1=None, init_arg2=None, init_arg3=None,
|
||||
init_arg4=None):
|
||||
"""Initializer.
|
||||
|
||||
:param init_arg1:
|
||||
|
@ -16,6 +17,8 @@ class Example(object):
|
|||
"""
|
||||
self.init_arg1 = init_arg1
|
||||
self.init_arg2 = init_arg2
|
||||
self.init_arg3 = init_arg3
|
||||
self.init_arg4 = init_arg4
|
||||
|
||||
self.attribute1 = None
|
||||
self.attribute2 = None
|
||||
|
@ -304,11 +307,13 @@ class FactoryTests(unittest.TestCase):
|
|||
|
||||
def test_call_with_context_args(self):
|
||||
"""Test creation of new instances with context args."""
|
||||
provider = di.Factory(Example)
|
||||
instance = provider(11, 22)
|
||||
provider = di.Factory(Example, 11, 22)
|
||||
instance = provider(33, 44)
|
||||
|
||||
self.assertEqual(instance.init_arg1, 11)
|
||||
self.assertEqual(instance.init_arg2, 22)
|
||||
self.assertEqual(instance.init_arg3, 33)
|
||||
self.assertEqual(instance.init_arg4, 44)
|
||||
|
||||
def test_call_with_context_kwargs(self):
|
||||
"""Test creation of new instances with context kwargs."""
|
||||
|
@ -319,9 +324,19 @@ class FactoryTests(unittest.TestCase):
|
|||
self.assertEqual(instance1.init_arg1, 1)
|
||||
self.assertEqual(instance1.init_arg2, 22)
|
||||
|
||||
instance1 = provider(init_arg1=11, init_arg2=22)
|
||||
self.assertEqual(instance1.init_arg1, 11)
|
||||
self.assertEqual(instance1.init_arg2, 22)
|
||||
instance2 = provider(init_arg1=11, init_arg2=22)
|
||||
self.assertEqual(instance2.init_arg1, 11)
|
||||
self.assertEqual(instance2.init_arg2, 22)
|
||||
|
||||
def test_call_with_context_args_and_kwargs(self):
|
||||
"""Test creation of new instances with context args and kwargs."""
|
||||
provider = di.Factory(Example, 11)
|
||||
instance = provider(22, init_arg3=33, init_arg4=44)
|
||||
|
||||
self.assertEqual(instance.init_arg1, 11)
|
||||
self.assertEqual(instance.init_arg2, 22)
|
||||
self.assertEqual(instance.init_arg3, 33)
|
||||
self.assertEqual(instance.init_arg4, 44)
|
||||
|
||||
def test_call_overridden(self):
|
||||
"""Test creation of new instances on overridden provider."""
|
||||
|
@ -521,6 +536,16 @@ class SingletonTests(unittest.TestCase):
|
|||
self.assertEqual(instance1.init_arg1, 1)
|
||||
self.assertEqual(instance1.init_arg2, 22)
|
||||
|
||||
def test_call_with_context_args_and_kwargs(self):
|
||||
"""Test getting of instances with context args and kwargs."""
|
||||
provider = di.Singleton(Example, 11)
|
||||
instance = provider(22, init_arg3=33, init_arg4=44)
|
||||
|
||||
self.assertEqual(instance.init_arg1, 11)
|
||||
self.assertEqual(instance.init_arg2, 22)
|
||||
self.assertEqual(instance.init_arg3, 33)
|
||||
self.assertEqual(instance.init_arg4, 44)
|
||||
|
||||
def test_call_overridden(self):
|
||||
"""Test getting of instances on overridden provider."""
|
||||
provider = di.Singleton(Example)
|
||||
|
@ -653,52 +678,82 @@ class StaticProvidersTests(unittest.TestCase):
|
|||
class CallableTests(unittest.TestCase):
|
||||
"""Callable test cases."""
|
||||
|
||||
def example(self, arg1, arg2, arg3):
|
||||
def example(self, arg1, arg2, arg3, arg4):
|
||||
"""Example callback."""
|
||||
return arg1, arg2, arg3
|
||||
return arg1, arg2, arg3, arg4
|
||||
|
||||
def setUp(self):
|
||||
"""Set test cases environment up."""
|
||||
self.provider = di.Callable(self.example,
|
||||
arg1='a1',
|
||||
arg2='a2',
|
||||
arg3='a3')
|
||||
def test_init_with_callable(self):
|
||||
"""Test creation of provider with a callable."""
|
||||
self.assertTrue(di.Callable(self.example))
|
||||
|
||||
def test_init_with_not_callable(self):
|
||||
"""Test creation of provider with not callable."""
|
||||
"""Test creation of provider with not a callable."""
|
||||
self.assertRaises(di.Error, di.Callable, 123)
|
||||
|
||||
def test_is_provider(self):
|
||||
"""Test `is_provider` check."""
|
||||
self.assertTrue(di.is_provider(self.provider))
|
||||
|
||||
def test_call(self):
|
||||
"""Test provider call."""
|
||||
self.assertEqual(self.provider(), ('a1', 'a2', 'a3'))
|
||||
"""Test call."""
|
||||
provider = di.Callable(lambda: True)
|
||||
self.assertTrue(provider())
|
||||
|
||||
def test_call_with_args(self):
|
||||
"""Test provider call with kwargs priority."""
|
||||
def test_call_with_positional_args(self):
|
||||
"""Test call with positional args.
|
||||
|
||||
New simplified syntax.
|
||||
"""
|
||||
provider = di.Callable(self.example, 1, 2, 3, 4)
|
||||
self.assertTupleEqual(provider(), (1, 2, 3, 4))
|
||||
|
||||
def test_call_with_keyword_args(self):
|
||||
"""Test call with keyword args.
|
||||
|
||||
New simplified syntax.
|
||||
"""
|
||||
provider = di.Callable(self.example, arg1=1, arg2=2, arg3=3, arg4=4)
|
||||
self.assertTupleEqual(provider(), (1, 2, 3, 4))
|
||||
|
||||
def test_call_with_positional_and_keyword_args(self):
|
||||
"""Test call with positional and keyword args.
|
||||
|
||||
Simplified syntax of positional and keyword arg injections.
|
||||
"""
|
||||
provider = di.Callable(self.example, 1, 2, arg3=3, arg4=4)
|
||||
self.assertTupleEqual(provider(), (1, 2, 3, 4))
|
||||
|
||||
def test_call_with_positional_and_keyword_args_extended_syntax(self):
|
||||
"""Test call with positional and keyword args.
|
||||
|
||||
Extended syntax of positional and keyword arg injections.
|
||||
"""
|
||||
provider = di.Callable(self.example,
|
||||
arg3='a3')
|
||||
self.assertEqual(provider(1, 2), (1, 2, 'a3'))
|
||||
di.Arg(1),
|
||||
di.Arg(2),
|
||||
di.KwArg('arg3', 3),
|
||||
di.KwArg('arg4', 4))
|
||||
self.assertTupleEqual(provider(), (1, 2, 3, 4))
|
||||
|
||||
def test_call_with_kwargs_priority(self):
|
||||
"""Test provider call with kwargs priority."""
|
||||
self.assertEqual(self.provider(arg1=1, arg3=3), (1, 'a2', 3))
|
||||
def test_call_with_context_args(self):
|
||||
"""Test call with context args."""
|
||||
provider = di.Callable(self.example, 1, 2)
|
||||
self.assertTupleEqual(provider(3, 4), (1, 2, 3, 4))
|
||||
|
||||
def test_call_with_context_kwargs(self):
|
||||
"""Test call with context kwargs."""
|
||||
provider = di.Callable(self.example,
|
||||
di.KwArg('arg1', 1))
|
||||
self.assertTupleEqual(provider(arg2=2, arg3=3, arg4=4), (1, 2, 3, 4))
|
||||
|
||||
def test_call_with_context_args_and_kwargs(self):
|
||||
"""Test call with context args and kwargs."""
|
||||
provider = di.Callable(self.example, 1)
|
||||
self.assertTupleEqual(provider(2, arg3=3, arg4=4), (1, 2, 3, 4))
|
||||
|
||||
def test_call_overridden(self):
|
||||
"""Test overridden provider call."""
|
||||
overriding_provider1 = di.Value((1, 2, 3))
|
||||
overriding_provider2 = di.Value((3, 2, 1))
|
||||
"""Test creation of new instances on overridden provider."""
|
||||
provider = di.Callable(self.example)
|
||||
provider.override(di.Value((4, 3, 2, 1)))
|
||||
provider.override(di.Value((1, 2, 3, 4)))
|
||||
|
||||
self.provider.override(overriding_provider1)
|
||||
self.provider.override(overriding_provider2)
|
||||
|
||||
result1 = self.provider()
|
||||
result2 = self.provider()
|
||||
|
||||
self.assertEqual(result1, (3, 2, 1))
|
||||
self.assertEqual(result2, (3, 2, 1))
|
||||
self.assertTupleEqual(provider(), (1, 2, 3, 4))
|
||||
|
||||
|
||||
class ConfigTests(unittest.TestCase):
|
||||
|
|
Loading…
Reference in New Issue
Block a user