diff --git a/src/dependency_injector/schema.py b/src/dependency_injector/schema.py index b307775e..996fab22 100644 --- a/src/dependency_injector/schema.py +++ b/src/dependency_injector/schema.py @@ -9,123 +9,129 @@ from . import containers, providers Schema = Dict[Any, Any] -def build_schema(schema: Schema) -> Dict[str, providers.Provider]: - """Build provider schema.""" - container = containers.DynamicContainer() - _create_providers(container, schema['providers']) - _setup_injections(container, schema['providers']) - return container.providers +class SchemaProcessorV1: + def __init__(self, schema: Schema) -> None: + self._schema = schema + self._container = containers.DynamicContainer() -def _create_providers( - container: containers.Container, - providers_data: Dict[str, Any], -): - for provider_name, data in providers_data.items(): - provider_type = _get_provider_cls(data['provider']) - args = [] + def process(self): + """Process schema.""" + self._create_providers(self._schema) + self._setup_injections(self._schema) - provides = data.get('provides') - if provides: - provides = _import_string(provides) + def get_providers(self): + """Return providers.""" + return self._container.providers + + def _create_providers(self, schema: Schema, container: Optional[containers.Container] = None) -> None: + if container is None: + container = self._container + for provider_name, data in schema['providers'].items(): + provider_type = _get_provider_cls(data['provider']) + args = [] + + provides = data.get('provides') if provides: + provides = _import_string(provides) + if provides: + args.append(provides) + + if provider_type is providers.Container: + provides = containers.DynamicContainer args.append(provides) - if provider_type is providers.Container: - provides = containers.DynamicContainer - args.append(provides) + provider = provider_type(*args) + container.set_provider(provider_name, provider) - provider = provider_type(*args) - container.set_provider(provider_name, provider) + if isinstance(provider, providers.Container): + self._create_providers(schema=data, container=provider) - if isinstance(provider, providers.Container): - _create_providers(provider, data['providers']) + def _setup_injections(self, schema: Schema, container: Optional[containers.Container] = None) -> None: + if container is None: + container = self._container + for provider_name, data in schema['providers'].items(): + provider = getattr(container, provider_name) + args = [] + kwargs = {} -def _setup_injections( - container: containers.Container, - providers_data: Dict[str, Any], - *, - current_container: Optional[containers.Container] = None, -): - if not current_container: - current_container = container + arg_injections = data.get('args') + if arg_injections: + for arg in arg_injections: + injection = None - for provider_name, data in providers_data.items(): - provider = getattr(current_container, provider_name) - args = [] - kwargs = {} + if isinstance(arg, str): + injection = self._resolve_provider(arg) - arg_injections = data.get('args') - if arg_injections: - for arg in arg_injections: - injection = None + # TODO: add inline injections - if isinstance(arg, str): - injection = _resolve_provider(container, arg) + if not injection: + injection = arg - # TODO: add inline injections + args.append(injection) + if args: + provider.add_args(*args) - if not injection: - injection = arg + kwarg_injections = data.get('kwargs') + if kwarg_injections: + for name, arg in kwarg_injections.items(): + injection = None - args.append(injection) - if args: - provider.add_args(*args) + if isinstance(arg, str): + injection = self._resolve_provider(arg) - kwarg_injections = data.get('kwargs') - if kwarg_injections: - for name, arg in kwarg_injections.items(): - injection = None - - if isinstance(arg, str): - injection = _resolve_provider(container, arg) - - # TODO: refactoring - if isinstance(arg, dict): - provider_args = [] - provider_type = _get_provider_cls(arg.get('provider')) - provides = arg.get('provides') - if provides: - provides = _import_string(provides) + # TODO: refactoring + if isinstance(arg, dict): + provider_args = [] + provider_type = _get_provider_cls(arg.get('provider')) + provides = arg.get('provides') if provides: - provider_args.append(provides) - for provider_arg in arg.get('args', []): - provider_args.append(_resolve_provider(container, provider_arg)) - injection = provider_type(*provider_args) + provides = _import_string(provides) + if provides: + provider_args.append(provides) + for provider_arg in arg.get('args', []): + provider_args.append(self._resolve_provider(provider_arg)) + injection = provider_type(*provider_args) - if not injection: - injection = arg + if not injection: + injection = arg - kwargs[name] = injection - if kwargs: - provider.add_kwargs(**kwargs) + kwargs[name] = injection + if kwargs: + provider.add_kwargs(**kwargs) - if isinstance(provider, providers.Container): - _setup_injections(container, data['providers'], current_container=provider) + if isinstance(provider, providers.Container): + self._setup_injections(schema=data, container=provider) + + def _resolve_provider(self, name: str) -> Optional[providers.Provider]: + segments = name.split('.') + try: + provider = getattr(self._container, segments[0]) + except AttributeError: + return None + + for segment in segments[1:]: + if segment == 'as_int()': + provider = provider.as_int() + elif segment == 'as_float()': + provider = provider.as_float() + elif segment.startswith('is_'): # TODO + provider = provider.as_(str) + ... + else: + try: + provider = getattr(provider, segment) + except AttributeError: + return None + return provider -def _resolve_provider(container: containers.Container, name: str) -> Optional[providers.Provider]: - segments = name.split('.') - try: - provider = getattr(container, segments[0]) - except AttributeError: - return None - - for segment in segments[1:]: - if segment == 'as_int()': - provider = provider.as_int() - elif segment == 'as_float()': - provider = provider.as_float() - elif segment.startswith('is_'): # TODO - provider = provider.as_(str) - ... - else: - try: - provider = getattr(provider, segment) - except AttributeError: - return None - return provider +def build_schema(schema: Schema) -> Dict[str, providers.Provider]: + """Build provider schema.""" + schema_processor = SchemaProcessorV1(schema) + schema_processor.process() + return schema_processor.get_providers() def _get_provider_cls(provider_cls_name: str) -> Type[providers.Provider]: