Refactor asserts in provider tests

This commit is contained in:
Roman Mogylatov 2021-10-15 14:29:37 -04:00
parent 33784d13a3
commit 42031a1b71
19 changed files with 1686 additions and 1920 deletions

View File

@ -15,6 +15,7 @@ Develop
- Add support of ``with`` statement for ``container.override_providers()`` method. - Add support of ``with`` statement for ``container.override_providers()`` method.
- Drop support of Python 3.4. There are no immediate breaking changes, but Dependency Injector - Drop support of Python 3.4. There are no immediate breaking changes, but Dependency Injector
will no longer be tested on Python 3.4 and any bugs will not be fixed. will no longer be tested on Python 3.4 and any bugs will not be fixed.
- Fix ``Dependency.is_defined`` attribute to always return boolean value.
- Update documentation and fix typos. - Update documentation and fix typos.
4.36.2 4.36.2

View File

@ -17675,7 +17675,7 @@ static PyObject *__pyx_pf_19dependency_injector_9providers_10Dependency_14set_de
* @property * @property
* def is_defined(self): # <<<<<<<<<<<<<< * def is_defined(self): # <<<<<<<<<<<<<<
* """Return True if dependency is defined.""" * """Return True if dependency is defined."""
* return self.__last_overriding or self.__default * return self.__last_overriding is not None or self.__default is not None
*/ */
/* Python wrapper */ /* Python wrapper */
@ -17696,6 +17696,7 @@ static PyObject *__pyx_pf_19dependency_injector_9providers_10Dependency_10is_def
__Pyx_RefNannyDeclarations __Pyx_RefNannyDeclarations
PyObject *__pyx_t_1 = NULL; PyObject *__pyx_t_1 = NULL;
int __pyx_t_2; int __pyx_t_2;
PyObject *__pyx_t_3 = NULL;
int __pyx_lineno = 0; int __pyx_lineno = 0;
const char *__pyx_filename = NULL; const char *__pyx_filename = NULL;
int __pyx_clineno = 0; int __pyx_clineno = 0;
@ -17704,20 +17705,25 @@ static PyObject *__pyx_pf_19dependency_injector_9providers_10Dependency_10is_def
/* "dependency_injector/providers.pyx":774 /* "dependency_injector/providers.pyx":774
* def is_defined(self): * def is_defined(self):
* """Return True if dependency is defined.""" * """Return True if dependency is defined."""
* return self.__last_overriding or self.__default # <<<<<<<<<<<<<< * return self.__last_overriding is not None or self.__default is not None # <<<<<<<<<<<<<<
* *
* def provided_by(self, provider): * def provided_by(self, provider):
*/ */
__Pyx_XDECREF(__pyx_r); __Pyx_XDECREF(__pyx_r);
__pyx_t_2 = __Pyx_PyObject_IsTrue(((PyObject *)__pyx_v_self->__pyx_base.__pyx___last_overriding)); if (unlikely(__pyx_t_2 < 0)) __PYX_ERR(1, 774, __pyx_L1_error) __pyx_t_2 = (((PyObject *)__pyx_v_self->__pyx_base.__pyx___last_overriding) != Py_None);
if (!__pyx_t_2) { if (!__pyx_t_2) {
} else { } else {
__Pyx_INCREF(((PyObject *)__pyx_v_self->__pyx_base.__pyx___last_overriding)); __pyx_t_3 = __Pyx_PyBool_FromLong(__pyx_t_2); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 774, __pyx_L1_error)
__pyx_t_1 = ((PyObject *)__pyx_v_self->__pyx_base.__pyx___last_overriding); __Pyx_GOTREF(__pyx_t_3);
__pyx_t_1 = __pyx_t_3;
__pyx_t_3 = 0;
goto __pyx_L3_bool_binop_done; goto __pyx_L3_bool_binop_done;
} }
__Pyx_INCREF(__pyx_v_self->__pyx___default); __pyx_t_2 = (__pyx_v_self->__pyx___default != Py_None);
__pyx_t_1 = __pyx_v_self->__pyx___default; __pyx_t_3 = __Pyx_PyBool_FromLong(__pyx_t_2); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 774, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_3);
__pyx_t_1 = __pyx_t_3;
__pyx_t_3 = 0;
__pyx_L3_bool_binop_done:; __pyx_L3_bool_binop_done:;
__pyx_r = __pyx_t_1; __pyx_r = __pyx_t_1;
__pyx_t_1 = 0; __pyx_t_1 = 0;
@ -17728,12 +17734,13 @@ static PyObject *__pyx_pf_19dependency_injector_9providers_10Dependency_10is_def
* @property * @property
* def is_defined(self): # <<<<<<<<<<<<<< * def is_defined(self): # <<<<<<<<<<<<<<
* """Return True if dependency is defined.""" * """Return True if dependency is defined."""
* return self.__last_overriding or self.__default * return self.__last_overriding is not None or self.__default is not None
*/ */
/* function exit code */ /* function exit code */
__pyx_L1_error:; __pyx_L1_error:;
__Pyx_XDECREF(__pyx_t_1); __Pyx_XDECREF(__pyx_t_1);
__Pyx_XDECREF(__pyx_t_3);
__Pyx_AddTraceback("dependency_injector.providers.Dependency.is_defined.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); __Pyx_AddTraceback("dependency_injector.providers.Dependency.is_defined.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename);
__pyx_r = NULL; __pyx_r = NULL;
__pyx_L0:; __pyx_L0:;
@ -17743,7 +17750,7 @@ static PyObject *__pyx_pf_19dependency_injector_9providers_10Dependency_10is_def
} }
/* "dependency_injector/providers.pyx":776 /* "dependency_injector/providers.pyx":776
* return self.__last_overriding or self.__default * return self.__last_overriding is not None or self.__default is not None
* *
* def provided_by(self, provider): # <<<<<<<<<<<<<< * def provided_by(self, provider): # <<<<<<<<<<<<<<
* """Set external dependency provider. * """Set external dependency provider.
@ -17805,7 +17812,7 @@ static PyObject *__pyx_pf_19dependency_injector_9providers_10Dependency_16provid
goto __pyx_L0; goto __pyx_L0;
/* "dependency_injector/providers.pyx":776 /* "dependency_injector/providers.pyx":776
* return self.__last_overriding or self.__default * return self.__last_overriding is not None or self.__default is not None
* *
* def provided_by(self, provider): # <<<<<<<<<<<<<< * def provided_by(self, provider): # <<<<<<<<<<<<<<
* """Set external dependency provider. * """Set external dependency provider.

View File

@ -771,7 +771,7 @@ cdef class Dependency(Provider):
@property @property
def is_defined(self): def is_defined(self):
"""Return True if dependency is defined.""" """Return True if dependency is defined."""
return self.__last_overriding or self.__default return self.__last_overriding is not None or self.__default is not None
def provided_by(self, provider): def provided_by(self, provider):
"""Set external dependency provider. """Set external dependency provider.

View File

@ -1,6 +1,7 @@
import sys import sys
from dependency_injector import providers, errors from dependency_injector import providers, errors
from pytest import raises
class Example(object): class Example(object):
@ -21,23 +22,21 @@ class _BaseSingletonTestCase(object):
singleton_cls = None singleton_cls = None
def test_is_provider(self): def test_is_provider(self):
self.assertTrue(providers.is_provider(self.singleton_cls(Example))) assert providers.is_provider(self.singleton_cls(Example)) is True
def test_init_with_callable(self):
self.assertTrue(self.singleton_cls(credits))
def test_init_with_not_callable(self): def test_init_with_not_callable(self):
self.assertRaises(errors.Error, self.singleton_cls, 123) with raises(errors.Error):
self.singleton_cls(123)
def test_init_optional_provides(self): def test_init_optional_provides(self):
provider = self.singleton_cls() provider = self.singleton_cls()
provider.set_provides(object) provider.set_provides(object)
self.assertIs(provider.provides, object) assert provider.provides is object
self.assertIsInstance(provider(), object) assert isinstance(provider(), object)
def test_set_provides_returns_self(self): def test_set_provides_returns_self(self):
provider = self.singleton_cls() provider = self.singleton_cls()
self.assertIs(provider.set_provides(object), provider) assert provider.set_provides(object) is provider
def test_init_with_valid_provided_type(self): def test_init_with_valid_provided_type(self):
class ExampleProvider(self.singleton_cls): class ExampleProvider(self.singleton_cls):
@ -45,7 +44,7 @@ class _BaseSingletonTestCase(object):
example_provider = ExampleProvider(Example, 1, 2) example_provider = ExampleProvider(Example, 1, 2)
self.assertIsInstance(example_provider(), Example) assert isinstance(example_provider(), Example)
def test_init_with_valid_provided_subtype(self): def test_init_with_valid_provided_subtype(self):
class ExampleProvider(self.singleton_cls): class ExampleProvider(self.singleton_cls):
@ -56,18 +55,18 @@ class _BaseSingletonTestCase(object):
example_provider = ExampleProvider(NewExampe, 1, 2) example_provider = ExampleProvider(NewExampe, 1, 2)
self.assertIsInstance(example_provider(), NewExampe) assert isinstance(example_provider(), NewExampe)
def test_init_with_invalid_provided_type(self): def test_init_with_invalid_provided_type(self):
class ExampleProvider(self.singleton_cls): class ExampleProvider(self.singleton_cls):
provided_type = Example provided_type = Example
with self.assertRaises(errors.Error): with raises(errors.Error):
ExampleProvider(list) ExampleProvider(list)
def test_provided_instance_provider(self): def test_provided_instance_provider(self):
provider = providers.Singleton(Example) provider = providers.Singleton(Example)
self.assertIsInstance(provider.provided, providers.ProvidedInstance) assert isinstance(provider.provided, providers.ProvidedInstance)
def test_call(self): def test_call(self):
provider = self.singleton_cls(Example) provider = self.singleton_cls(Example)
@ -75,9 +74,9 @@ class _BaseSingletonTestCase(object):
instance1 = provider() instance1 = provider()
instance2 = provider() instance2 = provider()
self.assertIs(instance1, instance2) assert instance1 is instance2
self.assertIsInstance(instance1, Example) assert isinstance(instance1, Example)
self.assertIsInstance(instance2, Example) assert isinstance(instance2, Example)
def test_call_with_init_positional_args(self): def test_call_with_init_positional_args(self):
provider = self.singleton_cls(Example, "i1", "i2") provider = self.singleton_cls(Example, "i1", "i2")
@ -85,15 +84,15 @@ class _BaseSingletonTestCase(object):
instance1 = provider() instance1 = provider()
instance2 = provider() instance2 = provider()
self.assertEqual(instance1.init_arg1, "i1") assert instance1.init_arg1 == "i1"
self.assertEqual(instance1.init_arg2, "i2") assert instance1.init_arg2 == "i2"
self.assertEqual(instance2.init_arg1, "i1") assert instance2.init_arg1 == "i1"
self.assertEqual(instance2.init_arg2, "i2") assert instance2.init_arg2 == "i2"
self.assertIs(instance1, instance2) assert instance1 is instance2
self.assertIsInstance(instance1, Example) assert isinstance(instance1, Example)
self.assertIsInstance(instance2, Example) assert isinstance(instance2, Example)
def test_call_with_init_keyword_args(self): def test_call_with_init_keyword_args(self):
provider = self.singleton_cls(Example, init_arg1="i1", init_arg2="i2") provider = self.singleton_cls(Example, init_arg1="i1", init_arg2="i2")
@ -101,15 +100,15 @@ class _BaseSingletonTestCase(object):
instance1 = provider() instance1 = provider()
instance2 = provider() instance2 = provider()
self.assertEqual(instance1.init_arg1, "i1") assert instance1.init_arg1 == "i1"
self.assertEqual(instance1.init_arg2, "i2") assert instance1.init_arg2 == "i2"
self.assertEqual(instance2.init_arg1, "i1") assert instance2.init_arg1 == "i1"
self.assertEqual(instance2.init_arg2, "i2") assert instance2.init_arg2 == "i2"
self.assertIs(instance1, instance2) assert instance1 is instance2
self.assertIsInstance(instance1, Example) assert isinstance(instance1, Example)
self.assertIsInstance(instance2, Example) assert isinstance(instance2, Example)
def test_call_with_init_positional_and_keyword_args(self): def test_call_with_init_positional_and_keyword_args(self):
provider = self.singleton_cls(Example, "i1", init_arg2="i2") provider = self.singleton_cls(Example, "i1", init_arg2="i2")
@ -117,15 +116,15 @@ class _BaseSingletonTestCase(object):
instance1 = provider() instance1 = provider()
instance2 = provider() instance2 = provider()
self.assertEqual(instance1.init_arg1, "i1") assert instance1.init_arg1 == "i1"
self.assertEqual(instance1.init_arg2, "i2") assert instance1.init_arg2 == "i2"
self.assertEqual(instance2.init_arg1, "i1") assert instance2.init_arg1 == "i1"
self.assertEqual(instance2.init_arg2, "i2") assert instance2.init_arg2 == "i2"
self.assertIs(instance1, instance2) assert instance1 is instance2
self.assertIsInstance(instance1, Example) assert isinstance(instance1, Example)
self.assertIsInstance(instance2, Example) assert isinstance(instance2, Example)
def test_call_with_attributes(self): def test_call_with_attributes(self):
provider = self.singleton_cls(Example) provider = self.singleton_cls(Example)
@ -134,45 +133,45 @@ class _BaseSingletonTestCase(object):
instance1 = provider() instance1 = provider()
instance2 = provider() instance2 = provider()
self.assertEqual(instance1.attribute1, "a1") assert instance1.attribute1 == "a1"
self.assertEqual(instance1.attribute2, "a2") assert instance1.attribute2 == "a2"
self.assertEqual(instance2.attribute1, "a1") assert instance2.attribute1 == "a1"
self.assertEqual(instance2.attribute2, "a2") assert instance2.attribute2 == "a2"
self.assertIs(instance1, instance2) assert instance1 is instance2
self.assertIsInstance(instance1, Example) assert isinstance(instance1, Example)
self.assertIsInstance(instance2, Example) assert isinstance(instance2, Example)
def test_call_with_context_args(self): def test_call_with_context_args(self):
provider = self.singleton_cls(Example) provider = self.singleton_cls(Example)
instance = provider(11, 22) instance = provider(11, 22)
self.assertEqual(instance.init_arg1, 11) assert instance.init_arg1 == 11
self.assertEqual(instance.init_arg2, 22) assert instance.init_arg2 == 22
def test_call_with_context_kwargs(self): def test_call_with_context_kwargs(self):
provider = self.singleton_cls(Example, init_arg1=1) provider = self.singleton_cls(Example, init_arg1=1)
instance1 = provider(init_arg2=22) instance1 = provider(init_arg2=22)
self.assertEqual(instance1.init_arg1, 1) assert instance1.init_arg1 == 1
self.assertEqual(instance1.init_arg2, 22) assert instance1.init_arg2 == 22
# Instance is created earlier # Instance is created earlier
instance1 = provider(init_arg1=11, init_arg2=22) instance1 = provider(init_arg1=11, init_arg2=22)
self.assertEqual(instance1.init_arg1, 1) assert instance1.init_arg1 == 1
self.assertEqual(instance1.init_arg2, 22) assert instance1.init_arg2 == 22
def test_call_with_context_args_and_kwargs(self): def test_call_with_context_args_and_kwargs(self):
provider = self.singleton_cls(Example, 11) provider = self.singleton_cls(Example, 11)
instance = provider(22, init_arg3=33, init_arg4=44) instance = provider(22, init_arg3=33, init_arg4=44)
self.assertEqual(instance.init_arg1, 11) assert instance.init_arg1 == 11
self.assertEqual(instance.init_arg2, 22) assert instance.init_arg2 == 22
self.assertEqual(instance.init_arg3, 33) assert instance.init_arg3 == 33
self.assertEqual(instance.init_arg4, 44) assert instance.init_arg4 == 44
def test_fluent_interface(self): def test_fluent_interface(self):
provider = self.singleton_cls(Example) \ provider = self.singleton_cls(Example) \
@ -182,48 +181,48 @@ class _BaseSingletonTestCase(object):
instance = provider() instance = provider()
self.assertEqual(instance.init_arg1, 1) assert instance.init_arg1 == 1
self.assertEqual(instance.init_arg2, 2) assert instance.init_arg2 == 2
self.assertEqual(instance.init_arg3, 3) assert instance.init_arg3 == 3
self.assertEqual(instance.init_arg4, 4) assert instance.init_arg4 == 4
self.assertEqual(instance.attribute1, 5) assert instance.attribute1 == 5
self.assertEqual(instance.attribute2, 6) assert instance.attribute2 == 6
def test_set_args(self): def test_set_args(self):
provider = self.singleton_cls(Example) \ provider = self.singleton_cls(Example) \
.add_args(1, 2) \ .add_args(1, 2) \
.set_args(3, 4) .set_args(3, 4)
self.assertEqual(provider.args, (3, 4)) assert provider.args == (3, 4)
def test_set_kwargs(self): def test_set_kwargs(self):
provider = self.singleton_cls(Example) \ provider = self.singleton_cls(Example) \
.add_kwargs(init_arg3=3, init_arg4=4) \ .add_kwargs(init_arg3=3, init_arg4=4) \
.set_kwargs(init_arg3=4, init_arg4=5) .set_kwargs(init_arg3=4, init_arg4=5)
self.assertEqual(provider.kwargs, dict(init_arg3=4, init_arg4=5)) assert provider.kwargs == dict(init_arg3=4, init_arg4=5)
def test_set_attributes(self): def test_set_attributes(self):
provider = self.singleton_cls(Example) \ provider = self.singleton_cls(Example) \
.add_attributes(attribute1=5, attribute2=6) \ .add_attributes(attribute1=5, attribute2=6) \
.set_attributes(attribute1=6, attribute2=7) .set_attributes(attribute1=6, attribute2=7)
self.assertEqual(provider.attributes, dict(attribute1=6, attribute2=7)) assert provider.attributes == dict(attribute1=6, attribute2=7)
def test_clear_args(self): def test_clear_args(self):
provider = self.singleton_cls(Example) \ provider = self.singleton_cls(Example) \
.add_args(1, 2) \ .add_args(1, 2) \
.clear_args() .clear_args()
self.assertEqual(provider.args, tuple()) assert provider.args == tuple()
def test_clear_kwargs(self): def test_clear_kwargs(self):
provider = self.singleton_cls(Example) \ provider = self.singleton_cls(Example) \
.add_kwargs(init_arg3=3, init_arg4=4) \ .add_kwargs(init_arg3=3, init_arg4=4) \
.clear_kwargs() .clear_kwargs()
self.assertEqual(provider.kwargs, dict()) assert provider.kwargs == dict()
def test_clear_attributes(self): def test_clear_attributes(self):
provider = self.singleton_cls(Example) \ provider = self.singleton_cls(Example) \
.add_attributes(attribute1=5, attribute2=6) \ .add_attributes(attribute1=5, attribute2=6) \
.clear_attributes() .clear_attributes()
self.assertEqual(provider.attributes, dict()) assert provider.attributes == dict()
def test_call_overridden(self): def test_call_overridden(self):
provider = self.singleton_cls(Example) provider = self.singleton_cls(Example)
@ -236,27 +235,26 @@ class _BaseSingletonTestCase(object):
instance1 = provider() instance1 = provider()
instance2 = provider() instance2 = provider()
self.assertIs(instance1, instance2) assert instance1 is instance2
self.assertIsInstance(instance1, list) assert isinstance(instance1, list)
self.assertIsInstance(instance2, list) assert isinstance(instance2, list)
def test_deepcopy(self): def test_deepcopy(self):
provider = self.singleton_cls(Example) provider = self.singleton_cls(Example)
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertIs(provider.cls, provider_copy.cls) assert provider.cls is provider_copy.cls
self.assertIsInstance(provider, self.singleton_cls) assert isinstance(provider, self.singleton_cls)
def test_deepcopy_from_memo(self): def test_deepcopy_from_memo(self):
provider = self.singleton_cls(Example) provider = self.singleton_cls(Example)
provider_copy_memo = self.singleton_cls(Example) provider_copy_memo = self.singleton_cls(Example)
provider_copy = providers.deepcopy( provider_copy = providers.deepcopy(provider, memo={id(provider): provider_copy_memo})
provider, memo={id(provider): provider_copy_memo})
self.assertIs(provider_copy, provider_copy_memo) assert provider_copy is provider_copy_memo
def test_deepcopy_args(self): def test_deepcopy_args(self):
provider = self.singleton_cls(Example) provider = self.singleton_cls(Example)
@ -269,13 +267,13 @@ class _BaseSingletonTestCase(object):
dependent_provider_copy1 = provider_copy.args[0] dependent_provider_copy1 = provider_copy.args[0]
dependent_provider_copy2 = provider_copy.args[1] dependent_provider_copy2 = provider_copy.args[1]
self.assertNotEqual(provider.args, provider_copy.args) assert provider.args != provider_copy.args
self.assertIs(dependent_provider1.cls, dependent_provider_copy1.cls) assert dependent_provider1.cls is dependent_provider_copy1.cls
self.assertIsNot(dependent_provider1, dependent_provider_copy1) assert dependent_provider1 is not dependent_provider_copy1
self.assertIs(dependent_provider2.cls, dependent_provider_copy2.cls) assert dependent_provider2.cls is dependent_provider_copy2.cls
self.assertIsNot(dependent_provider2, dependent_provider_copy2) assert dependent_provider2 is not dependent_provider_copy2
def test_deepcopy_kwargs(self): def test_deepcopy_kwargs(self):
provider = self.singleton_cls(Example) provider = self.singleton_cls(Example)
@ -288,13 +286,13 @@ class _BaseSingletonTestCase(object):
dependent_provider_copy1 = provider_copy.kwargs["a1"] dependent_provider_copy1 = provider_copy.kwargs["a1"]
dependent_provider_copy2 = provider_copy.kwargs["a2"] dependent_provider_copy2 = provider_copy.kwargs["a2"]
self.assertNotEqual(provider.kwargs, provider_copy.kwargs) assert provider.kwargs != provider_copy.kwargs
self.assertIs(dependent_provider1.cls, dependent_provider_copy1.cls) assert dependent_provider1.cls is dependent_provider_copy1.cls
self.assertIsNot(dependent_provider1, dependent_provider_copy1) assert dependent_provider1 is not dependent_provider_copy1
self.assertIs(dependent_provider2.cls, dependent_provider_copy2.cls) assert dependent_provider2.cls is dependent_provider_copy2.cls
self.assertIsNot(dependent_provider2, dependent_provider_copy2) assert dependent_provider2 is not dependent_provider_copy2
def test_deepcopy_attributes(self): def test_deepcopy_attributes(self):
provider = self.singleton_cls(Example) provider = self.singleton_cls(Example)
@ -307,13 +305,13 @@ class _BaseSingletonTestCase(object):
dependent_provider_copy1 = provider_copy.attributes["a1"] dependent_provider_copy1 = provider_copy.attributes["a1"]
dependent_provider_copy2 = provider_copy.attributes["a2"] dependent_provider_copy2 = provider_copy.attributes["a2"]
self.assertNotEqual(provider.attributes, provider_copy.attributes) assert provider.attributes != provider_copy.attributes
self.assertIs(dependent_provider1.cls, dependent_provider_copy1.cls) assert dependent_provider1.cls is dependent_provider_copy1.cls
self.assertIsNot(dependent_provider1, dependent_provider_copy1) assert dependent_provider1 is not dependent_provider_copy1
self.assertIs(dependent_provider2.cls, dependent_provider_copy2.cls) assert dependent_provider2.cls is dependent_provider_copy2.cls
self.assertIsNot(dependent_provider2, dependent_provider_copy2) assert dependent_provider2 is not dependent_provider_copy2
def test_deepcopy_overridden(self): def test_deepcopy_overridden(self):
provider = self.singleton_cls(Example) provider = self.singleton_cls(Example)
@ -324,12 +322,12 @@ class _BaseSingletonTestCase(object):
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
object_provider_copy = provider_copy.overridden[0] object_provider_copy = provider_copy.overridden[0]
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertIs(provider.cls, provider_copy.cls) assert provider.cls is provider_copy.cls
self.assertIsInstance(provider, self.singleton_cls) assert isinstance(provider, self.singleton_cls)
self.assertIsNot(object_provider, object_provider_copy) assert object_provider is not object_provider_copy
self.assertIsInstance(object_provider_copy, providers.Object) assert isinstance(object_provider_copy, providers.Object)
def test_deepcopy_with_sys_streams(self): def test_deepcopy_with_sys_streams(self):
provider = providers.Singleton(Example) provider = providers.Singleton(Example)
@ -339,24 +337,24 @@ class _BaseSingletonTestCase(object):
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertIsInstance(provider_copy, providers.Singleton) assert isinstance(provider_copy, providers.Singleton)
self.assertIs(provider.args[0], sys.stdin) assert provider.args[0] is sys.stdin
self.assertIs(provider.kwargs["a2"], sys.stdout) assert provider.kwargs["a2"] is sys.stdout
self.assertIs(provider.attributes["a3"], sys.stderr) assert provider.attributes["a3"] is sys.stderr
def test_reset(self): def test_reset(self):
provider = self.singleton_cls(object) provider = self.singleton_cls(object)
instance1 = provider() instance1 = provider()
self.assertIsInstance(instance1, object) assert isinstance(instance1, object)
provider.reset() provider.reset()
instance2 = provider() instance2 = provider()
self.assertIsInstance(instance2, object) assert isinstance(instance2, object)
self.assertIsNot(instance1, instance2) assert instance1 is not instance2
def test_reset_with_singleton(self): def test_reset_with_singleton(self):
dependent_singleton = providers.Singleton(object) dependent_singleton = providers.Singleton(object)
@ -364,14 +362,14 @@ class _BaseSingletonTestCase(object):
dependent_instance = dependent_singleton() dependent_instance = dependent_singleton()
instance1 = provider() instance1 = provider()
self.assertIs(instance1["dependency"], dependent_instance) assert instance1["dependency"] is dependent_instance
provider.reset() provider.reset()
instance2 = provider() instance2 = provider()
self.assertIs(instance1["dependency"], dependent_instance) assert instance1["dependency"] is dependent_instance
self.assertIsNot(instance1, instance2) assert instance1 is not instance2
def test_reset_context_manager(self): def test_reset_context_manager(self):
singleton = self.singleton_cls(object) singleton = self.singleton_cls(object)
@ -380,7 +378,7 @@ class _BaseSingletonTestCase(object):
with singleton.reset(): with singleton.reset():
instance2 = singleton() instance2 = singleton()
instance3 = singleton() instance3 = singleton()
self.assertEqual(len({instance1, instance2, instance3}), 3) assert len({instance1, instance2, instance3}) == 3
def test_reset_context_manager_as_attribute(self): def test_reset_context_manager_as_attribute(self):
singleton = self.singleton_cls(object) singleton = self.singleton_cls(object)
@ -388,7 +386,7 @@ class _BaseSingletonTestCase(object):
with singleton.reset() as alias: with singleton.reset() as alias:
pass pass
self.assertIs(singleton, alias) assert singleton is alias
def test_full_reset(self): def test_full_reset(self):
dependent_singleton = providers.Singleton(object) dependent_singleton = providers.Singleton(object)
@ -396,15 +394,15 @@ class _BaseSingletonTestCase(object):
dependent_instance1 = dependent_singleton() dependent_instance1 = dependent_singleton()
instance1 = provider() instance1 = provider()
self.assertIs(instance1["dependency"], dependent_instance1) assert instance1["dependency"] is dependent_instance1
provider.full_reset() provider.full_reset()
dependent_instance2 = dependent_singleton() dependent_instance2 = dependent_singleton()
instance2 = provider() instance2 = provider()
self.assertIsNot(instance2["dependency"], dependent_instance1) assert instance2["dependency"] is not dependent_instance1
self.assertIsNot(dependent_instance1, dependent_instance2) assert dependent_instance1 is not dependent_instance2
self.assertIsNot(instance1, instance2) assert instance1 is not instance2
def test_full_reset_context_manager(self): def test_full_reset_context_manager(self):
class Item: class Item:
@ -419,11 +417,8 @@ class _BaseSingletonTestCase(object):
instance2 = singleton() instance2 = singleton()
instance3 = singleton() instance3 = singleton()
self.assertEqual(len({instance1, instance2, instance3}), 3) assert len({instance1, instance2, instance3}) == 3
self.assertEqual( assert len({instance1.dependency, instance2.dependency, instance3.dependency}) == 3
len({instance1.dependency, instance2.dependency, instance3.dependency}),
3,
)
def test_full_reset_context_manager_as_attribute(self): def test_full_reset_context_manager_as_attribute(self):
singleton = self.singleton_cls(object) singleton = self.singleton_cls(object)
@ -431,4 +426,4 @@ class _BaseSingletonTestCase(object):
with singleton.full_reset() as alias: with singleton.full_reset() as alias:
pass pass
self.assertIs(singleton, alias) assert singleton is alias

View File

@ -8,6 +8,7 @@ from dependency_injector import (
providers, providers,
errors, errors,
) )
from pytest import raises
class ProviderTests(unittest.TestCase): class ProviderTests(unittest.TestCase):
@ -16,10 +17,11 @@ class ProviderTests(unittest.TestCase):
self.provider = providers.Provider() self.provider = providers.Provider()
def test_is_provider(self): def test_is_provider(self):
self.assertTrue(providers.is_provider(self.provider)) assert providers.is_provider(self.provider) is True
def test_call(self): def test_call(self):
self.assertRaises(NotImplementedError, self.provider.__call__) with raises(NotImplementedError):
self.provider()
def test_delegate(self): def test_delegate(self):
with warnings.catch_warnings(): with warnings.catch_warnings():
@ -27,32 +29,32 @@ class ProviderTests(unittest.TestCase):
delegate1 = self.provider.delegate() delegate1 = self.provider.delegate()
delegate2 = self.provider.delegate() delegate2 = self.provider.delegate()
self.assertIsInstance(delegate1, providers.Delegate) assert isinstance(delegate1, providers.Delegate)
self.assertIs(delegate1(), self.provider) assert delegate1() is self.provider
self.assertIsInstance(delegate2, providers.Delegate) assert isinstance(delegate2, providers.Delegate)
self.assertIs(delegate2(), self.provider) assert delegate2() is self.provider
self.assertIsNot(delegate1, delegate2) assert delegate1 is not delegate2
def test_provider(self): def test_provider(self):
delegate1 = self.provider.provider delegate1 = self.provider.provider
self.assertIsInstance(delegate1, providers.Delegate) assert isinstance(delegate1, providers.Delegate)
self.assertIs(delegate1(), self.provider) assert delegate1() is self.provider
delegate2 = self.provider.provider delegate2 = self.provider.provider
self.assertIsInstance(delegate2, providers.Delegate) assert isinstance(delegate2, providers.Delegate)
self.assertIs(delegate2(), self.provider) assert delegate2() is self.provider
self.assertIsNot(delegate1, delegate2) assert delegate1 is not delegate2
def test_override(self): def test_override(self):
overriding_provider = providers.Provider() overriding_provider = providers.Provider()
self.provider.override(overriding_provider) self.provider.override(overriding_provider)
self.assertTrue(self.provider.overridden) assert self.provider.overridden == (overriding_provider,)
self.assertIs(self.provider.last_overriding, overriding_provider) assert self.provider.last_overriding is overriding_provider
def test_double_override(self): def test_double_override(self):
overriding_provider1 = providers.Object(1) overriding_provider1 = providers.Object(1)
@ -61,21 +63,23 @@ class ProviderTests(unittest.TestCase):
self.provider.override(overriding_provider1) self.provider.override(overriding_provider1)
overriding_provider1.override(overriding_provider2) overriding_provider1.override(overriding_provider2)
self.assertEqual(self.provider(), overriding_provider2()) assert self.provider() == overriding_provider2()
def test_overriding_context(self): def test_overriding_context(self):
overriding_provider = providers.Provider() overriding_provider = providers.Provider()
with self.provider.override(overriding_provider): with self.provider.override(overriding_provider):
self.assertTrue(self.provider.overridden) assert self.provider.overridden == (overriding_provider,)
self.assertFalse(self.provider.overridden) assert self.provider.overridden == tuple()
assert not self.provider.overridden
def test_override_with_itself(self): def test_override_with_itself(self):
self.assertRaises(errors.Error, self.provider.override, self.provider) with raises(errors.Error):
self.provider.override(self.provider)
def test_override_with_not_provider(self): def test_override_with_not_provider(self):
obj = object() obj = object()
self.provider.override(obj) self.provider.override(obj)
self.assertIs(self.provider(), obj) assert self.provider() is obj
def test_reset_last_overriding(self): def test_reset_last_overriding(self):
overriding_provider1 = providers.Provider() overriding_provider1 = providers.Provider()
@ -84,38 +88,40 @@ class ProviderTests(unittest.TestCase):
self.provider.override(overriding_provider1) self.provider.override(overriding_provider1)
self.provider.override(overriding_provider2) self.provider.override(overriding_provider2)
self.assertIs(self.provider.overridden[-1], overriding_provider2) assert self.provider.overridden[-1] is overriding_provider2
self.assertIs(self.provider.last_overriding, overriding_provider2) assert self.provider.last_overriding is overriding_provider2
self.provider.reset_last_overriding() self.provider.reset_last_overriding()
self.assertIs(self.provider.overridden[-1], overriding_provider1) assert self.provider.overridden[-1] is overriding_provider1
self.assertIs(self.provider.last_overriding, overriding_provider1) assert self.provider.last_overriding is overriding_provider1
self.provider.reset_last_overriding() self.provider.reset_last_overriding()
self.assertFalse(self.provider.overridden) assert self.provider.overridden == tuple()
self.assertIsNone(self.provider.last_overriding) assert not self.provider.overridden
assert self.provider.last_overriding is None
def test_reset_last_overriding_of_not_overridden_provider(self): def test_reset_last_overriding_of_not_overridden_provider(self):
self.assertRaises(errors.Error, self.provider.reset_last_overriding) with raises(errors.Error):
self.provider.reset_last_overriding()
def test_reset_override(self): def test_reset_override(self):
overriding_provider = providers.Provider() overriding_provider = providers.Provider()
self.provider.override(overriding_provider) self.provider.override(overriding_provider)
self.assertTrue(self.provider.overridden) assert self.provider.overridden
self.assertEqual(self.provider.overridden, (overriding_provider,)) assert self.provider.overridden == (overriding_provider,)
self.provider.reset_override() self.provider.reset_override()
self.assertEqual(self.provider.overridden, tuple()) assert self.provider.overridden == tuple()
def test_deepcopy(self): def test_deepcopy(self):
provider = providers.Provider() provider = providers.Provider()
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertIsInstance(provider, providers.Provider) assert isinstance(provider, providers.Provider)
def test_deepcopy_from_memo(self): def test_deepcopy_from_memo(self):
provider = providers.Provider() provider = providers.Provider()
@ -124,7 +130,7 @@ class ProviderTests(unittest.TestCase):
provider_copy = providers.deepcopy( provider_copy = providers.deepcopy(
provider, memo={id(provider): provider_copy_memo}) provider, memo={id(provider): provider_copy_memo})
self.assertIs(provider_copy, provider_copy_memo) assert provider_copy is provider_copy_memo
def test_deepcopy_overridden(self): def test_deepcopy_overridden(self):
provider = providers.Provider() provider = providers.Provider()
@ -135,56 +141,57 @@ class ProviderTests(unittest.TestCase):
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
overriding_provider_copy = provider_copy.overridden[0] overriding_provider_copy = provider_copy.overridden[0]
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertIsInstance(provider, providers.Provider) assert isinstance(provider, providers.Provider)
self.assertIsNot(overriding_provider, overriding_provider_copy) assert overriding_provider is not overriding_provider_copy
self.assertIsInstance(overriding_provider_copy, providers.Provider) assert isinstance(overriding_provider_copy, providers.Provider)
def test_repr(self): def test_repr(self):
self.assertEqual(repr(self.provider), assert repr(self.provider) == (
"<dependency_injector.providers." "<dependency_injector.providers."
"Provider() at {0}>".format(hex(id(self.provider)))) "Provider() at {0}>".format(hex(id(self.provider)))
)
class ObjectProviderTests(unittest.TestCase): class ObjectProviderTests(unittest.TestCase):
def test_is_provider(self): def test_is_provider(self):
self.assertTrue(providers.is_provider(providers.Object(object()))) assert providers.is_provider(providers.Object(object())) is True
def test_init_optional_provides(self): def test_init_optional_provides(self):
instance = object() instance = object()
provider = providers.Object() provider = providers.Object()
provider.set_provides(instance) provider.set_provides(instance)
self.assertIs(provider.provides, instance) assert provider.provides is instance
self.assertIs(provider(), instance) assert provider() is instance
def test_set_provides_returns_self(self): def test_set_provides_returns_self(self):
provider = providers.Object() provider = providers.Object()
self.assertIs(provider.set_provides(object()), provider) assert provider.set_provides(object()) is provider
def test_provided_instance_provider(self): def test_provided_instance_provider(self):
provider = providers.Object(object()) provider = providers.Object(object())
self.assertIsInstance(provider.provided, providers.ProvidedInstance) assert isinstance(provider.provided, providers.ProvidedInstance)
def test_call_object_provider(self): def test_call_object_provider(self):
obj = object() obj = object()
self.assertIs(providers.Object(obj)(), obj) assert providers.Object(obj)() is obj
def test_call_overridden_object_provider(self): def test_call_overridden_object_provider(self):
obj1 = object() obj1 = object()
obj2 = object() obj2 = object()
provider = providers.Object(obj1) provider = providers.Object(obj1)
provider.override(providers.Object(obj2)) provider.override(providers.Object(obj2))
self.assertIs(provider(), obj2) assert provider() is obj2
def test_deepcopy(self): def test_deepcopy(self):
provider = providers.Object(1) provider = providers.Object(1)
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertIsInstance(provider, providers.Object) assert isinstance(provider, providers.Object)
def test_deepcopy_from_memo(self): def test_deepcopy_from_memo(self):
provider = providers.Object(1) provider = providers.Object(1)
@ -193,7 +200,7 @@ class ObjectProviderTests(unittest.TestCase):
provider_copy = providers.deepcopy( provider_copy = providers.deepcopy(
provider, memo={id(provider): provider_copy_memo}) provider, memo={id(provider): provider_copy_memo})
self.assertIs(provider_copy, provider_copy_memo) assert provider_copy is provider_copy_memo
def test_deepcopy_overridden(self): def test_deepcopy_overridden(self):
provider = providers.Object(1) provider = providers.Object(1)
@ -204,11 +211,11 @@ class ObjectProviderTests(unittest.TestCase):
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
overriding_provider_copy = provider_copy.overridden[0] overriding_provider_copy = provider_copy.overridden[0]
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertIsInstance(provider, providers.Object) assert isinstance(provider, providers.Object)
self.assertIsNot(overriding_provider, overriding_provider_copy) assert overriding_provider is not overriding_provider_copy
self.assertIsInstance(overriding_provider_copy, providers.Provider) assert isinstance(overriding_provider_copy, providers.Provider)
def test_deepcopy_doesnt_copy_provided_object(self): def test_deepcopy_doesnt_copy_provided_object(self):
# Fixes bug #231 # Fixes bug #231
@ -218,46 +225,45 @@ class ObjectProviderTests(unittest.TestCase):
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
self.assertIs(provider(), some_object) assert provider() is some_object
self.assertIs(provider_copy(), some_object) assert provider_copy() is some_object
def test_repr(self): def test_repr(self):
some_object = object() some_object = object()
provider = providers.Object(some_object) provider = providers.Object(some_object)
self.assertEqual(repr(provider), assert repr(provider) == (
"<dependency_injector.providers." "<dependency_injector.providers."
"Object({0}) at {1}>".format( "Object({0}) at {1}>".format(repr(some_object), hex(id(provider)))
repr(some_object), )
hex(id(provider))))
class SelfProviderTests(unittest.TestCase): class SelfProviderTests(unittest.TestCase):
def test_is_provider(self): def test_is_provider(self):
self.assertTrue(providers.is_provider(providers.Self())) assert providers.is_provider(providers.Self()) is True
def test_call_object_provider(self): def test_call_object_provider(self):
container = containers.DeclarativeContainer() container = containers.DeclarativeContainer()
self.assertIs(providers.Self(container)(), container) assert providers.Self(container)() is container
def test_set_container(self): def test_set_container(self):
container = containers.DeclarativeContainer() container = containers.DeclarativeContainer()
provider = providers.Self() provider = providers.Self()
provider.set_container(container) provider.set_container(container)
self.assertIs(provider(), container) assert provider() is container
def test_set_alt_names(self): def test_set_alt_names(self):
provider = providers.Self() provider = providers.Self()
provider.set_alt_names({"foo", "bar", "baz"}) provider.set_alt_names({"foo", "bar", "baz"})
self.assertEqual(set(provider.alt_names), {"foo", "bar", "baz"}) assert set(provider.alt_names) == {"foo", "bar", "baz"}
def test_deepcopy(self): def test_deepcopy(self):
provider = providers.Self() provider = providers.Self()
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertIsInstance(provider, providers.Self) assert isinstance(provider, providers.Self)
def test_deepcopy_from_memo(self): def test_deepcopy_from_memo(self):
provider = providers.Self() provider = providers.Self()
@ -266,7 +272,7 @@ class SelfProviderTests(unittest.TestCase):
provider_copy = providers.deepcopy( provider_copy = providers.deepcopy(
provider, memo={id(provider): provider_copy_memo}) provider, memo={id(provider): provider_copy_memo})
self.assertIs(provider_copy, provider_copy_memo) assert provider_copy is provider_copy_memo
def test_deepcopy_overridden(self): def test_deepcopy_overridden(self):
provider = providers.Self() provider = providers.Self()
@ -277,20 +283,19 @@ class SelfProviderTests(unittest.TestCase):
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
overriding_provider_copy = provider_copy.overridden[0] overriding_provider_copy = provider_copy.overridden[0]
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertIsInstance(provider, providers.Self) assert isinstance(provider, providers.Self)
self.assertIsNot(overriding_provider, overriding_provider_copy) assert overriding_provider is not overriding_provider_copy
self.assertIsInstance(overriding_provider_copy, providers.Provider) assert isinstance(overriding_provider_copy, providers.Provider)
def test_repr(self): def test_repr(self):
container = containers.DeclarativeContainer() container = containers.DeclarativeContainer()
provider = providers.Self(container) provider = providers.Self(container)
self.assertEqual(repr(provider), assert repr(provider) == (
"<dependency_injector.providers." "<dependency_injector.providers."
"Self({0}) at {1}>".format( "Self({0}) at {1}>".format(repr(container), hex(id(provider)))
repr(container), )
hex(id(provider))))
class DelegateTests(unittest.TestCase): class DelegateTests(unittest.TestCase):
@ -300,34 +305,34 @@ class DelegateTests(unittest.TestCase):
self.delegate = providers.Delegate(self.delegated) self.delegate = providers.Delegate(self.delegated)
def test_is_provider(self): def test_is_provider(self):
self.assertTrue(providers.is_provider(self.delegate)) assert providers.is_provider(self.delegate) is True
def test_init_optional_provides(self): def test_init_optional_provides(self):
provider = providers.Delegate() provider = providers.Delegate()
provider.set_provides(self.delegated) provider.set_provides(self.delegated)
self.assertIs(provider.provides, self.delegated) assert provider.provides is self.delegated
self.assertIs(provider(), self.delegated) assert provider() is self.delegated
def test_set_provides_returns_self(self): def test_set_provides_returns_self(self):
provider = providers.Delegate() provider = providers.Delegate()
self.assertIs(provider.set_provides(self.delegated), provider) assert provider.set_provides(self.delegated) is provider
def test_init_with_not_provider(self): def test_init_with_not_provider(self):
self.assertRaises(errors.Error, providers.Delegate, object()) with raises(errors.Error):
providers.Delegate(object())
def test_call(self): def test_call(self):
delegated1 = self.delegate() delegated1 = self.delegate()
delegated2 = self.delegate() delegated2 = self.delegate()
self.assertIs(delegated1, self.delegated) assert delegated1 is self.delegated
self.assertIs(delegated2, self.delegated) assert delegated2 is self.delegated
def test_repr(self): def test_repr(self):
self.assertEqual(repr(self.delegate), assert repr(self.delegate) == (
"<dependency_injector.providers." "<dependency_injector.providers."
"Delegate({0}) at {1}>".format( "Delegate({0}) at {1}>".format(repr(self.delegated), hex(id(self.delegate)))
repr(self.delegated), )
hex(id(self.delegate))))
class DependencyTests(unittest.TestCase): class DependencyTests(unittest.TestCase):
@ -341,20 +346,21 @@ class DependencyTests(unittest.TestCase):
provider.set_instance_of(list) provider.set_instance_of(list)
provider.set_default(list_provider) provider.set_default(list_provider)
self.assertIs(provider.instance_of, list) assert provider.instance_of is list
self.assertIs(provider.default, list_provider) assert provider.default is list_provider
self.assertEqual(provider(), [1, 2, 3]) assert provider() == [1, 2, 3]
def test_set_instance_of_returns_self(self): def test_set_instance_of_returns_self(self):
provider = providers.Dependency() provider = providers.Dependency()
self.assertIs(provider.set_instance_of(list), provider) assert provider.set_instance_of(list) is provider
def test_set_default_returns_self(self): def test_set_default_returns_self(self):
provider = providers.Dependency() provider = providers.Dependency()
self.assertIs(provider.set_default(providers.Provider()), provider) assert provider.set_default(providers.Provider()) is provider
def test_init_with_not_class(self): def test_init_with_not_class(self):
self.assertRaises(TypeError, providers.Dependency, object()) with raises(TypeError):
providers.Dependency(object())
def test_with_abc(self): def test_with_abc(self):
try: try:
@ -365,59 +371,59 @@ class DependencyTests(unittest.TestCase):
provider = providers.Dependency(collections_abc.Mapping) provider = providers.Dependency(collections_abc.Mapping)
provider.provided_by(providers.Factory(dict)) provider.provided_by(providers.Factory(dict))
self.assertIsInstance(provider(), collections_abc.Mapping) assert isinstance(provider(), collections_abc.Mapping)
self.assertIsInstance(provider(), dict) assert isinstance(provider(), dict)
def test_is_provider(self): def test_is_provider(self):
self.assertTrue(providers.is_provider(self.provider)) assert providers.is_provider(self.provider) is True
def test_provided_instance_provider(self): def test_provided_instance_provider(self):
self.assertIsInstance(self.provider.provided, providers.ProvidedInstance) assert isinstance(self.provider.provided, providers.ProvidedInstance)
def test_default(self): def test_default(self):
provider = providers.Dependency(instance_of=dict, default={"foo": "bar"}) provider = providers.Dependency(instance_of=dict, default={"foo": "bar"})
self.assertEqual(provider(), {"foo": "bar"}) assert provider() == {"foo": "bar"}
def test_default_attribute(self): def test_default_attribute(self):
provider = providers.Dependency(instance_of=dict, default={"foo": "bar"}) provider = providers.Dependency(instance_of=dict, default={"foo": "bar"})
self.assertEqual(provider.default(), {"foo": "bar"}) assert provider.default() == {"foo": "bar"}
def test_default_provider(self): def test_default_provider(self):
provider = providers.Dependency(instance_of=dict, default=providers.Factory(dict, foo="bar")) provider = providers.Dependency(instance_of=dict, default=providers.Factory(dict, foo="bar"))
self.assertEqual(provider.default(), {"foo": "bar"}) assert provider.default() == {"foo": "bar"}
def test_default_attribute_provider(self): def test_default_attribute_provider(self):
default = providers.Factory(dict, foo="bar") default = providers.Factory(dict, foo="bar")
provider = providers.Dependency(instance_of=dict, default=default) provider = providers.Dependency(instance_of=dict, default=default)
self.assertEqual(provider.default(), {"foo": "bar"}) assert provider.default() == {"foo": "bar"}
self.assertIs(provider.default, default) assert provider.default is default
def test_is_defined(self): def test_is_defined(self):
provider = providers.Dependency() provider = providers.Dependency()
self.assertFalse(provider.is_defined) assert provider.is_defined is False
def test_is_defined_when_overridden(self): def test_is_defined_when_overridden(self):
provider = providers.Dependency() provider = providers.Dependency()
provider.override("value") provider.override("value")
self.assertTrue(provider.is_defined) assert provider.is_defined is True
def test_is_defined_with_default(self): def test_is_defined_with_default(self):
provider = providers.Dependency(default="value") provider = providers.Dependency(default="value")
self.assertTrue(provider.is_defined) assert provider.is_defined is True
def test_call_overridden(self): def test_call_overridden(self):
self.provider.provided_by(providers.Factory(list)) self.provider.provided_by(providers.Factory(list))
self.assertIsInstance(self.provider(), list) assert isinstance(self.provider(), list)
def test_call_overridden_but_not_instance_of(self): def test_call_overridden_but_not_instance_of(self):
self.provider.provided_by(providers.Factory(dict)) self.provider.provided_by(providers.Factory(dict))
self.assertRaises(errors.Error, self.provider) with raises(errors.Error):
self.provider()
def test_call_undefined(self): def test_call_undefined(self):
with self.assertRaises(errors.Error) as context: with raises(errors.Error, match="Dependency is not defined"):
self.provider() self.provider()
self.assertEqual(str(context.exception), "Dependency is not defined")
def test_call_undefined_error_message_with_container_instance_parent(self): def test_call_undefined_error_message_with_container_instance_parent(self):
class UserService: class UserService:
@ -434,10 +440,9 @@ class DependencyTests(unittest.TestCase):
container = Container() container = Container()
with self.assertRaises(errors.Error) as context: with raises(errors.Error) as exception_info:
container.user_service() container.user_service()
assert str(exception_info.value) == "Dependency \"Container.database\" is not defined"
self.assertEqual(str(context.exception), "Dependency \"Container.database\" is not defined")
def test_call_undefined_error_message_with_container_provider_parent_deep(self): def test_call_undefined_error_message_with_container_provider_parent_deep(self):
class Database: class Database:
@ -468,13 +473,9 @@ class DependencyTests(unittest.TestCase):
container = Container() container = Container()
with self.assertRaises(errors.Error) as context: with raises(errors.Error) as exception_info:
container.services().user() container.services().user()
assert str(exception_info.value) == "Dependency \"Container.services.gateways.database_client\" is not defined"
self.assertEqual(
str(context.exception),
"Dependency \"Container.services.gateways.database_client\" is not defined",
)
def test_call_undefined_error_message_with_dependenciescontainer_provider_parent(self): def test_call_undefined_error_message_with_dependenciescontainer_provider_parent(self):
class UserService: class UserService:
@ -491,13 +492,9 @@ class DependencyTests(unittest.TestCase):
services = Services() services = Services()
with self.assertRaises(errors.Error) as context: with raises(errors.Error) as exception_info:
services.user() services.user()
assert str(exception_info.value) == "Dependency \"Services.gateways.database_client\" is not defined"
self.assertEqual(
str(context.exception),
"Dependency \"Services.gateways.database_client\" is not defined",
)
def test_assign_parent(self): def test_assign_parent(self):
parent = providers.DependenciesContainer() parent = providers.DependenciesContainer()
@ -505,23 +502,23 @@ class DependencyTests(unittest.TestCase):
provider.assign_parent(parent) provider.assign_parent(parent)
self.assertIs(provider.parent, parent) assert provider.parent is parent
def test_parent_name(self): def test_parent_name(self):
container = containers.DynamicContainer() container = containers.DynamicContainer()
provider = providers.Dependency() provider = providers.Dependency()
container.name = provider container.name = provider
self.assertEqual(provider.parent_name, "name") assert provider.parent_name == "name"
def test_parent_name_with_deep_parenting(self): def test_parent_name_with_deep_parenting(self):
provider = providers.Dependency() provider = providers.Dependency()
container = providers.DependenciesContainer(name=provider) container = providers.DependenciesContainer(name=provider)
_ = providers.DependenciesContainer(container=container) _ = providers.DependenciesContainer(container=container)
self.assertEqual(provider.parent_name, "container.name") assert provider.parent_name == "container.name"
def test_parent_name_is_none(self): def test_parent_name_is_none(self):
provider = providers.DependenciesContainer() provider = providers.DependenciesContainer()
self.assertIsNone(provider.parent_name) assert provider.parent_name is None
def test_parent_deepcopy(self): def test_parent_deepcopy(self):
container = containers.DynamicContainer() container = containers.DynamicContainer()
@ -530,12 +527,12 @@ class DependencyTests(unittest.TestCase):
copied = providers.deepcopy(container) copied = providers.deepcopy(container)
self.assertIs(container.name.parent, container) assert container.name.parent is container
self.assertIs(copied.name.parent, copied) assert copied.name.parent is copied
self.assertIsNot(container, copied) assert container is not copied
self.assertIsNot(container.name, copied.name) assert container.name is not copied.name
self.assertIsNot(container.name.parent, copied.name.parent) assert container.name.parent is not copied.name.parent
def test_forward_attr_to_default(self): def test_forward_attr_to_default(self):
default = providers.Configuration() default = providers.Configuration()
@ -543,7 +540,7 @@ class DependencyTests(unittest.TestCase):
provider = providers.Dependency(default=default) provider = providers.Dependency(default=default)
provider.from_dict({"foo": "bar"}) provider.from_dict({"foo": "bar"})
self.assertEqual(default(), {"foo": "bar"}) assert default() == {"foo": "bar"}
def test_forward_attr_to_overriding(self): def test_forward_attr_to_overriding(self):
overriding = providers.Configuration() overriding = providers.Configuration()
@ -552,11 +549,11 @@ class DependencyTests(unittest.TestCase):
provider.override(overriding) provider.override(overriding)
provider.from_dict({"foo": "bar"}) provider.from_dict({"foo": "bar"})
self.assertEqual(overriding(), {"foo": "bar"}) assert overriding() == {"foo": "bar"}
def test_forward_attr_to_none(self): def test_forward_attr_to_none(self):
provider = providers.Dependency() provider = providers.Dependency()
with self.assertRaises(AttributeError): with raises(AttributeError):
provider.from_dict provider.from_dict
def test_deepcopy(self): def test_deepcopy(self):
@ -564,8 +561,8 @@ class DependencyTests(unittest.TestCase):
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertIsInstance(provider, providers.Dependency) assert isinstance(provider, providers.Dependency)
def test_deepcopy_from_memo(self): def test_deepcopy_from_memo(self):
provider = providers.Dependency(int) provider = providers.Dependency(int)
@ -574,7 +571,7 @@ class DependencyTests(unittest.TestCase):
provider_copy = providers.deepcopy( provider_copy = providers.deepcopy(
provider, memo={id(provider): provider_copy_memo}) provider, memo={id(provider): provider_copy_memo})
self.assertIs(provider_copy, provider_copy_memo) assert provider_copy is provider_copy_memo
def test_deepcopy_overridden(self): def test_deepcopy_overridden(self):
provider = providers.Dependency(int) provider = providers.Dependency(int)
@ -585,11 +582,11 @@ class DependencyTests(unittest.TestCase):
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
overriding_provider_copy = provider_copy.overridden[0] overriding_provider_copy = provider_copy.overridden[0]
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertIsInstance(provider, providers.Dependency) assert isinstance(provider, providers.Dependency)
self.assertIsNot(overriding_provider, overriding_provider_copy) assert overriding_provider is not overriding_provider_copy
self.assertIsInstance(overriding_provider_copy, providers.Provider) assert isinstance(overriding_provider_copy, providers.Provider)
def test_deep_copy_default_object(self): def test_deep_copy_default_object(self):
default = {"foo": "bar"} default = {"foo": "bar"}
@ -597,8 +594,8 @@ class DependencyTests(unittest.TestCase):
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
self.assertIs(provider_copy(), default) assert provider_copy() is default
self.assertIs(provider_copy.default(), default) assert provider_copy.default() is default
def test_deep_copy_default_provider(self): def test_deep_copy_default_provider(self):
bar = object() bar = object()
@ -607,9 +604,9 @@ class DependencyTests(unittest.TestCase):
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
self.assertEqual(provider_copy(), {"foo": bar}) assert provider_copy() == {"foo": bar}
self.assertEqual(provider_copy.default(), {"foo": bar}) assert provider_copy.default() == {"foo": bar}
self.assertIs(provider_copy()["foo"], bar) assert provider_copy()["foo"] is bar
def test_with_container_default_object(self): def test_with_container_default_object(self):
default = {"foo": "bar"} default = {"foo": "bar"}
@ -619,8 +616,8 @@ class DependencyTests(unittest.TestCase):
container = Container() container = Container()
self.assertIs(container.provider(), default) assert container.provider() is default
self.assertIs(container.provider.default(), default) assert container.provider.default() is default
def test_with_container_default_provider(self): def test_with_container_default_provider(self):
bar = object() bar = object()
@ -630,9 +627,9 @@ class DependencyTests(unittest.TestCase):
container = Container() container = Container()
self.assertEqual(container.provider(), {"foo": bar}) assert container.provider() == {"foo": bar}
self.assertEqual(container.provider.default(), {"foo": bar}) assert container.provider.default() == {"foo": bar}
self.assertIs(container.provider()["foo"], bar) assert container.provider()["foo"] is bar
def test_with_container_default_provider_with_overriding(self): def test_with_container_default_provider_with_overriding(self):
bar = object() bar = object()
@ -643,16 +640,15 @@ class DependencyTests(unittest.TestCase):
container = Container(provider=providers.Factory(dict, foo=providers.Object(baz))) container = Container(provider=providers.Factory(dict, foo=providers.Object(baz)))
self.assertEqual(container.provider(), {"foo": baz}) assert container.provider() == {"foo": baz}
self.assertEqual(container.provider.default(), {"foo": bar}) assert container.provider.default() == {"foo": bar}
self.assertIs(container.provider()["foo"], baz) assert container.provider()["foo"] is baz
def test_repr(self): def test_repr(self):
self.assertEqual(repr(self.provider), assert repr(self.provider) == (
"<dependency_injector.providers." "<dependency_injector.providers."
"Dependency({0}) at {1}>".format( "Dependency({0}) at {1}>".format(repr(list), hex(id(self.provider)))
repr(list), )
hex(id(self.provider))))
def test_repr_in_container(self): def test_repr_in_container(self):
class Container(containers.DeclarativeContainer): class Container(containers.DeclarativeContainer):
@ -660,11 +656,13 @@ class DependencyTests(unittest.TestCase):
container = Container() container = Container()
self.assertEqual(repr(container.dependency), assert repr(container.dependency) == (
"<dependency_injector.providers." "<dependency_injector.providers."
"Dependency({0}) at {1}, container name: \"Container.dependency\">".format( "Dependency({0}) at {1}, container name: \"Container.dependency\">".format(
repr(int), repr(int),
hex(id(container.dependency)))) hex(id(container.dependency)),
)
)
class ExternalDependencyTests(unittest.TestCase): class ExternalDependencyTests(unittest.TestCase):
@ -673,7 +671,7 @@ class ExternalDependencyTests(unittest.TestCase):
self.provider = providers.ExternalDependency(instance_of=list) self.provider = providers.ExternalDependency(instance_of=list)
def test_is_instance(self): def test_is_instance(self):
self.assertIsInstance(self.provider, providers.Dependency) assert isinstance(self.provider, providers.Dependency)
class DependenciesContainerTests(unittest.TestCase): class DependenciesContainerTests(unittest.TestCase):
@ -690,47 +688,46 @@ class DependenciesContainerTests(unittest.TestCase):
has_dependency = hasattr(self.provider, "dependency") has_dependency = hasattr(self.provider, "dependency")
dependency = self.provider.dependency dependency = self.provider.dependency
self.assertIsInstance(dependency, providers.Dependency) assert isinstance(dependency, providers.Dependency)
self.assertIs(dependency, self.provider.dependency) assert dependency is self.provider.dependency
self.assertTrue(has_dependency) assert has_dependency is True
self.assertIsNone(dependency.last_overriding) assert dependency.last_overriding is None
def test_getattr_with_container(self): def test_getattr_with_container(self):
self.provider.override(self.container) self.provider.override(self.container)
dependency = self.provider.dependency dependency = self.provider.dependency
self.assertTrue(dependency.overridden) assert dependency.overridden == (self.container.dependency,)
self.assertIs(dependency.last_overriding, self.container.dependency) assert dependency.last_overriding is self.container.dependency
def test_providers(self): def test_providers(self):
dependency1 = self.provider.dependency1 dependency1 = self.provider.dependency1
dependency2 = self.provider.dependency2 dependency2 = self.provider.dependency2
self.assertEqual(self.provider.providers, {"dependency1": dependency1, assert self.provider.providers == {"dependency1": dependency1, "dependency2": dependency2}
"dependency2": dependency2})
def test_override(self): def test_override(self):
dependency = self.provider.dependency dependency = self.provider.dependency
self.provider.override(self.container) self.provider.override(self.container)
self.assertTrue(dependency.overridden) assert dependency.overridden == (self.container.dependency,)
self.assertIs(dependency.last_overriding, self.container.dependency) assert dependency.last_overriding is self.container.dependency
def test_reset_last_overriding(self): def test_reset_last_overriding(self):
dependency = self.provider.dependency dependency = self.provider.dependency
self.provider.override(self.container) self.provider.override(self.container)
self.provider.reset_last_overriding() self.provider.reset_last_overriding()
self.assertIsNone(dependency.last_overriding) assert dependency.last_overriding is None
self.assertIsNone(dependency.last_overriding) assert dependency.last_overriding is None
def test_reset_override(self): def test_reset_override(self):
dependency = self.provider.dependency dependency = self.provider.dependency
self.provider.override(self.container) self.provider.override(self.container)
self.provider.reset_override() self.provider.reset_override()
self.assertFalse(dependency.overridden) assert dependency.overridden == tuple()
self.assertFalse(dependency.overridden) assert not dependency.overridden
def test_assign_parent(self): def test_assign_parent(self):
parent = providers.DependenciesContainer() parent = providers.DependenciesContainer()
@ -738,23 +735,23 @@ class DependenciesContainerTests(unittest.TestCase):
provider.assign_parent(parent) provider.assign_parent(parent)
self.assertIs(provider.parent, parent) assert provider.parent is parent
def test_parent_name(self): def test_parent_name(self):
container = containers.DynamicContainer() container = containers.DynamicContainer()
provider = providers.DependenciesContainer() provider = providers.DependenciesContainer()
container.name = provider container.name = provider
self.assertEqual(provider.parent_name, "name") assert provider.parent_name == "name"
def test_parent_name_with_deep_parenting(self): def test_parent_name_with_deep_parenting(self):
provider = providers.DependenciesContainer() provider = providers.DependenciesContainer()
container = providers.DependenciesContainer(name=provider) container = providers.DependenciesContainer(name=provider)
_ = providers.DependenciesContainer(container=container) _ = providers.DependenciesContainer(container=container)
self.assertEqual(provider.parent_name, "container.name") assert provider.parent_name == "container.name"
def test_parent_name_is_none(self): def test_parent_name_is_none(self):
provider = providers.DependenciesContainer() provider = providers.DependenciesContainer()
self.assertIsNone(provider.parent_name) assert provider.parent_name is None
def test_parent_deepcopy(self): def test_parent_deepcopy(self):
container = containers.DynamicContainer() container = containers.DynamicContainer()
@ -763,29 +760,29 @@ class DependenciesContainerTests(unittest.TestCase):
copied = providers.deepcopy(container) copied = providers.deepcopy(container)
self.assertIs(container.name.parent, container) assert container.name.parent is container
self.assertIs(copied.name.parent, copied) assert copied.name.parent is copied
self.assertIsNot(container, copied) assert container is not copied
self.assertIsNot(container.name, copied.name) assert container.name is not copied.name
self.assertIsNot(container.name.parent, copied.name.parent) assert container.name.parent is not copied.name.parent
def test_parent_set_on__getattr__(self): def test_parent_set_on__getattr__(self):
provider = providers.DependenciesContainer() provider = providers.DependenciesContainer()
self.assertIsInstance(provider.name, providers.Dependency) assert isinstance(provider.name, providers.Dependency)
self.assertIs(provider.name.parent, provider) assert provider.name.parent is provider
def test_parent_set_on__init__(self): def test_parent_set_on__init__(self):
provider = providers.Dependency() provider = providers.Dependency()
container = providers.DependenciesContainer(name=provider) container = providers.DependenciesContainer(name=provider)
self.assertIs(container.name, provider) assert container.name is provider
self.assertIs(container.name.parent, container) assert container.name.parent is container
def test_resolve_provider_name(self): def test_resolve_provider_name(self):
container = providers.DependenciesContainer() container = providers.DependenciesContainer()
self.assertEqual(container.resolve_provider_name(container.name), "name") assert container.resolve_provider_name(container.name) == "name"
def test_resolve_provider_name_no_provider(self): def test_resolve_provider_name_no_provider(self):
container = providers.DependenciesContainer() container = providers.DependenciesContainer()
with self.assertRaises(errors.Error): with raises(errors.Error):
container.resolve_provider_name(providers.Provider()) container.resolve_provider_name(providers.Provider())

View File

@ -8,6 +8,7 @@ from dependency_injector import (
providers, providers,
errors, errors,
) )
from pytest import raises
def _example(arg1, arg2, arg3, arg4): def _example(arg1, arg2, arg3, arg4):
@ -16,88 +17,84 @@ def _example(arg1, arg2, arg3, arg4):
class CallableTests(unittest.TestCase): class CallableTests(unittest.TestCase):
def test_init_with_callable(self): def test_is_provider(self):
self.assertTrue(providers.Callable(_example)) assert providers.is_provider(providers.Callable(_example)) is True
def test_init_with_not_callable(self): def test_init_with_not_callable(self):
self.assertRaises(errors.Error, providers.Callable, 123) with raises(errors.Error):
providers.Callable(123)
def test_init_optional_provides(self): def test_init_optional_provides(self):
provider = providers.Callable() provider = providers.Callable()
provider.set_provides(object) provider.set_provides(object)
self.assertIs(provider.provides, object) assert provider.provides is object
self.assertIsInstance(provider(), object) assert isinstance(provider(), object)
def test_set_provides_returns_self(self): def test_set_provides_returns_self(self):
provider = providers.Callable() provider = providers.Callable()
self.assertIs(provider.set_provides(object), provider) assert provider.set_provides(object) is provider
def test_provided_instance_provider(self): def test_provided_instance_provider(self):
provider = providers.Callable(_example) provider = providers.Callable(_example)
self.assertIsInstance(provider.provided, providers.ProvidedInstance) assert isinstance(provider.provided, providers.ProvidedInstance)
def test_call(self): def test_call(self):
provider = providers.Callable(lambda: True) provider = providers.Callable(lambda: True)
self.assertTrue(provider()) assert provider() is True
def test_call_with_positional_args(self): def test_call_with_positional_args(self):
provider = providers.Callable(_example, provider = providers.Callable(_example, 1, 2, 3, 4)
1, 2, 3, 4) assert provider() == (1, 2, 3, 4)
self.assertTupleEqual(provider(), (1, 2, 3, 4))
def test_call_with_keyword_args(self): def test_call_with_keyword_args(self):
provider = providers.Callable(_example, provider = providers.Callable(_example, arg1=1, arg2=2, arg3=3, arg4=4)
arg1=1, arg2=2, arg3=3, arg4=4) assert provider() == (1, 2, 3, 4)
self.assertTupleEqual(provider(), (1, 2, 3, 4))
def test_call_with_positional_and_keyword_args(self): def test_call_with_positional_and_keyword_args(self):
provider = providers.Callable(_example, provider = providers.Callable(_example, 1, 2, arg3=3, arg4=4)
1, 2, assert provider() == (1, 2, 3, 4)
arg3=3, arg4=4)
self.assertTupleEqual(provider(), (1, 2, 3, 4))
def test_call_with_context_args(self): def test_call_with_context_args(self):
provider = providers.Callable(_example, 1, 2) provider = providers.Callable(_example, 1, 2)
self.assertTupleEqual(provider(3, 4), (1, 2, 3, 4)) assert provider(3, 4) == (1, 2, 3, 4)
def test_call_with_context_kwargs(self): def test_call_with_context_kwargs(self):
provider = providers.Callable(_example, arg1=1) provider = providers.Callable(_example, arg1=1)
self.assertTupleEqual(provider(arg2=2, arg3=3, arg4=4), (1, 2, 3, 4)) assert provider(arg2=2, arg3=3, arg4=4) == (1, 2, 3, 4)
def test_call_with_context_args_and_kwargs(self): def test_call_with_context_args_and_kwargs(self):
provider = providers.Callable(_example, 1) provider = providers.Callable(_example, 1)
self.assertTupleEqual(provider(2, arg3=3, arg4=4), (1, 2, 3, 4)) assert provider(2, arg3=3, arg4=4) == (1, 2, 3, 4)
def test_fluent_interface(self): def test_fluent_interface(self):
provider = providers.Singleton(_example) \ provider = providers.Singleton(_example) \
.add_args(1, 2) \ .add_args(1, 2) \
.add_kwargs(arg3=3, arg4=4) .add_kwargs(arg3=3, arg4=4)
assert provider() == (1, 2, 3, 4)
self.assertTupleEqual(provider(), (1, 2, 3, 4))
def test_set_args(self): def test_set_args(self):
provider = providers.Callable(_example) \ provider = providers.Callable(_example) \
.add_args(1, 2) \ .add_args(1, 2) \
.set_args(3, 4) .set_args(3, 4)
self.assertEqual(provider.args, (3, 4)) assert provider.args == (3, 4)
def test_set_kwargs(self): def test_set_kwargs(self):
provider = providers.Callable(_example) \ provider = providers.Callable(_example) \
.add_kwargs(init_arg3=3, init_arg4=4) \ .add_kwargs(init_arg3=3, init_arg4=4) \
.set_kwargs(init_arg3=4, init_arg4=5) .set_kwargs(init_arg3=4, init_arg4=5)
self.assertEqual(provider.kwargs, dict(init_arg3=4, init_arg4=5)) assert provider.kwargs == dict(init_arg3=4, init_arg4=5)
def test_clear_args(self): def test_clear_args(self):
provider = providers.Callable(_example) \ provider = providers.Callable(_example) \
.add_args(1, 2) \ .add_args(1, 2) \
.clear_args() .clear_args()
self.assertEqual(provider.args, tuple()) assert provider.args == tuple()
def test_clear_kwargs(self): def test_clear_kwargs(self):
provider = providers.Callable(_example) \ provider = providers.Callable(_example) \
.add_kwargs(init_arg3=3, init_arg4=4) \ .add_kwargs(init_arg3=3, init_arg4=4) \
.clear_kwargs() .clear_kwargs()
self.assertEqual(provider.kwargs, dict()) assert provider.kwargs == dict()
def test_call_overridden(self): def test_call_overridden(self):
provider = providers.Callable(_example) provider = providers.Callable(_example)
@ -105,25 +102,24 @@ class CallableTests(unittest.TestCase):
provider.override(providers.Object((4, 3, 2, 1))) provider.override(providers.Object((4, 3, 2, 1)))
provider.override(providers.Object((1, 2, 3, 4))) provider.override(providers.Object((1, 2, 3, 4)))
self.assertTupleEqual(provider(), (1, 2, 3, 4)) assert provider() == (1, 2, 3, 4)
def test_deepcopy(self): def test_deepcopy(self):
provider = providers.Callable(_example) provider = providers.Callable(_example)
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertIs(provider.provides, provider_copy.provides) assert provider.provides is provider_copy.provides
self.assertIsInstance(provider, providers.Callable) assert isinstance(provider, providers.Callable)
def test_deepcopy_from_memo(self): def test_deepcopy_from_memo(self):
provider = providers.Callable(_example) provider = providers.Callable(_example)
provider_copy_memo = providers.Callable(_example) provider_copy_memo = providers.Callable(_example)
provider_copy = providers.deepcopy( provider_copy = providers.deepcopy(provider, memo={id(provider): provider_copy_memo})
provider, memo={id(provider): provider_copy_memo})
self.assertIs(provider_copy, provider_copy_memo) assert provider_copy is provider_copy_memo
def test_deepcopy_args(self): def test_deepcopy_args(self):
provider = providers.Callable(_example) provider = providers.Callable(_example)
@ -136,15 +132,13 @@ class CallableTests(unittest.TestCase):
dependent_provider_copy1 = provider_copy.args[0] dependent_provider_copy1 = provider_copy.args[0]
dependent_provider_copy2 = provider_copy.args[1] dependent_provider_copy2 = provider_copy.args[1]
self.assertNotEqual(provider.args, provider_copy.args) assert provider.args != provider_copy.args
self.assertIs(dependent_provider1.provides, assert dependent_provider1.provides is dependent_provider_copy1.provides
dependent_provider_copy1.provides) assert dependent_provider1 is not dependent_provider_copy1
self.assertIsNot(dependent_provider1, dependent_provider_copy1)
self.assertIs(dependent_provider2.provides, assert dependent_provider2.provides is dependent_provider_copy2.provides
dependent_provider_copy2.provides) assert dependent_provider2 is not dependent_provider_copy2
self.assertIsNot(dependent_provider2, dependent_provider_copy2)
def test_deepcopy_kwargs(self): def test_deepcopy_kwargs(self):
provider = providers.Callable(_example) provider = providers.Callable(_example)
@ -157,15 +151,13 @@ class CallableTests(unittest.TestCase):
dependent_provider_copy1 = provider_copy.kwargs["a1"] dependent_provider_copy1 = provider_copy.kwargs["a1"]
dependent_provider_copy2 = provider_copy.kwargs["a2"] dependent_provider_copy2 = provider_copy.kwargs["a2"]
self.assertNotEqual(provider.kwargs, provider_copy.kwargs) assert provider.kwargs != provider_copy.kwargs
self.assertIs(dependent_provider1.provides, assert dependent_provider1.provides is dependent_provider_copy1.provides
dependent_provider_copy1.provides) assert dependent_provider1 is not dependent_provider_copy1
self.assertIsNot(dependent_provider1, dependent_provider_copy1)
self.assertIs(dependent_provider2.provides, assert dependent_provider2.provides is dependent_provider_copy2.provides
dependent_provider_copy2.provides) assert dependent_provider2 is not dependent_provider_copy2
self.assertIsNot(dependent_provider2, dependent_provider_copy2)
def test_deepcopy_overridden(self): def test_deepcopy_overridden(self):
provider = providers.Callable(_example) provider = providers.Callable(_example)
@ -176,12 +168,12 @@ class CallableTests(unittest.TestCase):
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
object_provider_copy = provider_copy.overridden[0] object_provider_copy = provider_copy.overridden[0]
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertIs(provider.provides, provider_copy.provides) assert provider.provides is provider_copy.provides
self.assertIsInstance(provider, providers.Callable) assert isinstance(provider, providers.Callable)
self.assertIsNot(object_provider, object_provider_copy) assert object_provider is not object_provider_copy
self.assertIsInstance(object_provider_copy, providers.Object) assert isinstance(object_provider_copy, providers.Object)
def test_deepcopy_with_sys_streams(self): def test_deepcopy_with_sys_streams(self):
provider = providers.Callable(_example) provider = providers.Callable(_example)
@ -190,50 +182,44 @@ class CallableTests(unittest.TestCase):
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertIsInstance(provider_copy, providers.Callable) assert isinstance(provider_copy, providers.Callable)
self.assertIs(provider.args[0], sys.stdin) assert provider.args[0] is sys.stdin
self.assertIs(provider.kwargs["a2"], sys.stdout) assert provider.kwargs["a2"] is sys.stdout
def test_repr(self): def test_repr(self):
provider = providers.Callable(_example) provider = providers.Callable(_example)
assert repr(provider) == (
self.assertEqual(repr(provider), "<dependency_injector.providers."
"<dependency_injector.providers." "Callable({0}) at {1}>".format(repr(_example), hex(id(provider)))
"Callable({0}) at {1}>".format( )
repr(_example),
hex(id(provider))))
class DelegatedCallableTests(unittest.TestCase): class DelegatedCallableTests(unittest.TestCase):
def test_inheritance(self): def test_inheritance(self):
self.assertIsInstance(providers.DelegatedCallable(_example), assert isinstance(providers.DelegatedCallable(_example),
providers.Callable) providers.Callable)
def test_is_provider(self): def test_is_provider(self):
self.assertTrue( assert providers.is_provider(providers.DelegatedCallable(_example)) is True
providers.is_provider(providers.DelegatedCallable(_example)))
def test_is_delegated_provider(self): def test_is_delegated_provider(self):
provider = providers.DelegatedCallable(_example) provider = providers.DelegatedCallable(_example)
self.assertTrue(providers.is_delegated(provider)) assert providers.is_delegated(provider) is True
def test_repr(self): def test_repr(self):
provider = providers.DelegatedCallable(_example) provider = providers.DelegatedCallable(_example)
assert repr(provider) == (
self.assertEqual(repr(provider), "<dependency_injector.providers."
"<dependency_injector.providers." "DelegatedCallable({0}) at {1}>".format(repr(_example), hex(id(provider)))
"DelegatedCallable({0}) at {1}>".format( )
repr(_example),
hex(id(provider))))
class AbstractCallableTests(unittest.TestCase): class AbstractCallableTests(unittest.TestCase):
def test_inheritance(self): def test_inheritance(self):
self.assertIsInstance(providers.AbstractCallable(_example), assert isinstance(providers.AbstractCallable(_example), providers.Callable)
providers.Callable)
def test_call_overridden_by_callable(self): def test_call_overridden_by_callable(self):
def _abstract_example(): def _abstract_example():
@ -242,7 +228,7 @@ class AbstractCallableTests(unittest.TestCase):
provider = providers.AbstractCallable(_abstract_example) provider = providers.AbstractCallable(_abstract_example)
provider.override(providers.Callable(_example)) provider.override(providers.Callable(_example))
self.assertTrue(provider(1, 2, 3, 4), (1, 2, 3, 4)) assert provider(1, 2, 3, 4) == (1, 2, 3, 4)
def test_call_overridden_by_delegated_callable(self): def test_call_overridden_by_delegated_callable(self):
def _abstract_example(): def _abstract_example():
@ -251,34 +237,33 @@ class AbstractCallableTests(unittest.TestCase):
provider = providers.AbstractCallable(_abstract_example) provider = providers.AbstractCallable(_abstract_example)
provider.override(providers.DelegatedCallable(_example)) provider.override(providers.DelegatedCallable(_example))
self.assertTrue(provider(1, 2, 3, 4), (1, 2, 3, 4)) assert provider(1, 2, 3, 4) == (1, 2, 3, 4)
def test_call_not_overridden(self): def test_call_not_overridden(self):
provider = providers.AbstractCallable(_example) provider = providers.AbstractCallable(_example)
with self.assertRaises(errors.Error): with raises(errors.Error):
provider(1, 2, 3, 4) provider(1, 2, 3, 4)
def test_override_by_not_callable(self): def test_override_by_not_callable(self):
provider = providers.AbstractCallable(_example) provider = providers.AbstractCallable(_example)
with self.assertRaises(errors.Error): with raises(errors.Error):
provider.override(providers.Factory(object)) provider.override(providers.Factory(object))
def test_provide_not_implemented(self): def test_provide_not_implemented(self):
provider = providers.AbstractCallable(_example) provider = providers.AbstractCallable(_example)
with self.assertRaises(NotImplementedError): with raises(NotImplementedError):
provider._provide((1, 2, 3, 4), dict()) provider._provide((1, 2, 3, 4), dict())
def test_repr(self): def test_repr(self):
provider = providers.AbstractCallable(_example) provider = providers.AbstractCallable(_example)
self.assertEqual(repr(provider), assert repr(provider) == (
"<dependency_injector.providers." "<dependency_injector.providers."
"AbstractCallable({0}) at {1}>".format( "AbstractCallable({0}) at {1}>".format(repr(_example), hex(id(provider)))
repr(_example), )
hex(id(provider))))
class CallableDelegateTests(unittest.TestCase): class CallableDelegateTests(unittest.TestCase):
@ -288,9 +273,8 @@ class CallableDelegateTests(unittest.TestCase):
self.delegate = providers.CallableDelegate(self.delegated) self.delegate = providers.CallableDelegate(self.delegated)
def test_is_delegate(self): def test_is_delegate(self):
self.assertIsInstance(self.delegate, providers.Delegate) assert isinstance(self.delegate, providers.Delegate)
def test_init_with_not_callable(self): def test_init_with_not_callable(self):
self.assertRaises(errors.Error, with raises(errors.Error):
providers.CallableDelegate, providers.CallableDelegate(providers.Object(object()))
providers.Object(object()))

File diff suppressed because it is too large Load Diff

View File

@ -5,6 +5,7 @@ import copy
import unittest import unittest
from dependency_injector import containers, providers, errors from dependency_injector import containers, providers, errors
from pytest import raises
TEST_VALUE_1 = "core_section_value1" TEST_VALUE_1 = "core_section_value1"
@ -45,13 +46,13 @@ class ContainerTests(unittest.TestCase):
def test(self): def test(self):
application = TestApplication(config=_copied(TEST_CONFIG_1)) application = TestApplication(config=_copied(TEST_CONFIG_1))
self.assertEqual(application.dict_factory(), {"value": TEST_VALUE_1}) assert application.dict_factory() == {"value": TEST_VALUE_1}
def test_double_override(self): def test_double_override(self):
application = TestApplication() application = TestApplication()
application.config.override(_copied(TEST_CONFIG_1)) application.config.override(_copied(TEST_CONFIG_1))
application.config.override(_copied(TEST_CONFIG_2)) application.config.override(_copied(TEST_CONFIG_2))
self.assertEqual(application.dict_factory(), {"value": TEST_VALUE_2}) assert application.dict_factory() == {"value": TEST_VALUE_2}
def test_override(self): def test_override(self):
# See: https://github.com/ets-labs/python-dependency-injector/issues/354 # See: https://github.com/ets-labs/python-dependency-injector/issues/354
@ -69,7 +70,7 @@ class ContainerTests(unittest.TestCase):
b = B(d=D()) b = B(d=D())
result = b.a().bar() result = b.a().bar()
self.assertEqual(result, "foo++") assert result == "foo++"
def test_override_not_root_provider(self): def test_override_not_root_provider(self):
# See: https://github.com/ets-labs/python-dependency-injector/issues/379 # See: https://github.com/ets-labs/python-dependency-injector/issues/379
@ -105,33 +106,20 @@ class ContainerTests(unittest.TestCase):
container="using_factory", container="using_factory",
foo="bar" foo="bar"
)) ))
self.assertEqual( assert container_using_factory.root_container().print_settings() == {"container": "using_factory", "foo": "bar"}
container_using_factory.root_container().print_settings(), assert container_using_factory.not_root_container().print_settings() == {"container": "using_factory", "foo": "bar"}
{"container": "using_factory", "foo": "bar"},
)
self.assertEqual(
container_using_factory.not_root_container().print_settings(),
{"container": "using_factory", "foo": "bar"},
)
container_using_container = TestContainer(settings=dict( container_using_container = TestContainer(settings=dict(
container="using_container", container="using_container",
foo="bar" foo="bar"
)) ))
self.assertEqual( assert container_using_container.root_container().print_settings() == {"container": "using_container", "foo": "bar"}
container_using_container.root_container().print_settings(), assert container_using_container.not_root_container().print_settings() == {"container": "using_container", "foo": "bar"}
{"container": "using_container", "foo": "bar"},
)
self.assertEqual(
container_using_container.not_root_container().print_settings(),
{"container": "using_container", "foo": "bar"},
)
def test_override_by_not_a_container(self): def test_override_by_not_a_container(self):
provider = providers.Container(TestCore) provider = providers.Container(TestCore)
with self.assertRaises(errors.Error): with raises(errors.Error):
provider.override(providers.Object("foo")) provider.override(providers.Object("foo"))
def test_lazy_overriding(self): def test_lazy_overriding(self):
@ -151,7 +139,7 @@ class ContainerTests(unittest.TestCase):
b = B(d=D()) b = B(d=D())
result = b.a().bar() result = b.a().bar()
self.assertEqual(result, "foo++") assert result == "foo++"
def test_lazy_overriding_deep(self): def test_lazy_overriding_deep(self):
# Extended version of test_lazy_overriding() # Extended version of test_lazy_overriding()
@ -174,7 +162,7 @@ class ContainerTests(unittest.TestCase):
b = B(d=D()) b = B(d=D())
result = b.a().c().bar() result = b.a().c().bar()
self.assertEqual(result, "foo++") assert result == "foo++"
def test_reset_last_overriding(self): def test_reset_last_overriding(self):
application = TestApplication(config=_copied(TEST_CONFIG_1)) application = TestApplication(config=_copied(TEST_CONFIG_1))
@ -182,7 +170,7 @@ class ContainerTests(unittest.TestCase):
application.core.reset_last_overriding() application.core.reset_last_overriding()
self.assertEqual(application.dict_factory(), {"value": TEST_VALUE_1}) assert application.dict_factory() == {"value": TEST_VALUE_1}
def test_reset_last_overriding_only_overridden(self): def test_reset_last_overriding_only_overridden(self):
application = TestApplication(config=_copied(TEST_CONFIG_1)) application = TestApplication(config=_copied(TEST_CONFIG_1))
@ -190,17 +178,17 @@ class ContainerTests(unittest.TestCase):
application.core.reset_last_overriding() application.core.reset_last_overriding()
self.assertEqual(application.dict_factory(), {"value": TEST_VALUE_1}) assert application.dict_factory() == {"value": TEST_VALUE_1}
def test_override_context_manager(self): def test_override_context_manager(self):
application = TestApplication(config=_copied(TEST_CONFIG_1)) application = TestApplication(config=_copied(TEST_CONFIG_1))
overriding_core = TestCore(config=_copied(TEST_CONFIG_2["core"])) overriding_core = TestCore(config=_copied(TEST_CONFIG_2["core"]))
with application.core.override(overriding_core) as context_core: with application.core.override(overriding_core) as context_core:
self.assertEqual(application.dict_factory(), {"value": TEST_VALUE_2}) assert application.dict_factory() == {"value": TEST_VALUE_2}
self.assertIs(context_core(), overriding_core) assert context_core() is overriding_core
self.assertEqual(application.dict_factory(), {"value": TEST_VALUE_1}) assert application.dict_factory() == {"value": TEST_VALUE_1}
def test_reset_override(self): def test_reset_override(self):
application = TestApplication(config=_copied(TEST_CONFIG_1)) application = TestApplication(config=_copied(TEST_CONFIG_1))
@ -208,7 +196,7 @@ class ContainerTests(unittest.TestCase):
application.core.reset_override() application.core.reset_override()
self.assertEqual(application.dict_factory(), {"value": None}) assert application.dict_factory() == {"value": None}
def test_reset_override_only_overridden(self): def test_reset_override_only_overridden(self):
application = TestApplication(config=_copied(TEST_CONFIG_1)) application = TestApplication(config=_copied(TEST_CONFIG_1))
@ -216,7 +204,7 @@ class ContainerTests(unittest.TestCase):
application.core.reset_override() application.core.reset_override()
self.assertEqual(application.dict_factory(), {"value": None}) assert application.dict_factory() == {"value": None}
def test_assign_parent(self): def test_assign_parent(self):
parent = providers.DependenciesContainer() parent = providers.DependenciesContainer()
@ -224,23 +212,23 @@ class ContainerTests(unittest.TestCase):
provider.assign_parent(parent) provider.assign_parent(parent)
self.assertIs(provider.parent, parent) assert provider.parent is parent
def test_parent_name(self): def test_parent_name(self):
container = containers.DynamicContainer() container = containers.DynamicContainer()
provider = providers.Container(TestCore) provider = providers.Container(TestCore)
container.name = provider container.name = provider
self.assertEqual(provider.parent_name, "name") assert provider.parent_name == "name"
def test_parent_name_with_deep_parenting(self): def test_parent_name_with_deep_parenting(self):
provider = providers.Container(TestCore) provider = providers.Container(TestCore)
container = providers.DependenciesContainer(name=provider) container = providers.DependenciesContainer(name=provider)
_ = providers.DependenciesContainer(container=container) _ = providers.DependenciesContainer(container=container)
self.assertEqual(provider.parent_name, "container.name") assert provider.parent_name == "container.name"
def test_parent_name_is_none(self): def test_parent_name_is_none(self):
provider = providers.Container(TestCore) provider = providers.Container(TestCore)
self.assertIsNone(provider.parent_name) assert provider.parent_name is None
def test_parent_deepcopy(self): def test_parent_deepcopy(self):
container = containers.DynamicContainer() container = containers.DynamicContainer()
@ -249,18 +237,18 @@ class ContainerTests(unittest.TestCase):
copied = providers.deepcopy(container) copied = providers.deepcopy(container)
self.assertIs(container.name.parent, container) assert container.name.parent is container
self.assertIs(copied.name.parent, copied) assert copied.name.parent is copied
self.assertIsNot(container, copied) assert container is not copied
self.assertIsNot(container.name, copied.name) assert container.name is not copied.name
self.assertIsNot(container.name.parent, copied.name.parent) assert container.name.parent is not copied.name.parent
def test_resolve_provider_name(self): def test_resolve_provider_name(self):
container = providers.Container(TestCore) container = providers.Container(TestCore)
self.assertEqual(container.resolve_provider_name(container.value_getter), "value_getter") assert container.resolve_provider_name(container.value_getter) == "value_getter"
def test_resolve_provider_name_no_provider(self): def test_resolve_provider_name_no_provider(self):
container = providers.Container(TestCore) container = providers.Container(TestCore)
with self.assertRaises(errors.Error): with raises(errors.Error):
container.resolve_provider_name(providers.Provider()) container.resolve_provider_name(providers.Provider())

View File

@ -8,6 +8,7 @@ from dependency_injector import (
providers, providers,
errors, errors,
) )
from pytest import raises
# Runtime import to get asyncutils module # Runtime import to get asyncutils module
import os import os
@ -38,84 +39,75 @@ def run(main):
class CoroutineTests(AsyncTestCase): class CoroutineTests(AsyncTestCase):
def test_init_with_coroutine(self): def test_init_with_coroutine(self):
self.assertTrue(providers.Coroutine(_example)) assert isinstance(providers.Coroutine(_example), providers.Coroutine)
def test_init_with_not_coroutine(self): def test_init_with_not_coroutine(self):
self.assertRaises(errors.Error, providers.Coroutine, lambda: None) with raises(errors.Error):
providers.Coroutine(lambda: None)
def test_init_optional_provides(self): def test_init_optional_provides(self):
provider = providers.Coroutine() provider = providers.Coroutine()
provider.set_provides(_example) provider.set_provides(_example)
self.assertIs(provider.provides, _example) assert provider.provides is _example
self.assertEqual(run(provider(1, 2, 3, 4)), (1, 2, 3, 4)) assert run(provider(1, 2, 3, 4)) == (1, 2, 3, 4)
def test_set_provides_returns_self(self): def test_set_provides_returns_self(self):
provider = providers.Coroutine() provider = providers.Coroutine()
self.assertIs(provider.set_provides(_example), provider) assert provider.set_provides(_example) is provider
def test_call_with_positional_args(self): def test_call_with_positional_args(self):
provider = providers.Coroutine(_example, 1, 2, 3, 4) provider = providers.Coroutine(_example, 1, 2, 3, 4)
self.assertTupleEqual(self._run(provider()), (1, 2, 3, 4)) assert self._run(provider()) == (1, 2, 3, 4)
def test_call_with_keyword_args(self): def test_call_with_keyword_args(self):
provider = providers.Coroutine(_example, provider = providers.Coroutine(_example, arg1=1, arg2=2, arg3=3, arg4=4)
arg1=1, arg2=2, arg3=3, arg4=4) assert self._run(provider()) == (1, 2, 3, 4)
self.assertTupleEqual(self._run(provider()), (1, 2, 3, 4))
def test_call_with_positional_and_keyword_args(self): def test_call_with_positional_and_keyword_args(self):
provider = providers.Coroutine(_example, provider = providers.Coroutine(_example, 1, 2, arg3=3, arg4=4)
1, 2, assert run(provider()) == (1, 2, 3, 4)
arg3=3, arg4=4)
self.assertTupleEqual(run(provider()), (1, 2, 3, 4))
def test_call_with_context_args(self): def test_call_with_context_args(self):
provider = providers.Coroutine(_example, 1, 2) provider = providers.Coroutine(_example, 1, 2)
self.assertTupleEqual(self._run(provider(3, 4)), (1, 2, 3, 4)) assert self._run(provider(3, 4)) == (1, 2, 3, 4)
def test_call_with_context_kwargs(self): def test_call_with_context_kwargs(self):
provider = providers.Coroutine(_example, arg1=1) provider = providers.Coroutine(_example, arg1=1)
self.assertTupleEqual( assert self._run(provider(arg2=2, arg3=3, arg4=4)) == (1, 2, 3, 4)
self._run(provider(arg2=2, arg3=3, arg4=4)),
(1, 2, 3, 4),
)
def test_call_with_context_args_and_kwargs(self): def test_call_with_context_args_and_kwargs(self):
provider = providers.Coroutine(_example, 1) provider = providers.Coroutine(_example, 1)
self.assertTupleEqual( assert self._run(provider(2, arg3=3, arg4=4)) == (1, 2, 3, 4)
self._run(provider(2, arg3=3, arg4=4)),
(1, 2, 3, 4),
)
def test_fluent_interface(self): def test_fluent_interface(self):
provider = providers.Coroutine(_example) \ provider = providers.Coroutine(_example) \
.add_args(1, 2) \ .add_args(1, 2) \
.add_kwargs(arg3=3, arg4=4) .add_kwargs(arg3=3, arg4=4)
assert self._run(provider()) == (1, 2, 3, 4)
self.assertTupleEqual(self._run(provider()), (1, 2, 3, 4))
def test_set_args(self): def test_set_args(self):
provider = providers.Coroutine(_example) \ provider = providers.Coroutine(_example) \
.add_args(1, 2) \ .add_args(1, 2) \
.set_args(3, 4) .set_args(3, 4)
self.assertEqual(provider.args, (3, 4)) assert provider.args == (3, 4)
def test_set_kwargs(self): def test_set_kwargs(self):
provider = providers.Coroutine(_example) \ provider = providers.Coroutine(_example) \
.add_kwargs(init_arg3=3, init_arg4=4) \ .add_kwargs(init_arg3=3, init_arg4=4) \
.set_kwargs(init_arg3=4, init_arg4=5) .set_kwargs(init_arg3=4, init_arg4=5)
self.assertEqual(provider.kwargs, dict(init_arg3=4, init_arg4=5)) assert provider.kwargs == dict(init_arg3=4, init_arg4=5)
def test_clear_args(self): def test_clear_args(self):
provider = providers.Coroutine(_example) \ provider = providers.Coroutine(_example) \
.add_args(1, 2) \ .add_args(1, 2) \
.clear_args() .clear_args()
self.assertEqual(provider.args, tuple()) assert provider.args == tuple()
def test_clear_kwargs(self): def test_clear_kwargs(self):
provider = providers.Coroutine(_example) \ provider = providers.Coroutine(_example) \
.add_kwargs(init_arg3=3, init_arg4=4) \ .add_kwargs(init_arg3=3, init_arg4=4) \
.clear_kwargs() .clear_kwargs()
self.assertEqual(provider.kwargs, dict()) assert provider.kwargs == dict()
def test_call_overridden(self): def test_call_overridden(self):
provider = providers.Coroutine(_example) provider = providers.Coroutine(_example)
@ -123,16 +115,16 @@ class CoroutineTests(AsyncTestCase):
provider.override(providers.Object((4, 3, 2, 1))) provider.override(providers.Object((4, 3, 2, 1)))
provider.override(providers.Object((1, 2, 3, 4))) provider.override(providers.Object((1, 2, 3, 4)))
self.assertTupleEqual(provider(), (1, 2, 3, 4)) assert provider() == (1, 2, 3, 4)
def test_deepcopy(self): def test_deepcopy(self):
provider = providers.Coroutine(_example) provider = providers.Coroutine(_example)
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertIs(provider.provides, provider_copy.provides) assert provider.provides is provider_copy.provides
self.assertIsInstance(provider, providers.Coroutine) assert isinstance(provider, providers.Coroutine)
def test_deepcopy_from_memo(self): def test_deepcopy_from_memo(self):
provider = providers.Coroutine(_example) provider = providers.Coroutine(_example)
@ -141,7 +133,7 @@ class CoroutineTests(AsyncTestCase):
provider_copy = providers.deepcopy( provider_copy = providers.deepcopy(
provider, memo={id(provider): provider_copy_memo}) provider, memo={id(provider): provider_copy_memo})
self.assertIs(provider_copy, provider_copy_memo) assert provider_copy is provider_copy_memo
def test_deepcopy_args(self): def test_deepcopy_args(self):
provider = providers.Coroutine(_example) provider = providers.Coroutine(_example)
@ -154,15 +146,13 @@ class CoroutineTests(AsyncTestCase):
dependent_provider_copy1 = provider_copy.args[0] dependent_provider_copy1 = provider_copy.args[0]
dependent_provider_copy2 = provider_copy.args[1] dependent_provider_copy2 = provider_copy.args[1]
self.assertNotEqual(provider.args, provider_copy.args) assert provider.args != provider_copy.args
self.assertIs(dependent_provider1.provides, assert dependent_provider1.provides is dependent_provider_copy1.provides
dependent_provider_copy1.provides) assert dependent_provider1 is not dependent_provider_copy1
self.assertIsNot(dependent_provider1, dependent_provider_copy1)
self.assertIs(dependent_provider2.provides, assert dependent_provider2.provides is dependent_provider_copy2.provides
dependent_provider_copy2.provides) assert dependent_provider2 is not dependent_provider_copy2
self.assertIsNot(dependent_provider2, dependent_provider_copy2)
def test_deepcopy_kwargs(self): def test_deepcopy_kwargs(self):
provider = providers.Coroutine(_example) provider = providers.Coroutine(_example)
@ -175,15 +165,13 @@ class CoroutineTests(AsyncTestCase):
dependent_provider_copy1 = provider_copy.kwargs["a1"] dependent_provider_copy1 = provider_copy.kwargs["a1"]
dependent_provider_copy2 = provider_copy.kwargs["a2"] dependent_provider_copy2 = provider_copy.kwargs["a2"]
self.assertNotEqual(provider.kwargs, provider_copy.kwargs) assert provider.kwargs != provider_copy.kwargs
self.assertIs(dependent_provider1.provides, assert dependent_provider1.provides is dependent_provider_copy1.provides
dependent_provider_copy1.provides) assert dependent_provider1 is not dependent_provider_copy1
self.assertIsNot(dependent_provider1, dependent_provider_copy1)
self.assertIs(dependent_provider2.provides, assert dependent_provider2.provides is dependent_provider_copy2.provides
dependent_provider_copy2.provides) assert dependent_provider2 is not dependent_provider_copy2
self.assertIsNot(dependent_provider2, dependent_provider_copy2)
def test_deepcopy_overridden(self): def test_deepcopy_overridden(self):
provider = providers.Coroutine(_example) provider = providers.Coroutine(_example)
@ -194,51 +182,48 @@ class CoroutineTests(AsyncTestCase):
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
object_provider_copy = provider_copy.overridden[0] object_provider_copy = provider_copy.overridden[0]
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertIs(provider.provides, provider_copy.provides) assert provider.provides is provider_copy.provides
self.assertIsInstance(provider, providers.Callable) assert isinstance(provider, providers.Callable)
self.assertIsNot(object_provider, object_provider_copy) assert object_provider is not object_provider_copy
self.assertIsInstance(object_provider_copy, providers.Object) assert isinstance(object_provider_copy, providers.Object)
def test_repr(self): def test_repr(self):
provider = providers.Coroutine(_example) provider = providers.Coroutine(_example)
self.assertEqual(repr(provider), assert repr(provider) == (
"<dependency_injector.providers." "<dependency_injector.providers."
"Coroutine({0}) at {1}>".format( "Coroutine({0}) at {1}>".format(repr(_example), hex(id(provider)))
repr(_example), )
hex(id(provider))))
class DelegatedCoroutineTests(unittest.TestCase): class DelegatedCoroutineTests(unittest.TestCase):
def test_inheritance(self): def test_inheritance(self):
self.assertIsInstance(providers.DelegatedCoroutine(_example), assert isinstance(providers.DelegatedCoroutine(_example),
providers.Coroutine) providers.Coroutine)
def test_is_provider(self): def test_is_provider(self):
self.assertTrue( assert providers.is_provider(providers.DelegatedCoroutine(_example)) is True
providers.is_provider(providers.DelegatedCoroutine(_example)))
def test_is_delegated_provider(self): def test_is_delegated_provider(self):
provider = providers.DelegatedCoroutine(_example) provider = providers.DelegatedCoroutine(_example)
self.assertTrue(providers.is_delegated(provider)) assert providers.is_delegated(provider) is True
def test_repr(self): def test_repr(self):
provider = providers.DelegatedCoroutine(_example) provider = providers.DelegatedCoroutine(_example)
self.assertEqual(repr(provider), assert repr(provider) == (
"<dependency_injector.providers." "<dependency_injector.providers."
"DelegatedCoroutine({0}) at {1}>".format( "DelegatedCoroutine({0}) at {1}>".format(repr(_example), hex(id(provider)))
repr(_example), )
hex(id(provider))))
class AbstractCoroutineTests(AsyncTestCase): class AbstractCoroutineTests(AsyncTestCase):
def test_inheritance(self): def test_inheritance(self):
self.assertIsInstance(providers.AbstractCoroutine(_example), assert isinstance(providers.AbstractCoroutine(_example),
providers.Coroutine) providers.Coroutine)
def test_call_overridden_by_coroutine(self): def test_call_overridden_by_coroutine(self):
@ -252,7 +237,7 @@ class AbstractCoroutineTests(AsyncTestCase):
provider = providers.AbstractCoroutine(_abstract_example) provider = providers.AbstractCoroutine(_abstract_example)
provider.override(providers.Coroutine(_example)) provider.override(providers.Coroutine(_example))
self.assertTrue(self._run(provider(1, 2, 3, 4)), (1, 2, 3, 4)) assert self._run(provider(1, 2, 3, 4)) == (1, 2, 3, 4)
def test_call_overridden_by_delegated_coroutine(self): def test_call_overridden_by_delegated_coroutine(self):
with warnings.catch_warnings(): with warnings.catch_warnings():
@ -265,34 +250,33 @@ class AbstractCoroutineTests(AsyncTestCase):
provider = providers.AbstractCoroutine(_abstract_example) provider = providers.AbstractCoroutine(_abstract_example)
provider.override(providers.DelegatedCoroutine(_example)) provider.override(providers.DelegatedCoroutine(_example))
self.assertTrue(self._run(provider(1, 2, 3, 4)), (1, 2, 3, 4)) assert self._run(provider(1, 2, 3, 4)) == (1, 2, 3, 4)
def test_call_not_overridden(self): def test_call_not_overridden(self):
provider = providers.AbstractCoroutine(_example) provider = providers.AbstractCoroutine(_example)
with self.assertRaises(errors.Error): with raises(errors.Error):
provider(1, 2, 3, 4) provider(1, 2, 3, 4)
def test_override_by_not_coroutine(self): def test_override_by_not_coroutine(self):
provider = providers.AbstractCoroutine(_example) provider = providers.AbstractCoroutine(_example)
with self.assertRaises(errors.Error): with raises(errors.Error):
provider.override(providers.Factory(object)) provider.override(providers.Factory(object))
def test_provide_not_implemented(self): def test_provide_not_implemented(self):
provider = providers.AbstractCoroutine(_example) provider = providers.AbstractCoroutine(_example)
with self.assertRaises(NotImplementedError): with raises(NotImplementedError):
provider._provide((1, 2, 3, 4), dict()) provider._provide((1, 2, 3, 4), dict())
def test_repr(self): def test_repr(self):
provider = providers.AbstractCoroutine(_example) provider = providers.AbstractCoroutine(_example)
self.assertEqual(repr(provider), assert repr(provider) == (
"<dependency_injector.providers." "<dependency_injector.providers."
"AbstractCoroutine({0}) at {1}>".format( "AbstractCoroutine({0}) at {1}>".format(repr(_example), hex(id(provider)))
repr(_example), )
hex(id(provider))))
class CoroutineDelegateTests(unittest.TestCase): class CoroutineDelegateTests(unittest.TestCase):
@ -302,9 +286,8 @@ class CoroutineDelegateTests(unittest.TestCase):
self.delegate = providers.CoroutineDelegate(self.delegated) self.delegate = providers.CoroutineDelegate(self.delegated)
def test_is_delegate(self): def test_is_delegate(self):
self.assertIsInstance(self.delegate, providers.Delegate) assert isinstance(self.delegate, providers.Delegate)
def test_init_with_not_callable(self): def test_init_with_not_callable(self):
self.assertRaises(errors.Error, with raises(errors.Error):
providers.CoroutineDelegate, providers.CoroutineDelegate(providers.Object(object()))
providers.Object(object()))

View File

@ -10,11 +10,11 @@ from dependency_injector import providers
class DictTests(unittest.TestCase): class DictTests(unittest.TestCase):
def test_is_provider(self): def test_is_provider(self):
self.assertTrue(providers.is_provider(providers.Dict())) assert providers.is_provider(providers.Dict()) is True
def test_provided_instance_provider(self): def test_provided_instance_provider(self):
provider = providers.Dict() provider = providers.Dict()
self.assertIsInstance(provider.provided, providers.ProvidedInstance) assert isinstance(provider.provided, providers.ProvidedInstance)
def test_init_with_non_string_keys(self): def test_init_with_non_string_keys(self):
a1 = object() a1 = object()
@ -24,10 +24,10 @@ class DictTests(unittest.TestCase):
dict1 = provider() dict1 = provider()
dict2 = provider() dict2 = provider()
self.assertEqual(dict1, {a1: "i1", a2: "i2"}) assert dict1 == {a1: "i1", a2: "i2"}
self.assertEqual(dict2, {a1: "i1", a2: "i2"}) assert dict2 == {a1: "i1", a2: "i2"}
self.assertIsNot(dict1, dict2) assert dict1 is not dict2
def test_init_with_string_and_non_string_keys(self): def test_init_with_string_and_non_string_keys(self):
a1 = object() a1 = object()
@ -36,10 +36,10 @@ class DictTests(unittest.TestCase):
dict1 = provider() dict1 = provider()
dict2 = provider() dict2 = provider()
self.assertEqual(dict1, {a1: "i1", "a2": "i2"}) assert dict1 == {a1: "i1", "a2": "i2"}
self.assertEqual(dict2, {a1: "i1", "a2": "i2"}) assert dict2 == {a1: "i1", "a2": "i2"}
self.assertIsNot(dict1, dict2) assert dict1 is not dict2
def test_call_with_init_keyword_args(self): def test_call_with_init_keyword_args(self):
provider = providers.Dict(a1="i1", a2="i2") provider = providers.Dict(a1="i1", a2="i2")
@ -47,35 +47,32 @@ class DictTests(unittest.TestCase):
dict1 = provider() dict1 = provider()
dict2 = provider() dict2 = provider()
self.assertEqual(dict1, {"a1": "i1", "a2": "i2"}) assert dict1 == {"a1": "i1", "a2": "i2"}
self.assertEqual(dict2, {"a1": "i1", "a2": "i2"}) assert dict2 == {"a1": "i1", "a2": "i2"}
self.assertIsNot(dict1, dict2) assert dict1 is not dict2
def test_call_with_context_keyword_args(self): def test_call_with_context_keyword_args(self):
provider = providers.Dict(a1="i1", a2="i2") provider = providers.Dict(a1="i1", a2="i2")
self.assertEqual( assert provider(a3="i3", a4="i4") == {"a1": "i1", "a2": "i2", "a3": "i3", "a4": "i4"}
provider(a3="i3", a4="i4"),
{"a1": "i1", "a2": "i2", "a3": "i3", "a4": "i4"},
)
def test_call_with_provider(self): def test_call_with_provider(self):
provider = providers.Dict( provider = providers.Dict(
a1=providers.Factory(str, "i1"), a1=providers.Factory(str, "i1"),
a2=providers.Factory(str, "i2"), a2=providers.Factory(str, "i2"),
) )
self.assertEqual(provider(), {"a1": "i1", "a2": "i2"}) assert provider() == {"a1": "i1", "a2": "i2"}
def test_fluent_interface(self): def test_fluent_interface(self):
provider = providers.Dict() \ provider = providers.Dict() \
.add_kwargs(a1="i1", a2="i2") .add_kwargs(a1="i1", a2="i2")
self.assertEqual(provider(), {"a1": "i1", "a2": "i2"}) assert provider() == {"a1": "i1", "a2": "i2"}
def test_add_kwargs(self): def test_add_kwargs(self):
provider = providers.Dict() \ provider = providers.Dict() \
.add_kwargs(a1="i1") \ .add_kwargs(a1="i1") \
.add_kwargs(a2="i2") .add_kwargs(a2="i2")
self.assertEqual(provider.kwargs, {"a1": "i1", "a2": "i2"}) assert provider.kwargs == {"a1": "i1", "a2": "i2"}
def test_add_kwargs_non_string_keys(self): def test_add_kwargs_non_string_keys(self):
a1 = object() a1 = object()
@ -83,20 +80,20 @@ class DictTests(unittest.TestCase):
provider = providers.Dict() \ provider = providers.Dict() \
.add_kwargs({a1: "i1"}) \ .add_kwargs({a1: "i1"}) \
.add_kwargs({a2: "i2"}) .add_kwargs({a2: "i2"})
self.assertEqual(provider.kwargs, {a1: "i1", a2: "i2"}) assert provider.kwargs == {a1: "i1", a2: "i2"}
def test_add_kwargs_string_and_non_string_keys(self): def test_add_kwargs_string_and_non_string_keys(self):
a2 = object() a2 = object()
provider = providers.Dict() \ provider = providers.Dict() \
.add_kwargs(a1="i1") \ .add_kwargs(a1="i1") \
.add_kwargs({a2: "i2"}) .add_kwargs({a2: "i2"})
self.assertEqual(provider.kwargs, {"a1": "i1", a2: "i2"}) assert provider.kwargs == {"a1": "i1", a2: "i2"}
def test_set_kwargs(self): def test_set_kwargs(self):
provider = providers.Dict() \ provider = providers.Dict() \
.add_kwargs(a1="i1", a2="i2") \ .add_kwargs(a1="i1", a2="i2") \
.set_kwargs(a3="i3", a4="i4") .set_kwargs(a3="i3", a4="i4")
self.assertEqual(provider.kwargs, {"a3": "i3", "a4": "i4"}) assert provider.kwargs == {"a3": "i3", "a4": "i4"}
def test_set_kwargs_non_string_keys(self): def test_set_kwargs_non_string_keys(self):
a3 = object() a3 = object()
@ -104,20 +101,20 @@ class DictTests(unittest.TestCase):
provider = providers.Dict() \ provider = providers.Dict() \
.add_kwargs(a1="i1", a2="i2") \ .add_kwargs(a1="i1", a2="i2") \
.set_kwargs({a3: "i3", a4: "i4"}) .set_kwargs({a3: "i3", a4: "i4"})
self.assertEqual(provider.kwargs, {a3: "i3", a4: "i4"}) assert provider.kwargs == {a3: "i3", a4: "i4"}
def test_set_kwargs_string_and_non_string_keys(self): def test_set_kwargs_string_and_non_string_keys(self):
a3 = object() a3 = object()
provider = providers.Dict() \ provider = providers.Dict() \
.add_kwargs(a1="i1", a2="i2") \ .add_kwargs(a1="i1", a2="i2") \
.set_kwargs({a3: "i3"}, a4="i4") .set_kwargs({a3: "i3"}, a4="i4")
self.assertEqual(provider.kwargs, {a3: "i3", "a4": "i4"}) assert provider.kwargs == {a3: "i3", "a4": "i4"}
def test_clear_kwargs(self): def test_clear_kwargs(self):
provider = providers.Dict() \ provider = providers.Dict() \
.add_kwargs(a1="i1", a2="i2") \ .add_kwargs(a1="i1", a2="i2") \
.clear_kwargs() .clear_kwargs()
self.assertEqual(provider.kwargs, {}) assert provider.kwargs == {}
def test_call_overridden(self): def test_call_overridden(self):
provider = providers.Dict(a1="i1", a2="i2") provider = providers.Dict(a1="i1", a2="i2")
@ -130,18 +127,18 @@ class DictTests(unittest.TestCase):
instance1 = provider() instance1 = provider()
instance2 = provider() instance2 = provider()
self.assertIsNot(instance1, instance2) assert instance1 is not instance2
self.assertEqual(instance1, {"a3": "i3", "a4": "i4"}) assert instance1 == {"a3": "i3", "a4": "i4"}
self.assertEqual(instance2, {"a3": "i3", "a4": "i4"}) assert instance2 == {"a3": "i3", "a4": "i4"}
def test_deepcopy(self): def test_deepcopy(self):
provider = providers.Dict(a1="i1", a2="i2") provider = providers.Dict(a1="i1", a2="i2")
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertEqual(provider.kwargs, provider_copy.kwargs) assert provider.kwargs == provider_copy.kwargs
self.assertIsInstance(provider, providers.Dict) assert isinstance(provider, providers.Dict)
def test_deepcopy_from_memo(self): def test_deepcopy_from_memo(self):
provider = providers.Dict(a1="i1", a2="i2") provider = providers.Dict(a1="i1", a2="i2")
@ -152,7 +149,7 @@ class DictTests(unittest.TestCase):
memo={id(provider): provider_copy_memo}, memo={id(provider): provider_copy_memo},
) )
self.assertIs(provider_copy, provider_copy_memo) assert provider_copy is provider_copy_memo
def test_deepcopy_kwargs(self): def test_deepcopy_kwargs(self):
provider = providers.Dict() provider = providers.Dict()
@ -165,13 +162,13 @@ class DictTests(unittest.TestCase):
dependent_provider_copy1 = provider_copy.kwargs["d1"] dependent_provider_copy1 = provider_copy.kwargs["d1"]
dependent_provider_copy2 = provider_copy.kwargs["d2"] dependent_provider_copy2 = provider_copy.kwargs["d2"]
self.assertNotEqual(provider.kwargs, provider_copy.kwargs) assert provider.kwargs != provider_copy.kwargs
self.assertIs(dependent_provider1.cls, dependent_provider_copy1.cls) assert dependent_provider1.cls is dependent_provider_copy1.cls
self.assertIsNot(dependent_provider1, dependent_provider_copy1) assert dependent_provider1 is not dependent_provider_copy1
self.assertIs(dependent_provider2.cls, dependent_provider_copy2.cls) assert dependent_provider2.cls is dependent_provider_copy2.cls
self.assertIsNot(dependent_provider2, dependent_provider_copy2) assert dependent_provider2 is not dependent_provider_copy2
def test_deepcopy_kwargs_non_string_keys(self): def test_deepcopy_kwargs_non_string_keys(self):
a1 = object() a1 = object()
@ -186,13 +183,13 @@ class DictTests(unittest.TestCase):
dependent_provider_copy1 = provider_copy.kwargs[a1] dependent_provider_copy1 = provider_copy.kwargs[a1]
dependent_provider_copy2 = provider_copy.kwargs[a2] dependent_provider_copy2 = provider_copy.kwargs[a2]
self.assertNotEqual(provider.kwargs, provider_copy.kwargs) assert provider.kwargs != provider_copy.kwargs
self.assertIs(dependent_provider1.cls, dependent_provider_copy1.cls) assert dependent_provider1.cls is dependent_provider_copy1.cls
self.assertIsNot(dependent_provider1, dependent_provider_copy1) assert dependent_provider1 is not dependent_provider_copy1
self.assertIs(dependent_provider2.cls, dependent_provider_copy2.cls) assert dependent_provider2.cls is dependent_provider_copy2.cls
self.assertIsNot(dependent_provider2, dependent_provider_copy2) assert dependent_provider2 is not dependent_provider_copy2
def test_deepcopy_overridden(self): def test_deepcopy_overridden(self):
provider = providers.Dict() provider = providers.Dict()
@ -203,12 +200,12 @@ class DictTests(unittest.TestCase):
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
object_provider_copy = provider_copy.overridden[0] object_provider_copy = provider_copy.overridden[0]
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertEqual(provider.kwargs, provider_copy.kwargs) assert provider.kwargs == provider_copy.kwargs
self.assertIsInstance(provider, providers.Dict) assert isinstance(provider, providers.Dict)
self.assertIsNot(object_provider, object_provider_copy) assert object_provider is not object_provider_copy
self.assertIsInstance(object_provider_copy, providers.Object) assert isinstance(object_provider_copy, providers.Object)
def test_deepcopy_with_sys_streams(self): def test_deepcopy_with_sys_streams(self):
provider = providers.Dict() provider = providers.Dict()
@ -216,16 +213,15 @@ class DictTests(unittest.TestCase):
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertIsInstance(provider_copy, providers.Dict) assert isinstance(provider_copy, providers.Dict)
self.assertIs(provider.kwargs["stdin"], sys.stdin) assert provider.kwargs["stdin"] is sys.stdin
self.assertIs(provider.kwargs["stdout"], sys.stdout) assert provider.kwargs["stdout"] is sys.stdout
self.assertIs(provider.kwargs["stderr"], sys.stderr) assert provider.kwargs["stderr"] is sys.stderr
def test_repr(self): def test_repr(self):
provider = providers.Dict(a1=1, a2=2) provider = providers.Dict(a1=1, a2=2)
self.assertEqual(repr(provider), assert repr(provider) == (
"<dependency_injector.providers." "<dependency_injector.providers."
"Dict({0}) at {1}>".format( "Dict({0}) at {1}>".format(repr(provider.kwargs), hex(id(provider)))
repr(provider.kwargs), )
hex(id(provider))))

View File

@ -5,10 +5,10 @@ import sys
import unittest import unittest
from dependency_injector import ( from dependency_injector import (
containers,
providers, providers,
errors, errors,
) )
from pytest import raises
class Example(object): class Example(object):
@ -27,23 +27,21 @@ class Example(object):
class FactoryTests(unittest.TestCase): class FactoryTests(unittest.TestCase):
def test_is_provider(self): def test_is_provider(self):
self.assertTrue(providers.is_provider(providers.Factory(Example))) assert providers.is_provider(providers.Factory(Example)) is True
def test_init_with_callable(self):
self.assertTrue(providers.Factory(credits))
def test_init_with_not_callable(self): def test_init_with_not_callable(self):
self.assertRaises(errors.Error, providers.Factory, 123) with raises(errors.Error):
providers.Factory(123)
def test_init_optional_provides(self): def test_init_optional_provides(self):
provider = providers.Factory() provider = providers.Factory()
provider.set_provides(object) provider.set_provides(object)
self.assertIs(provider.provides, object) assert provider.provides is object
self.assertIsInstance(provider(), object) assert isinstance(provider(), object)
def test_set_provides_returns_self(self): def test_set_provides_returns_self(self):
provider = providers.Factory() provider = providers.Factory()
self.assertIs(provider.set_provides(object), provider) assert provider.set_provides(object) is provider
def test_init_with_valid_provided_type(self): def test_init_with_valid_provided_type(self):
class ExampleProvider(providers.Factory): class ExampleProvider(providers.Factory):
@ -51,7 +49,7 @@ class FactoryTests(unittest.TestCase):
example_provider = ExampleProvider(Example, 1, 2) example_provider = ExampleProvider(Example, 1, 2)
self.assertIsInstance(example_provider(), Example) assert isinstance(example_provider(), Example)
def test_init_with_valid_provided_subtype(self): def test_init_with_valid_provided_subtype(self):
class ExampleProvider(providers.Factory): class ExampleProvider(providers.Factory):
@ -62,18 +60,18 @@ class FactoryTests(unittest.TestCase):
example_provider = ExampleProvider(NewExampe, 1, 2) example_provider = ExampleProvider(NewExampe, 1, 2)
self.assertIsInstance(example_provider(), NewExampe) assert isinstance(example_provider(), NewExampe)
def test_init_with_invalid_provided_type(self): def test_init_with_invalid_provided_type(self):
class ExampleProvider(providers.Factory): class ExampleProvider(providers.Factory):
provided_type = Example provided_type = Example
with self.assertRaises(errors.Error): with raises(errors.Error):
ExampleProvider(list) ExampleProvider(list)
def test_provided_instance_provider(self): def test_provided_instance_provider(self):
provider = providers.Factory(Example) provider = providers.Factory(Example)
self.assertIsInstance(provider.provided, providers.ProvidedInstance) assert isinstance(provider.provided, providers.ProvidedInstance)
def test_call(self): def test_call(self):
provider = providers.Factory(Example) provider = providers.Factory(Example)
@ -81,9 +79,9 @@ class FactoryTests(unittest.TestCase):
instance1 = provider() instance1 = provider()
instance2 = provider() instance2 = provider()
self.assertIsNot(instance1, instance2) assert instance1 is not instance2
self.assertIsInstance(instance1, Example) assert isinstance(instance1, Example)
self.assertIsInstance(instance2, Example) assert isinstance(instance2, Example)
def test_call_with_init_positional_args(self): def test_call_with_init_positional_args(self):
provider = providers.Factory(Example, "i1", "i2") provider = providers.Factory(Example, "i1", "i2")
@ -91,15 +89,15 @@ class FactoryTests(unittest.TestCase):
instance1 = provider() instance1 = provider()
instance2 = provider() instance2 = provider()
self.assertEqual(instance1.init_arg1, "i1") assert instance1.init_arg1 == "i1"
self.assertEqual(instance1.init_arg2, "i2") assert instance1.init_arg2 == "i2"
self.assertEqual(instance2.init_arg1, "i1") assert instance2.init_arg1 == "i1"
self.assertEqual(instance2.init_arg2, "i2") assert instance2.init_arg2 == "i2"
self.assertIsNot(instance1, instance2) assert instance1 is not instance2
self.assertIsInstance(instance1, Example) assert isinstance(instance1, Example)
self.assertIsInstance(instance2, Example) assert isinstance(instance2, Example)
def test_call_with_init_keyword_args(self): def test_call_with_init_keyword_args(self):
provider = providers.Factory(Example, init_arg1="i1", init_arg2="i2") provider = providers.Factory(Example, init_arg1="i1", init_arg2="i2")
@ -107,15 +105,15 @@ class FactoryTests(unittest.TestCase):
instance1 = provider() instance1 = provider()
instance2 = provider() instance2 = provider()
self.assertEqual(instance1.init_arg1, "i1") assert instance1.init_arg1 == "i1"
self.assertEqual(instance1.init_arg2, "i2") assert instance1.init_arg2 == "i2"
self.assertEqual(instance2.init_arg1, "i1") assert instance2.init_arg1 == "i1"
self.assertEqual(instance2.init_arg2, "i2") assert instance2.init_arg2 == "i2"
self.assertIsNot(instance1, instance2) assert instance1 is not instance2
self.assertIsInstance(instance1, Example) assert isinstance(instance1, Example)
self.assertIsInstance(instance2, Example) assert isinstance(instance2, Example)
def test_call_with_init_positional_and_keyword_args(self): def test_call_with_init_positional_and_keyword_args(self):
provider = providers.Factory(Example, "i1", init_arg2="i2") provider = providers.Factory(Example, "i1", init_arg2="i2")
@ -123,15 +121,15 @@ class FactoryTests(unittest.TestCase):
instance1 = provider() instance1 = provider()
instance2 = provider() instance2 = provider()
self.assertEqual(instance1.init_arg1, "i1") assert instance1.init_arg1 == "i1"
self.assertEqual(instance1.init_arg2, "i2") assert instance1.init_arg2 == "i2"
self.assertEqual(instance2.init_arg1, "i1") assert instance2.init_arg1 == "i1"
self.assertEqual(instance2.init_arg2, "i2") assert instance2.init_arg2 == "i2"
self.assertIsNot(instance1, instance2) assert instance1 is not instance2
self.assertIsInstance(instance1, Example) assert isinstance(instance1, Example)
self.assertIsInstance(instance2, Example) assert isinstance(instance2, Example)
def test_call_with_attributes(self): def test_call_with_attributes(self):
provider = providers.Factory(Example) provider = providers.Factory(Example)
@ -140,46 +138,46 @@ class FactoryTests(unittest.TestCase):
instance1 = provider() instance1 = provider()
instance2 = provider() instance2 = provider()
self.assertEqual(instance1.attribute1, "a1") assert instance1.attribute1 == "a1"
self.assertEqual(instance1.attribute2, "a2") assert instance1.attribute2 == "a2"
self.assertEqual(instance2.attribute1, "a1") assert instance2.attribute1 == "a1"
self.assertEqual(instance2.attribute2, "a2") assert instance2.attribute2 == "a2"
self.assertIsNot(instance1, instance2) assert instance1 is not instance2
self.assertIsInstance(instance1, Example) assert isinstance(instance1, Example)
self.assertIsInstance(instance2, Example) assert isinstance(instance2, Example)
def test_call_with_context_args(self): def test_call_with_context_args(self):
provider = providers.Factory(Example, 11, 22) provider = providers.Factory(Example, 11, 22)
instance = provider(33, 44) instance = provider(33, 44)
self.assertEqual(instance.init_arg1, 11) assert instance.init_arg1 == 11
self.assertEqual(instance.init_arg2, 22) assert instance.init_arg2 == 22
self.assertEqual(instance.init_arg3, 33) assert instance.init_arg3 == 33
self.assertEqual(instance.init_arg4, 44) assert instance.init_arg4 == 44
def test_call_with_context_kwargs(self): def test_call_with_context_kwargs(self):
provider = providers.Factory(Example, init_arg1=1) provider = providers.Factory(Example, init_arg1=1)
instance1 = provider(init_arg2=22) instance1 = provider(init_arg2=22)
self.assertEqual(instance1.init_arg1, 1) assert instance1.init_arg1 == 1
self.assertEqual(instance1.init_arg2, 22) assert instance1.init_arg2 == 22
instance2 = provider(init_arg1=11, init_arg2=22) instance2 = provider(init_arg1=11, init_arg2=22)
self.assertEqual(instance2.init_arg1, 11) assert instance2.init_arg1 == 11
self.assertEqual(instance2.init_arg2, 22) assert instance2.init_arg2 == 22
def test_call_with_context_args_and_kwargs(self): def test_call_with_context_args_and_kwargs(self):
provider = providers.Factory(Example, 11) provider = providers.Factory(Example, 11)
instance = provider(22, init_arg3=33, init_arg4=44) instance = provider(22, init_arg3=33, init_arg4=44)
self.assertEqual(instance.init_arg1, 11) assert instance.init_arg1 == 11
self.assertEqual(instance.init_arg2, 22) assert instance.init_arg2 == 22
self.assertEqual(instance.init_arg3, 33) assert instance.init_arg3 == 33
self.assertEqual(instance.init_arg4, 44) assert instance.init_arg4 == 44
def test_call_with_deep_context_kwargs(self): def test_call_with_deep_context_kwargs(self):
"""`Factory` providers deep init injections example.""" """`Factory` providers deep init injections example."""
@ -216,9 +214,9 @@ class FactoryTests(unittest.TestCase):
algorithm_2 = algorithm_factory(task__loss__regularizer__alpha=0.7) algorithm_2 = algorithm_factory(task__loss__regularizer__alpha=0.7)
algorithm_3 = algorithm_factory(task__loss__regularizer=Regularizer(alpha=0.8)) algorithm_3 = algorithm_factory(task__loss__regularizer=Regularizer(alpha=0.8))
self.assertEqual(algorithm_1.task.loss.regularizer.alpha, 0.5) assert algorithm_1.task.loss.regularizer.alpha == 0.5
self.assertEqual(algorithm_2.task.loss.regularizer.alpha, 0.7) assert algorithm_2.task.loss.regularizer.alpha == 0.7
self.assertEqual(algorithm_3.task.loss.regularizer.alpha, 0.8) assert algorithm_3.task.loss.regularizer.alpha == 0.8
def test_fluent_interface(self): def test_fluent_interface(self):
provider = providers.Factory(Example) \ provider = providers.Factory(Example) \
@ -228,48 +226,48 @@ class FactoryTests(unittest.TestCase):
instance = provider() instance = provider()
self.assertEqual(instance.init_arg1, 1) assert instance.init_arg1 == 1
self.assertEqual(instance.init_arg2, 2) assert instance.init_arg2 == 2
self.assertEqual(instance.init_arg3, 3) assert instance.init_arg3 == 3
self.assertEqual(instance.init_arg4, 4) assert instance.init_arg4 == 4
self.assertEqual(instance.attribute1, 5) assert instance.attribute1 == 5
self.assertEqual(instance.attribute2, 6) assert instance.attribute2 == 6
def test_set_args(self): def test_set_args(self):
provider = providers.Factory(Example) \ provider = providers.Factory(Example) \
.add_args(1, 2) \ .add_args(1, 2) \
.set_args(3, 4) .set_args(3, 4)
self.assertEqual(provider.args, (3, 4)) assert provider.args == (3, 4)
def test_set_kwargs(self): def test_set_kwargs(self):
provider = providers.Factory(Example) \ provider = providers.Factory(Example) \
.add_kwargs(init_arg3=3, init_arg4=4) \ .add_kwargs(init_arg3=3, init_arg4=4) \
.set_kwargs(init_arg3=4, init_arg4=5) .set_kwargs(init_arg3=4, init_arg4=5)
self.assertEqual(provider.kwargs, dict(init_arg3=4, init_arg4=5)) assert provider.kwargs == dict(init_arg3=4, init_arg4=5)
def test_set_attributes(self): def test_set_attributes(self):
provider = providers.Factory(Example) \ provider = providers.Factory(Example) \
.add_attributes(attribute1=5, attribute2=6) \ .add_attributes(attribute1=5, attribute2=6) \
.set_attributes(attribute1=6, attribute2=7) .set_attributes(attribute1=6, attribute2=7)
self.assertEqual(provider.attributes, dict(attribute1=6, attribute2=7)) assert provider.attributes == dict(attribute1=6, attribute2=7)
def test_clear_args(self): def test_clear_args(self):
provider = providers.Factory(Example) \ provider = providers.Factory(Example) \
.add_args(1, 2) \ .add_args(1, 2) \
.clear_args() .clear_args()
self.assertEqual(provider.args, tuple()) assert provider.args == tuple()
def test_clear_kwargs(self): def test_clear_kwargs(self):
provider = providers.Factory(Example) \ provider = providers.Factory(Example) \
.add_kwargs(init_arg3=3, init_arg4=4) \ .add_kwargs(init_arg3=3, init_arg4=4) \
.clear_kwargs() .clear_kwargs()
self.assertEqual(provider.kwargs, dict()) assert provider.kwargs == dict()
def test_clear_attributes(self): def test_clear_attributes(self):
provider = providers.Factory(Example) \ provider = providers.Factory(Example) \
.add_attributes(attribute1=5, attribute2=6) \ .add_attributes(attribute1=5, attribute2=6) \
.clear_attributes() .clear_attributes()
self.assertEqual(provider.attributes, dict()) assert provider.attributes == dict()
def test_call_overridden(self): def test_call_overridden(self):
provider = providers.Factory(Example) provider = providers.Factory(Example)
@ -282,27 +280,26 @@ class FactoryTests(unittest.TestCase):
instance1 = provider() instance1 = provider()
instance2 = provider() instance2 = provider()
self.assertIsNot(instance1, instance2) assert instance1 is not instance2
self.assertIsInstance(instance1, list) assert isinstance(instance1, list)
self.assertIsInstance(instance2, list) assert isinstance(instance2, list)
def test_deepcopy(self): def test_deepcopy(self):
provider = providers.Factory(Example) provider = providers.Factory(Example)
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertIs(provider.cls, provider_copy.cls) assert provider.cls is provider_copy.cls
self.assertIsInstance(provider, providers.Factory) assert isinstance(provider, providers.Factory)
def test_deepcopy_from_memo(self): def test_deepcopy_from_memo(self):
provider = providers.Factory(Example) provider = providers.Factory(Example)
provider_copy_memo = providers.Factory(Example) provider_copy_memo = providers.Factory(Example)
provider_copy = providers.deepcopy( provider_copy = providers.deepcopy(provider, memo={id(provider): provider_copy_memo})
provider, memo={id(provider): provider_copy_memo})
self.assertIs(provider_copy, provider_copy_memo) assert provider_copy is provider_copy_memo
def test_deepcopy_args(self): def test_deepcopy_args(self):
provider = providers.Factory(Example) provider = providers.Factory(Example)
@ -315,13 +312,13 @@ class FactoryTests(unittest.TestCase):
dependent_provider_copy1 = provider_copy.args[0] dependent_provider_copy1 = provider_copy.args[0]
dependent_provider_copy2 = provider_copy.args[1] dependent_provider_copy2 = provider_copy.args[1]
self.assertNotEqual(provider.args, provider_copy.args) assert provider.args != provider_copy.args
self.assertIs(dependent_provider1.cls, dependent_provider_copy1.cls) assert dependent_provider1.cls is dependent_provider_copy1.cls
self.assertIsNot(dependent_provider1, dependent_provider_copy1) assert dependent_provider1 is not dependent_provider_copy1
self.assertIs(dependent_provider2.cls, dependent_provider_copy2.cls) assert dependent_provider2.cls is dependent_provider_copy2.cls
self.assertIsNot(dependent_provider2, dependent_provider_copy2) assert dependent_provider2 is not dependent_provider_copy2
def test_deepcopy_kwargs(self): def test_deepcopy_kwargs(self):
provider = providers.Factory(Example) provider = providers.Factory(Example)
@ -334,13 +331,13 @@ class FactoryTests(unittest.TestCase):
dependent_provider_copy1 = provider_copy.kwargs["a1"] dependent_provider_copy1 = provider_copy.kwargs["a1"]
dependent_provider_copy2 = provider_copy.kwargs["a2"] dependent_provider_copy2 = provider_copy.kwargs["a2"]
self.assertNotEqual(provider.kwargs, provider_copy.kwargs) assert provider.kwargs != provider_copy.kwargs
self.assertIs(dependent_provider1.cls, dependent_provider_copy1.cls) assert dependent_provider1.cls is dependent_provider_copy1.cls
self.assertIsNot(dependent_provider1, dependent_provider_copy1) assert dependent_provider1 is not dependent_provider_copy1
self.assertIs(dependent_provider2.cls, dependent_provider_copy2.cls) assert dependent_provider2.cls is dependent_provider_copy2.cls
self.assertIsNot(dependent_provider2, dependent_provider_copy2) assert dependent_provider2 is not dependent_provider_copy2
def test_deepcopy_attributes(self): def test_deepcopy_attributes(self):
provider = providers.Factory(Example) provider = providers.Factory(Example)
@ -353,13 +350,13 @@ class FactoryTests(unittest.TestCase):
dependent_provider_copy1 = provider_copy.attributes["a1"] dependent_provider_copy1 = provider_copy.attributes["a1"]
dependent_provider_copy2 = provider_copy.attributes["a2"] dependent_provider_copy2 = provider_copy.attributes["a2"]
self.assertNotEqual(provider.attributes, provider_copy.attributes) assert provider.attributes != provider_copy.attributes
self.assertIs(dependent_provider1.cls, dependent_provider_copy1.cls) assert dependent_provider1.cls is dependent_provider_copy1.cls
self.assertIsNot(dependent_provider1, dependent_provider_copy1) assert dependent_provider1 is not dependent_provider_copy1
self.assertIs(dependent_provider2.cls, dependent_provider_copy2.cls) assert dependent_provider2.cls is dependent_provider_copy2.cls
self.assertIsNot(dependent_provider2, dependent_provider_copy2) assert dependent_provider2 is not dependent_provider_copy2
def test_deepcopy_overridden(self): def test_deepcopy_overridden(self):
provider = providers.Factory(Example) provider = providers.Factory(Example)
@ -370,12 +367,12 @@ class FactoryTests(unittest.TestCase):
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
object_provider_copy = provider_copy.overridden[0] object_provider_copy = provider_copy.overridden[0]
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertIs(provider.cls, provider_copy.cls) assert provider.cls is provider_copy.cls
self.assertIsInstance(provider, providers.Factory) assert isinstance(provider, providers.Factory)
self.assertIsNot(object_provider, object_provider_copy) assert object_provider is not object_provider_copy
self.assertIsInstance(object_provider_copy, providers.Object) assert isinstance(object_provider_copy, providers.Object)
def test_deepcopy_with_sys_streams(self): def test_deepcopy_with_sys_streams(self):
provider = providers.Factory(Example) provider = providers.Factory(Example)
@ -385,90 +382,79 @@ class FactoryTests(unittest.TestCase):
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertIsInstance(provider_copy, providers.Factory) assert isinstance(provider_copy, providers.Factory)
self.assertIs(provider.args[0], sys.stdin) assert provider.args[0] is sys.stdin
self.assertIs(provider.kwargs["a2"], sys.stdout) assert provider.kwargs["a2"] is sys.stdout
self.assertIs(provider.attributes["a3"], sys.stderr) assert provider.attributes["a3"] is sys.stderr
def test_repr(self): def test_repr(self):
provider = providers.Factory(Example) provider = providers.Factory(Example)
assert repr(provider) == (
self.assertEqual(repr(provider), "<dependency_injector.providers."
"<dependency_injector.providers." "Factory({0}) at {1}>".format(repr(Example), hex(id(provider)))
"Factory({0}) at {1}>".format( )
repr(Example),
hex(id(provider))))
class DelegatedFactoryTests(unittest.TestCase): class DelegatedFactoryTests(unittest.TestCase):
def test_inheritance(self): def test_inheritance(self):
self.assertIsInstance(providers.DelegatedFactory(object), assert isinstance(providers.DelegatedFactory(object),
providers.Factory) providers.Factory)
def test_is_provider(self): def test_is_provider(self):
self.assertTrue( assert providers.is_provider(providers.DelegatedFactory(object)) is True
providers.is_provider(providers.DelegatedFactory(object)))
def test_is_delegated_provider(self): def test_is_delegated_provider(self):
self.assertTrue( assert providers.is_delegated(providers.DelegatedFactory(object)) is True
providers.is_delegated(providers.DelegatedFactory(object)))
def test_repr(self): def test_repr(self):
provider = providers.DelegatedFactory(Example) provider = providers.DelegatedFactory(Example)
assert repr(provider) == (
self.assertEqual(repr(provider), "<dependency_injector.providers."
"<dependency_injector.providers." "DelegatedFactory({0}) at {1}>".format(repr(Example), hex(id(provider)))
"DelegatedFactory({0}) at {1}>".format( )
repr(Example),
hex(id(provider))))
class AbstractFactoryTests(unittest.TestCase): class AbstractFactoryTests(unittest.TestCase):
def test_inheritance(self): def test_inheritance(self):
self.assertIsInstance(providers.AbstractFactory(Example), assert isinstance(providers.AbstractFactory(Example),
providers.Factory) providers.Factory)
def test_call_overridden_by_factory(self): def test_call_overridden_by_factory(self):
provider = providers.AbstractFactory(object) provider = providers.AbstractFactory(object)
provider.override(providers.Factory(Example)) provider.override(providers.Factory(Example))
self.assertIsInstance(provider(), Example) assert isinstance(provider(), Example)
def test_call_overridden_by_delegated_factory(self): def test_call_overridden_by_delegated_factory(self):
provider = providers.AbstractFactory(object) provider = providers.AbstractFactory(object)
provider.override(providers.DelegatedFactory(Example)) provider.override(providers.DelegatedFactory(Example))
self.assertIsInstance(provider(), Example) assert isinstance(provider(), Example)
def test_call_not_overridden(self): def test_call_not_overridden(self):
provider = providers.AbstractFactory(object) provider = providers.AbstractFactory(object)
with raises(errors.Error):
with self.assertRaises(errors.Error):
provider() provider()
def test_override_by_not_factory(self): def test_override_by_not_factory(self):
provider = providers.AbstractFactory(object) provider = providers.AbstractFactory(object)
with raises(errors.Error):
with self.assertRaises(errors.Error):
provider.override(providers.Callable(object)) provider.override(providers.Callable(object))
def test_provide_not_implemented(self): def test_provide_not_implemented(self):
provider = providers.AbstractFactory(Example) provider = providers.AbstractFactory(Example)
with raises(NotImplementedError):
with self.assertRaises(NotImplementedError):
provider._provide(tuple(), dict()) provider._provide(tuple(), dict())
def test_repr(self): def test_repr(self):
provider = providers.AbstractFactory(Example) provider = providers.AbstractFactory(Example)
assert repr(provider) == (
self.assertEqual(repr(provider), "<dependency_injector.providers."
"<dependency_injector.providers." "AbstractFactory({0}) at {1}>".format(repr(Example), hex(id(provider)))
"AbstractFactory({0}) at {1}>".format( )
repr(Example),
hex(id(provider))))
class FactoryDelegateTests(unittest.TestCase): class FactoryDelegateTests(unittest.TestCase):
@ -478,12 +464,11 @@ class FactoryDelegateTests(unittest.TestCase):
self.delegate = providers.FactoryDelegate(self.delegated) self.delegate = providers.FactoryDelegate(self.delegated)
def test_is_delegate(self): def test_is_delegate(self):
self.assertIsInstance(self.delegate, providers.Delegate) assert isinstance(self.delegate, providers.Delegate)
def test_init_with_not_factory(self): def test_init_with_not_factory(self):
self.assertRaises(errors.Error, with raises(errors.Error):
providers.FactoryDelegate, providers.FactoryDelegate(providers.Object(object()))
providers.Object(object()))
class FactoryAggregateTests(unittest.TestCase): class FactoryAggregateTests(unittest.TestCase):
@ -503,10 +488,10 @@ class FactoryAggregateTests(unittest.TestCase):
) )
def test_is_provider(self): def test_is_provider(self):
self.assertTrue(providers.is_provider(self.factory_aggregate)) assert providers.is_provider(self.factory_aggregate) is True
def test_is_delegated_provider(self): def test_is_delegated_provider(self):
self.assertTrue(providers.is_delegated(self.factory_aggregate)) assert providers.is_delegated(self.factory_aggregate) is True
def test_init_with_non_string_keys(self): def test_init_with_non_string_keys(self):
factory = providers.FactoryAggregate({ factory = providers.FactoryAggregate({
@ -517,28 +502,25 @@ class FactoryAggregateTests(unittest.TestCase):
object_a = factory(self.ExampleA, 1, 2, init_arg3=3, init_arg4=4) object_a = factory(self.ExampleA, 1, 2, init_arg3=3, init_arg4=4)
object_b = factory(self.ExampleB, 11, 22, init_arg3=33, init_arg4=44) object_b = factory(self.ExampleB, 11, 22, init_arg3=33, init_arg4=44)
self.assertIsInstance(object_a, self.ExampleA) assert isinstance(object_a, self.ExampleA)
self.assertEqual(object_a.init_arg1, 1) assert object_a.init_arg1 == 1
self.assertEqual(object_a.init_arg2, 2) assert object_a.init_arg2 == 2
self.assertEqual(object_a.init_arg3, 3) assert object_a.init_arg3 == 3
self.assertEqual(object_a.init_arg4, 4) assert object_a.init_arg4 == 4
self.assertIsInstance(object_b, self.ExampleB) assert isinstance(object_b, self.ExampleB)
self.assertEqual(object_b.init_arg1, 11) assert object_b.init_arg1 == 11
self.assertEqual(object_b.init_arg2, 22) assert object_b.init_arg2 == 22
self.assertEqual(object_b.init_arg3, 33) assert object_b.init_arg3 == 33
self.assertEqual(object_b.init_arg4, 44) assert object_b.init_arg4 == 44
self.assertEqual( assert factory.factories == {
factory.factories, self.ExampleA: self.example_a_factory,
{ self.ExampleB: self.example_b_factory,
self.ExampleA: self.example_a_factory, }
self.ExampleB: self.example_b_factory,
},
)
def test_init_with_not_a_factory(self): def test_init_with_not_a_factory(self):
with self.assertRaises(errors.Error): with raises(errors.Error):
providers.FactoryAggregate( providers.FactoryAggregate(
example_a=providers.Factory(self.ExampleA), example_a=providers.Factory(self.ExampleA),
example_b=object()) example_b=object())
@ -549,15 +531,12 @@ class FactoryAggregateTests(unittest.TestCase):
example_a=self.example_a_factory, example_a=self.example_a_factory,
example_b=self.example_b_factory, example_b=self.example_b_factory,
) )
self.assertEqual( assert provider.factories == {
provider.factories, "example_a": self.example_a_factory,
{ "example_b": self.example_b_factory,
"example_a": self.example_a_factory, }
"example_b": self.example_b_factory, assert isinstance(provider("example_a"), self.ExampleA)
}, assert isinstance(provider("example_b"), self.ExampleB)
)
self.assertIsInstance(provider("example_a"), self.ExampleA)
self.assertIsInstance(provider("example_b"), self.ExampleB)
def test_set_factories_with_non_string_keys(self): def test_set_factories_with_non_string_keys(self):
factory = providers.FactoryAggregate() factory = providers.FactoryAggregate()
@ -569,29 +548,26 @@ class FactoryAggregateTests(unittest.TestCase):
object_a = factory(self.ExampleA, 1, 2, init_arg3=3, init_arg4=4) object_a = factory(self.ExampleA, 1, 2, init_arg3=3, init_arg4=4)
object_b = factory(self.ExampleB, 11, 22, init_arg3=33, init_arg4=44) object_b = factory(self.ExampleB, 11, 22, init_arg3=33, init_arg4=44)
self.assertIsInstance(object_a, self.ExampleA) assert isinstance(object_a, self.ExampleA)
self.assertEqual(object_a.init_arg1, 1) assert object_a.init_arg1 == 1
self.assertEqual(object_a.init_arg2, 2) assert object_a.init_arg2 == 2
self.assertEqual(object_a.init_arg3, 3) assert object_a.init_arg3 == 3
self.assertEqual(object_a.init_arg4, 4) assert object_a.init_arg4 == 4
self.assertIsInstance(object_b, self.ExampleB) assert isinstance(object_b, self.ExampleB)
self.assertEqual(object_b.init_arg1, 11) assert object_b.init_arg1 == 11
self.assertEqual(object_b.init_arg2, 22) assert object_b.init_arg2 == 22
self.assertEqual(object_b.init_arg3, 33) assert object_b.init_arg3 == 33
self.assertEqual(object_b.init_arg4, 44) assert object_b.init_arg4 == 44
self.assertEqual( assert factory.factories == {
factory.factories, self.ExampleA: self.example_a_factory,
{ self.ExampleB: self.example_b_factory,
self.ExampleA: self.example_a_factory, }
self.ExampleB: self.example_b_factory,
},
)
def test_set_factories_returns_self(self): def test_set_factories_returns_self(self):
provider = providers.FactoryAggregate() provider = providers.FactoryAggregate()
self.assertIs(provider.set_factories(example_a=self.example_a_factory), provider) assert provider.set_factories(example_a=self.example_a_factory) is provider
def test_call(self): def test_call(self):
object_a = self.factory_aggregate("example_a", object_a = self.factory_aggregate("example_a",
@ -599,17 +575,17 @@ class FactoryAggregateTests(unittest.TestCase):
object_b = self.factory_aggregate("example_b", object_b = self.factory_aggregate("example_b",
11, 22, init_arg3=33, init_arg4=44) 11, 22, init_arg3=33, init_arg4=44)
self.assertIsInstance(object_a, self.ExampleA) assert isinstance(object_a, self.ExampleA)
self.assertEqual(object_a.init_arg1, 1) assert object_a.init_arg1 == 1
self.assertEqual(object_a.init_arg2, 2) assert object_a.init_arg2 == 2
self.assertEqual(object_a.init_arg3, 3) assert object_a.init_arg3 == 3
self.assertEqual(object_a.init_arg4, 4) assert object_a.init_arg4 == 4
self.assertIsInstance(object_b, self.ExampleB) assert isinstance(object_b, self.ExampleB)
self.assertEqual(object_b.init_arg1, 11) assert object_b.init_arg1 == 11
self.assertEqual(object_b.init_arg2, 22) assert object_b.init_arg2 == 22
self.assertEqual(object_b.init_arg3, 33) assert object_b.init_arg3 == 33
self.assertEqual(object_b.init_arg4, 44) assert object_b.init_arg4 == 44
def test_call_factory_name_as_kwarg(self): def test_call_factory_name_as_kwarg(self):
object_a = self.factory_aggregate( object_a = self.factory_aggregate(
@ -619,50 +595,51 @@ class FactoryAggregateTests(unittest.TestCase):
init_arg3=3, init_arg3=3,
init_arg4=4, init_arg4=4,
) )
self.assertIsInstance(object_a, self.ExampleA) assert isinstance(object_a, self.ExampleA)
self.assertEqual(object_a.init_arg1, 1) assert object_a.init_arg1 == 1
self.assertEqual(object_a.init_arg2, 2) assert object_a.init_arg2 == 2
self.assertEqual(object_a.init_arg3, 3) assert object_a.init_arg3 == 3
self.assertEqual(object_a.init_arg4, 4) assert object_a.init_arg4 == 4
def test_call_no_factory_name(self): def test_call_no_factory_name(self):
with self.assertRaises(TypeError): with raises(TypeError):
self.factory_aggregate() self.factory_aggregate()
def test_call_no_such_provider(self): def test_call_no_such_provider(self):
with self.assertRaises(errors.NoSuchProviderError): with raises(errors.NoSuchProviderError):
self.factory_aggregate("unknown") self.factory_aggregate("unknown")
def test_overridden(self): def test_overridden(self):
with self.assertRaises(errors.Error): with raises(errors.Error):
self.factory_aggregate.override(providers.Object(object())) self.factory_aggregate.override(providers.Object(object()))
def test_getattr(self): def test_getattr(self):
self.assertIs(self.factory_aggregate.example_a, self.example_a_factory) assert self.factory_aggregate.example_a is self.example_a_factory
self.assertIs(self.factory_aggregate.example_b, self.example_b_factory) assert self.factory_aggregate.example_b is self.example_b_factory
def test_getattr_no_such_provider(self): def test_getattr_no_such_provider(self):
with self.assertRaises(errors.NoSuchProviderError): with raises(errors.NoSuchProviderError):
self.factory_aggregate.unknown self.factory_aggregate.unknown
def test_factories(self): def test_factories(self):
self.assertDictEqual(self.factory_aggregate.factories, assert self.factory_aggregate.factories == dict(
dict(example_a=self.example_a_factory, example_a=self.example_a_factory,
example_b=self.example_b_factory)) example_b=self.example_b_factory,
)
def test_deepcopy(self): def test_deepcopy(self):
provider_copy = providers.deepcopy(self.factory_aggregate) provider_copy = providers.deepcopy(self.factory_aggregate)
self.assertIsNot(self.factory_aggregate, provider_copy) assert self.factory_aggregate is not provider_copy
self.assertIsInstance(provider_copy, type(self.factory_aggregate)) assert isinstance(provider_copy, type(self.factory_aggregate))
self.assertIsNot(self.factory_aggregate.example_a, provider_copy.example_a) assert self.factory_aggregate.example_a is not provider_copy.example_a
self.assertIsInstance(self.factory_aggregate.example_a, type(provider_copy.example_a)) assert isinstance(self.factory_aggregate.example_a, type(provider_copy.example_a))
self.assertIs(self.factory_aggregate.example_a.cls, provider_copy.example_a.cls) assert self.factory_aggregate.example_a.cls is provider_copy.example_a.cls
self.assertIsNot(self.factory_aggregate.example_b, provider_copy.example_b) assert self.factory_aggregate.example_b is not provider_copy.example_b
self.assertIsInstance(self.factory_aggregate.example_b, type(provider_copy.example_b)) assert isinstance(self.factory_aggregate.example_b, type(provider_copy.example_b))
self.assertIs(self.factory_aggregate.example_b.cls, provider_copy.example_b.cls) assert self.factory_aggregate.example_b.cls is provider_copy.example_b.cls
def test_deepcopy_with_non_string_keys(self): def test_deepcopy_with_non_string_keys(self):
factory_aggregate = providers.FactoryAggregate({ factory_aggregate = providers.FactoryAggregate({
@ -671,20 +648,22 @@ class FactoryAggregateTests(unittest.TestCase):
}) })
provider_copy = providers.deepcopy(factory_aggregate) provider_copy = providers.deepcopy(factory_aggregate)
self.assertIsNot(factory_aggregate, provider_copy) assert factory_aggregate is not provider_copy
self.assertIsInstance(provider_copy, type(factory_aggregate)) assert isinstance(provider_copy, type(factory_aggregate))
self.assertIsNot(factory_aggregate.factories[self.ExampleA], provider_copy.factories[self.ExampleA]) assert factory_aggregate.factories[self.ExampleA] is not provider_copy.factories[self.ExampleA]
self.assertIsInstance(factory_aggregate.factories[self.ExampleA], type(provider_copy.factories[self.ExampleA])) assert isinstance(factory_aggregate.factories[self.ExampleA], type(provider_copy.factories[self.ExampleA]))
self.assertIs(factory_aggregate.factories[self.ExampleA].cls, provider_copy.factories[self.ExampleA].cls) assert factory_aggregate.factories[self.ExampleA].cls is provider_copy.factories[self.ExampleA].cls
self.assertIsNot(factory_aggregate.factories[self.ExampleB], provider_copy.factories[self.ExampleB]) assert factory_aggregate.factories[self.ExampleB] is not provider_copy.factories[self.ExampleB]
self.assertIsInstance(factory_aggregate.factories[self.ExampleB], type(provider_copy.factories[self.ExampleB])) assert isinstance(factory_aggregate.factories[self.ExampleB], type(provider_copy.factories[self.ExampleB]))
self.assertIs(factory_aggregate.factories[self.ExampleB].cls, provider_copy.factories[self.ExampleB].cls) assert factory_aggregate.factories[self.ExampleB].cls is provider_copy.factories[self.ExampleB].cls
def test_repr(self): def test_repr(self):
self.assertEqual(repr(self.factory_aggregate), assert repr(self.factory_aggregate) == (
"<dependency_injector.providers." "<dependency_injector.providers."
"FactoryAggregate({0}) at {1}>".format( "FactoryAggregate({0}) at {1}>".format(
repr(self.factory_aggregate.factories), repr(self.factory_aggregate.factories),
hex(id(self.factory_aggregate)))) hex(id(self.factory_aggregate)),
)
)

View File

@ -9,11 +9,11 @@ class PositionalInjectionTests(unittest.TestCase):
def test_isinstance(self): def test_isinstance(self):
injection = providers.PositionalInjection(1) injection = providers.PositionalInjection(1)
self.assertIsInstance(injection, providers.Injection) assert isinstance(injection, providers.Injection)
def test_get_value_with_not_provider(self): def test_get_value_with_not_provider(self):
injection = providers.PositionalInjection(123) injection = providers.PositionalInjection(123)
self.assertEqual(injection.get_value(), 123) assert injection.get_value() == 123
def test_get_value_with_factory(self): def test_get_value_with_factory(self):
injection = providers.PositionalInjection(providers.Factory(object)) injection = providers.PositionalInjection(providers.Factory(object))
@ -21,14 +21,14 @@ class PositionalInjectionTests(unittest.TestCase):
obj1 = injection.get_value() obj1 = injection.get_value()
obj2 = injection.get_value() obj2 = injection.get_value()
self.assertIs(type(obj1), object) assert type(obj1) is object
self.assertIs(type(obj2), object) assert type(obj2) is object
self.assertIsNot(obj1, obj2) assert obj1 is not obj2
def test_get_original_value(self): def test_get_original_value(self):
provider = providers.Factory(object) provider = providers.Factory(object)
injection = providers.PositionalInjection(provider) injection = providers.PositionalInjection(provider)
self.assertIs(injection.get_original_value(), provider) assert injection.get_original_value() is provider
def test_deepcopy(self): def test_deepcopy(self):
provider = providers.Factory(object) provider = providers.Factory(object)
@ -36,9 +36,8 @@ class PositionalInjectionTests(unittest.TestCase):
injection_copy = providers.deepcopy(injection) injection_copy = providers.deepcopy(injection)
self.assertIsNot(injection_copy, injection) assert injection_copy is not injection
self.assertIsNot(injection_copy.get_original_value(), assert injection_copy.get_original_value() is not injection.get_original_value()
injection.get_original_value())
def test_deepcopy_memo(self): def test_deepcopy_memo(self):
provider = providers.Factory(object) provider = providers.Factory(object)
@ -48,40 +47,38 @@ class PositionalInjectionTests(unittest.TestCase):
injection_copy = providers.deepcopy( injection_copy = providers.deepcopy(
injection, {id(injection): injection_copy_orig}) injection, {id(injection): injection_copy_orig})
self.assertIs(injection_copy, injection_copy_orig) assert injection_copy is injection_copy_orig
self.assertIs(injection_copy.get_original_value(), assert injection_copy.get_original_value() is injection.get_original_value()
injection.get_original_value())
class NamedInjectionTests(unittest.TestCase): class NamedInjectionTests(unittest.TestCase):
def test_isinstance(self): def test_isinstance(self):
injection = providers.NamedInjection("name", 1) injection = providers.NamedInjection("name", 1)
self.assertIsInstance(injection, providers.Injection) assert isinstance(injection, providers.Injection)
def test_get_name(self): def test_get_name(self):
injection = providers.NamedInjection("name", 123) injection = providers.NamedInjection("name", 123)
self.assertEqual(injection.get_name(), "name") assert injection.get_name() == "name"
def test_get_value_with_not_provider(self): def test_get_value_with_not_provider(self):
injection = providers.NamedInjection("name", 123) injection = providers.NamedInjection("name", 123)
self.assertEqual(injection.get_value(), 123) assert injection.get_value() == 123
def test_get_value_with_factory(self): def test_get_value_with_factory(self):
injection = providers.NamedInjection("name", injection = providers.NamedInjection("name", providers.Factory(object))
providers.Factory(object))
obj1 = injection.get_value() obj1 = injection.get_value()
obj2 = injection.get_value() obj2 = injection.get_value()
self.assertIs(type(obj1), object) assert type(obj1) is object
self.assertIs(type(obj2), object) assert type(obj2) is object
self.assertIsNot(obj1, obj2) assert obj1 is not obj2
def test_get_original_value(self): def test_get_original_value(self):
provider = providers.Factory(object) provider = providers.Factory(object)
injection = providers.NamedInjection("name", provider) injection = providers.NamedInjection("name", provider)
self.assertIs(injection.get_original_value(), provider) assert injection.get_original_value() is provider
def test_deepcopy(self): def test_deepcopy(self):
provider = providers.Factory(object) provider = providers.Factory(object)
@ -89,9 +86,8 @@ class NamedInjectionTests(unittest.TestCase):
injection_copy = providers.deepcopy(injection) injection_copy = providers.deepcopy(injection)
self.assertIsNot(injection_copy, injection) assert injection_copy is not injection
self.assertIsNot(injection_copy.get_original_value(), assert injection_copy.get_original_value() is not injection.get_original_value()
injection.get_original_value())
def test_deepcopy_memo(self): def test_deepcopy_memo(self):
provider = providers.Factory(object) provider = providers.Factory(object)
@ -101,6 +97,5 @@ class NamedInjectionTests(unittest.TestCase):
injection_copy = providers.deepcopy( injection_copy = providers.deepcopy(
injection, {id(injection): injection_copy_orig}) injection, {id(injection): injection_copy_orig})
self.assertIs(injection_copy, injection_copy_orig) assert injection_copy is injection_copy_orig
self.assertIs(injection_copy.get_original_value(), assert injection_copy.get_original_value() is injection.get_original_value()
injection.get_original_value())

View File

@ -10,11 +10,11 @@ from dependency_injector import providers
class ListTests(unittest.TestCase): class ListTests(unittest.TestCase):
def test_is_provider(self): def test_is_provider(self):
self.assertTrue(providers.is_provider(providers.List())) assert providers.is_provider(providers.List()) is True
def test_provided_instance_provider(self): def test_provided_instance_provider(self):
provider = providers.List() provider = providers.List()
self.assertIsInstance(provider.provided, providers.ProvidedInstance) assert isinstance(provider.provided, providers.ProvidedInstance)
def test_call_with_init_positional_args(self): def test_call_with_init_positional_args(self):
provider = providers.List("i1", "i2") provider = providers.List("i1", "i2")
@ -22,33 +22,30 @@ class ListTests(unittest.TestCase):
list1 = provider() list1 = provider()
list2 = provider() list2 = provider()
self.assertEqual(list1, ["i1", "i2"]) assert list1 == ["i1", "i2"]
self.assertEqual(list2, ["i1", "i2"]) assert list2 == ["i1", "i2"]
assert list1 is not list2
self.assertIsNot(list1, list2)
def test_call_with_context_args(self): def test_call_with_context_args(self):
provider = providers.List("i1", "i2") provider = providers.List("i1", "i2")
assert provider("i3", "i4") == ["i1", "i2", "i3", "i4"]
self.assertEqual(provider("i3", "i4"), ["i1", "i2", "i3", "i4"])
def test_fluent_interface(self): def test_fluent_interface(self):
provider = providers.List() \ provider = providers.List() \
.add_args(1, 2) .add_args(1, 2)
assert provider() == [1, 2]
self.assertEqual(provider(), [1, 2])
def test_set_args(self): def test_set_args(self):
provider = providers.List() \ provider = providers.List() \
.add_args(1, 2) \ .add_args(1, 2) \
.set_args(3, 4) .set_args(3, 4)
self.assertEqual(provider.args, (3, 4)) assert provider.args == (3, 4)
def test_clear_args(self): def test_clear_args(self):
provider = providers.List() \ provider = providers.List() \
.add_args(1, 2) \ .add_args(1, 2) \
.clear_args() .clear_args()
self.assertEqual(provider.args, tuple()) assert provider.args == tuple()
def test_call_overridden(self): def test_call_overridden(self):
provider = providers.List(1, 2) provider = providers.List(1, 2)
@ -61,27 +58,26 @@ class ListTests(unittest.TestCase):
instance1 = provider() instance1 = provider()
instance2 = provider() instance2 = provider()
self.assertIsNot(instance1, instance2) assert instance1 is not instance2
self.assertEqual(instance1, [3, 4]) assert instance1 == [3, 4]
self.assertEqual(instance2, [3, 4]) assert instance2 == [3, 4]
def test_deepcopy(self): def test_deepcopy(self):
provider = providers.List(1, 2) provider = providers.List(1, 2)
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertEqual(provider.args, provider_copy.args) assert provider.args == provider_copy.args
self.assertIsInstance(provider, providers.List) assert isinstance(provider, providers.List)
def test_deepcopy_from_memo(self): def test_deepcopy_from_memo(self):
provider = providers.List(1, 2) provider = providers.List(1, 2)
provider_copy_memo = providers.List(1, 2) provider_copy_memo = providers.List(1, 2)
provider_copy = providers.deepcopy( provider_copy = providers.deepcopy(provider, memo={id(provider): provider_copy_memo})
provider, memo={id(provider): provider_copy_memo})
self.assertIs(provider_copy, provider_copy_memo) assert provider_copy is provider_copy_memo
def test_deepcopy_args(self): def test_deepcopy_args(self):
provider = providers.List() provider = providers.List()
@ -94,13 +90,13 @@ class ListTests(unittest.TestCase):
dependent_provider_copy1 = provider_copy.args[0] dependent_provider_copy1 = provider_copy.args[0]
dependent_provider_copy2 = provider_copy.args[1] dependent_provider_copy2 = provider_copy.args[1]
self.assertNotEqual(provider.args, provider_copy.args) assert provider.args != provider_copy.args
self.assertIs(dependent_provider1.cls, dependent_provider_copy1.cls) assert dependent_provider1.cls is dependent_provider_copy1.cls
self.assertIsNot(dependent_provider1, dependent_provider_copy1) assert dependent_provider1 is not dependent_provider_copy1
self.assertIs(dependent_provider2.cls, dependent_provider_copy2.cls) assert dependent_provider2.cls is dependent_provider_copy2.cls
self.assertIsNot(dependent_provider2, dependent_provider_copy2) assert dependent_provider2 is not dependent_provider_copy2
def test_deepcopy_overridden(self): def test_deepcopy_overridden(self):
provider = providers.List() provider = providers.List()
@ -111,12 +107,12 @@ class ListTests(unittest.TestCase):
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
object_provider_copy = provider_copy.overridden[0] object_provider_copy = provider_copy.overridden[0]
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertEqual(provider.args, provider_copy.args) assert provider.args == provider_copy.args
self.assertIsInstance(provider, providers.List) assert isinstance(provider, providers.List)
self.assertIsNot(object_provider, object_provider_copy) assert object_provider is not object_provider_copy
self.assertIsInstance(object_provider_copy, providers.Object) assert isinstance(object_provider_copy, providers.Object)
def test_deepcopy_with_sys_streams(self): def test_deepcopy_with_sys_streams(self):
provider = providers.List() provider = providers.List()
@ -124,17 +120,15 @@ class ListTests(unittest.TestCase):
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertIsInstance(provider_copy, providers.List) assert isinstance(provider_copy, providers.List)
self.assertIs(provider.args[0], sys.stdin) assert provider.args[0] is sys.stdin
self.assertIs(provider.args[1], sys.stdout) assert provider.args[1] is sys.stdout
self.assertIs(provider.args[2], sys.stderr) assert provider.args[2] is sys.stderr
def test_repr(self): def test_repr(self):
provider = providers.List(1, 2) provider = providers.List(1, 2)
assert repr(provider) == (
self.assertEqual(repr(provider), "<dependency_injector.providers."
"<dependency_injector.providers." "List({0}) at {1}>".format(repr(list(provider.args)), hex(id(provider)))
"List({0}) at {1}>".format( )
repr(list(provider.args)),
hex(id(provider))))

View File

@ -70,60 +70,51 @@ class ProvidedInstanceTests(unittest.TestCase):
self.container = Container() self.container = Container()
def test_is_provider(self): def test_is_provider(self):
self.assertTrue(providers.is_provider(self.container.service.provided)) assert providers.is_provider(self.container.service.provided) is True
def test_attribute(self): def test_attribute(self):
client = self.container.client_attribute() client = self.container.client_attribute()
self.assertEqual(client.value, "foo") assert client.value == "foo"
def test_item(self): def test_item(self):
client = self.container.client_item() client = self.container.client_item()
self.assertEqual(client.value, "foo") assert client.value == "foo"
def test_attribute_item(self): def test_attribute_item(self):
client = self.container.client_attribute_item() client = self.container.client_attribute_item()
self.assertEqual(client.value, "foo") assert client.value == "foo"
def test_method_call(self): def test_method_call(self):
client = self.container.client_method_call() client = self.container.client_method_call()
self.assertEqual(client.value, "foo") assert client.value == "foo"
def test_method_closure_call(self): def test_method_closure_call(self):
client = self.container.client_method_closure_call() client = self.container.client_method_closure_call()
self.assertEqual(client.value, "foo") assert client.value == "foo"
def test_provided_call(self): def test_provided_call(self):
client = self.container.client_provided_call() client = self.container.client_provided_call()
self.assertEqual(client.value, "foo") assert client.value == "foo"
def test_call_overridden(self): def test_call_overridden(self):
value = "bar" value = "bar"
with self.container.service.override(Service(value)): with self.container.service.override(Service(value)):
self.assertEqual(self.container.client_attribute().value, value) assert self.container.client_attribute().value == value
self.assertEqual(self.container.client_item().value, value) assert self.container.client_item().value == value
self.assertEqual(self.container.client_attribute_item().value, value) assert self.container.client_attribute_item().value == value
self.assertEqual(self.container.client_method_call().value, value) assert self.container.client_method_call().value == value
def test_repr_provided_instance(self): def test_repr_provided_instance(self):
provider = self.container.service.provided provider = self.container.service.provided
self.assertEqual( assert repr(provider) == "ProvidedInstance(\"{0}\")".format(repr(self.container.service))
"ProvidedInstance(\"{0}\")".format(repr(self.container.service)),
repr(provider),
)
def test_repr_attribute_getter(self): def test_repr_attribute_getter(self):
provider = self.container.service.provided.value provider = self.container.service.provided.value
self.assertEqual( assert repr(provider) == "AttributeGetter(\"value\")"
"AttributeGetter(\"value\")",
repr(provider),
)
def test_repr_item_getter(self): def test_repr_item_getter(self):
provider = self.container.service.provided["test-test"] provider = self.container.service.provided["test-test"]
self.assertEqual( assert repr(provider) == "ItemGetter(\"test-test\")"
"ItemGetter(\"test-test\")",
repr(provider),
)
class LazyInitTests(unittest.TestCase): class LazyInitTests(unittest.TestCase):
@ -132,36 +123,36 @@ class LazyInitTests(unittest.TestCase):
provides = providers.Object(object()) provides = providers.Object(object())
provider = providers.ProvidedInstance() provider = providers.ProvidedInstance()
provider.set_provides(provides) provider.set_provides(provides)
self.assertIs(provider.provides, provides) assert provider.provides is provides
self.assertIs(provider.set_provides(providers.Provider()), provider) assert provider.set_provides(providers.Provider()) is provider
def test_attribute_getter(self): def test_attribute_getter(self):
provides = providers.Object(object()) provides = providers.Object(object())
provider = providers.AttributeGetter() provider = providers.AttributeGetter()
provider.set_provides(provides) provider.set_provides(provides)
provider.set_name("__dict__") provider.set_name("__dict__")
self.assertIs(provider.provides, provides) assert provider.provides is provides
self.assertEqual(provider.name, "__dict__") assert provider.name == "__dict__"
self.assertIs(provider.set_provides(providers.Provider()), provider) assert provider.set_provides(providers.Provider()) is provider
self.assertIs(provider.set_name("__dict__"), provider) assert provider.set_name("__dict__") is provider
def test_item_getter(self): def test_item_getter(self):
provides = providers.Object({"foo": "bar"}) provides = providers.Object({"foo": "bar"})
provider = providers.ItemGetter() provider = providers.ItemGetter()
provider.set_provides(provides) provider.set_provides(provides)
provider.set_name("foo") provider.set_name("foo")
self.assertIs(provider.provides, provides) assert provider.provides is provides
self.assertEqual(provider.name, "foo") assert provider.name == "foo"
self.assertIs(provider.set_provides(providers.Provider()), provider) assert provider.set_provides(providers.Provider()) is provider
self.assertIs(provider.set_name("foo"), provider) assert provider.set_name("foo") is provider
def test_method_caller(self): def test_method_caller(self):
provides = providers.Object(lambda: 42) provides = providers.Object(lambda: 42)
provider = providers.MethodCaller() provider = providers.MethodCaller()
provider.set_provides(provides) provider.set_provides(provides)
self.assertIs(provider.provides, provides) assert provider.provides is provides
self.assertEqual(provider(), 42) assert provider() == 42
self.assertIs(provider.set_provides(providers.Provider()), provider) assert provider.set_provides(providers.Provider()) is provider
class ProvidedInstancePuzzleTests(unittest.TestCase): class ProvidedInstancePuzzleTests(unittest.TestCase):
@ -189,17 +180,13 @@ class ProvidedInstancePuzzleTests(unittest.TestCase):
) )
result = test_list() result = test_list()
assert result == [
self.assertEqual( 10,
result, 22,
[ service(),
10, "foo-bar",
22, "foo-bar",
service(), ]
"foo-bar",
"foo-bar",
],
)
class ProvidedInstanceInBaseClassTests(unittest.TestCase): class ProvidedInstanceInBaseClassTests(unittest.TestCase):

View File

@ -6,6 +6,7 @@ import unittest
from typing import Any from typing import Any
from dependency_injector import containers, providers, resources, errors from dependency_injector import containers, providers, resources, errors
from pytest import raises
# Runtime import to get asyncutils module # Runtime import to get asyncutils module
import os import os
@ -28,21 +29,21 @@ def init_fn(*args, **kwargs):
class ResourceTests(unittest.TestCase): class ResourceTests(unittest.TestCase):
def test_is_provider(self): def test_is_provider(self):
self.assertTrue(providers.is_provider(providers.Resource(init_fn))) assert providers.is_provider(providers.Resource(init_fn)) is True
def test_init_optional_provides(self): def test_init_optional_provides(self):
provider = providers.Resource() provider = providers.Resource()
provider.set_provides(init_fn) provider.set_provides(init_fn)
self.assertIs(provider.provides, init_fn) assert provider.provides is init_fn
self.assertEqual(provider(), (tuple(), dict())) assert provider() == (tuple(), dict())
def test_set_provides_returns_self(self): def test_set_provides_returns_self(self):
provider = providers.Resource() provider = providers.Resource()
self.assertIs(provider.set_provides(init_fn), provider) assert provider.set_provides(init_fn) is provider
def test_provided_instance_provider(self): def test_provided_instance_provider(self):
provider = providers.Resource(init_fn) provider = providers.Resource(init_fn)
self.assertIsInstance(provider.provided, providers.ProvidedInstance) assert isinstance(provider.provided, providers.ProvidedInstance)
def test_injection(self): def test_injection(self):
resource = object() resource = object()
@ -61,13 +62,13 @@ class ResourceTests(unittest.TestCase):
list1 = container.dependency1() list1 = container.dependency1()
list2 = container.dependency2() list2 = container.dependency2()
self.assertEqual(list1, [resource]) assert list1 == [resource]
self.assertIs(list1[0], resource) assert list1[0] is resource
self.assertEqual(list2, [resource]) assert list2 == [resource]
self.assertIs(list2[0], resource) assert list2[0] is resource
self.assertEqual(_init.counter, 1) assert _init.counter == 1
def test_init_function(self): def test_init_function(self):
def _init(): def _init():
@ -77,12 +78,12 @@ class ResourceTests(unittest.TestCase):
provider = providers.Resource(_init) provider = providers.Resource(_init)
result1 = provider() result1 = provider()
self.assertIsNone(result1) assert result1 is None
self.assertEqual(_init.counter, 1) assert _init.counter == 1
result2 = provider() result2 = provider()
self.assertIsNone(result2) assert result2 is None
self.assertEqual(_init.counter, 1) assert _init.counter == 1
provider.shutdown() provider.shutdown()
@ -98,22 +99,22 @@ class ResourceTests(unittest.TestCase):
provider = providers.Resource(_init) provider = providers.Resource(_init)
result1 = provider() result1 = provider()
self.assertIsNone(result1) assert result1 is None
self.assertEqual(_init.init_counter, 1) assert _init.init_counter == 1
self.assertEqual(_init.shutdown_counter, 0) assert _init.shutdown_counter == 0
provider.shutdown() provider.shutdown()
self.assertEqual(_init.init_counter, 1) assert _init.init_counter == 1
self.assertEqual(_init.shutdown_counter, 1) assert _init.shutdown_counter == 1
result2 = provider() result2 = provider()
self.assertIsNone(result2) assert result2 is None
self.assertEqual(_init.init_counter, 2) assert _init.init_counter == 2
self.assertEqual(_init.shutdown_counter, 1) assert _init.shutdown_counter == 1
provider.shutdown() provider.shutdown()
self.assertEqual(_init.init_counter, 2) assert _init.init_counter == 2
self.assertEqual(_init.shutdown_counter, 2) assert _init.shutdown_counter == 2
def test_init_class(self): def test_init_class(self):
class TestResource(resources.Resource): class TestResource(resources.Resource):
@ -129,22 +130,22 @@ class ResourceTests(unittest.TestCase):
provider = providers.Resource(TestResource) provider = providers.Resource(TestResource)
result1 = provider() result1 = provider()
self.assertIsNone(result1) assert result1 is None
self.assertEqual(TestResource.init_counter, 1) assert TestResource.init_counter == 1
self.assertEqual(TestResource.shutdown_counter, 0) assert TestResource.shutdown_counter == 0
provider.shutdown() provider.shutdown()
self.assertEqual(TestResource.init_counter, 1) assert TestResource.init_counter == 1
self.assertEqual(TestResource.shutdown_counter, 1) assert TestResource.shutdown_counter == 1
result2 = provider() result2 = provider()
self.assertIsNone(result2) assert result2 is None
self.assertEqual(TestResource.init_counter, 2) assert TestResource.init_counter == 2
self.assertEqual(TestResource.shutdown_counter, 1) assert TestResource.shutdown_counter == 1
provider.shutdown() provider.shutdown()
self.assertEqual(TestResource.init_counter, 2) assert TestResource.init_counter == 2
self.assertEqual(TestResource.shutdown_counter, 2) assert TestResource.shutdown_counter == 2
def test_init_class_generic_typing(self): def test_init_class_generic_typing(self):
# See issue: https://github.com/ets-labs/python-dependency-injector/issues/488 # See issue: https://github.com/ets-labs/python-dependency-injector/issues/488
@ -157,27 +158,27 @@ class ResourceTests(unittest.TestCase):
def shutdown(self, resource: TestDependency) -> None: ... def shutdown(self, resource: TestDependency) -> None: ...
self.assertTrue(issubclass(TestResource, resources.Resource)) assert issubclass(TestResource, resources.Resource) is True
def test_init_class_abc_init_definition_is_required(self): def test_init_class_abc_init_definition_is_required(self):
class TestResource(resources.Resource): class TestResource(resources.Resource):
... ...
with self.assertRaises(TypeError) as context: with raises(TypeError) as context:
TestResource() TestResource()
self.assertIn("Can't instantiate abstract class TestResource", str(context.exception)) assert "Can't instantiate abstract class TestResource" in str(context.value)
self.assertIn("init", str(context.exception)) assert "init" in str(context.value)
def test_init_class_abc_shutdown_definition_is_not_required(self): def test_init_class_abc_shutdown_definition_is_not_required(self):
class TestResource(resources.Resource): class TestResource(resources.Resource):
def init(self): def init(self):
... ...
self.assertTrue(hasattr(TestResource(), "shutdown")) assert hasattr(TestResource(), "shutdown") is True
def test_init_not_callable(self): def test_init_not_callable(self):
provider = providers.Resource(1) provider = providers.Resource(1)
with self.assertRaises(errors.Error): with raises(errors.Error):
provider.init() provider.init()
def test_init_and_shutdown(self): def test_init_and_shutdown(self):
@ -192,22 +193,22 @@ class ResourceTests(unittest.TestCase):
provider = providers.Resource(_init) provider = providers.Resource(_init)
result1 = provider.init() result1 = provider.init()
self.assertIsNone(result1) assert result1 is None
self.assertEqual(_init.init_counter, 1) assert _init.init_counter == 1
self.assertEqual(_init.shutdown_counter, 0) assert _init.shutdown_counter == 0
provider.shutdown() provider.shutdown()
self.assertEqual(_init.init_counter, 1) assert _init.init_counter == 1
self.assertEqual(_init.shutdown_counter, 1) assert _init.shutdown_counter == 1
result2 = provider.init() result2 = provider.init()
self.assertIsNone(result2) assert result2 is None
self.assertEqual(_init.init_counter, 2) assert _init.init_counter == 2
self.assertEqual(_init.shutdown_counter, 1) assert _init.shutdown_counter == 1
provider.shutdown() provider.shutdown()
self.assertEqual(_init.init_counter, 2) assert _init.init_counter == 2
self.assertEqual(_init.shutdown_counter, 2) assert _init.shutdown_counter == 2
def test_shutdown_of_not_initialized(self): def test_shutdown_of_not_initialized(self):
def _init(): def _init():
@ -216,52 +217,51 @@ class ResourceTests(unittest.TestCase):
provider = providers.Resource(_init) provider = providers.Resource(_init)
result = provider.shutdown() result = provider.shutdown()
self.assertIsNone(result) assert result is None
def test_initialized(self): def test_initialized(self):
provider = providers.Resource(init_fn) provider = providers.Resource(init_fn)
self.assertFalse(provider.initialized) assert provider.initialized is False
provider.init() provider.init()
self.assertTrue(provider.initialized) assert provider.initialized is True
provider.shutdown() provider.shutdown()
self.assertFalse(provider.initialized) assert provider.initialized is False
def test_call_with_context_args(self): def test_call_with_context_args(self):
provider = providers.Resource(init_fn, "i1", "i2") provider = providers.Resource(init_fn, "i1", "i2")
self.assertEqual(provider("i3", i4=4), (("i1", "i2", "i3"), {"i4": 4})) assert provider("i3", i4=4) == (("i1", "i2", "i3"), {"i4": 4})
def test_fluent_interface(self): def test_fluent_interface(self):
provider = providers.Resource(init_fn) \ provider = providers.Resource(init_fn) \
.add_args(1, 2) \ .add_args(1, 2) \
.add_kwargs(a3=3, a4=4) .add_kwargs(a3=3, a4=4)
assert provider() == ((1, 2), {"a3": 3, "a4": 4})
self.assertEqual(provider(), ((1, 2), {"a3": 3, "a4": 4}))
def test_set_args(self): def test_set_args(self):
provider = providers.Resource(init_fn) \ provider = providers.Resource(init_fn) \
.add_args(1, 2) \ .add_args(1, 2) \
.set_args(3, 4) .set_args(3, 4)
self.assertEqual(provider.args, (3, 4)) assert provider.args == (3, 4)
def test_clear_args(self): def test_clear_args(self):
provider = providers.Resource(init_fn) \ provider = providers.Resource(init_fn) \
.add_args(1, 2) \ .add_args(1, 2) \
.clear_args() .clear_args()
self.assertEqual(provider.args, tuple()) assert provider.args == tuple()
def test_set_kwargs(self): def test_set_kwargs(self):
provider = providers.Resource(init_fn) \ provider = providers.Resource(init_fn) \
.add_kwargs(a1="i1", a2="i2") \ .add_kwargs(a1="i1", a2="i2") \
.set_kwargs(a3="i3", a4="i4") .set_kwargs(a3="i3", a4="i4")
self.assertEqual(provider.kwargs, {"a3": "i3", "a4": "i4"}) assert provider.kwargs == {"a3": "i3", "a4": "i4"}
def test_clear_kwargs(self): def test_clear_kwargs(self):
provider = providers.Resource(init_fn) \ provider = providers.Resource(init_fn) \
.add_kwargs(a1="i1", a2="i2") \ .add_kwargs(a1="i1", a2="i2") \
.clear_kwargs() .clear_kwargs()
self.assertEqual(provider.kwargs, {}) assert provider.kwargs == {}
def test_call_overridden(self): def test_call_overridden(self):
provider = providers.Resource(init_fn, 1) provider = providers.Resource(init_fn, 1)
@ -274,25 +274,25 @@ class ResourceTests(unittest.TestCase):
instance1 = provider() instance1 = provider()
instance2 = provider() instance2 = provider()
self.assertIs(instance1, instance2) assert instance1 is instance2
self.assertEqual(instance1, ((3,), {})) assert instance1 == ((3,), {})
self.assertEqual(instance2, ((3,), {})) assert instance2 == ((3,), {})
def test_deepcopy(self): def test_deepcopy(self):
provider = providers.Resource(init_fn, 1, 2, a3=3, a4=4) provider = providers.Resource(init_fn, 1, 2, a3=3, a4=4)
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertEqual(provider.args, provider_copy.args) assert provider.args == provider_copy.args
self.assertEqual(provider.kwargs, provider_copy.kwargs) assert provider.kwargs == provider_copy.kwargs
self.assertIsInstance(provider, providers.Resource) assert isinstance(provider, providers.Resource)
def test_deepcopy_initialized(self): def test_deepcopy_initialized(self):
provider = providers.Resource(init_fn) provider = providers.Resource(init_fn)
provider.init() provider.init()
with self.assertRaises(errors.Error): with raises(errors.Error):
providers.deepcopy(provider) providers.deepcopy(provider)
def test_deepcopy_from_memo(self): def test_deepcopy_from_memo(self):
@ -304,7 +304,7 @@ class ResourceTests(unittest.TestCase):
memo={id(provider): provider_copy_memo}, memo={id(provider): provider_copy_memo},
) )
self.assertIs(provider_copy, provider_copy_memo) assert provider_copy is provider_copy_memo
def test_deepcopy_args(self): def test_deepcopy_args(self):
provider = providers.Resource(init_fn) provider = providers.Resource(init_fn)
@ -317,13 +317,13 @@ class ResourceTests(unittest.TestCase):
dependent_provider_copy1 = provider_copy.args[0] dependent_provider_copy1 = provider_copy.args[0]
dependent_provider_copy2 = provider_copy.args[1] dependent_provider_copy2 = provider_copy.args[1]
self.assertNotEqual(provider.args, provider_copy.args) assert provider.args != provider_copy.args
self.assertIs(dependent_provider1.cls, dependent_provider_copy1.cls) assert dependent_provider1.cls is dependent_provider_copy1.cls
self.assertIsNot(dependent_provider1, dependent_provider_copy1) assert dependent_provider1 is not dependent_provider_copy1
self.assertIs(dependent_provider2.cls, dependent_provider_copy2.cls) assert dependent_provider2.cls is dependent_provider_copy2.cls
self.assertIsNot(dependent_provider2, dependent_provider_copy2) assert dependent_provider2 is not dependent_provider_copy2
def test_deepcopy_kwargs(self): def test_deepcopy_kwargs(self):
provider = providers.Resource(init_fn) provider = providers.Resource(init_fn)
@ -336,13 +336,13 @@ class ResourceTests(unittest.TestCase):
dependent_provider_copy1 = provider_copy.kwargs["d1"] dependent_provider_copy1 = provider_copy.kwargs["d1"]
dependent_provider_copy2 = provider_copy.kwargs["d2"] dependent_provider_copy2 = provider_copy.kwargs["d2"]
self.assertNotEqual(provider.kwargs, provider_copy.kwargs) assert provider.kwargs != provider_copy.kwargs
self.assertIs(dependent_provider1.cls, dependent_provider_copy1.cls) assert dependent_provider1.cls is dependent_provider_copy1.cls
self.assertIsNot(dependent_provider1, dependent_provider_copy1) assert dependent_provider1 is not dependent_provider_copy1
self.assertIs(dependent_provider2.cls, dependent_provider_copy2.cls) assert dependent_provider2.cls is dependent_provider_copy2.cls
self.assertIsNot(dependent_provider2, dependent_provider_copy2) assert dependent_provider2 is not dependent_provider_copy2
def test_deepcopy_overridden(self): def test_deepcopy_overridden(self):
provider = providers.Resource(init_fn) provider = providers.Resource(init_fn)
@ -353,12 +353,12 @@ class ResourceTests(unittest.TestCase):
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
object_provider_copy = provider_copy.overridden[0] object_provider_copy = provider_copy.overridden[0]
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertEqual(provider.args, provider_copy.args) assert provider.args == provider_copy.args
self.assertIsInstance(provider, providers.Resource) assert isinstance(provider, providers.Resource)
self.assertIsNot(object_provider, object_provider_copy) assert object_provider is not object_provider_copy
self.assertIsInstance(object_provider_copy, providers.Object) assert isinstance(object_provider_copy, providers.Object)
def test_deepcopy_with_sys_streams(self): def test_deepcopy_with_sys_streams(self):
provider = providers.Resource(init_fn) provider = providers.Resource(init_fn)
@ -366,17 +366,16 @@ class ResourceTests(unittest.TestCase):
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertIsInstance(provider_copy, providers.Resource) assert isinstance(provider_copy, providers.Resource)
self.assertIs(provider.args[0], sys.stdin) assert provider.args[0] is sys.stdin
self.assertIs(provider.args[1], sys.stdout) assert provider.args[1] is sys.stdout
self.assertIs(provider.args[2], sys.stderr) assert provider.args[2] is sys.stderr
def test_repr(self): def test_repr(self):
provider = providers.Resource(init_fn) provider = providers.Resource(init_fn)
self.assertEqual( assert repr(provider) == (
repr(provider),
"<dependency_injector.providers.Resource({0}) at {1}>".format( "<dependency_injector.providers.Resource({0}) at {1}>".format(
repr(init_fn), repr(init_fn),
hex(id(provider)), hex(id(provider)),
@ -398,12 +397,12 @@ class AsyncResourceTest(AsyncTestCase):
provider = providers.Resource(_init) provider = providers.Resource(_init)
result1 = self._run(provider()) result1 = self._run(provider())
self.assertIs(result1, resource) assert result1 is resource
self.assertEqual(_init.counter, 1) assert _init.counter == 1
result2 = self._run(provider()) result2 = self._run(provider())
self.assertIs(result2, resource) assert result2 is resource
self.assertEqual(_init.counter, 1) assert _init.counter == 1
self._run(provider.shutdown()) self._run(provider.shutdown())
@ -425,22 +424,22 @@ class AsyncResourceTest(AsyncTestCase):
provider = providers.Resource(_init) provider = providers.Resource(_init)
result1 = self._run(provider()) result1 = self._run(provider())
self.assertIs(result1, resource) assert result1 is resource
self.assertEqual(_init.init_counter, 1) assert _init.init_counter == 1
self.assertEqual(_init.shutdown_counter, 0) assert _init.shutdown_counter == 0
self._run(provider.shutdown()) self._run(provider.shutdown())
self.assertEqual(_init.init_counter, 1) assert _init.init_counter == 1
self.assertEqual(_init.shutdown_counter, 1) assert _init.shutdown_counter == 1
result2 = self._run(provider()) result2 = self._run(provider())
self.assertIs(result2, resource) assert result2 is resource
self.assertEqual(_init.init_counter, 2) assert _init.init_counter == 2
self.assertEqual(_init.shutdown_counter, 1) assert _init.shutdown_counter == 1
self._run(provider.shutdown()) self._run(provider.shutdown())
self.assertEqual(_init.init_counter, 2) assert _init.init_counter == 2
self.assertEqual(_init.shutdown_counter, 2) assert _init.shutdown_counter == 2
def test_init_async_class(self): def test_init_async_class(self):
resource = object() resource = object()
@ -462,22 +461,22 @@ class AsyncResourceTest(AsyncTestCase):
provider = providers.Resource(TestResource) provider = providers.Resource(TestResource)
result1 = self._run(provider()) result1 = self._run(provider())
self.assertIs(result1, resource) assert result1 is resource
self.assertEqual(TestResource.init_counter, 1) assert TestResource.init_counter == 1
self.assertEqual(TestResource.shutdown_counter, 0) assert TestResource.shutdown_counter == 0
self._run(provider.shutdown()) self._run(provider.shutdown())
self.assertEqual(TestResource.init_counter, 1) assert TestResource.init_counter == 1
self.assertEqual(TestResource.shutdown_counter, 1) assert TestResource.shutdown_counter == 1
result2 = self._run(provider()) result2 = self._run(provider())
self.assertIs(result2, resource) assert result2 is resource
self.assertEqual(TestResource.init_counter, 2) assert TestResource.init_counter == 2
self.assertEqual(TestResource.shutdown_counter, 1) assert TestResource.shutdown_counter == 1
self._run(provider.shutdown()) self._run(provider.shutdown())
self.assertEqual(TestResource.init_counter, 2) assert TestResource.init_counter == 2
self.assertEqual(TestResource.shutdown_counter, 2) assert TestResource.shutdown_counter == 2
def test_init_async_class_generic_typing(self): def test_init_async_class_generic_typing(self):
# See issue: https://github.com/ets-labs/python-dependency-injector/issues/488 # See issue: https://github.com/ets-labs/python-dependency-injector/issues/488
@ -490,24 +489,24 @@ class AsyncResourceTest(AsyncTestCase):
async def shutdown(self, resource: TestDependency) -> None: ... async def shutdown(self, resource: TestDependency) -> None: ...
self.assertTrue(issubclass(TestAsyncResource, resources.AsyncResource)) assert issubclass(TestAsyncResource, resources.AsyncResource) is True
def test_init_async_class_abc_init_definition_is_required(self): def test_init_async_class_abc_init_definition_is_required(self):
class TestAsyncResource(resources.AsyncResource): class TestAsyncResource(resources.AsyncResource):
... ...
with self.assertRaises(TypeError) as context: with raises(TypeError) as context:
TestAsyncResource() TestAsyncResource()
self.assertIn("Can't instantiate abstract class TestAsyncResource", str(context.exception)) assert "Can't instantiate abstract class TestAsyncResource" in str(context.value)
self.assertIn("init", str(context.exception)) assert "init" in str(context.value)
def test_init_async_class_abc_shutdown_definition_is_not_required(self): def test_init_async_class_abc_shutdown_definition_is_not_required(self):
class TestAsyncResource(resources.AsyncResource): class TestAsyncResource(resources.AsyncResource):
async def init(self): async def init(self):
... ...
self.assertTrue(hasattr(TestAsyncResource(), "shutdown")) assert hasattr(TestAsyncResource(), "shutdown") is True
self.assertTrue(inspect.iscoroutinefunction(TestAsyncResource.shutdown)) assert inspect.iscoroutinefunction(TestAsyncResource.shutdown) is True
def test_init_with_error(self): def test_init_with_error(self):
async def _init(): async def _init():
@ -516,14 +515,14 @@ class AsyncResourceTest(AsyncTestCase):
provider = providers.Resource(_init) provider = providers.Resource(_init)
future = provider() future = provider()
self.assertTrue(provider.initialized) assert provider.initialized is True
self.assertTrue(provider.is_async_mode_enabled()) assert provider.is_async_mode_enabled() is True
with self.assertRaises(RuntimeError): with raises(RuntimeError):
self._run(future) self._run(future)
self.assertFalse(provider.initialized) assert provider.initialized is False
self.assertTrue(provider.is_async_mode_enabled()) assert provider.is_async_mode_enabled() is True
def test_init_async_gen_with_error(self): def test_init_async_gen_with_error(self):
async def _init(): async def _init():
@ -533,14 +532,14 @@ class AsyncResourceTest(AsyncTestCase):
provider = providers.Resource(_init) provider = providers.Resource(_init)
future = provider() future = provider()
self.assertTrue(provider.initialized) assert provider.initialized is True
self.assertTrue(provider.is_async_mode_enabled()) assert provider.is_async_mode_enabled() is True
with self.assertRaises(RuntimeError): with raises(RuntimeError):
self._run(future) self._run(future)
self.assertFalse(provider.initialized) assert provider.initialized is False
self.assertTrue(provider.is_async_mode_enabled()) assert provider.is_async_mode_enabled() is True
def test_init_async_subclass_with_error(self): def test_init_async_subclass_with_error(self):
class _Resource(resources.AsyncResource): class _Resource(resources.AsyncResource):
@ -553,24 +552,24 @@ class AsyncResourceTest(AsyncTestCase):
provider = providers.Resource(_Resource) provider = providers.Resource(_Resource)
future = provider() future = provider()
self.assertTrue(provider.initialized) assert provider.initialized is True
self.assertTrue(provider.is_async_mode_enabled()) assert provider.is_async_mode_enabled() is True
with self.assertRaises(RuntimeError): with raises(RuntimeError):
self._run(future) self._run(future)
self.assertFalse(provider.initialized) assert provider.initialized is False
self.assertTrue(provider.is_async_mode_enabled()) assert provider.is_async_mode_enabled() is True
def test_init_with_dependency_to_other_resource(self): def test_init_with_dependency_to_other_resource(self):
# See: https://github.com/ets-labs/python-dependency-injector/issues/361 # See: https://github.com/ets-labs/python-dependency-injector/issues/361
async def init_db_connection(db_url: str): async def init_db_connection(db_url: str):
await asyncio.sleep(0.001) await asyncio.sleep(0.001)
yield {"connection": "ok", "url": db_url} yield {"connection": "OK", "url": db_url}
async def init_user_session(db): async def init_user_session(db):
await asyncio.sleep(0.001) await asyncio.sleep(0.001)
yield {"session": "ok", "db": db} yield {"session": "OK", "db": db}
class Container(containers.DeclarativeContainer): class Container(containers.DeclarativeContainer):
config = providers.Configuration() config = providers.Configuration()
@ -593,11 +592,7 @@ class AsyncResourceTest(AsyncTestCase):
await container.shutdown_resources() await container.shutdown_resources()
result = self._run(main()) result = self._run(main())
assert result == {"session": "OK", "db": {"connection": "OK", "url": "postgres://..."}}
self.assertEqual(
result,
{"session": "ok", "db": {"connection": "ok", "url": "postgres://..."}},
)
def test_init_and_shutdown_methods(self): def test_init_and_shutdown_methods(self):
async def _init(): async def _init():
@ -615,20 +610,20 @@ class AsyncResourceTest(AsyncTestCase):
provider = providers.Resource(_init) provider = providers.Resource(_init)
self._run(provider.init()) self._run(provider.init())
self.assertEqual(_init.init_counter, 1) assert _init.init_counter == 1
self.assertEqual(_init.shutdown_counter, 0) assert _init.shutdown_counter == 0
self._run(provider.shutdown()) self._run(provider.shutdown())
self.assertEqual(_init.init_counter, 1) assert _init.init_counter == 1
self.assertEqual(_init.shutdown_counter, 1) assert _init.shutdown_counter == 1
self._run(provider.init()) self._run(provider.init())
self.assertEqual(_init.init_counter, 2) assert _init.init_counter == 2
self.assertEqual(_init.shutdown_counter, 1) assert _init.shutdown_counter == 1
self._run(provider.shutdown()) self._run(provider.shutdown())
self.assertEqual(_init.init_counter, 2) assert _init.init_counter == 2
self.assertEqual(_init.shutdown_counter, 2) assert _init.shutdown_counter == 2
def test_shutdown_of_not_initialized(self): def test_shutdown_of_not_initialized(self):
async def _init(): async def _init():
@ -638,7 +633,7 @@ class AsyncResourceTest(AsyncTestCase):
provider.enable_async_mode() provider.enable_async_mode()
result = self._run(provider.shutdown()) result = self._run(provider.shutdown())
self.assertIsNone(result) assert result is None
def test_concurrent_init(self): def test_concurrent_init(self):
resource = object() resource = object()
@ -658,8 +653,8 @@ class AsyncResourceTest(AsyncTestCase):
), ),
) )
self.assertIs(result1, resource) assert result1 is resource
self.assertEqual(_init.counter, 1) assert _init.counter == 1
self.assertIs(result2, resource) assert result2 is resource
self.assertEqual(_init.counter, 1) assert _init.counter == 1

View File

@ -7,6 +7,7 @@ import sys
import unittest import unittest
from dependency_injector import providers, errors from dependency_injector import providers, errors
from pytest import raises
class SelectorTests(unittest.TestCase): class SelectorTests(unittest.TestCase):
@ -14,7 +15,7 @@ class SelectorTests(unittest.TestCase):
selector = providers.Configuration() selector = providers.Configuration()
def test_is_provider(self): def test_is_provider(self):
self.assertTrue(providers.is_provider(providers.Selector(self.selector))) assert providers.is_provider(providers.Selector(self.selector)) is True
def test_init_optional(self): def test_init_optional(self):
one = providers.Object(1) one = providers.Object(1)
@ -24,23 +25,23 @@ class SelectorTests(unittest.TestCase):
provider.set_selector(self.selector) provider.set_selector(self.selector)
provider.set_providers(one=one, two=two) provider.set_providers(one=one, two=two)
self.assertEqual(provider.providers, {"one": one, "two": two}) assert provider.providers == {"one": one, "two": two}
with self.selector.override("one"): with self.selector.override("one"):
self.assertEqual(provider(), one()) assert provider() == one()
with self.selector.override("two"): with self.selector.override("two"):
self.assertEqual(provider(), two()) assert provider() == two()
def test_set_selector_returns_self(self): def test_set_selector_returns_self(self):
provider = providers.Selector() provider = providers.Selector()
self.assertIs(provider.set_selector(self.selector), provider) assert provider.set_selector(self.selector) is provider
def test_set_providers_returns_self(self): def test_set_providers_returns_self(self):
provider = providers.Selector() provider = providers.Selector()
self.assertIs(provider.set_providers(one=providers.Provider()), provider) assert provider.set_providers(one=providers.Provider()) is provider
def test_provided_instance_provider(self): def test_provided_instance_provider(self):
provider = providers.Selector(self.selector) provider = providers.Selector(self.selector)
self.assertIsInstance(provider.provided, providers.ProvidedInstance) assert isinstance(provider.provided, providers.ProvidedInstance)
def test_call(self): def test_call(self):
provider = providers.Selector( provider = providers.Selector(
@ -50,10 +51,10 @@ class SelectorTests(unittest.TestCase):
) )
with self.selector.override("one"): with self.selector.override("one"):
self.assertEqual(provider(), 1) assert provider() == 1
with self.selector.override("two"): with self.selector.override("two"):
self.assertEqual(provider(), 2) assert provider() == 2
def test_call_undefined_provider(self): def test_call_undefined_provider(self):
provider = providers.Selector( provider = providers.Selector(
@ -63,7 +64,7 @@ class SelectorTests(unittest.TestCase):
) )
with self.selector.override("three"): with self.selector.override("three"):
with self.assertRaises(errors.Error): with raises(errors.Error):
provider() provider()
def test_call_selector_is_none(self): def test_call_selector_is_none(self):
@ -74,7 +75,7 @@ class SelectorTests(unittest.TestCase):
) )
with self.selector.override(None): with self.selector.override(None):
with self.assertRaises(errors.Error): with raises(errors.Error):
provider() provider()
def test_call_any_callable(self): def test_call_any_callable(self):
@ -84,10 +85,10 @@ class SelectorTests(unittest.TestCase):
two=providers.Object(2), two=providers.Object(2),
) )
self.assertEqual(provider(), 1) assert provider() == 1
self.assertEqual(provider(), 2) assert provider() == 2
self.assertEqual(provider(), 1) assert provider() == 1
self.assertEqual(provider(), 2) assert provider() == 2
def test_call_with_context_args(self): def test_call_with_context_args(self):
provider = providers.Selector( provider = providers.Selector(
@ -98,8 +99,8 @@ class SelectorTests(unittest.TestCase):
with self.selector.override("one"): with self.selector.override("one"):
args, kwargs = provider(1, 2, three=3, four=4) args, kwargs = provider(1, 2, three=3, four=4)
self.assertEqual(args, (1, 2)) assert args == (1, 2)
self.assertEqual(kwargs, {"three": 3, "four": 4}) assert kwargs == {"three": 3, "four": 4}
def test_getattr(self): def test_getattr(self):
provider_one = providers.Object(1) provider_one = providers.Object(1)
@ -111,8 +112,8 @@ class SelectorTests(unittest.TestCase):
two=provider_two, two=provider_two,
) )
self.assertIs(provider.one, provider_one) assert provider.one is provider_one
self.assertIs(provider.two, provider_two) assert provider.two is provider_two
def test_getattr_attribute_error(self): def test_getattr_attribute_error(self):
provider_one = providers.Object(1) provider_one = providers.Object(1)
@ -124,7 +125,7 @@ class SelectorTests(unittest.TestCase):
two=provider_two, two=provider_two,
) )
with self.assertRaises(AttributeError): with raises(AttributeError):
_ = provider.provider_three _ = provider.provider_three
def test_call_overridden(self): def test_call_overridden(self):
@ -136,7 +137,7 @@ class SelectorTests(unittest.TestCase):
provider.override(overriding_provider2) provider.override(overriding_provider2)
with self.selector.override("sample"): with self.selector.override("sample"):
self.assertEqual(provider(), 3) assert provider() == 3
def test_providers_attribute(self): def test_providers_attribute(self):
provider_one = providers.Object(1) provider_one = providers.Object(1)
@ -147,16 +148,15 @@ class SelectorTests(unittest.TestCase):
one=provider_one, one=provider_one,
two=provider_two, two=provider_two,
) )
assert provider.providers == {"one": provider_one, "two": provider_two}
self.assertEqual(provider.providers, {"one": provider_one, "two": provider_two})
def test_deepcopy(self): def test_deepcopy(self):
provider = providers.Selector(self.selector) provider = providers.Selector(self.selector)
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertIsInstance(provider, providers.Selector) assert isinstance(provider, providers.Selector)
def test_deepcopy_from_memo(self): def test_deepcopy_from_memo(self):
provider = providers.Selector(self.selector) provider = providers.Selector(self.selector)
@ -167,7 +167,7 @@ class SelectorTests(unittest.TestCase):
memo={id(provider): provider_copy_memo}, memo={id(provider): provider_copy_memo},
) )
self.assertIs(provider_copy, provider_copy_memo) assert provider_copy is provider_copy_memo
def test_deepcopy_overridden(self): def test_deepcopy_overridden(self):
provider = providers.Selector(self.selector) provider = providers.Selector(self.selector)
@ -178,11 +178,11 @@ class SelectorTests(unittest.TestCase):
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
object_provider_copy = provider_copy.overridden[0] object_provider_copy = provider_copy.overridden[0]
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertIsInstance(provider, providers.Selector) assert isinstance(provider, providers.Selector)
self.assertIsNot(object_provider, object_provider_copy) assert object_provider is not object_provider_copy
self.assertIsInstance(object_provider_copy, providers.Object) assert isinstance(object_provider_copy, providers.Object)
def test_deepcopy_with_sys_streams(self): def test_deepcopy_with_sys_streams(self):
provider = providers.Selector( provider = providers.Selector(
@ -195,17 +195,17 @@ class SelectorTests(unittest.TestCase):
provider_copy = providers.deepcopy(provider) provider_copy = providers.deepcopy(provider)
self.assertIsNot(provider, provider_copy) assert provider is not provider_copy
self.assertIsInstance(provider_copy, providers.Selector) assert isinstance(provider_copy, providers.Selector)
with self.selector.override("stdin"): with self.selector.override("stdin"):
self.assertIs(provider(), sys.stdin) assert provider() is sys.stdin
with self.selector.override("stdout"): with self.selector.override("stdout"):
self.assertIs(provider(), sys.stdout) assert provider() is sys.stdout
with self.selector.override("stderr"): with self.selector.override("stderr"):
self.assertIs(provider(), sys.stderr) assert provider() is sys.stderr
def test_repr(self): def test_repr(self):
provider = providers.Selector( provider = providers.Selector(
@ -214,10 +214,7 @@ class SelectorTests(unittest.TestCase):
two=providers.Object(2), two=providers.Object(2),
) )
self.assertIn( assert "<dependency_injector.providers.Selector({0}".format(repr(self.selector)) in repr(provider)
"<dependency_injector.providers.Selector({0}".format(repr(self.selector)), assert "one={0}".format(repr(provider.one)) in repr(provider)
repr(provider), assert "two={0}".format(repr(provider.two)) in repr(provider)
) assert "at {0}".format(hex(id(provider))) in repr(provider)
self.assertIn("one={0}".format(repr(provider.one)), repr(provider))
self.assertIn("two={0}".format(repr(provider.two)), repr(provider))
self.assertIn("at {0}".format(hex(id(provider))), repr(provider))

View File

@ -6,6 +6,7 @@ from dependency_injector import (
providers, providers,
errors, errors,
) )
from pytest import raises
from .singleton_common import Example, _BaseSingletonTestCase from .singleton_common import Example, _BaseSingletonTestCase
@ -16,12 +17,10 @@ class SingletonTests(_BaseSingletonTestCase, unittest.TestCase):
def test_repr(self): def test_repr(self):
provider = self.singleton_cls(Example) provider = self.singleton_cls(Example)
assert repr(provider) == (
self.assertEqual(repr(provider), "<dependency_injector.providers."
"<dependency_injector.providers." "Singleton({0}) at {1}>".format(repr(Example), hex(id(provider)))
"Singleton({0}) at {1}>".format( )
repr(Example),
hex(id(provider))))
class DelegatedSingletonTests(_BaseSingletonTestCase, unittest.TestCase): class DelegatedSingletonTests(_BaseSingletonTestCase, unittest.TestCase):
@ -30,16 +29,14 @@ class DelegatedSingletonTests(_BaseSingletonTestCase, unittest.TestCase):
def test_is_delegated_provider(self): def test_is_delegated_provider(self):
provider = self.singleton_cls(object) provider = self.singleton_cls(object)
self.assertTrue(providers.is_delegated(provider)) assert providers.is_delegated(provider) is True
def test_repr(self): def test_repr(self):
provider = self.singleton_cls(Example) provider = self.singleton_cls(Example)
assert repr(provider) == (
self.assertEqual(repr(provider), "<dependency_injector.providers."
"<dependency_injector.providers." "DelegatedSingleton({0}) at {1}>".format(repr(Example), hex(id(provider)))
"DelegatedSingleton({0}) at {1}>".format( )
repr(Example),
hex(id(provider))))
class ThreadLocalSingletonTests(_BaseSingletonTestCase, unittest.TestCase): class ThreadLocalSingletonTests(_BaseSingletonTestCase, unittest.TestCase):
@ -48,25 +45,23 @@ class ThreadLocalSingletonTests(_BaseSingletonTestCase, unittest.TestCase):
def test_repr(self): def test_repr(self):
provider = providers.ThreadLocalSingleton(Example) provider = providers.ThreadLocalSingleton(Example)
assert repr(provider) == (
self.assertEqual(repr(provider), "<dependency_injector.providers."
"<dependency_injector.providers." "ThreadLocalSingleton({0}) at {1}>".format(repr(Example), hex(id(provider)))
"ThreadLocalSingleton({0}) at {1}>".format( )
repr(Example),
hex(id(provider))))
def test_reset(self): def test_reset(self):
provider = providers.ThreadLocalSingleton(Example) provider = providers.ThreadLocalSingleton(Example)
instance1 = provider() instance1 = provider()
self.assertIsInstance(instance1, Example) assert isinstance(instance1, Example)
provider.reset() provider.reset()
instance2 = provider() instance2 = provider()
self.assertIsInstance(instance2, Example) assert isinstance(instance2, Example)
self.assertIsNot(instance1, instance2) assert instance1 is not instance2
def test_reset_clean(self): def test_reset_clean(self):
provider = providers.ThreadLocalSingleton(Example) provider = providers.ThreadLocalSingleton(Example)
@ -76,7 +71,7 @@ class ThreadLocalSingletonTests(_BaseSingletonTestCase, unittest.TestCase):
provider.reset() provider.reset()
instance2 = provider() instance2 = provider()
self.assertIsNot(instance1, instance2) assert instance1 is not instance2
class DelegatedThreadLocalSingletonTests(_BaseSingletonTestCase, class DelegatedThreadLocalSingletonTests(_BaseSingletonTestCase,
@ -86,16 +81,15 @@ class DelegatedThreadLocalSingletonTests(_BaseSingletonTestCase,
def test_is_delegated_provider(self): def test_is_delegated_provider(self):
provider = self.singleton_cls(object) provider = self.singleton_cls(object)
self.assertTrue(providers.is_delegated(provider)) assert providers.is_delegated(provider) is True
def test_repr(self): def test_repr(self):
provider = self.singleton_cls(Example) provider = self.singleton_cls(Example)
self.assertEqual(repr(provider), assert repr(provider) == (
"<dependency_injector.providers." "<dependency_injector.providers."
"DelegatedThreadLocalSingleton({0}) at {1}>".format( "DelegatedThreadLocalSingleton({0}) at {1}>".format(repr(Example), hex(id(provider)))
repr(Example), )
hex(id(provider))))
class ThreadSafeSingletonTests(_BaseSingletonTestCase, unittest.TestCase): class ThreadSafeSingletonTests(_BaseSingletonTestCase, unittest.TestCase):
@ -104,12 +98,10 @@ class ThreadSafeSingletonTests(_BaseSingletonTestCase, unittest.TestCase):
def test_repr(self): def test_repr(self):
provider = self.singleton_cls(Example) provider = self.singleton_cls(Example)
assert repr(provider) == (
self.assertEqual(repr(provider), "<dependency_injector.providers."
"<dependency_injector.providers." "ThreadSafeSingleton({0}) at {1}>".format(repr(Example), hex(id(provider)))
"ThreadSafeSingleton({0}) at {1}>".format( )
repr(Example),
hex(id(provider))))
class DelegatedThreadSafeSingletonTests(_BaseSingletonTestCase, class DelegatedThreadSafeSingletonTests(_BaseSingletonTestCase,
@ -119,40 +111,38 @@ class DelegatedThreadSafeSingletonTests(_BaseSingletonTestCase,
def test_is_delegated_provider(self): def test_is_delegated_provider(self):
provider = self.singleton_cls(object) provider = self.singleton_cls(object)
self.assertTrue(providers.is_delegated(provider)) assert providers.is_delegated(provider) is True
def test_repr(self): def test_repr(self):
provider = self.singleton_cls(Example) provider = self.singleton_cls(Example)
assert repr(provider) == (
self.assertEqual(repr(provider), "<dependency_injector.providers."
"<dependency_injector.providers." "DelegatedThreadSafeSingleton({0}) at {1}>".format(repr(Example), hex(id(provider)))
"DelegatedThreadSafeSingleton({0}) at {1}>".format( )
repr(Example),
hex(id(provider))))
class AbstractSingletonTests(unittest.TestCase): class AbstractSingletonTests(unittest.TestCase):
def test_inheritance(self): def test_inheritance(self):
self.assertIsInstance(providers.AbstractSingleton(Example), assert isinstance(providers.AbstractSingleton(Example),
providers.BaseSingleton) providers.BaseSingleton)
def test_call_overridden_by_singleton(self): def test_call_overridden_by_singleton(self):
provider = providers.AbstractSingleton(object) provider = providers.AbstractSingleton(object)
provider.override(providers.Singleton(Example)) provider.override(providers.Singleton(Example))
self.assertIsInstance(provider(), Example) assert isinstance(provider(), Example)
def test_call_overridden_by_delegated_singleton(self): def test_call_overridden_by_delegated_singleton(self):
provider = providers.AbstractSingleton(object) provider = providers.AbstractSingleton(object)
provider.override(providers.DelegatedSingleton(Example)) provider.override(providers.DelegatedSingleton(Example))
self.assertIsInstance(provider(), Example) assert isinstance(provider(), Example)
def test_call_not_overridden(self): def test_call_not_overridden(self):
provider = providers.AbstractSingleton(object) provider = providers.AbstractSingleton(object)
with self.assertRaises(errors.Error): with raises(errors.Error):
provider() provider()
def test_reset_overridden(self): def test_reset_overridden(self):
@ -165,30 +155,28 @@ class AbstractSingletonTests(unittest.TestCase):
instance2 = provider() instance2 = provider()
self.assertIsNot(instance1, instance2) assert instance1 is not instance2
self.assertIsInstance(instance1, Example) assert isinstance(instance1, Example)
self.assertIsInstance(instance2, Example) assert isinstance(instance2, Example)
def test_reset_not_overridden(self): def test_reset_not_overridden(self):
provider = providers.AbstractSingleton(object) provider = providers.AbstractSingleton(object)
with self.assertRaises(errors.Error): with raises(errors.Error):
provider.reset() provider.reset()
def test_override_by_not_singleton(self): def test_override_by_not_singleton(self):
provider = providers.AbstractSingleton(object) provider = providers.AbstractSingleton(object)
with self.assertRaises(errors.Error): with raises(errors.Error):
provider.override(providers.Factory(object)) provider.override(providers.Factory(object))
def test_repr(self): def test_repr(self):
provider = providers.AbstractSingleton(Example) provider = providers.AbstractSingleton(Example)
assert repr(provider) == (
self.assertEqual(repr(provider), "<dependency_injector.providers."
"<dependency_injector.providers." "AbstractSingleton({0}) at {1}>".format(repr(Example), hex(id(provider)))
"AbstractSingleton({0}) at {1}>".format( )
repr(Example),
hex(id(provider))))
class SingletonDelegateTests(unittest.TestCase): class SingletonDelegateTests(unittest.TestCase):
@ -198,9 +186,9 @@ class SingletonDelegateTests(unittest.TestCase):
self.delegate = providers.SingletonDelegate(self.delegated) self.delegate = providers.SingletonDelegate(self.delegated)
def test_is_delegate(self): def test_is_delegate(self):
self.assertIsInstance(self.delegate, providers.Delegate) assert isinstance(self.delegate, providers.Delegate)
def test_init_with_not_singleton(self): def test_init_with_not_singleton(self):
self.assertRaises(errors.Error, raises(errors.Error,
providers.SingletonDelegate, providers.SingletonDelegate,
providers.Object(object())) providers.Object(object()))

View File

@ -12,24 +12,23 @@ class ContextLocalSingletonTests(_BaseSingletonTestCase, unittest.TestCase):
def test_repr(self): def test_repr(self):
provider = providers.ContextLocalSingleton(Example) provider = providers.ContextLocalSingleton(Example)
self.assertEqual(repr(provider), assert repr(provider) == (
"<dependency_injector.providers." "<dependency_injector.providers."
"ContextLocalSingleton({0}) at {1}>".format( "ContextLocalSingleton({0}) at {1}>".format(repr(Example), hex(id(provider)))
repr(Example), )
hex(id(provider))))
def test_reset(self): def test_reset(self):
provider = providers.ContextLocalSingleton(Example) provider = providers.ContextLocalSingleton(Example)
instance1 = provider() instance1 = provider()
self.assertIsInstance(instance1, Example) assert isinstance(instance1, Example)
provider.reset() provider.reset()
instance2 = provider() instance2 = provider()
self.assertIsInstance(instance2, Example) assert isinstance(instance2, Example)
self.assertIsNot(instance1, instance2) assert instance1 is not instance2
def test_reset_clean(self): def test_reset_clean(self):
provider = providers.ContextLocalSingleton(Example) provider = providers.ContextLocalSingleton(Example)
@ -39,4 +38,4 @@ class ContextLocalSingletonTests(_BaseSingletonTestCase, unittest.TestCase):
provider.reset() provider.reset()
instance2 = provider() instance2 = provider()
self.assertIsNot(instance1, instance2) assert instance1 is not instance2

View File

@ -18,10 +18,10 @@ class TraverseTests(unittest.TestCase):
all_providers = list(providers.traverse(provider1)) all_providers = list(providers.traverse(provider1))
self.assertEqual(len(all_providers), 3) assert len(all_providers) == 3
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
self.assertIn(provider3, all_providers) assert provider3 in all_providers
def test_traverse_types_filtering(self): def test_traverse_types_filtering(self):
provider1 = providers.Resource(dict) provider1 = providers.Resource(dict)
@ -36,9 +36,9 @@ class TraverseTests(unittest.TestCase):
all_providers = list(providers.traverse(provider, types=[providers.Resource])) all_providers = list(providers.traverse(provider, types=[providers.Resource]))
self.assertEqual(len(all_providers), 2) assert len(all_providers) == 2
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
class ProviderTests(unittest.TestCase): class ProviderTests(unittest.TestCase):
@ -56,10 +56,10 @@ class ProviderTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3) assert len(all_providers) == 3
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
self.assertIn(provider3, all_providers) assert provider3 in all_providers
def test_traversal_overriding_nested(self): def test_traversal_overriding_nested(self):
provider1 = providers.Provider() provider1 = providers.Provider()
@ -75,10 +75,10 @@ class ProviderTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3) assert len(all_providers) == 3
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
self.assertIn(provider3, all_providers) assert provider3 in all_providers
def test_traverse_types_filtering(self): def test_traverse_types_filtering(self):
provider1 = providers.Resource(dict) provider1 = providers.Resource(dict)
@ -93,9 +93,9 @@ class ProviderTests(unittest.TestCase):
all_providers = list(provider.traverse(types=[providers.Resource])) all_providers = list(provider.traverse(types=[providers.Resource]))
self.assertEqual(len(all_providers), 2) assert len(all_providers) == 2
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
class ObjectTests(unittest.TestCase): class ObjectTests(unittest.TestCase):
@ -103,7 +103,7 @@ class ObjectTests(unittest.TestCase):
def test_traversal(self): def test_traversal(self):
provider = providers.Object("string") provider = providers.Object("string")
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0) assert len(all_providers) == 0
def test_traversal_provider(self): def test_traversal_provider(self):
another_provider = providers.Provider() another_provider = providers.Provider()
@ -111,8 +111,8 @@ class ObjectTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 1) assert len(all_providers) == 1
self.assertIn(another_provider, all_providers) assert another_provider in all_providers
def test_traversal_provider_and_overriding(self): def test_traversal_provider_and_overriding(self):
another_provider_1 = providers.Provider() another_provider_1 = providers.Provider()
@ -126,10 +126,10 @@ class ObjectTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3) assert len(all_providers) == 3
self.assertIn(another_provider_1, all_providers) assert another_provider_1 in all_providers
self.assertIn(another_provider_2, all_providers) assert another_provider_2 in all_providers
self.assertIn(another_provider_3, all_providers) assert another_provider_3 in all_providers
class DelegateTests(unittest.TestCase): class DelegateTests(unittest.TestCase):
@ -140,8 +140,8 @@ class DelegateTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 1) assert len(all_providers) == 1
self.assertIn(another_provider, all_providers) assert another_provider in all_providers
def test_traversal_provider_and_overriding(self): def test_traversal_provider_and_overriding(self):
provider1 = providers.Provider() provider1 = providers.Provider()
@ -156,10 +156,10 @@ class DelegateTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3) assert len(all_providers) == 3
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
self.assertIn(provider3, all_providers) assert provider3 in all_providers
class DependencyTests(unittest.TestCase): class DependencyTests(unittest.TestCase):
@ -167,7 +167,7 @@ class DependencyTests(unittest.TestCase):
def test_traversal(self): def test_traversal(self):
provider = providers.Dependency() provider = providers.Dependency()
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0) assert len(all_providers) == 0
def test_traversal_default(self): def test_traversal_default(self):
another_provider = providers.Provider() another_provider = providers.Provider()
@ -175,8 +175,8 @@ class DependencyTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 1) assert len(all_providers) == 1
self.assertIn(another_provider, all_providers) assert another_provider in all_providers
def test_traversal_overriding(self): def test_traversal_overriding(self):
provider1 = providers.Provider() provider1 = providers.Provider()
@ -189,9 +189,9 @@ class DependencyTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2) assert len(all_providers) == 2
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
class DependenciesContainerTests(unittest.TestCase): class DependenciesContainerTests(unittest.TestCase):
@ -199,7 +199,7 @@ class DependenciesContainerTests(unittest.TestCase):
def test_traversal(self): def test_traversal(self):
provider = providers.DependenciesContainer() provider = providers.DependenciesContainer()
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0) assert len(all_providers) == 0
def test_traversal_default(self): def test_traversal_default(self):
another_provider = providers.Provider() another_provider = providers.Provider()
@ -207,8 +207,8 @@ class DependenciesContainerTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 1) assert len(all_providers) == 1
self.assertIn(another_provider, all_providers) assert another_provider in all_providers
def test_traversal_fluent_interface(self): def test_traversal_fluent_interface(self):
provider = providers.DependenciesContainer() provider = providers.DependenciesContainer()
@ -217,9 +217,9 @@ class DependenciesContainerTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2) assert len(all_providers) == 2
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
def test_traversal_overriding(self): def test_traversal_overriding(self):
provider1 = providers.Provider() provider1 = providers.Provider()
@ -234,12 +234,12 @@ class DependenciesContainerTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 5) assert len(all_providers) == 5
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
self.assertIn(provider3, all_providers) assert provider3 in all_providers
self.assertIn(provider.provider1, all_providers) assert provider.provider1 in all_providers
self.assertIn(provider.provider2, all_providers) assert provider.provider2 in all_providers
class CallableTests(unittest.TestCase): class CallableTests(unittest.TestCase):
@ -247,7 +247,7 @@ class CallableTests(unittest.TestCase):
def test_traverse(self): def test_traverse(self):
provider = providers.Callable(dict) provider = providers.Callable(dict)
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0) assert len(all_providers) == 0
def test_traverse_args(self): def test_traverse_args(self):
provider1 = providers.Object("bar") provider1 = providers.Object("bar")
@ -256,9 +256,9 @@ class CallableTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2) assert len(all_providers) == 2
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
def test_traverse_kwargs(self): def test_traverse_kwargs(self):
provider1 = providers.Object("bar") provider1 = providers.Object("bar")
@ -267,9 +267,9 @@ class CallableTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2) assert len(all_providers) == 2
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
def test_traverse_overridden(self): def test_traverse_overridden(self):
provider1 = providers.Object("bar") provider1 = providers.Object("bar")
@ -281,9 +281,9 @@ class CallableTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2) assert len(all_providers) == 2
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
def test_traverse_provides(self): def test_traverse_provides(self):
provider1 = providers.Callable(list) provider1 = providers.Callable(list)
@ -295,10 +295,10 @@ class CallableTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3) assert len(all_providers) == 3
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
self.assertIn(provider3, all_providers) assert provider3 in all_providers
class ConfigurationTests(unittest.TestCase): class ConfigurationTests(unittest.TestCase):
@ -311,10 +311,10 @@ class ConfigurationTests(unittest.TestCase):
all_providers = list(config.traverse()) all_providers = list(config.traverse())
self.assertEqual(len(all_providers), 3) assert len(all_providers) == 3
self.assertIn(option1, all_providers) assert option1 in all_providers
self.assertIn(option2, all_providers) assert option2 in all_providers
self.assertIn(option3, all_providers) assert option3 in all_providers
def test_traverse_typed(self): def test_traverse_typed(self):
config = providers.Configuration() config = providers.Configuration()
@ -323,8 +323,8 @@ class ConfigurationTests(unittest.TestCase):
all_providers = list(typed_option.traverse()) all_providers = list(typed_option.traverse())
self.assertEqual(len(all_providers), 1) assert len(all_providers) == 1
self.assertIn(option, all_providers) assert option in all_providers
def test_traverse_overridden(self): def test_traverse_overridden(self):
options = {"option1": {"option2": "option2"}} options = {"option1": {"option2": "option2"}}
@ -333,10 +333,10 @@ class ConfigurationTests(unittest.TestCase):
all_providers = list(config.traverse()) all_providers = list(config.traverse())
self.assertEqual(len(all_providers), 1) assert len(all_providers) == 1
overridden, = all_providers overridden, = all_providers
self.assertEqual(overridden(), options) assert overridden() == options
self.assertIs(overridden, config.last_overriding) assert overridden is config.last_overriding
def test_traverse_overridden_option_1(self): def test_traverse_overridden_option_1(self):
options = {"option2": "option2"} options = {"option2": "option2"}
@ -345,9 +345,9 @@ class ConfigurationTests(unittest.TestCase):
all_providers = list(config.traverse()) all_providers = list(config.traverse())
self.assertEqual(len(all_providers), 2) assert len(all_providers) == 2
self.assertIn(config.option1, all_providers) assert config.option1 in all_providers
self.assertIn(config.last_overriding, all_providers) assert config.last_overriding in all_providers
def test_traverse_overridden_option_2(self): def test_traverse_overridden_option_2(self):
options = {"option2": "option2"} options = {"option2": "option2"}
@ -356,7 +356,7 @@ class ConfigurationTests(unittest.TestCase):
all_providers = list(config.option1.traverse()) all_providers = list(config.option1.traverse())
self.assertEqual(len(all_providers), 0) assert len(all_providers) == 0
class FactoryTests(unittest.TestCase): class FactoryTests(unittest.TestCase):
@ -364,7 +364,7 @@ class FactoryTests(unittest.TestCase):
def test_traverse(self): def test_traverse(self):
provider = providers.Factory(dict) provider = providers.Factory(dict)
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0) assert len(all_providers) == 0
def test_traverse_args(self): def test_traverse_args(self):
provider1 = providers.Object("bar") provider1 = providers.Object("bar")
@ -373,9 +373,9 @@ class FactoryTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2) assert len(all_providers) == 2
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
def test_traverse_kwargs(self): def test_traverse_kwargs(self):
provider1 = providers.Object("bar") provider1 = providers.Object("bar")
@ -384,9 +384,9 @@ class FactoryTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2) assert len(all_providers) == 2
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
def test_traverse_attributes(self): def test_traverse_attributes(self):
provider1 = providers.Object("bar") provider1 = providers.Object("bar")
@ -396,9 +396,9 @@ class FactoryTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2) assert len(all_providers) == 2
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
def test_traverse_overridden(self): def test_traverse_overridden(self):
provider1 = providers.Object("bar") provider1 = providers.Object("bar")
@ -410,9 +410,9 @@ class FactoryTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2) assert len(all_providers) == 2
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
def test_traverse_provides(self): def test_traverse_provides(self):
provider1 = providers.Callable(list) provider1 = providers.Callable(list)
@ -424,10 +424,10 @@ class FactoryTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3) assert len(all_providers) == 3
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
self.assertIn(provider3, all_providers) assert provider3 in all_providers
class FactoryAggregateTests(unittest.TestCase): class FactoryAggregateTests(unittest.TestCase):
@ -439,9 +439,9 @@ class FactoryAggregateTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2) assert len(all_providers) == 2
self.assertIn(factory1, all_providers) assert factory1 in all_providers
self.assertIn(factory2, all_providers) assert factory2 in all_providers
class BaseSingletonTests(unittest.TestCase): class BaseSingletonTests(unittest.TestCase):
@ -449,7 +449,7 @@ class BaseSingletonTests(unittest.TestCase):
def test_traverse(self): def test_traverse(self):
provider = providers.Singleton(dict) provider = providers.Singleton(dict)
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0) assert len(all_providers) == 0
def test_traverse_args(self): def test_traverse_args(self):
provider1 = providers.Object("bar") provider1 = providers.Object("bar")
@ -458,9 +458,9 @@ class BaseSingletonTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2) assert len(all_providers) == 2
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
def test_traverse_kwargs(self): def test_traverse_kwargs(self):
provider1 = providers.Object("bar") provider1 = providers.Object("bar")
@ -469,9 +469,9 @@ class BaseSingletonTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2) assert len(all_providers) == 2
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
def test_traverse_attributes(self): def test_traverse_attributes(self):
provider1 = providers.Object("bar") provider1 = providers.Object("bar")
@ -481,9 +481,9 @@ class BaseSingletonTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2) assert len(all_providers) == 2
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
def test_traverse_overridden(self): def test_traverse_overridden(self):
provider1 = providers.Object("bar") provider1 = providers.Object("bar")
@ -495,9 +495,9 @@ class BaseSingletonTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2) assert len(all_providers) == 2
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
def test_traverse_provides(self): def test_traverse_provides(self):
provider1 = providers.Callable(list) provider1 = providers.Callable(list)
@ -509,10 +509,10 @@ class BaseSingletonTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3) assert len(all_providers) == 3
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
self.assertIn(provider3, all_providers) assert provider3 in all_providers
class ListTests(unittest.TestCase): class ListTests(unittest.TestCase):
@ -524,9 +524,9 @@ class ListTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2) assert len(all_providers) == 2
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
def test_traverse_overridden(self): def test_traverse_overridden(self):
provider1 = providers.Object("bar") provider1 = providers.Object("bar")
@ -538,10 +538,10 @@ class ListTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3) assert len(all_providers) == 3
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
self.assertIn(provider3, all_providers) assert provider3 in all_providers
class DictTests(unittest.TestCase): class DictTests(unittest.TestCase):
@ -553,9 +553,9 @@ class DictTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2) assert len(all_providers) == 2
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
def test_traverse_overridden(self): def test_traverse_overridden(self):
provider1 = providers.Object("bar") provider1 = providers.Object("bar")
@ -567,10 +567,10 @@ class DictTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3) assert len(all_providers) == 3
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
self.assertIn(provider3, all_providers) assert provider3 in all_providers
class ResourceTests(unittest.TestCase): class ResourceTests(unittest.TestCase):
@ -578,7 +578,7 @@ class ResourceTests(unittest.TestCase):
def test_traverse(self): def test_traverse(self):
provider = providers.Resource(dict) provider = providers.Resource(dict)
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 0) assert len(all_providers) == 0
def test_traverse_args(self): def test_traverse_args(self):
provider1 = providers.Object("bar") provider1 = providers.Object("bar")
@ -587,9 +587,9 @@ class ResourceTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2) assert len(all_providers) == 2
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
def test_traverse_kwargs(self): def test_traverse_kwargs(self):
provider1 = providers.Object("bar") provider1 = providers.Object("bar")
@ -598,9 +598,9 @@ class ResourceTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2) assert len(all_providers) == 2
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
def test_traverse_overridden(self): def test_traverse_overridden(self):
provider1 = providers.Resource(list) provider1 = providers.Resource(list)
@ -612,9 +612,9 @@ class ResourceTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2) assert len(all_providers) == 2
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
def test_traverse_provides(self): def test_traverse_provides(self):
provider1 = providers.Callable(list) provider1 = providers.Callable(list)
@ -623,8 +623,8 @@ class ResourceTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 1) assert len(all_providers) == 1
self.assertIn(provider1, all_providers) assert provider1 in all_providers
class ContainerTests(unittest.TestCase): class ContainerTests(unittest.TestCase):
@ -638,11 +638,8 @@ class ContainerTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2) assert len(all_providers) == 2
self.assertEqual( assert {list, dict} == {provider.provides for provider in all_providers}
{provider.provides for provider in all_providers},
{list, dict},
)
def test_traverse_overridden(self): def test_traverse_overridden(self):
class Container1(containers.DeclarativeContainer): class Container1(containers.DeclarativeContainer):
@ -660,17 +657,14 @@ class ContainerTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 5) assert len(all_providers) == 5
self.assertEqual( assert {list, dict, tuple, str} == {
{ provider.provides
provider.provides for provider in all_providers
for provider in all_providers if isinstance(provider, providers.Callable)
if isinstance(provider, providers.Callable) }
}, assert provider.last_overriding in all_providers
{list, dict, tuple, str}, assert provider.last_overriding() is container2
)
self.assertIn(provider.last_overriding, all_providers)
self.assertIs(provider.last_overriding(), container2)
class SelectorTests(unittest.TestCase): class SelectorTests(unittest.TestCase):
@ -688,9 +682,9 @@ class SelectorTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2) assert len(all_providers) == 2
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
def test_traverse_switch(self): def test_traverse_switch(self):
switch = providers.Callable(lambda: "provider1") switch = providers.Callable(lambda: "provider1")
@ -705,10 +699,10 @@ class SelectorTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3) assert len(all_providers) == 3
self.assertIn(switch, all_providers) assert switch in all_providers
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
def test_traverse_overridden(self): def test_traverse_overridden(self):
provider1 = providers.Callable(list) provider1 = providers.Callable(list)
@ -723,10 +717,10 @@ class SelectorTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3) assert len(all_providers) == 3
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
self.assertIn(selector1, all_providers) assert selector1 in all_providers
class ProvidedInstanceTests(unittest.TestCase): class ProvidedInstanceTests(unittest.TestCase):
@ -737,8 +731,8 @@ class ProvidedInstanceTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 1) assert len(all_providers) == 1
self.assertIn(provider1, all_providers) assert provider1 in all_providers
def test_traverse_overridden(self): def test_traverse_overridden(self):
provider1 = providers.Provider() provider1 = providers.Provider()
@ -749,9 +743,9 @@ class ProvidedInstanceTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2) assert len(all_providers) == 2
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
class AttributeGetterTests(unittest.TestCase): class AttributeGetterTests(unittest.TestCase):
@ -763,9 +757,9 @@ class AttributeGetterTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2) assert len(all_providers) == 2
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provided, all_providers) assert provided in all_providers
def test_traverse_overridden(self): def test_traverse_overridden(self):
provider1 = providers.Provider() provider1 = providers.Provider()
@ -777,10 +771,10 @@ class AttributeGetterTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3) assert len(all_providers) == 3
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
self.assertIn(provided, all_providers) assert provided in all_providers
class ItemGetterTests(unittest.TestCase): class ItemGetterTests(unittest.TestCase):
@ -792,9 +786,9 @@ class ItemGetterTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 2) assert len(all_providers) == 2
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provided, all_providers) assert provided in all_providers
def test_traverse_overridden(self): def test_traverse_overridden(self):
provider1 = providers.Provider() provider1 = providers.Provider()
@ -806,10 +800,10 @@ class ItemGetterTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3) assert len(all_providers) == 3
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
self.assertIn(provided, all_providers) assert provided in all_providers
class MethodCallerTests(unittest.TestCase): class MethodCallerTests(unittest.TestCase):
@ -822,10 +816,10 @@ class MethodCallerTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 3) assert len(all_providers) == 3
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provided, all_providers) assert provided in all_providers
self.assertIn(method, all_providers) assert method in all_providers
def test_traverse_args(self): def test_traverse_args(self):
provider1 = providers.Provider() provider1 = providers.Provider()
@ -836,11 +830,11 @@ class MethodCallerTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 4) assert len(all_providers) == 4
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
self.assertIn(provided, all_providers) assert provided in all_providers
self.assertIn(method, all_providers) assert method in all_providers
def test_traverse_kwargs(self): def test_traverse_kwargs(self):
provider1 = providers.Provider() provider1 = providers.Provider()
@ -851,11 +845,11 @@ class MethodCallerTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 4) assert len(all_providers) == 4
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
self.assertIn(provided, all_providers) assert provided in all_providers
self.assertIn(method, all_providers) assert method in all_providers
def test_traverse_overridden(self): def test_traverse_overridden(self):
provider1 = providers.Provider() provider1 = providers.Provider()
@ -868,8 +862,8 @@ class MethodCallerTests(unittest.TestCase):
all_providers = list(provider.traverse()) all_providers = list(provider.traverse())
self.assertEqual(len(all_providers), 4) assert len(all_providers) == 4
self.assertIn(provider1, all_providers) assert provider1 in all_providers
self.assertIn(provider2, all_providers) assert provider2 in all_providers
self.assertIn(provided, all_providers) assert provided in all_providers
self.assertIn(method, all_providers) assert method in all_providers