Add validation of provided type for Factory provider

This commit is contained in:
Roman Mogilatov 2015-12-11 16:01:07 +02:00
parent ee37558946
commit 10e76f65d7
3 changed files with 73 additions and 5 deletions

View File

@ -246,6 +246,16 @@ class Factory(Provider):
some_object = factory()
"""
provided_type = None
"""Provided type.
If provided type is defined, :py:class:`Factory` checks that
:py:attr:`Factory.provides` is subclass of
:py:attr:`Factory.provided_type`.
:type: type | None
"""
__slots__ = ('provides', 'args', 'kwargs', 'attributes', 'methods')
def __init__(self, provides, *args, **kwargs):
@ -261,10 +271,7 @@ class Factory(Provider):
:param kwargs: Dictionary of injections.
:type kwargs: dict
"""
if not callable(provides):
raise Error('Factory provider expects to get callable, ' +
'got {0} instead'.format(str(provides)))
self.provides = provides
self.provides = self._ensure_provides_type(provides)
"""Class or other callable that provides object for creation.
:type: type | callable
@ -328,6 +335,31 @@ class Factory(Provider):
return instance
def _ensure_provides_type(self, provides):
"""Check if provided type is valid type for this factory.
:param provides: Factory provided type
:type provides: type
:raise: :py:exc:`dependency_injector.errors.Error` if ``provides``
is not callable
:raise: :py:exc:`dependency_injector.errors.Error` if ``provides``
doesn't meet factory provided type
:return: validated ``provides``
:rtype: type
"""
if not callable(provides):
raise Error('Factory provider expects to get callable, ' +
'got {0} instead'.format(str(provides)))
if (self.__class__.provided_type and
not issubclass(provides, self.__class__.provided_type)):
raise Error('{0} can provide only {1} instances'.format(
'.'.join((self.__class__.__module__,
self.__class__.__name__)),
self.__class__.provided_type))
return provides
def __str__(self):
"""Return string representation of provider.

View File

@ -9,7 +9,8 @@ follows `Semantic versioning`_
Development version
-------------------
- No features.
- Add possibility to validate ``Factory`` provided type on ``Factory``
initialization.
1.11.2
------

View File

@ -176,6 +176,41 @@ class FactoryTests(unittest.TestCase):
"""Test creation of provider with a callable."""
self.assertTrue(providers.Factory(credits))
def test_init_with_valid_provided_type(self):
"""Test creation with not valid provided type."""
class ExampleFactory(providers.Factory):
"""Example factory."""
provided_type = Example
example_factory = ExampleFactory(Example, 1, 2)
self.assertIsInstance(example_factory(), Example)
def test_init_with_valid_provided_subtype(self):
"""Test creation with not valid provided type."""
class ExampleFactory(providers.Factory):
"""Example factory."""
provided_type = Example
class NewExampe(Example):
"""Example class subclass."""
example_factory = ExampleFactory(NewExampe, 1, 2)
self.assertIsInstance(example_factory(), NewExampe)
def test_init_with_invalid_provided_type(self):
"""Test creation with not valid provided type."""
class ExampleFactory(providers.Factory):
"""Example factory."""
provided_type = Example
with self.assertRaises(errors.Error):
ExampleFactory(list)
def test_init_with_not_callable(self):
"""Test creation of provider with not a callable."""
self.assertRaises(errors.Error, providers.Factory, 123)