mirror of
				https://github.com/ets-labs/python-dependency-injector.git
				synced 2025-11-04 09:57:37 +03:00 
			
		
		
		
	Change provider-to-catalog binding restrictions
This commit is contained in:
		
							parent
							
								
									9f526a3ceb
								
							
						
					
					
						commit
						c5602cf88b
					
				| 
						 | 
				
			
			@ -21,14 +21,14 @@ class CatalogBundle(object):
 | 
			
		|||
 | 
			
		||||
    def __init__(self, *providers):
 | 
			
		||||
        """Initializer."""
 | 
			
		||||
        self.providers = dict((provider.bind.name, provider)
 | 
			
		||||
                              for provider in providers
 | 
			
		||||
                              if self._ensure_provider_is_bound(provider))
 | 
			
		||||
        self.providers = dict((self.catalog.get_provider_bind_name(provider),
 | 
			
		||||
                               provider)
 | 
			
		||||
                              for provider in providers)
 | 
			
		||||
        self.__dict__.update(self.providers)
 | 
			
		||||
        super(CatalogBundle, self).__init__()
 | 
			
		||||
 | 
			
		||||
    def get(self, name):
 | 
			
		||||
        """Return provider with specified name or raises error."""
 | 
			
		||||
        """Return provider with specified name or raise an error."""
 | 
			
		||||
        try:
 | 
			
		||||
            return self.providers[name]
 | 
			
		||||
        except KeyError:
 | 
			
		||||
| 
						 | 
				
			
			@ -38,16 +38,6 @@ class CatalogBundle(object):
 | 
			
		|||
        """Check if there is provider with certain name."""
 | 
			
		||||
        return name in self.providers
 | 
			
		||||
 | 
			
		||||
    def _ensure_provider_is_bound(self, provider):
 | 
			
		||||
        """Check that provider is bound to the bundle's catalog."""
 | 
			
		||||
        if not provider.is_bound:
 | 
			
		||||
            raise Error('Provider {0} is not bound to '
 | 
			
		||||
                        'any catalog'.format(provider))
 | 
			
		||||
        if provider is not self.catalog.get(provider.bind.name):
 | 
			
		||||
            raise Error('{0} can contain providers from '
 | 
			
		||||
                        'catalog {0}'.format(self.__class__, self.catalog))
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    def _raise_undefined_provider_error(self, name):
 | 
			
		||||
        """Raise error for cases when there is no such provider in bundle."""
 | 
			
		||||
        raise Error('Provider "{0}" is not a part of {1}'.format(name, self))
 | 
			
		||||
| 
						 | 
				
			
			@ -67,15 +57,6 @@ class CatalogBundle(object):
 | 
			
		|||
    __str__ = __repr__
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Catalog(object):
 | 
			
		||||
    """Catalog of providers."""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, name, **providers):
 | 
			
		||||
        """Initializer."""
 | 
			
		||||
        self.name = name
 | 
			
		||||
        self.providers = providers
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@six.python_2_unicode_compatible
 | 
			
		||||
class DeclarativeCatalogMetaClass(type):
 | 
			
		||||
    """Declarative catalog meta class."""
 | 
			
		||||
| 
						 | 
				
			
			@ -105,13 +86,13 @@ class DeclarativeCatalogMetaClass(type):
 | 
			
		|||
 | 
			
		||||
        cls.Bundle = mcs.bundle_cls_factory(cls)
 | 
			
		||||
 | 
			
		||||
        for name, provider in six.iteritems(cls_providers):
 | 
			
		||||
            if provider.is_bound:
 | 
			
		||||
                raise Error('Provider {0} has been already bound to catalog'
 | 
			
		||||
                            '{1} as "{2}"'.format(provider,
 | 
			
		||||
                                                  provider.bind.catalog,
 | 
			
		||||
                                                  provider.bind.name))
 | 
			
		||||
            provider.bind = ProviderBinding(cls, name)
 | 
			
		||||
        cls.provider_names = dict()
 | 
			
		||||
        for name, provider in six.iteritems(providers):
 | 
			
		||||
            if provider in cls.provider_names:
 | 
			
		||||
                raise Error('Provider {0} could not be bound to the same '
 | 
			
		||||
                            'catalog (or catalogs hierarchy) more '
 | 
			
		||||
                            'than once'.format(provider))
 | 
			
		||||
            cls.provider_names[provider] = name
 | 
			
		||||
 | 
			
		||||
        return cls
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -151,6 +132,10 @@ class DeclarativeCatalog(object):
 | 
			
		|||
    :param providers: Dict of all catalog providers, including inherited from
 | 
			
		||||
        parent catalogs
 | 
			
		||||
 | 
			
		||||
    :type provider_names: dict[dependency_injector.Provider, str]
 | 
			
		||||
    :param provider_names: Dict of all catalog providers, including inherited
 | 
			
		||||
        from parent catalogs
 | 
			
		||||
 | 
			
		||||
    :type cls_providers: dict[str, dependency_injector.Provider]
 | 
			
		||||
    :param cls_providers: Dict of current catalog providers
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -174,6 +159,7 @@ class DeclarativeCatalog(object):
 | 
			
		|||
    cls_providers = dict()
 | 
			
		||||
    inherited_providers = dict()
 | 
			
		||||
    providers = dict()
 | 
			
		||||
    provider_names = dict()
 | 
			
		||||
 | 
			
		||||
    overridden_by = tuple()
 | 
			
		||||
    is_overridden = bool
 | 
			
		||||
