mirror of
https://github.com/ets-labs/python-dependency-injector.git
synced 2024-11-22 09:36:48 +03:00
Add support of positional argument injections for Callable provider
Also current commit contains: - Some refactoring of internals - Additional unit tests for Factory and Singleton providers
This commit is contained in:
parent
402539ed7f
commit
59b98959bc
|
@ -1,11 +1,14 @@
|
||||||
"""Injections module."""
|
"""Injections module."""
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
import itertools
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from .utils import is_provider
|
from .utils import is_provider
|
||||||
from .utils import ensure_is_injection
|
from .utils import is_injection
|
||||||
from .utils import get_injectable_kwargs
|
from .utils import is_arg_injection
|
||||||
|
from .utils import is_kwarg_injection
|
||||||
|
|
||||||
from .errors import Error
|
from .errors import Error
|
||||||
|
|
||||||
|
@ -77,11 +80,7 @@ def inject(*args, **kwargs):
|
||||||
:type injection: Injection
|
:type injection: Injection
|
||||||
:return: (callable) -> (callable)
|
:return: (callable) -> (callable)
|
||||||
"""
|
"""
|
||||||
injections = tuple(KwArg(name, value)
|
injections = _parse_kwargs_injections(args, kwargs)
|
||||||
for name, value in six.iteritems(kwargs))
|
|
||||||
if args:
|
|
||||||
injections += tuple(ensure_is_injection(injection)
|
|
||||||
for injection in args)
|
|
||||||
|
|
||||||
def decorator(callback_or_cls):
|
def decorator(callback_or_cls):
|
||||||
"""Dependency injection decorator."""
|
"""Dependency injection decorator."""
|
||||||
|
@ -107,10 +106,41 @@ def inject(*args, **kwargs):
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args, **kwargs):
|
||||||
"""Decorated with dependency injection callback."""
|
"""Decorated with dependency injection callback."""
|
||||||
return callback(*args,
|
return callback(*args,
|
||||||
**get_injectable_kwargs(kwargs,
|
**_get_injectable_kwargs(kwargs,
|
||||||
decorated.injections))
|
decorated.injections))
|
||||||
|
|
||||||
decorated.injections = injections
|
decorated.injections = injections
|
||||||
|
|
||||||
return decorated
|
return decorated
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_args_injections(args):
|
||||||
|
"""Parse positional argument injections according to current syntax."""
|
||||||
|
return tuple(Arg(arg) if not is_injection(arg) else arg
|
||||||
|
for arg in args
|
||||||
|
if not is_injection(arg) or is_arg_injection(arg))
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_kwargs_injections(args, kwargs):
|
||||||
|
"""Parse keyword argument injections according to current syntax."""
|
||||||
|
kwarg_injections = tuple(injection
|
||||||
|
for injection in args
|
||||||
|
if is_kwarg_injection(injection))
|
||||||
|
if kwargs:
|
||||||
|
kwarg_injections += tuple(KwArg(name, value)
|
||||||
|
for name, value in six.iteritems(kwargs))
|
||||||
|
return kwarg_injections
|
||||||
|
|
||||||
|
|
||||||
|
def _get_injectable_args(context_args, arg_injections):
|
||||||
|
"""Return tuple of positional arguments, patched with injections."""
|
||||||
|
return itertools.chain((arg.value for arg in arg_injections), context_args)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_injectable_kwargs(context_kwargs, kwarg_injections):
|
||||||
|
"""Return dictionary of keyword arguments, patched with injections."""
|
||||||
|
injectable_kwargs = dict((kwarg.name, kwarg.value)
|
||||||
|
for kwarg in kwarg_injections)
|
||||||
|
injectable_kwargs.update(context_kwargs)
|
||||||
|
return injectable_kwargs
|
||||||
|
|
|
@ -2,17 +2,14 @@
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from .injections import Arg
|
from .injections import _parse_args_injections
|
||||||
from .injections import KwArg
|
from .injections import _parse_kwargs_injections
|
||||||
|
from .injections import _get_injectable_args
|
||||||
|
from .injections import _get_injectable_kwargs
|
||||||
|
|
||||||
from .utils import ensure_is_provider
|
from .utils import ensure_is_provider
|
||||||
from .utils import is_injection
|
|
||||||
from .utils import is_arg_injection
|
|
||||||
from .utils import is_kwarg_injection
|
|
||||||
from .utils import is_attribute_injection
|
from .utils import is_attribute_injection
|
||||||
from .utils import is_method_injection
|
from .utils import is_method_injection
|
||||||
from .utils import get_injectable_args
|
|
||||||
from .utils import get_injectable_kwargs
|
|
||||||
from .utils import GLOBAL_LOCK
|
from .utils import GLOBAL_LOCK
|
||||||
|
|
||||||
from .errors import Error
|
from .errors import Error
|
||||||
|
@ -116,15 +113,8 @@ class Factory(Provider):
|
||||||
raise Error('Factory provider expects to get callable, ' +
|
raise Error('Factory provider expects to get callable, ' +
|
||||||
'got {0} instead'.format(str(provides)))
|
'got {0} instead'.format(str(provides)))
|
||||||
self.provides = provides
|
self.provides = provides
|
||||||
self.args = tuple(Arg(arg) if not is_injection(arg) else arg
|
self.args = _parse_args_injections(args)
|
||||||
for arg in args
|
self.kwargs = _parse_kwargs_injections(args, kwargs)
|
||||||
if not is_injection(arg) or is_arg_injection(arg))
|
|
||||||
self.kwargs = tuple(injection
|
|
||||||
for injection in args
|
|
||||||
if is_kwarg_injection(injection))
|
|
||||||
if kwargs:
|
|
||||||
self.kwargs += tuple(KwArg(name, value)
|
|
||||||
for name, value in six.iteritems(kwargs))
|
|
||||||
self.attributes = tuple(injection
|
self.attributes = tuple(injection
|
||||||
for injection in args
|
for injection in args
|
||||||
if is_attribute_injection(injection))
|
if is_attribute_injection(injection))
|
||||||
|
@ -135,8 +125,8 @@ class Factory(Provider):
|
||||||
|
|
||||||
def _provide(self, *args, **kwargs):
|
def _provide(self, *args, **kwargs):
|
||||||
"""Return provided instance."""
|
"""Return provided instance."""
|
||||||
instance = self.provides(*get_injectable_args(args, self.args),
|
instance = self.provides(*_get_injectable_args(args, self.args),
|
||||||
**get_injectable_kwargs(kwargs, self.kwargs))
|
**_get_injectable_kwargs(kwargs, self.kwargs))
|
||||||
for attribute in self.attributes:
|
for attribute in self.attributes:
|
||||||
setattr(instance, attribute.name, attribute.value)
|
setattr(instance, attribute.name, attribute.value)
|
||||||
for method in self.methods:
|
for method in self.methods:
|
||||||
|
@ -258,21 +248,21 @@ class Callable(Provider):
|
||||||
with some predefined dependency injections.
|
with some predefined dependency injections.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = ('callback', 'kwargs')
|
__slots__ = ('callback', 'args', 'kwargs')
|
||||||
|
|
||||||
def __init__(self, callback, **kwargs):
|
def __init__(self, callback, *args, **kwargs):
|
||||||
"""Initializer."""
|
"""Initializer."""
|
||||||
if not callable(callback):
|
if not callable(callback):
|
||||||
raise Error('Callable expected, got {0}'.format(str(callback)))
|
raise Error('Callable expected, got {0}'.format(str(callback)))
|
||||||
self.callback = callback
|
self.callback = callback
|
||||||
self.kwargs = tuple(KwArg(name, value)
|
self.args = _parse_args_injections(args)
|
||||||
for name, value in six.iteritems(kwargs))
|
self.kwargs = _parse_kwargs_injections(args, kwargs)
|
||||||
super(Callable, self).__init__()
|
super(Callable, self).__init__()
|
||||||
|
|
||||||
def _provide(self, *args, **kwargs):
|
def _provide(self, *args, **kwargs):
|
||||||
"""Return provided instance."""
|
"""Return provided instance."""
|
||||||
return self.callback(*args, **get_injectable_kwargs(kwargs,
|
return self.callback(*_get_injectable_args(args, self.args),
|
||||||
self.kwargs))
|
**_get_injectable_kwargs(kwargs, self.kwargs))
|
||||||
|
|
||||||
|
|
||||||
class Config(Provider):
|
class Config(Provider):
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
"""Utils module."""
|
"""Utils module."""
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
import itertools
|
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
@ -87,16 +86,3 @@ def ensure_is_catalog_bundle(instance):
|
||||||
raise Error('Expected catalog bundle instance, '
|
raise Error('Expected catalog bundle instance, '
|
||||||
'got {0}'.format(str(instance)))
|
'got {0}'.format(str(instance)))
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
|
|
||||||
def get_injectable_args(context_args, arg_injections):
|
|
||||||
"""Return tuple of positional args, patched with injections."""
|
|
||||||
return itertools.chain((arg.value for arg in arg_injections), context_args)
|
|
||||||
|
|
||||||
|
|
||||||
def get_injectable_kwargs(context_kwargs, kwarg_injections):
|
|
||||||
"""Return dictionary of keyword args, patched with injections."""
|
|
||||||
kwargs = dict((kwarg.name, kwarg.value)
|
|
||||||
for kwarg in kwarg_injections)
|
|
||||||
kwargs.update(context_kwargs)
|
|
||||||
return kwargs
|
|
||||||
|
|
|
@ -161,10 +161,6 @@ class InjectTests(unittest.TestCase):
|
||||||
self.assertIsInstance(b2, list)
|
self.assertIsInstance(b2, list)
|
||||||
self.assertIsNot(b1, b2)
|
self.assertIsNot(b1, b2)
|
||||||
|
|
||||||
def test_decorate_with_not_injection(self):
|
|
||||||
"""Test `inject()` decorator with not an injection instance."""
|
|
||||||
self.assertRaises(di.Error, di.inject, object)
|
|
||||||
|
|
||||||
def test_decorate_class_method(self):
|
def test_decorate_class_method(self):
|
||||||
"""Test `inject()` decorator with class method."""
|
"""Test `inject()` decorator with class method."""
|
||||||
class Test(object):
|
class Test(object):
|
||||||
|
|
|
@ -7,7 +7,8 @@ import dependency_injector as di
|
||||||
class Example(object):
|
class Example(object):
|
||||||
"""Example class for Factory provider tests."""
|
"""Example class for Factory provider tests."""
|
||||||
|
|
||||||
def __init__(self, init_arg1=None, init_arg2=None):
|
def __init__(self, init_arg1=None, init_arg2=None, init_arg3=None,
|
||||||
|
init_arg4=None):
|
||||||
"""Initializer.
|
"""Initializer.
|
||||||
|
|
||||||
:param init_arg1:
|
:param init_arg1:
|
||||||
|
@ -16,6 +17,8 @@ class Example(object):
|
||||||
"""
|
"""
|
||||||
self.init_arg1 = init_arg1
|
self.init_arg1 = init_arg1
|
||||||
self.init_arg2 = init_arg2
|
self.init_arg2 = init_arg2
|
||||||
|
self.init_arg3 = init_arg3
|
||||||
|
self.init_arg4 = init_arg4
|
||||||
|
|
||||||
self.attribute1 = None
|
self.attribute1 = None
|
||||||
self.attribute2 = None
|
self.attribute2 = None
|
||||||
|
@ -304,11 +307,13 @@ class FactoryTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_call_with_context_args(self):
|
def test_call_with_context_args(self):
|
||||||
"""Test creation of new instances with context args."""
|
"""Test creation of new instances with context args."""
|
||||||
provider = di.Factory(Example)
|
provider = di.Factory(Example, 11, 22)
|
||||||
instance = provider(11, 22)
|
instance = provider(33, 44)
|
||||||
|
|
||||||
self.assertEqual(instance.init_arg1, 11)
|
self.assertEqual(instance.init_arg1, 11)
|
||||||
self.assertEqual(instance.init_arg2, 22)
|
self.assertEqual(instance.init_arg2, 22)
|
||||||
|
self.assertEqual(instance.init_arg3, 33)
|
||||||
|
self.assertEqual(instance.init_arg4, 44)
|
||||||
|
|
||||||
def test_call_with_context_kwargs(self):
|
def test_call_with_context_kwargs(self):
|
||||||
"""Test creation of new instances with context kwargs."""
|
"""Test creation of new instances with context kwargs."""
|
||||||
|
@ -319,9 +324,19 @@ class FactoryTests(unittest.TestCase):
|
||||||
self.assertEqual(instance1.init_arg1, 1)
|
self.assertEqual(instance1.init_arg1, 1)
|
||||||
self.assertEqual(instance1.init_arg2, 22)
|
self.assertEqual(instance1.init_arg2, 22)
|
||||||
|
|
||||||
instance1 = provider(init_arg1=11, init_arg2=22)
|
instance2 = provider(init_arg1=11, init_arg2=22)
|
||||||
self.assertEqual(instance1.init_arg1, 11)
|
self.assertEqual(instance2.init_arg1, 11)
|
||||||
self.assertEqual(instance1.init_arg2, 22)
|
self.assertEqual(instance2.init_arg2, 22)
|
||||||
|
|
||||||
|
def test_call_with_context_args_and_kwargs(self):
|
||||||
|
"""Test creation of new instances with context args and kwargs."""
|
||||||
|
provider = di.Factory(Example, 11)
|
||||||
|
instance = provider(22, init_arg3=33, init_arg4=44)
|
||||||
|
|
||||||
|
self.assertEqual(instance.init_arg1, 11)
|
||||||
|
self.assertEqual(instance.init_arg2, 22)
|
||||||
|
self.assertEqual(instance.init_arg3, 33)
|
||||||
|
self.assertEqual(instance.init_arg4, 44)
|
||||||
|
|
||||||
def test_call_overridden(self):
|
def test_call_overridden(self):
|
||||||
"""Test creation of new instances on overridden provider."""
|
"""Test creation of new instances on overridden provider."""
|
||||||
|
@ -521,6 +536,16 @@ class SingletonTests(unittest.TestCase):
|
||||||
self.assertEqual(instance1.init_arg1, 1)
|
self.assertEqual(instance1.init_arg1, 1)
|
||||||
self.assertEqual(instance1.init_arg2, 22)
|
self.assertEqual(instance1.init_arg2, 22)
|
||||||
|
|
||||||
|
def test_call_with_context_args_and_kwargs(self):
|
||||||
|
"""Test getting of instances with context args and kwargs."""
|
||||||
|
provider = di.Singleton(Example, 11)
|
||||||
|
instance = provider(22, init_arg3=33, init_arg4=44)
|
||||||
|
|
||||||
|
self.assertEqual(instance.init_arg1, 11)
|
||||||
|
self.assertEqual(instance.init_arg2, 22)
|
||||||
|
self.assertEqual(instance.init_arg3, 33)
|
||||||
|
self.assertEqual(instance.init_arg4, 44)
|
||||||
|
|
||||||
def test_call_overridden(self):
|
def test_call_overridden(self):
|
||||||
"""Test getting of instances on overridden provider."""
|
"""Test getting of instances on overridden provider."""
|
||||||
provider = di.Singleton(Example)
|
provider = di.Singleton(Example)
|
||||||
|
@ -653,52 +678,82 @@ class StaticProvidersTests(unittest.TestCase):
|
||||||
class CallableTests(unittest.TestCase):
|
class CallableTests(unittest.TestCase):
|
||||||
"""Callable test cases."""
|
"""Callable test cases."""
|
||||||
|
|
||||||
def example(self, arg1, arg2, arg3):
|
def example(self, arg1, arg2, arg3, arg4):
|
||||||
"""Example callback."""
|
"""Example callback."""
|
||||||
return arg1, arg2, arg3
|
return arg1, arg2, arg3, arg4
|
||||||
|
|
||||||
def setUp(self):
|
def test_init_with_callable(self):
|
||||||
"""Set test cases environment up."""
|
"""Test creation of provider with a callable."""
|
||||||
self.provider = di.Callable(self.example,
|
self.assertTrue(di.Callable(self.example))
|
||||||
arg1='a1',
|
|
||||||
arg2='a2',
|
|
||||||
arg3='a3')
|
|
||||||
|
|
||||||
def test_init_with_not_callable(self):
|
def test_init_with_not_callable(self):
|
||||||
"""Test creation of provider with not callable."""
|
"""Test creation of provider with not a callable."""
|
||||||
self.assertRaises(di.Error, di.Callable, 123)
|
self.assertRaises(di.Error, di.Callable, 123)
|
||||||
|
|
||||||
def test_is_provider(self):
|
|
||||||
"""Test `is_provider` check."""
|
|
||||||
self.assertTrue(di.is_provider(self.provider))
|
|
||||||
|
|
||||||
def test_call(self):
|
def test_call(self):
|
||||||
"""Test provider call."""
|
"""Test call."""
|
||||||
self.assertEqual(self.provider(), ('a1', 'a2', 'a3'))
|
provider = di.Callable(lambda: True)
|
||||||
|
self.assertTrue(provider())
|
||||||
|
|
||||||
def test_call_with_args(self):
|
def test_call_with_positional_args(self):
|
||||||
"""Test provider call with kwargs priority."""
|
"""Test call with positional args.
|
||||||
|
|
||||||
|
New simplified syntax.
|
||||||
|
"""
|
||||||
|
provider = di.Callable(self.example, 1, 2, 3, 4)
|
||||||
|
self.assertTupleEqual(provider(), (1, 2, 3, 4))
|
||||||
|
|
||||||
|
def test_call_with_keyword_args(self):
|
||||||
|
"""Test call with keyword args.
|
||||||
|
|
||||||
|
New simplified syntax.
|
||||||
|
"""
|
||||||
|
provider = di.Callable(self.example, arg1=1, arg2=2, arg3=3, arg4=4)
|
||||||
|
self.assertTupleEqual(provider(), (1, 2, 3, 4))
|
||||||
|
|
||||||
|
def test_call_with_positional_and_keyword_args(self):
|
||||||
|
"""Test call with positional and keyword args.
|
||||||
|
|
||||||
|
Simplified syntax of positional and keyword arg injections.
|
||||||
|
"""
|
||||||
|
provider = di.Callable(self.example, 1, 2, arg3=3, arg4=4)
|
||||||
|
self.assertTupleEqual(provider(), (1, 2, 3, 4))
|
||||||
|
|
||||||
|
def test_call_with_positional_and_keyword_args_extended_syntax(self):
|
||||||
|
"""Test call with positional and keyword args.
|
||||||
|
|
||||||
|
Extended syntax of positional and keyword arg injections.
|
||||||
|
"""
|
||||||
provider = di.Callable(self.example,
|
provider = di.Callable(self.example,
|
||||||
arg3='a3')
|
di.Arg(1),
|
||||||
self.assertEqual(provider(1, 2), (1, 2, 'a3'))
|
di.Arg(2),
|
||||||
|
di.KwArg('arg3', 3),
|
||||||
|
di.KwArg('arg4', 4))
|
||||||
|
self.assertTupleEqual(provider(), (1, 2, 3, 4))
|
||||||
|
|
||||||
def test_call_with_kwargs_priority(self):
|
def test_call_with_context_args(self):
|
||||||
"""Test provider call with kwargs priority."""
|
"""Test call with context args."""
|
||||||
self.assertEqual(self.provider(arg1=1, arg3=3), (1, 'a2', 3))
|
provider = di.Callable(self.example, 1, 2)
|
||||||
|
self.assertTupleEqual(provider(3, 4), (1, 2, 3, 4))
|
||||||
|
|
||||||
|
def test_call_with_context_kwargs(self):
|
||||||
|
"""Test call with context kwargs."""
|
||||||
|
provider = di.Callable(self.example,
|
||||||
|
di.KwArg('arg1', 1))
|
||||||
|
self.assertTupleEqual(provider(arg2=2, arg3=3, arg4=4), (1, 2, 3, 4))
|
||||||
|
|
||||||
|
def test_call_with_context_args_and_kwargs(self):
|
||||||
|
"""Test call with context args and kwargs."""
|
||||||
|
provider = di.Callable(self.example, 1)
|
||||||
|
self.assertTupleEqual(provider(2, arg3=3, arg4=4), (1, 2, 3, 4))
|
||||||
|
|
||||||
def test_call_overridden(self):
|
def test_call_overridden(self):
|
||||||
"""Test overridden provider call."""
|
"""Test creation of new instances on overridden provider."""
|
||||||
overriding_provider1 = di.Value((1, 2, 3))
|
provider = di.Callable(self.example)
|
||||||
overriding_provider2 = di.Value((3, 2, 1))
|
provider.override(di.Value((4, 3, 2, 1)))
|
||||||
|
provider.override(di.Value((1, 2, 3, 4)))
|
||||||
|
|
||||||
self.provider.override(overriding_provider1)
|
self.assertTupleEqual(provider(), (1, 2, 3, 4))
|
||||||
self.provider.override(overriding_provider2)
|
|
||||||
|
|
||||||
result1 = self.provider()
|
|
||||||
result2 = self.provider()
|
|
||||||
|
|
||||||
self.assertEqual(result1, (3, 2, 1))
|
|
||||||
self.assertEqual(result2, (3, 2, 1))
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigTests(unittest.TestCase):
|
class ConfigTests(unittest.TestCase):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user