spaCy/spacy/pipeline/_edit_tree_internals/schemas.py
Adriane Boyd 85778dfcf4
Add edit tree lemmatizer (#10231)
* Add edit tree lemmatizer

Co-authored-by: Daniël de Kok <me@danieldk.eu>

* Hide edit tree lemmatizer labels

* Use relative imports

* Switch to single quotes in error message

* Type annotation fixes

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Reformat edit_tree_lemmatizer with black

* EditTreeLemmatizer.predict: take Iterable

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Validate edit trees during deserialization

This change also changes the serialized representation. Rather than
mirroring the deep C structure, we use a simple flat union of the match
and substitution node types.

* Move edit_trees to _edit_tree_internals

* Fix invalid edit tree format error message

* edit_tree_lemmatizer: remove outdated TODO comment

* Rename factory name to trainable_lemmatizer

* Ignore type instead of casting truths to List[Union[Ints1d, Floats2d, List[int], List[str]]] for thinc v8.0.14

* Switch to Tagger.v2

* Add documentation for EditTreeLemmatizer

* docs: Fix 3.2 -> 3.3 somewhere

* trainable_lemmatizer documentation fixes

* docs: EditTreeLemmatizer is in edit_tree_lemmatizer.py

Co-authored-by: Daniël de Kok <me@danieldk.eu>
Co-authored-by: Daniël de Kok <me@github.danieldk.eu>
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
2022-03-28 11:13:50 +02:00

45 lines
1.4 KiB
Python

from typing import Any, Dict, List, Union
from collections import defaultdict
from pydantic import BaseModel, Field, ValidationError
from pydantic.types import StrictBool, StrictInt, StrictStr
class MatchNodeSchema(BaseModel):
prefix_len: StrictInt = Field(..., title="Prefix length")
suffix_len: StrictInt = Field(..., title="Suffix length")
prefix_tree: StrictInt = Field(..., title="Prefix tree")
suffix_tree: StrictInt = Field(..., title="Suffix tree")
class Config:
extra = "forbid"
class SubstNodeSchema(BaseModel):
orig: Union[int, StrictStr] = Field(..., title="Original substring")
subst: Union[int, StrictStr] = Field(..., title="Replacement substring")
class Config:
extra = "forbid"
class EditTreeSchema(BaseModel):
__root__: Union[MatchNodeSchema, SubstNodeSchema]
def validate_edit_tree(obj: Dict[str, Any]) -> List[str]:
"""Validate edit tree.
obj (Dict[str, Any]): JSON-serializable data to validate.
RETURNS (List[str]): A list of error messages, if available.
"""
try:
EditTreeSchema.parse_obj(obj)
return []
except ValidationError as e:
errors = e.errors()
data = defaultdict(list)
for error in errors:
err_loc = " -> ".join([str(p) for p in error.get("loc", [])])
data[err_loc].append(error.get("msg"))
return [f"[{loc}] {', '.join(msg)}" for loc, msg in data.items()] # type: ignore[arg-type]