| 
						 | 
				
			
			@ -186,6 +172,19 @@ class DeclarativeCatalog(object):
 | 
			
		|||
        """Check if catalog is bundle owner."""
 | 
			
		||||
        return ensure_is_catalog_bundle(bundle) and bundle.catalog is cls
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def get_provider_bind_name(cls, provider):
 | 
			
		||||
        """Return provider's name in catalog."""
 | 
			
		||||
        if not cls.is_provider_bound(provider):
 | 
			
		||||
            raise Error('Can not find bind name for {0} in catalog {1}'.format(
 | 
			
		||||
                provider, cls))
 | 
			
		||||
        return cls.provider_names[provider]
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def is_provider_bound(cls, provider):
 | 
			
		||||
        """Check if provider is bound to the catalog."""
 | 
			
		||||
        return provider in cls.provider_names
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def filter(cls, provider_type):
 | 
			
		||||
        """Return dict of providers, that are instance of provided type."""
 | 
			
		||||
| 
						 | 
				
			
			@ -238,17 +237,6 @@ class DeclarativeCatalog(object):
 | 
			
		|||
AbstractCatalog = DeclarativeCatalog
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ProviderBinding(object):
 | 
			
		||||
    """Catalog provider binding."""
 | 
			
		||||
 | 
			
		||||
    __slots__ = ('catalog', 'name')
 | 
			
		||||
 | 
			
		||||
    def __init__(self, catalog, name):
 | 
			
		||||
        """Initializer."""
 | 
			
		||||
        self.catalog = catalog
 | 
			
		||||
        self.name = name
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def override(catalog):
 | 
			
		||||
    """Catalog overriding decorator."""
 | 
			
		||||
    def decorator(overriding_catalog):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,12 +19,11 @@ class Provider(object):
 | 
			
		|||
    """Base provider class."""
 | 
			
		||||
 | 
			
		||||
    __IS_PROVIDER__ = True
 | 
			
		||||
    __slots__ = ('overridden_by', 'bind')
 | 
			
		||||
    __slots__ = ('overridden_by',)
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        """Initializer."""
 | 
			
		||||
        self.overridden_by = None
 | 
			
		||||
        self.bind = None
 | 
			
		||||
 | 
			
		||||
    def __call__(self, *args, **kwargs):
 | 
			
		||||
        """Return provided instance."""
 | 
			
		||||
