diff --git a/dependency_injector/providers.py b/dependency_injector/providers.py index 3a2c7a2b..92929ff4 100644 --- a/dependency_injector/providers.py +++ b/dependency_injector/providers.py @@ -1,6 +1,8 @@ """Providers module.""" -from six import class_types +import six + +from .injections import KwArg from .utils import ensure_is_provider from .utils import is_kwarg_injection @@ -100,7 +102,7 @@ class Factory(Provider): __slots__ = ('_provides', '_kwargs', '_attributes', '_methods') - def __init__(self, provides, *injections): + def __init__(self, provides, *injections, **kwargs): """Initializer.""" if not callable(provides): raise Error('Factory provider expects to get callable, ' + @@ -109,6 +111,9 @@ class Factory(Provider): self._kwargs = tuple((injection for injection in injections if is_kwarg_injection(injection))) + if kwargs: + self._kwargs += tuple((KwArg(name, value) + for name, value in six.iteritems(kwargs))) self._attributes = tuple((injection for injection in injections if is_attribute_injection(injection))) @@ -149,10 +154,10 @@ class Singleton(Provider): __slots__ = ('_instance', '_factory') - def __init__(self, provides, *injections): + def __init__(self, provides, *injections, **kwargs): """Initializer.""" self._instance = None - self._factory = Factory(provides, *injections) + self._factory = Factory(provides, *injections, **kwargs) super(Singleton, self).__init__() def _provide(self, *args, **kwargs): @@ -178,7 +183,7 @@ class ExternalDependency(Provider): def __init__(self, instance_of): """Initializer.""" - if not isinstance(instance_of, class_types): + if not isinstance(instance_of, six.class_types): raise Error('ExternalDependency provider expects to get class, ' + 'got {0} instead'.format(str(instance_of))) self._instance_of = instance_of diff --git a/examples/concept.py b/examples/concept.py index aeeb6e74..0e79dd42 100644 --- a/examples/concept.py +++ b/examples/concept.py @@ -28,17 +28,16 @@ class Catalog(di.AbstractCatalog): """Catalog of providers.""" database = di.Singleton(sqlite3.Connection, - di.KwArg('database', ':memory:'), - di.Attribute('row_factory', sqlite3.Row)) + database=':memory:') """:type: (di.Provider) -> sqlite3.Connection""" object_a_factory = di.Factory(ObjectA, - di.KwArg('db', database)) + db=database) """:type: (di.Provider) -> ObjectA""" object_b_factory = di.Factory(ObjectB, - di.KwArg('a', object_a_factory), - di.KwArg('db', database)) + a=object_a_factory, + db=database) """:type: (di.Provider) -> ObjectB""" diff --git a/tests/test_providers.py b/tests/test_providers.py index 485988cd..f3d0cd23 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -223,6 +223,28 @@ class FactoryTests(unittest.TestCase): self.assertIsInstance(instance1, self.Example) self.assertIsInstance(instance2, self.Example) + def test_call_with_init_args_simplified_syntax(self): + """Test creation of new instances with init args injections. + + Simplified syntax. + """ + provider = Factory(self.Example, + init_arg1='i1', + init_arg2='i2') + + instance1 = provider() + instance2 = provider() + + self.assertEqual(instance1.init_arg1, 'i1') + self.assertEqual(instance1.init_arg2, 'i2') + + self.assertEqual(instance2.init_arg1, 'i1') + self.assertEqual(instance2.init_arg2, 'i2') + + self.assertIsNot(instance1, instance2) + self.assertIsInstance(instance1, self.Example) + self.assertIsInstance(instance2, self.Example) + def test_call_with_attributes(self): """Test creation of new instances with attribute injections.""" provider = Factory(self.Example,