diff --git a/dependency_injector/providers.py b/dependency_injector/providers.py index f53ac14d..c767e42b 100644 --- a/dependency_injector/providers.py +++ b/dependency_injector/providers.py @@ -246,6 +246,16 @@ class Factory(Provider): some_object = factory() """ + provided_type = None + """Provided type. + + If provided type is defined, :py:class:`Factory` checks that + :py:attr:`Factory.provides` is subclass of + :py:attr:`Factory.provided_type`. + + :type: type | None + """ + __slots__ = ('provides', 'args', 'kwargs', 'attributes', 'methods') def __init__(self, provides, *args, **kwargs): @@ -261,10 +271,7 @@ class Factory(Provider): :param kwargs: Dictionary of injections. :type kwargs: dict """ - if not callable(provides): - raise Error('Factory provider expects to get callable, ' + - 'got {0} instead'.format(str(provides))) - self.provides = provides + self.provides = self._ensure_provides_type(provides) """Class or other callable that provides object for creation. :type: type | callable @@ -328,6 +335,31 @@ class Factory(Provider): return instance + def _ensure_provides_type(self, provides): + """Check if provided type is valid type for this factory. + + :param provides: Factory provided type + :type provides: type + + :raise: :py:exc:`dependency_injector.errors.Error` if ``provides`` + is not callable + :raise: :py:exc:`dependency_injector.errors.Error` if ``provides`` + doesn't meet factory provided type + + :return: validated ``provides`` + :rtype: type + """ + if not callable(provides): + raise Error('Factory provider expects to get callable, ' + + 'got {0} instead'.format(str(provides))) + if (self.__class__.provided_type and + not issubclass(provides, self.__class__.provided_type)): + raise Error('{0} can provide only {1} instances'.format( + '.'.join((self.__class__.__module__, + self.__class__.__name__)), + self.__class__.provided_type)) + return provides + def __str__(self): """Return string representation of provider. diff --git a/docs/main/changelog.rst b/docs/main/changelog.rst index 75a24c5b..b791b1f6 100644 --- a/docs/main/changelog.rst +++ b/docs/main/changelog.rst @@ -9,7 +9,8 @@ follows `Semantic versioning`_ Development version ------------------- -- No features. +- Add possibility to validate ``Factory`` provided type on ``Factory`` + initialization. 1.11.2 ------ diff --git a/tests/test_providers.py b/tests/test_providers.py index 98cf9f20..3621958c 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -176,6 +176,41 @@ class FactoryTests(unittest.TestCase): """Test creation of provider with a callable.""" self.assertTrue(providers.Factory(credits)) + def test_init_with_valid_provided_type(self): + """Test creation with not valid provided type.""" + class ExampleFactory(providers.Factory): + """Example factory.""" + + provided_type = Example + + example_factory = ExampleFactory(Example, 1, 2) + + self.assertIsInstance(example_factory(), Example) + + def test_init_with_valid_provided_subtype(self): + """Test creation with not valid provided type.""" + class ExampleFactory(providers.Factory): + """Example factory.""" + + provided_type = Example + + class NewExampe(Example): + """Example class subclass.""" + + example_factory = ExampleFactory(NewExampe, 1, 2) + + self.assertIsInstance(example_factory(), NewExampe) + + def test_init_with_invalid_provided_type(self): + """Test creation with not valid provided type.""" + class ExampleFactory(providers.Factory): + """Example factory.""" + + provided_type = Example + + with self.assertRaises(errors.Error): + ExampleFactory(list) + def test_init_with_not_callable(self): """Test creation of provider with not a callable.""" self.assertRaises(errors.Error, providers.Factory, 123)