| 
						 | 
				
			
			@ -75,11 +74,6 @@ class Provider(object):
 | 
			
		|||
        """Reset all overriding providers."""
 | 
			
		||||
        self.overridden_by = None
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def is_bound(self):
 | 
			
		||||
        """Check if provider is bound to any catalog."""
 | 
			
		||||
        return bool(self.bind)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Delegate(Provider):
 | 
			
		||||
    """Provider's delegate."""
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -76,17 +76,64 @@ class CatalogProvidersBindingTests(unittest.TestCase):
 | 
			
		|||
 | 
			
		||||
    def test_provider_is_bound(self):
 | 
			
		||||
        """Test that providers are bound to the catalogs."""
 | 
			
		||||
        self.assertIs(CatalogA.p11.bind.catalog, CatalogA)
 | 
			
		||||
        self.assertEquals(CatalogA.p11.bind.name, 'p11')
 | 
			
		||||
        self.assertTrue(CatalogA.is_provider_bound(CatalogA.p11))
 | 
			
		||||
        self.assertEquals(CatalogA.get_provider_bind_name(CatalogA.p11), 'p11')
 | 
			
		||||
 | 
			
		||||
        self.assertIs(CatalogA.p12.bind.catalog, CatalogA)
 | 
			
		||||
        self.assertEquals(CatalogA.p12.bind.name, 'p12')
 | 
			
		||||
        self.assertTrue(CatalogA.is_provider_bound(CatalogA.p12))
 | 
			
		||||
        self.assertEquals(CatalogA.get_provider_bind_name(CatalogA.p12), 'p12')
 | 
			
		||||
 | 
			
		||||
    def test_provider_rebinding(self):
 | 
			
		||||
        """Test that provider could not be bound twice."""
 | 
			
		||||
        self.assertRaises(di.Error, type, 'TestCatalog',
 | 
			
		||||
                          (di.DeclarativeCatalog,),
 | 
			
		||||
                          dict(some_name=CatalogA.p11))
 | 
			
		||||
    def test_provider_binding_to_different_catalogs(self):
 | 
			
		||||
        """Test that provider could be bound to different catalogs."""
 | 
			
		||||
        p11 = CatalogA.p11
 | 
			
		||||
        p12 = CatalogA.p12
 | 
			
		||||
 | 
			
		||||
        class CatalogD(di.DeclarativeCatalog):
 | 
			
		||||
            """Test catalog."""
 | 
			
		||||
 | 
			
		||||
            pd1 = p11
 | 
			
		||||
            pd2 = p12
 | 
			
		||||
 | 
			
		||||
        class CatalogE(di.DeclarativeCatalog):
 | 
			
		||||
            """Test catalog."""
 | 
			
		||||
 | 
			
		||||
            pe1 = p11
 | 
			
		||||
            pe2 = p12
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(CatalogA.is_provider_bound(p11))
 | 
			
		||||
        self.assertTrue(CatalogD.is_provider_bound(p11))
 | 
			
		||||
        self.assertTrue(CatalogE.is_provider_bound(p11))
 | 
			
		||||
        self.assertEquals(CatalogA.get_provider_bind_name(p11), 'p11')
 | 
			
		||||
        self.assertEquals(CatalogD.get_provider_bind_name(p11), 'pd1')
 | 
			
		||||
        self.assertEquals(CatalogE.get_provider_bind_name(p11), 'pe1')
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(CatalogA.is_provider_bound(p12))
 | 
			
		||||
        self.assertTrue(CatalogD.is_provider_bound(p12))
 | 
			
		||||
        self.assertTrue(CatalogE.is_provider_bound(p12))
 | 
			
		||||
        self.assertEquals(CatalogA.get_provider_bind_name(p12), 'p12')
 | 
			
		||||
        self.assertEquals(CatalogD.get_provider_bind_name(p12), 'pd2')
 | 
			
		||||
        self.assertEquals(CatalogE.get_provider_bind_name(p12), 'pe2')
 | 
			
		||||
 | 
			
		||||
    def test_provider_rebinding_to_the_same_catalog(self):
 | 
			
		||||
        """Test provider rebinding to the same catalog."""
 | 
			
		||||
        with self.assertRaises(di.Error):
 | 
			
		||||
            class TestCatalog(di.DeclarativeCatalog):
 | 
			
		||||
                """Test catalog."""
 | 
			
		||||
 | 
			
		||||
                p1 = di.Provider()
 | 
			
		||||
                p2 = p1
 | 
			
		||||
 | 
			
		||||
    def test_provider_rebinding_to_the_same_catalogs_hierarchy(self):
 | 
			
		||||
        """Test provider rebinding to the same catalogs hierarchy."""
 | 
			
		||||
        class TestCatalog1(di.DeclarativeCatalog):
 | 
			
		||||
            """Test catalog."""
 | 
			
		||||
 | 
			
		||||
            p1 = di.Provider()
 | 
			
		||||
 | 
			
		||||
        with self.assertRaises(di.Error):
 | 
			
		||||
            class TestCatalog2(TestCatalog1):
 | 
			
		||||
                """Test catalog."""
 | 
			
		||||
 | 
			
		||||
                p2 = TestCatalog1.p1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CatalogBundleTests(unittest.TestCase):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user