diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index 0aa3a5a2..5137e929 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -136,26 +136,10 @@ class ProvidersMap: try: provider = getattr(provider, segment) except AttributeError: - return + return None - if isinstance(modifier, TypeModifier): - provider = provider.as_(modifier.type_) - elif isinstance(modifier, RequiredModifier): - provider = provider.required() - if modifier.type_modifier: - provider = provider.as_(modifier.type_modifier.type_) - elif isinstance(modifier, InvariantModifier): - invariant_segment = self._resolve_string_id(modifier.id) - provider = provider[invariant_segment] - elif isinstance(modifier, ProvidedInstance): - provider = provider.provided - for type_, value in modifier.segments: - if type_ == ProvidedInstance.TYPE_ATTRIBUTE: - provider = getattr(provider, value) - elif type_ == ProvidedInstance.TYPE_ITEM: - provider = provider[value] - elif type_ == ProvidedInstance.TYPE_CALL: - provider = provider.call() + if modifier: + provider = modifier.modify(provider, providers_map=self) return provider def _resolve_provided_instance( @@ -230,7 +214,7 @@ class ProvidersMap: try: return self._map[original] except KeyError: - pass + return None @classmethod def _create_providers_map( @@ -563,13 +547,19 @@ def _is_declarative_container(instance: Any) -> bool: class Modifier: - ... + + def modify(self, provider: providers.ConfigurationOption, providers_map: ProvidersMap) -> providers.Provider: + ... class TypeModifier(Modifier): + def __init__(self, type_: Type): self.type_ = type_ + def modify(self, provider: providers.ConfigurationOption, providers_map: ProvidersMap) -> providers.Provider: + return provider.as_(self.type_) + def as_int() -> TypeModifier: """Return int type modifier.""" @@ -587,6 +577,7 @@ def as_(type_: Type) -> TypeModifier: class RequiredModifier(Modifier): + def __init__(self): self.type_modifier = None @@ -604,6 +595,12 @@ class RequiredModifier(Modifier): self.type_modifier = TypeModifier(type_) return self + def modify(self, provider: providers.ConfigurationOption, providers_map: ProvidersMap) -> providers.Provider: + provider = provider.required() + if self.type_modifier: + provider = provider.as_(self.type_modifier.type_) + return provider + def required() -> RequiredModifier: """Return required modifier.""" @@ -611,9 +608,14 @@ def required() -> RequiredModifier: class InvariantModifier(Modifier): + def __init__(self, id: str) -> None: self.id = id + def modify(self, provider: providers.ConfigurationOption, providers_map: ProvidersMap) -> providers.Provider: + invariant_segment = providers_map.resolve_provider(self.id) + return provider[invariant_segment] + def invariant(id: str) -> InvariantModifier: """Return invariant modifier.""" @@ -641,6 +643,17 @@ class ProvidedInstance(Modifier): self.segments.append((self.TYPE_CALL, None)) return self + def modify(self, provider: providers.ConfigurationOption, providers_map: ProvidersMap) -> providers.Provider: + provider = provider.provided + for type_, value in self.segments: + if type_ == ProvidedInstance.TYPE_ATTRIBUTE: + provider = getattr(provider, value) + elif type_ == ProvidedInstance.TYPE_ITEM: + provider = provider[value] + elif type_ == ProvidedInstance.TYPE_CALL: + provider = provider.call() + return provider + def provided() -> ProvidedInstance: """Return provided instance modifier."""