diff --git a/spacy/language.py b/spacy/language.py index a5b78b178..20b7a7256 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -27,7 +27,7 @@ from .lang.punctuation import TOKENIZER_INFIXES from .tokens import Doc from .tokenizer import Tokenizer from .errors import Errors, Warnings -from .schemas import ConfigSchema, ConfigSchemaNlp +from .schemas import ConfigSchema, ConfigSchemaNlp, validate_init_settings from .git_info import GIT_VERSION from . import util from . import about @@ -1162,6 +1162,7 @@ class Language: self, get_examples: Optional[Callable[[], Iterable[Example]]] = None, *, + settings: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), sgd: Optional[Optimizer] = None, device: int = -1, ) -> Optimizer: @@ -1207,10 +1208,26 @@ class Language: if sgd is None: sgd = create_default_optimizer() self._optimizer = sgd + if hasattr(self.tokenizer, "initialize"): + tok_settings = settings.get("tokenizer", {}) + tok_settings = validate_init_settings( + self.tokenizer.initialize, + tok_settings, + section="tokenizer", + name="tokenizer", + ) + self.tokenizer.initialize(get_examples, nlp=self, **tok_settings) for name, proc in self.pipeline: if hasattr(proc, "initialize"): + p_settings = settings.get(name, {}) + p_settings = validate_init_settings( + proc.initialize, p_settings, section="components", name=name + ) proc.initialize( - get_examples, pipeline=self.pipeline, sgd=self._optimizer + get_examples, + pipeline=self.pipeline, + sgd=self._optimizer, + **p_settings, ) self._link_components() return self._optimizer diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index 5a4503cf9..78e3422f6 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -1,4 +1,4 @@ -# cython: infer_types=True, cdivision=True, boundscheck=False +# cython: infer_types=True, cdivision=True, boundscheck=False, binding=True from __future__ import print_function from cymem.cymem cimport Pool cimport numpy as np diff --git a/spacy/schemas.py b/spacy/schemas.py index b98498b8b..cdd8c11ed 100644 --- a/spacy/schemas.py +++ b/spacy/schemas.py @@ -1,11 +1,13 @@ from typing import Dict, List, Union, Optional, Any, Callable, Type, Tuple from typing import Iterable, TypeVar, TYPE_CHECKING from enum import Enum -from pydantic import BaseModel, Field, ValidationError, validator +from pydantic import BaseModel, Field, ValidationError, validator, create_model from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool +from pydantic.main import ModelMetaclass +from thinc.api import Optimizer, ConfigValidationError from thinc.config import Promise from collections import defaultdict -from thinc.api import Optimizer +import inspect from .attrs import NAMES from .lookups import Lookups @@ -43,6 +45,93 @@ def validate(schema: Type[BaseModel], obj: Dict[str, Any]) -> List[str]: return [f"[{loc}] {', '.join(msg)}" for loc, msg in data.items()] +# Initialization + + +class ArgSchemaConfig: + extra = "forbid" + arbitrary_types_allowed = True + + +class ArgSchemaConfigExtra: + extra = "forbid" + arbitrary_types_allowed = True + + +def get_arg_model( + func: Callable, + *, + exclude: Iterable[str] = tuple(), + name: str = "ArgModel", + strict: bool = True, +) -> ModelMetaclass: + """Generate a pydantic model for function arguments. + + func (Callable): The function to generate the schema for. + exclude (Iterable[str]): Parameter names to ignore. + name (str): Name of created model class. + strict (bool): Don't allow extra arguments if no variable keyword arguments + are allowed on the function. + RETURNS (ModelMetaclass): A pydantic model. + """ + sig_args = {} + try: + sig = inspect.signature(func) + except ValueError: + # Typically happens if the method is part of a Cython module without + # binding=True. Here we just use an empty model that allows everything. + return create_model(name, __config__=ArgSchemaConfigExtra) + has_variable = False + for param in sig.parameters.values(): + if param.name in exclude: + continue + if param.kind == param.VAR_KEYWORD: + # The function allows variable keyword arguments so we shouldn't + # include **kwargs etc. in the schema and switch to non-strict + # mode and pass through all other values + has_variable = True + continue + # If no annotation is specified assume it's anything + annotation = param.annotation if param.annotation != param.empty else Any + # If no default value is specified assume that it's required + default = param.default if param.default != param.empty else ... + sig_args[param.name] = (annotation, default) + is_strict = strict and not has_variable + sig_args["__config__"] = ArgSchemaConfig if is_strict else ArgSchemaConfigExtra + return create_model(name, **sig_args) + + +def validate_init_settings( + func: Callable, + settings: Dict[str, Any], + *, + section: Optional[str] = None, + name: str = "", + exclude: Iterable[str] = ("get_examples", "pipeline", "sgd"), +) -> Dict[str, Any]: + """Validate initialization settings against the expected arguments in + the method signature. Will parse values if possible (e.g. int to string) + and return the updated settings dict. Will raise a ConfigValidationError + if types don't match or required values are missing. + + func (Callable): The initialize method of a given component etc. + settings (Dict[str, Any]): The settings from the repsective [initialize] block. + section (str): Initialize section, for error message. + name (str): Name of the block in the section. + exclude (Iterable[str]): Parameter names to exclude from schema. + RETURNS (Dict[str, Any]): The validated settings. + """ + schema = get_arg_model(func, exclude=exclude, name="InitArgModel") + try: + return schema(**settings).dict() + except ValidationError as e: + block = "initialize" if not section else f"initialize.{section}" + title = f"Error validating initialization settings in [{block}]" + raise ConfigValidationError( + title=title, errors=e.errors(), config=settings, parent=name, + ) from None + + # Matcher token patterns