spaCy/spacy/schemas.py
Adriane Boyd 0fe43f40f1
Support registered vectors (#12492)
* Support registered vectors

* Format

* Auto-fill [nlp] on load from config and from bytes/disk

* Only auto-fill [nlp]

* Undo all changes to Language.from_disk

* Expand BaseVectors

These methods are needed in various places for training and vector
similarity.

* isort

* More linting

* Only fill [nlp.vectors]

* Update spacy/vocab.pyx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Revert changes to test related to auto-filling [nlp]

* Add vectors registry

* Rephrase error about vocab methods for vectors

* Switch to dummy implementation for BaseVectors.to_ops

* Add initial draft of docs

* Remove example from BaseVectors docs

* Apply suggestions from code review

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update website/docs/api/basevectors.mdx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Fix type and lint bpemb example

* Update website/docs/api/basevectors.mdx

---------

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
2023-08-01 15:46:08 +02:00

520 lines
21 KiB
Python

import inspect
import re
from collections import defaultdict
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from pydantic import (
BaseModel,
ConstrainedStr,
Field,
StrictBool,
StrictFloat,
StrictInt,
StrictStr,
ValidationError,
create_model,
validator,
)
from pydantic.main import ModelMetaclass
from thinc.api import ConfigValidationError, Model, Optimizer
from thinc.config import Promise
from .attrs import NAMES
from .compat import Literal
from .lookups import Lookups
from .util import is_cython_func
if TYPE_CHECKING:
# This lets us add type hints for mypy etc. without causing circular imports
from .language import Language # noqa: F401
from .training import Example # noqa: F401
from .vocab import Vocab # noqa: F401
# fmt: off
ItemT = TypeVar("ItemT")
Batcher = Union[Callable[[Iterable[ItemT]], Iterable[List[ItemT]]], Promise]
Reader = Union[Callable[["Language", str], Iterable["Example"]], Promise]
Logger = Union[Callable[["Language"], Tuple[Callable[[Dict[str, Any]], None], Callable]], Promise]
# fmt: on
def validate(schema: Type[BaseModel], obj: Dict[str, Any]) -> List[str]:
"""Validate data against a given pydantic schema.
obj (Dict[str, Any]): JSON-serializable data to validate.
schema (pydantic.BaseModel): The schema to validate against.
RETURNS (List[str]): A list of error messages, if available.
"""
try:
schema(**obj)
return []
except ValidationError as e:
errors = e.errors()
data = defaultdict(list)
for error in errors:
err_loc = " -> ".join([str(p) for p in error.get("loc", [])])
data[err_loc].append(error.get("msg"))
return [f"[{loc}] {', '.join(msg)}" for loc, msg in data.items()] # type: ignore[arg-type]
# Initialization
class ArgSchemaConfig:
extra = "forbid"
arbitrary_types_allowed = True
class ArgSchemaConfigExtra:
extra = "forbid"
arbitrary_types_allowed = True
def get_arg_model(
func: Callable,
*,
exclude: Iterable[str] = tuple(),
name: str = "ArgModel",
strict: bool = True,
) -> ModelMetaclass:
"""Generate a pydantic model for function arguments.
func (Callable): The function to generate the schema for.
exclude (Iterable[str]): Parameter names to ignore.
name (str): Name of created model class.
strict (bool): Don't allow extra arguments if no variable keyword arguments
are allowed on the function.
RETURNS (ModelMetaclass): A pydantic model.
"""
sig_args = {}
try:
sig = inspect.signature(func)
except ValueError:
# Typically happens if the method is part of a Cython module without
# binding=True. Here we just use an empty model that allows everything.
return create_model(name, __config__=ArgSchemaConfigExtra) # type: ignore[arg-type, return-value]
has_variable = False
for param in sig.parameters.values():
if param.name in exclude:
continue
if param.kind == param.VAR_KEYWORD:
# The function allows variable keyword arguments so we shouldn't
# include **kwargs etc. in the schema and switch to non-strict
# mode and pass through all other values
has_variable = True
continue
# If no annotation is specified assume it's anything
annotation = param.annotation if param.annotation != param.empty else Any
# If no default value is specified assume that it's required. Cython
# functions/methods will have param.empty for default value None so we
# need to treat them differently
default_empty = None if is_cython_func(func) else ...
default = param.default if param.default != param.empty else default_empty
sig_args[param.name] = (annotation, default)
is_strict = strict and not has_variable
sig_args["__config__"] = ArgSchemaConfig if is_strict else ArgSchemaConfigExtra # type: ignore[assignment]
return create_model(name, **sig_args) # type: ignore[call-overload, arg-type, return-value]
def validate_init_settings(
func: Callable,
settings: Dict[str, Any],
*,
section: Optional[str] = None,
name: str = "",
exclude: Iterable[str] = ("get_examples", "nlp"),
) -> Dict[str, Any]:
"""Validate initialization settings against the expected arguments in
the method signature. Will parse values if possible (e.g. int to string)
and return the updated settings dict. Will raise a ConfigValidationError
if types don't match or required values are missing.
func (Callable): The initialize method of a given component etc.
settings (Dict[str, Any]): The settings from the respective [initialize] block.
section (str): Initialize section, for error message.
name (str): Name of the block in the section.
exclude (Iterable[str]): Parameter names to exclude from schema.
RETURNS (Dict[str, Any]): The validated settings.
"""
schema = get_arg_model(func, exclude=exclude, name="InitArgModel")
try:
return schema(**settings).dict()
except ValidationError as e:
block = "initialize" if not section else f"initialize.{section}"
title = f"Error validating initialization settings in [{block}]"
raise ConfigValidationError(
title=title, errors=e.errors(), config=settings, parent=name
) from None
# Matcher token patterns
def validate_token_pattern(obj: list) -> List[str]:
# Try to convert non-string keys (e.g. {ORTH: "foo"} -> {"ORTH": "foo"})
get_key = lambda k: NAMES[k] if isinstance(k, int) and k < len(NAMES) else k
if isinstance(obj, list):
converted = []
for pattern in obj:
if isinstance(pattern, dict):
pattern = {get_key(k): v for k, v in pattern.items()}
converted.append(pattern)
obj = converted
return validate(TokenPatternSchema, {"pattern": obj})
class TokenPatternString(BaseModel):
REGEX: Optional[Union[StrictStr, "TokenPatternString"]] = Field(None, alias="regex")
IN: Optional[List[StrictStr]] = Field(None, alias="in")
NOT_IN: Optional[List[StrictStr]] = Field(None, alias="not_in")
IS_SUBSET: Optional[List[StrictStr]] = Field(None, alias="is_subset")
IS_SUPERSET: Optional[List[StrictStr]] = Field(None, alias="is_superset")
INTERSECTS: Optional[List[StrictStr]] = Field(None, alias="intersects")
FUZZY: Optional[Union[StrictStr, "TokenPatternString"]] = Field(None, alias="fuzzy")
FUZZY1: Optional[Union[StrictStr, "TokenPatternString"]] = Field(
None, alias="fuzzy1"
)
FUZZY2: Optional[Union[StrictStr, "TokenPatternString"]] = Field(
None, alias="fuzzy2"
)
FUZZY3: Optional[Union[StrictStr, "TokenPatternString"]] = Field(
None, alias="fuzzy3"
)
FUZZY4: Optional[Union[StrictStr, "TokenPatternString"]] = Field(
None, alias="fuzzy4"
)
FUZZY5: Optional[Union[StrictStr, "TokenPatternString"]] = Field(
None, alias="fuzzy5"
)
FUZZY6: Optional[Union[StrictStr, "TokenPatternString"]] = Field(
None, alias="fuzzy6"
)
FUZZY7: Optional[Union[StrictStr, "TokenPatternString"]] = Field(
None, alias="fuzzy7"
)
FUZZY8: Optional[Union[StrictStr, "TokenPatternString"]] = Field(
None, alias="fuzzy8"
)
FUZZY9: Optional[Union[StrictStr, "TokenPatternString"]] = Field(
None, alias="fuzzy9"
)
class Config:
extra = "forbid"
allow_population_by_field_name = True # allow alias and field name
@validator("*", pre=True, each_item=True, allow_reuse=True)
def raise_for_none(cls, v):
if v is None:
raise ValueError("None / null is not allowed")
return v
class TokenPatternNumber(BaseModel):
REGEX: Optional[StrictStr] = Field(None, alias="regex")
IN: Optional[List[StrictInt]] = Field(None, alias="in")
NOT_IN: Optional[List[StrictInt]] = Field(None, alias="not_in")
IS_SUBSET: Optional[List[StrictInt]] = Field(None, alias="is_subset")
IS_SUPERSET: Optional[List[StrictInt]] = Field(None, alias="is_superset")
INTERSECTS: Optional[List[StrictInt]] = Field(None, alias="intersects")
EQ: Optional[Union[StrictInt, StrictFloat]] = Field(None, alias="==")
NEQ: Optional[Union[StrictInt, StrictFloat]] = Field(None, alias="!=")
GEQ: Optional[Union[StrictInt, StrictFloat]] = Field(None, alias=">=")
LEQ: Optional[Union[StrictInt, StrictFloat]] = Field(None, alias="<=")
GT: Optional[Union[StrictInt, StrictFloat]] = Field(None, alias=">")
LT: Optional[Union[StrictInt, StrictFloat]] = Field(None, alias="<")
class Config:
extra = "forbid"
allow_population_by_field_name = True # allow alias and field name
@validator("*", pre=True, each_item=True, allow_reuse=True)
def raise_for_none(cls, v):
if v is None:
raise ValueError("None / null is not allowed")
return v
class TokenPatternOperatorSimple(str, Enum):
plus: StrictStr = StrictStr("+")
star: StrictStr = StrictStr("*")
question: StrictStr = StrictStr("?")
exclamation: StrictStr = StrictStr("!")
class TokenPatternOperatorMinMax(ConstrainedStr):
regex = re.compile(r"^({\d+}|{\d+,\d*}|{\d*,\d+})$")
TokenPatternOperator = Union[TokenPatternOperatorSimple, TokenPatternOperatorMinMax]
StringValue = Union[TokenPatternString, StrictStr]
NumberValue = Union[TokenPatternNumber, StrictInt, StrictFloat]
UnderscoreValue = Union[
TokenPatternString, TokenPatternNumber, str, int, float, list, bool
]
IobValue = Literal["", "I", "O", "B", 0, 1, 2, 3]
class TokenPattern(BaseModel):
orth: Optional[StringValue] = None
text: Optional[StringValue] = None
lower: Optional[StringValue] = None
pos: Optional[StringValue] = None
tag: Optional[StringValue] = None
morph: Optional[StringValue] = None
dep: Optional[StringValue] = None
lemma: Optional[StringValue] = None
shape: Optional[StringValue] = None
ent_type: Optional[StringValue] = None
ent_iob: Optional[IobValue] = None
ent_id: Optional[StringValue] = None
ent_kb_id: Optional[StringValue] = None
norm: Optional[StringValue] = None
length: Optional[NumberValue] = None
spacy: Optional[StrictBool] = None
is_alpha: Optional[StrictBool] = None
is_ascii: Optional[StrictBool] = None
is_digit: Optional[StrictBool] = None
is_lower: Optional[StrictBool] = None
is_upper: Optional[StrictBool] = None
is_title: Optional[StrictBool] = None
is_punct: Optional[StrictBool] = None
is_space: Optional[StrictBool] = None
is_bracket: Optional[StrictBool] = None
is_quote: Optional[StrictBool] = None
is_left_punct: Optional[StrictBool] = None
is_right_punct: Optional[StrictBool] = None
is_currency: Optional[StrictBool] = None
is_stop: Optional[StrictBool] = None
is_sent_start: Optional[StrictBool] = None
sent_start: Optional[StrictBool] = None
like_num: Optional[StrictBool] = None
like_url: Optional[StrictBool] = None
like_email: Optional[StrictBool] = None
op: Optional[TokenPatternOperator] = None
underscore: Optional[Dict[StrictStr, UnderscoreValue]] = Field(None, alias="_")
class Config:
extra = "forbid"
allow_population_by_field_name = True
alias_generator = lambda value: value.upper()
@validator("*", pre=True, allow_reuse=True)
def raise_for_none(cls, v):
if v is None:
raise ValueError("None / null is not allowed")
return v
class TokenPatternSchema(BaseModel):
pattern: List[TokenPattern] = Field(..., min_items=1)
class Config:
extra = "forbid"
# Model meta
class ModelMetaSchema(BaseModel):
# fmt: off
lang: StrictStr = Field(..., title="Two-letter language code, e.g. 'en'")
name: StrictStr = Field(..., title="Model name")
version: StrictStr = Field(..., title="Model version")
spacy_version: StrictStr = Field("", title="Compatible spaCy version identifier")
parent_package: StrictStr = Field("spacy", title="Name of parent spaCy package, e.g. spacy or spacy-nightly")
requirements: List[StrictStr] = Field([], title="Additional Python package dependencies, used for the Python package setup")
pipeline: List[StrictStr] = Field([], title="Names of pipeline components")
description: StrictStr = Field("", title="Model description")
license: StrictStr = Field("", title="Model license")
author: StrictStr = Field("", title="Model author name")
email: StrictStr = Field("", title="Model author email")
url: StrictStr = Field("", title="Model author URL")
sources: Optional[Union[List[StrictStr], List[Dict[str, str]]]] = Field(None, title="Training data sources")
vectors: Dict[str, Any] = Field({}, title="Included word vectors")
labels: Dict[str, List[str]] = Field({}, title="Component labels, keyed by component name")
performance: Dict[str, Any] = Field({}, title="Accuracy and speed numbers")
spacy_git_version: StrictStr = Field("", title="Commit of spaCy version used")
# fmt: on
# Config schema
# We're not setting any defaults here (which is too messy) and are making all
# fields required, so we can raise validation errors for missing values. To
# provide a default, we include a separate .cfg file with all values and
# check that against this schema in the test suite to make sure it's always
# up to date.
class ConfigSchemaTraining(BaseModel):
# fmt: off
dev_corpus: StrictStr = Field(..., title="Path in the config to the dev data")
train_corpus: StrictStr = Field(..., title="Path in the config to the training data")
batcher: Batcher = Field(..., title="Batcher for the training data")
dropout: StrictFloat = Field(..., title="Dropout rate")
patience: StrictInt = Field(..., title="How many steps to continue without improvement in evaluation score")
max_epochs: StrictInt = Field(..., title="Maximum number of epochs to train for")
max_steps: StrictInt = Field(..., title="Maximum number of update steps to train for")
eval_frequency: StrictInt = Field(..., title="How often to evaluate during training (steps)")
seed: Optional[StrictInt] = Field(..., title="Random seed")
gpu_allocator: Optional[StrictStr] = Field(..., title="Memory allocator when running on GPU")
accumulate_gradient: StrictInt = Field(..., title="Whether to divide the batch up into substeps")
score_weights: Dict[StrictStr, Optional[Union[StrictFloat, StrictInt]]] = Field(..., title="Scores to report and their weights for selecting final model")
optimizer: Optimizer = Field(..., title="The optimizer to use")
logger: Logger = Field(..., title="The logger to track training progress")
frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training")
annotating_components: List[str] = Field(..., title="Pipeline components that should set annotations during training")
before_to_disk: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after training, before it's saved to disk")
before_update: Optional[Callable[["Language", Dict[str, Any]], None]] = Field(..., title="Optional callback that is invoked at the start of each training step")
# fmt: on
class Config:
extra = "forbid"
arbitrary_types_allowed = True
class ConfigSchemaNlp(BaseModel):
# fmt: off
lang: StrictStr = Field(..., title="The base language to use")
pipeline: List[StrictStr] = Field(..., title="The pipeline component names in order")
disabled: List[StrictStr] = Field(..., title="Pipeline components to disable by default")
tokenizer: Callable = Field(..., title="The tokenizer to use")
before_creation: Optional[Callable[[Type["Language"]], Type["Language"]]] = Field(..., title="Optional callback to modify Language class before initialization")
after_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after creation and before the pipeline is constructed")
after_pipeline_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after the pipeline is constructed")
batch_size: Optional[int] = Field(..., title="Default batch size")
vectors: Callable = Field(..., title="Vectors implementation")
# fmt: on
class Config:
extra = "forbid"
arbitrary_types_allowed = True
class ConfigSchemaPretrainEmpty(BaseModel):
class Config:
extra = "forbid"
class ConfigSchemaPretrain(BaseModel):
# fmt: off
max_epochs: StrictInt = Field(..., title="Maximum number of epochs to train for")
dropout: StrictFloat = Field(..., title="Dropout rate")
n_save_every: Optional[StrictInt] = Field(..., title="Saving additional temporary model after n batches within an epoch")
n_save_epoch: Optional[StrictInt] = Field(..., title="Saving model after every n epoch")
optimizer: Optimizer = Field(..., title="The optimizer to use")
corpus: StrictStr = Field(..., title="Path in the config to the training data")
batcher: Batcher = Field(..., title="Batcher for the training data")
component: str = Field(..., title="Component to find the layer to pretrain")
layer: str = Field(..., title="Layer to pretrain. Whole model if empty.")
objective: Callable[["Vocab", Model], Model] = Field(..., title="A function that creates the pretraining objective.")
# fmt: on
class Config:
extra = "forbid"
arbitrary_types_allowed = True
class ConfigSchemaInit(BaseModel):
# fmt: off
vocab_data: Optional[StrictStr] = Field(..., title="Path to JSON-formatted vocabulary file")
lookups: Optional[Lookups] = Field(..., title="Vocabulary lookups, e.g. lexeme normalization")
vectors: Optional[StrictStr] = Field(..., title="Path to vectors")
init_tok2vec: Optional[StrictStr] = Field(..., title="Path to pretrained tok2vec weights")
tokenizer: Dict[StrictStr, Any] = Field(..., help="Arguments to be passed into Tokenizer.initialize")
components: Dict[StrictStr, Dict[StrictStr, Any]] = Field(..., help="Arguments for TrainablePipe.initialize methods of pipeline components, keyed by component")
before_init: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object before initialization")
after_init: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after initialization")
# fmt: on
class Config:
extra = "forbid"
arbitrary_types_allowed = True
class ConfigSchema(BaseModel):
training: ConfigSchemaTraining
nlp: ConfigSchemaNlp
pretraining: Union[ConfigSchemaPretrain, ConfigSchemaPretrainEmpty] = {} # type: ignore[assignment]
components: Dict[str, Dict[str, Any]]
corpora: Dict[str, Reader]
initialize: ConfigSchemaInit
class Config:
extra = "allow"
arbitrary_types_allowed = True
CONFIG_SCHEMAS = {
"nlp": ConfigSchemaNlp,
"training": ConfigSchemaTraining,
"pretraining": ConfigSchemaPretrain,
"initialize": ConfigSchemaInit,
}
# Recommendations for init config workflows
class RecommendationTrfItem(BaseModel):
name: str
size_factor: int
class RecommendationTrf(BaseModel):
efficiency: RecommendationTrfItem
accuracy: RecommendationTrfItem
class RecommendationSchema(BaseModel):
word_vectors: Optional[str] = None
transformer: Optional[RecommendationTrf] = None
has_letters: bool = True
class DocJSONSchema(BaseModel):
"""
JSON/dict format for JSON representation of Doc objects.
"""
cats: Optional[Dict[StrictStr, StrictFloat]] = Field(
None, title="Categories with corresponding probabilities"
)
ents: Optional[List[Dict[StrictStr, Union[StrictInt, StrictStr]]]] = Field(
None, title="Information on entities"
)
sents: Optional[List[Dict[StrictStr, StrictInt]]] = Field(
None, title="Indices of sentences' start and end indices"
)
text: StrictStr = Field(..., title="Document text")
spans: Optional[
Dict[StrictStr, List[Dict[StrictStr, Union[StrictStr, StrictInt]]]]
] = Field(None, title="Span information - end/start indices, label, KB ID")
tokens: List[Dict[StrictStr, Union[StrictStr, StrictInt]]] = Field(
..., title="Token information - ID, start, annotations"
)
underscore_doc: Optional[Dict[StrictStr, Any]] = Field(
None,
title="Any custom data stored in the document's _ attribute",
alias="_",
)
underscore_token: Optional[Dict[StrictStr, List[Dict[StrictStr, Any]]]] = Field(
None, title="Any custom data stored in the token's _ attribute"
)
underscore_span: Optional[Dict[StrictStr, List[Dict[StrictStr, Any]]]] = Field(
None, title="Any custom data stored in the span's _ attribute"
)