From 2236d77313aeac3322e4688bbb7a9945408187cc Mon Sep 17 00:00:00 2001 From: Roman Mogilatov Date: Thu, 12 Nov 2015 16:56:00 +0200 Subject: [PATCH] Refactor DeclarativeCatalog --- dependency_injector/catalogs.py | 40 +++++++++++++++++++++--- tests/test_catalogs.py | 54 ++++++++++++++++++++++++++++++++- 2 files changed, 88 insertions(+), 6 deletions(-) diff --git a/dependency_injector/catalogs.py b/dependency_injector/catalogs.py index 843da8b0..b1bd26d6 100644 --- a/dependency_injector/catalogs.py +++ b/dependency_injector/catalogs.py @@ -178,6 +178,12 @@ class DynamicCatalog(object): """Check if there is provider with certain name.""" return name in self.providers + def unbind_provider(self, name): + """Remove provider binding.""" + provider = self.get_provider(name) + del self.providers[name] + del self.provider_names[provider] + def __getattr__(self, name): """Return provider with specified name or raise en error.""" return self.get_provider(name) @@ -192,6 +198,14 @@ class DynamicCatalog(object): return self.bind_provider(name, value) return super(DynamicCatalog, self).__setattr__(name, value) + def __delattr__(self, name): + """Handle deleting of catalog attibute. + + Deleting of attributes works as usual, but if value of attribute is + provider, this provider will be unbound from catalog correctly. + """ + self.unbind_provider(name) + def __repr__(self): """Return Python representation of catalog.""" return '<{0}({1})>'.format(self.name, @@ -270,6 +284,16 @@ class DeclarativeCatalogMetaClass(type): setattr(cls.catalog, name, value) return super(DeclarativeCatalogMetaClass, cls).__setattr__(name, value) + def __delattr__(cls, name): + """Handle deleting of catalog attibute. + + Deleting of attributes works as usual, but if value of attribute is + provider, this provider will be unbound from catalog correctly. + """ + if is_provider(getattr(cls, name)): + delattr(cls.catalog, name) + return super(DeclarativeCatalogMetaClass, cls).__delattr__(name) + def __repr__(cls): """Return string representation of the catalog.""" return '<{0}({1})>'.format(cls.name, @@ -360,12 +384,12 @@ class DeclarativeCatalog(object): @classmethod def reset_last_overriding(cls): """Reset last overriding catalog.""" - return cls.catalog.reset_last_overriding() + cls.catalog.reset_last_overriding() @classmethod def reset_override(cls): """Reset all overridings for all catalog providers.""" - return cls.catalog.reset_override() + cls.catalog.reset_override() @classmethod def get_provider(cls, name): @@ -375,17 +399,23 @@ class DeclarativeCatalog(object): @classmethod def bind_provider(cls, name, provider): """Bind provider to catalog with specified name.""" - return cls.catalog.bind_provider(name, provider) + setattr(cls, name, provider) @classmethod def bind_providers(cls, providers): """Bind providers dictionary to catalog.""" - return cls.catalog.bind_providers(providers) + for name, provider in six.iteritems(providers): + setattr(cls, name, provider) @classmethod def has_provider(cls, name): """Check if there is provider with certain name.""" - return cls.catalog.has_provider(name) + return hasattr(cls, name) + + @classmethod + def unbind_provider(cls, name): + """Remove provider binding.""" + delattr(cls, name) get = get_provider # Backward compatibility for versions < 0.11.* has = has_provider # Backward compatibility for versions < 0.11.* diff --git a/tests/test_catalogs.py b/tests/test_catalogs.py index 081ca80f..0f24788c 100644 --- a/tests/test_catalogs.py +++ b/tests/test_catalogs.py @@ -157,8 +157,45 @@ class DeclarativeCatalogTests(unittest.TestCase): p31=CatalogC.p31, p32=CatalogC.p32)) + def test_bind_provider(self): + """Test setting of provider via bind_provider() to catalog.""" + px = di.Provider() + py = di.Provider() + + CatalogC.bind_provider('px', px) + CatalogC.bind_provider('py', py) + + self.assertIs(CatalogC.px, px) + self.assertIs(CatalogC.get_provider('px'), px) + self.assertIs(CatalogC.catalog.px, px) + + self.assertIs(CatalogC.py, py) + self.assertIs(CatalogC.get_provider('py'), py) + self.assertIs(CatalogC.catalog.py, py) + + del CatalogC.px + del CatalogC.py + + def test_bind_providers(self): + """Test setting of provider via bind_providers() to catalog.""" + px = di.Provider() + py = di.Provider() + + CatalogC.bind_providers(dict(px=px, py=py)) + + self.assertIs(CatalogC.px, px) + self.assertIs(CatalogC.get_provider('px'), px) + self.assertIs(CatalogC.catalog.px, px) + + self.assertIs(CatalogC.py, py) + self.assertIs(CatalogC.get_provider('py'), py) + self.assertIs(CatalogC.catalog.py, py) + + del CatalogC.px + del CatalogC.py + def test_setattr(self): - """Test setting of provider attributes to catalog.""" + """Test setting of providers via attributes to catalog.""" px = di.Provider() py = di.Provider() @@ -173,6 +210,21 @@ class DeclarativeCatalogTests(unittest.TestCase): self.assertIs(CatalogC.get_provider('py'), py) self.assertIs(CatalogC.catalog.py, py) + del CatalogC.px + del CatalogC.py + + def test_unbind_provider(self): + """Test that catalog unbinds provider correct.""" + CatalogC.px = di.Provider() + CatalogC.unbind_provider('px') + self.assertFalse(CatalogC.has_provider('px')) + + def test_unbind_via_delattr(self): + """Test that catalog unbinds provider correct.""" + CatalogC.px = di.Provider() + del CatalogC.px + self.assertFalse(CatalogC.has_provider('px')) + def test_provider_is_bound(self): """Test that providers are bound to the catalogs.""" self.assertTrue(CatalogA.is_provider_bound(CatalogA.p11))