diff --git a/dependency_injector/catalog.py b/dependency_injector/catalog.py index 41ea95b1..08b6d2f0 100644 --- a/dependency_injector/catalog.py +++ b/dependency_injector/catalog.py @@ -21,14 +21,14 @@ class CatalogBundle(object): def __init__(self, *providers): """Initializer.""" - self.providers = dict((provider.bind.name, provider) - for provider in providers - if self._ensure_provider_is_bound(provider)) + self.providers = dict((self.catalog.get_provider_bind_name(provider), + provider) + for provider in providers) self.__dict__.update(self.providers) super(CatalogBundle, self).__init__() def get(self, name): - """Return provider with specified name or raises error.""" + """Return provider with specified name or raise an error.""" try: return self.providers[name] except KeyError: @@ -38,16 +38,6 @@ class CatalogBundle(object): """Check if there is provider with certain name.""" return name in self.providers - def _ensure_provider_is_bound(self, provider): - """Check that provider is bound to the bundle's catalog.""" - if not provider.is_bound: - raise Error('Provider {0} is not bound to ' - 'any catalog'.format(provider)) - if provider is not self.catalog.get(provider.bind.name): - raise Error('{0} can contain providers from ' - 'catalog {0}'.format(self.__class__, self.catalog)) - return True - def _raise_undefined_provider_error(self, name): """Raise error for cases when there is no such provider in bundle.""" raise Error('Provider "{0}" is not a part of {1}'.format(name, self)) @@ -67,15 +57,6 @@ class CatalogBundle(object): __str__ = __repr__ -class Catalog(object): - """Catalog of providers.""" - - def __init__(self, name, **providers): - """Initializer.""" - self.name = name - self.providers = providers - - @six.python_2_unicode_compatible class DeclarativeCatalogMetaClass(type): """Declarative catalog meta class.""" @@ -105,13 +86,13 @@ class DeclarativeCatalogMetaClass(type): cls.Bundle = mcs.bundle_cls_factory(cls) - for name, provider in six.iteritems(cls_providers): - if provider.is_bound: - raise Error('Provider {0} has been already bound to catalog' - '{1} as "{2}"'.format(provider, - provider.bind.catalog, - provider.bind.name)) - provider.bind = ProviderBinding(cls, name) + cls.provider_names = dict() + for name, provider in six.iteritems(providers): + if provider in cls.provider_names: + raise Error('Provider {0} could not be bound to the same ' + 'catalog (or catalogs hierarchy) more ' + 'than once'.format(provider)) + cls.provider_names[provider] = name return cls @@ -151,6 +132,10 @@ class DeclarativeCatalog(object): :param providers: Dict of all catalog providers, including inherited from parent catalogs + :type provider_names: dict[dependency_injector.Provider, str] + :param provider_names: Dict of all catalog providers, including inherited + from parent catalogs + :type cls_providers: dict[str, dependency_injector.Provider] :param cls_providers: Dict of current catalog providers @@ -174,6 +159,7 @@ class DeclarativeCatalog(object): cls_providers = dict() inherited_providers = dict() providers = dict() + provider_names = dict() overridden_by = tuple() is_overridden = bool @@ -186,6 +172,19 @@ class DeclarativeCatalog(object): """Check if catalog is bundle owner.""" return ensure_is_catalog_bundle(bundle) and bundle.catalog is cls + @classmethod + def get_provider_bind_name(cls, provider): + """Return provider's name in catalog.""" + if not cls.is_provider_bound(provider): + raise Error('Can not find bind name for {0} in catalog {1}'.format( + provider, cls)) + return cls.provider_names[provider] + + @classmethod + def is_provider_bound(cls, provider): + """Check if provider is bound to the catalog.""" + return provider in cls.provider_names + @classmethod def filter(cls, provider_type): """Return dict of providers, that are instance of provided type.""" @@ -238,17 +237,6 @@ class DeclarativeCatalog(object): AbstractCatalog = DeclarativeCatalog -class ProviderBinding(object): - """Catalog provider binding.""" - - __slots__ = ('catalog', 'name') - - def __init__(self, catalog, name): - """Initializer.""" - self.catalog = catalog - self.name = name - - def override(catalog): """Catalog overriding decorator.""" def decorator(overriding_catalog): diff --git a/dependency_injector/providers.py b/dependency_injector/providers.py index 59e1c314..22595481 100644 --- a/dependency_injector/providers.py +++ b/dependency_injector/providers.py @@ -19,12 +19,11 @@ class Provider(object): """Base provider class.""" __IS_PROVIDER__ = True - __slots__ = ('overridden_by', 'bind') + __slots__ = ('overridden_by',) def __init__(self): """Initializer.""" self.overridden_by = None - self.bind = None def __call__(self, *args, **kwargs): """Return provided instance.""" @@ -75,11 +74,6 @@ class Provider(object): """Reset all overriding providers.""" self.overridden_by = None - @property - def is_bound(self): - """Check if provider is bound to any catalog.""" - return bool(self.bind) - class Delegate(Provider): """Provider's delegate.""" diff --git a/tests/test_catalog.py b/tests/test_catalog.py index 5a28d050..0f1d37a6 100644 --- a/tests/test_catalog.py +++ b/tests/test_catalog.py @@ -76,17 +76,64 @@ class CatalogProvidersBindingTests(unittest.TestCase): def test_provider_is_bound(self): """Test that providers are bound to the catalogs.""" - self.assertIs(CatalogA.p11.bind.catalog, CatalogA) - self.assertEquals(CatalogA.p11.bind.name, 'p11') + self.assertTrue(CatalogA.is_provider_bound(CatalogA.p11)) + self.assertEquals(CatalogA.get_provider_bind_name(CatalogA.p11), 'p11') - self.assertIs(CatalogA.p12.bind.catalog, CatalogA) - self.assertEquals(CatalogA.p12.bind.name, 'p12') + self.assertTrue(CatalogA.is_provider_bound(CatalogA.p12)) + self.assertEquals(CatalogA.get_provider_bind_name(CatalogA.p12), 'p12') - def test_provider_rebinding(self): - """Test that provider could not be bound twice.""" - self.assertRaises(di.Error, type, 'TestCatalog', - (di.DeclarativeCatalog,), - dict(some_name=CatalogA.p11)) + def test_provider_binding_to_different_catalogs(self): + """Test that provider could be bound to different catalogs.""" + p11 = CatalogA.p11 + p12 = CatalogA.p12 + + class CatalogD(di.DeclarativeCatalog): + """Test catalog.""" + + pd1 = p11 + pd2 = p12 + + class CatalogE(di.DeclarativeCatalog): + """Test catalog.""" + + pe1 = p11 + pe2 = p12 + + self.assertTrue(CatalogA.is_provider_bound(p11)) + self.assertTrue(CatalogD.is_provider_bound(p11)) + self.assertTrue(CatalogE.is_provider_bound(p11)) + self.assertEquals(CatalogA.get_provider_bind_name(p11), 'p11') + self.assertEquals(CatalogD.get_provider_bind_name(p11), 'pd1') + self.assertEquals(CatalogE.get_provider_bind_name(p11), 'pe1') + + self.assertTrue(CatalogA.is_provider_bound(p12)) + self.assertTrue(CatalogD.is_provider_bound(p12)) + self.assertTrue(CatalogE.is_provider_bound(p12)) + self.assertEquals(CatalogA.get_provider_bind_name(p12), 'p12') + self.assertEquals(CatalogD.get_provider_bind_name(p12), 'pd2') + self.assertEquals(CatalogE.get_provider_bind_name(p12), 'pe2') + + def test_provider_rebinding_to_the_same_catalog(self): + """Test provider rebinding to the same catalog.""" + with self.assertRaises(di.Error): + class TestCatalog(di.DeclarativeCatalog): + """Test catalog.""" + + p1 = di.Provider() + p2 = p1 + + def test_provider_rebinding_to_the_same_catalogs_hierarchy(self): + """Test provider rebinding to the same catalogs hierarchy.""" + class TestCatalog1(di.DeclarativeCatalog): + """Test catalog.""" + + p1 = di.Provider() + + with self.assertRaises(di.Error): + class TestCatalog2(TestCatalog1): + """Test catalog.""" + + p2 = TestCatalog1.p1 class CatalogBundleTests(unittest.TestCase):