python-dependency-injector/tests/test_catalog.py

125 lines
4.1 KiB
Python

"""Dependency injector catalog unittests."""
import unittest2 as unittest
import dependency_injector as di
class CatalogsInheritanceTests(unittest.TestCase):
"""Catalogs inheritance tests."""
class CatalogA(di.AbstractCatalog):
"""Test catalog A."""
p11 = di.Provider()
p12 = di.Provider()
class CatalogB(CatalogA):
"""Test catalog B."""
p21 = di.Provider()
p22 = di.Provider()
class CatalogC(CatalogB):
"""Test catalog C."""
p31 = di.Provider()
p32 = di.Provider()
def test_cls_providers(self):
"""Test `di.AbstractCatalog.cls_providers` contents."""
self.assertDictEqual(self.CatalogA.cls_providers,
dict(p11=self.CatalogA.p11,
p12=self.CatalogA.p12))
self.assertDictEqual(self.CatalogB.cls_providers,
dict(p21=self.CatalogB.p21,
p22=self.CatalogB.p22))
self.assertDictEqual(self.CatalogC.cls_providers,
dict(p31=self.CatalogC.p31,
p32=self.CatalogC.p32))
def test_inherited_providers(self):
"""Test `di.AbstractCatalog.inherited_providers` contents."""
self.assertDictEqual(self.CatalogA.inherited_providers, dict())
self.assertDictEqual(self.CatalogB.inherited_providers,
dict(p11=self.CatalogA.p11,
p12=self.CatalogA.p12))
self.assertDictEqual(self.CatalogC.inherited_providers,
dict(p11=self.CatalogA.p11,
p12=self.CatalogA.p12,
p21=self.CatalogB.p21,
p22=self.CatalogB.p22))
def test_providers(self):
"""Test `di.AbstractCatalog.inherited_providers` contents."""
self.assertDictEqual(self.CatalogA.providers,
dict(p11=self.CatalogA.p11,
p12=self.CatalogA.p12))
self.assertDictEqual(self.CatalogB.providers,
dict(p11=self.CatalogA.p11,
p12=self.CatalogA.p12,
p21=self.CatalogB.p21,
p22=self.CatalogB.p22))
self.assertDictEqual(self.CatalogC.providers,
dict(p11=self.CatalogA.p11,
p12=self.CatalogA.p12,
p21=self.CatalogB.p21,
p22=self.CatalogB.p22,
p31=self.CatalogC.p31,
p32=self.CatalogC.p32))
class CatalogTests(unittest.TestCase):
"""Catalog test cases."""
class Catalog(di.AbstractCatalog):
"""Test catalog."""
obj = di.Object(object())
another_obj = di.Object(object())
def test_get_used(self):
"""Test retrieving used provider."""
catalog = self.Catalog(self.Catalog.obj)
self.assertIsInstance(catalog.obj(), object)
def test_get_unused(self):
"""Test retrieving unused provider."""
catalog = self.Catalog()
self.assertRaises(di.Error, getattr, catalog, 'obj')
def test_all_providers_by_type(self):
"""Test getting of all catalog providers of specific type."""
self.assertTrue(len(self.Catalog.filter(di.Object)) == 2)
self.assertTrue(len(self.Catalog.filter(di.Value)) == 0)
class OverrideTests(unittest.TestCase):
"""Override decorator test cases."""
class Catalog(di.AbstractCatalog):
"""Test catalog."""
obj = di.Object(object())
another_obj = di.Object(object())
def test_overriding(self):
"""Test catalog overriding with another catalog."""
@di.override(self.Catalog)
class OverridingCatalog(self.Catalog):
"""Overriding catalog."""
obj = di.Value(1)
another_obj = di.Value(2)
self.assertEqual(self.Catalog.obj(), 1)
self.assertEqual(self.Catalog.another_obj(), 2)