Add better schemas and validation using Pydantic (#4831)

* Remove unicode declarations

* Remove Python 3.5 and 2.7 from CI

* Don't require pathlib

* Replace compat helpers

* Remove OrderedDict

* Use f-strings

* Set Cython compiler language level

* Fix typo

* Re-add OrderedDict for Table

* Update setup.cfg

* Revert CONTRIBUTING.md

* Add better schemas and validation using Pydantic

* Revert lookups.md

* Remove unused import

* Update spacy/schemas.py

Co-Authored-By: Sebastián Ramírez <tiangolo@gmail.com>

* Various small fixes

* Fix docstring

Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
This commit is contained in:
Ines Montani 2019-12-25 12:39:49 +01:00 committed by GitHub
parent db55577c45
commit 33a2682d60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 217 additions and 563 deletions

View File

@ -12,8 +12,7 @@ numpy>=1.15.0
requests>=2.13.0,<3.0.0 requests>=2.13.0,<3.0.0
plac>=0.9.6,<1.2.0 plac>=0.9.6,<1.2.0
tqdm>=4.38.0,<5.0.0 tqdm>=4.38.0,<5.0.0
# Optional dependencies pydantic>=1.0.0,<2.0.0
jsonschema>=2.6.0,<3.1.0
# Development dependencies # Development dependencies
cython>=0.25 cython>=0.25
pytest>=4.6.5 pytest>=4.6.5

View File

@ -51,6 +51,7 @@ install_requires =
numpy>=1.15.0 numpy>=1.15.0
plac>=0.9.6,<1.2.0 plac>=0.9.6,<1.2.0
requests>=2.13.0,<3.0.0 requests>=2.13.0,<3.0.0
pydantic>=1.0.0,<2.0.0
[options.extras_require] [options.extras_require]
lookups = lookups =

View File

@ -1,217 +0,0 @@
# NB: This schema describes the new format of the training data, see #2928
TRAINING_SCHEMA = {
"$schema": "http://json-schema.org/draft-06/schema",
"title": "Training data for spaCy models",
"type": "array",
"items": {
"type": "object",
"properties": {
"text": {
"title": "The text of the training example",
"type": "string",
"minLength": 1,
},
"ents": {
"title": "Named entity spans in the text",
"type": "array",
"items": {
"type": "object",
"properties": {
"start": {
"title": "Start character offset of the span",
"type": "integer",
"minimum": 0,
},
"end": {
"title": "End character offset of the span",
"type": "integer",
"minimum": 0,
},
"label": {
"title": "Entity label",
"type": "string",
"minLength": 1,
"pattern": "^[A-Z0-9]*$",
},
},
"required": ["start", "end", "label"],
},
},
"sents": {
"title": "Sentence spans in the text",
"type": "array",
"items": {
"type": "object",
"properties": {
"start": {
"title": "Start character offset of the span",
"type": "integer",
"minimum": 0,
},
"end": {
"title": "End character offset of the span",
"type": "integer",
"minimum": 0,
},
},
"required": ["start", "end"],
},
},
"cats": {
"title": "Text categories for the text classifier",
"type": "object",
"patternProperties": {
"*": {
"title": "A text category",
"oneOf": [
{"type": "boolean"},
{"type": "number", "minimum": 0},
],
}
},
"propertyNames": {"pattern": "^[A-Z0-9]*$", "minLength": 1},
},
"tokens": {
"title": "The tokens in the text",
"type": "array",
"items": {
"type": "object",
"minProperties": 1,
"properties": {
"id": {
"title": "Token ID, usually token index",
"type": "integer",
"minimum": 0,
},
"start": {
"title": "Start character offset of the token",
"type": "integer",
"minimum": 0,
},
"end": {
"title": "End character offset of the token",
"type": "integer",
"minimum": 0,
},
"pos": {
"title": "Coarse-grained part-of-speech tag",
"type": "string",
"minLength": 1,
},
"tag": {
"title": "Fine-grained part-of-speech tag",
"type": "string",
"minLength": 1,
},
"dep": {
"title": "Dependency label",
"type": "string",
"minLength": 1,
},
"head": {
"title": "Index of the token's head",
"type": "integer",
"minimum": 0,
},
},
"required": ["start", "end"],
},
},
"_": {"title": "Custom user space", "type": "object"},
},
"required": ["text"],
},
}
META_SCHEMA = {
"$schema": "http://json-schema.org/draft-06/schema",
"type": "object",
"properties": {
"lang": {
"title": "Two-letter language code, e.g. 'en'",
"type": "string",
"minLength": 2,
"maxLength": 2,
"pattern": "^[a-z]*$",
},
"name": {
"title": "Model name",
"type": "string",
"minLength": 1,
"pattern": "^[a-z_]*$",
},
"version": {
"title": "Model version",
"type": "string",
"minLength": 1,
"pattern": "^[0-9a-z.-]*$",
},
"spacy_version": {
"title": "Compatible spaCy version identifier",
"type": "string",
"minLength": 1,
"pattern": "^[0-9a-z.-><=]*$",
},
"parent_package": {
"title": "Name of parent spaCy package, e.g. spacy or spacy-nightly",
"type": "string",
"minLength": 1,
"default": "spacy",
},
"pipeline": {
"title": "Names of pipeline components",
"type": "array",
"items": {"type": "string", "minLength": 1},
},
"description": {"title": "Model description", "type": "string"},
"license": {"title": "Model license", "type": "string"},
"author": {"title": "Model author name", "type": "string"},
"email": {"title": "Model author email", "type": "string", "format": "email"},
"url": {"title": "Model author URL", "type": "string", "format": "uri"},
"sources": {
"title": "Training data sources",
"type": "array",
"items": {"type": "string"},
},
"vectors": {
"title": "Included word vectors",
"type": "object",
"properties": {
"keys": {
"title": "Number of unique keys",
"type": "integer",
"minimum": 0,
},
"vectors": {
"title": "Number of unique vectors",
"type": "integer",
"minimum": 0,
},
"width": {
"title": "Number of dimensions",
"type": "integer",
"minimum": 0,
},
},
},
"accuracy": {
"title": "Accuracy numbers",
"type": "object",
"patternProperties": {"*": {"type": "number", "minimum": 0.0}},
},
"speed": {
"title": "Speed evaluation numbers",
"type": "object",
"patternProperties": {
"*": {
"oneOf": [
{"type": "number", "minimum": 0.0},
{"type": "integer", "minimum": 0},
]
}
},
},
},
"required": ["lang", "name", "version"],
}

View File

@ -105,7 +105,6 @@ class Warnings(object):
"smaller JSON files instead.") "smaller JSON files instead.")
@add_codes @add_codes
class Errors(object): class Errors(object):
E001 = ("No component '{name}' found in pipeline. Available names: {opts}") E001 = ("No component '{name}' found in pipeline. Available names: {opts}")
@ -419,8 +418,6 @@ class Errors(object):
E134 = ("Entity '{entity}' is not defined in the Knowledge Base.") E134 = ("Entity '{entity}' is not defined in the Knowledge Base.")
E135 = ("If you meant to replace a built-in component, use `create_pipe`: " E135 = ("If you meant to replace a built-in component, use `create_pipe`: "
"`nlp.replace_pipe('{name}', nlp.create_pipe('{name}'))`") "`nlp.replace_pipe('{name}', nlp.create_pipe('{name}'))`")
E136 = ("This additional feature requires the jsonschema library to be "
"installed:\npip install jsonschema")
E137 = ("Expected 'dict' type, but got '{type}' from '{line}'. Make sure " E137 = ("Expected 'dict' type, but got '{type}' from '{line}'. Make sure "
"to provide a valid JSON object as input with either the `text` " "to provide a valid JSON object as input with either the `text` "
"or `tokens` key. For more info, see the docs:\n" "or `tokens` key. For more info, see the docs:\n"

View File

@ -1,197 +0,0 @@
TOKEN_PATTERN_SCHEMA = {
"$schema": "http://json-schema.org/draft-06/schema",
"definitions": {
"string_value": {
"anyOf": [
{"type": "string"},
{
"type": "object",
"properties": {
"REGEX": {"type": "string"},
"IN": {"type": "array", "items": {"type": "string"}},
"NOT_IN": {"type": "array", "items": {"type": "string"}},
},
"additionalProperties": False,
},
]
},
"integer_value": {
"anyOf": [
{"type": "integer"},
{
"type": "object",
"properties": {
"REGEX": {"type": "string"},
"IN": {"type": "array", "items": {"type": "integer"}},
"NOT_IN": {"type": "array", "items": {"type": "integer"}},
"==": {"type": "integer"},
">=": {"type": "integer"},
"<=": {"type": "integer"},
">": {"type": "integer"},
"<": {"type": "integer"},
},
"additionalProperties": False,
},
]
},
"boolean_value": {"type": "boolean"},
"underscore_value": {
"anyOf": [
{"type": ["string", "integer", "number", "array", "boolean", "null"]},
{
"type": "object",
"properties": {
"REGEX": {"type": "string"},
"IN": {
"type": "array",
"items": {"type": ["string", "integer"]},
},
"NOT_IN": {
"type": "array",
"items": {"type": ["string", "integer"]},
},
"==": {"type": "integer"},
">=": {"type": "integer"},
"<=": {"type": "integer"},
">": {"type": "integer"},
"<": {"type": "integer"},
},
"additionalProperties": False,
},
]
},
},
"type": "array",
"items": {
"type": "object",
"properties": {
"ORTH": {
"title": "Verbatim token text",
"$ref": "#/definitions/string_value",
},
"TEXT": {
"title": "Verbatim token text (spaCy v2.1+)",
"$ref": "#/definitions/string_value",
},
"LOWER": {
"title": "Lowercase form of token text",
"$ref": "#/definitions/string_value",
},
"POS": {
"title": "Coarse-grained part-of-speech tag",
"$ref": "#/definitions/string_value",
},
"TAG": {
"title": "Fine-grained part-of-speech tag",
"$ref": "#/definitions/string_value",
},
"DEP": {"title": "Dependency label", "$ref": "#/definitions/string_value"},
"LEMMA": {
"title": "Lemma (base form)",
"$ref": "#/definitions/string_value",
},
"SHAPE": {
"title": "Abstract token shape",
"$ref": "#/definitions/string_value",
},
"ENT_TYPE": {
"title": "Entity label of single token",
"$ref": "#/definitions/string_value",
},
"NORM": {
"title": "Normalized form of the token text",
"$ref": "#/definitions/string_value",
},
"LENGTH": {
"title": "Token character length",
"$ref": "#/definitions/integer_value",
},
"IS_ALPHA": {
"title": "Token consists of alphabetic characters",
"$ref": "#/definitions/boolean_value",
},
"IS_ASCII": {
"title": "Token consists of ASCII characters",
"$ref": "#/definitions/boolean_value",
},
"IS_DIGIT": {
"title": "Token consists of digits",
"$ref": "#/definitions/boolean_value",
},
"IS_LOWER": {
"title": "Token is lowercase",
"$ref": "#/definitions/boolean_value",
},
"IS_UPPER": {
"title": "Token is uppercase",
"$ref": "#/definitions/boolean_value",
},
"IS_TITLE": {
"title": "Token is titlecase",
"$ref": "#/definitions/boolean_value",
},
"IS_PUNCT": {
"title": "Token is punctuation",
"$ref": "#/definitions/boolean_value",
},
"IS_SPACE": {
"title": "Token is whitespace",
"$ref": "#/definitions/boolean_value",
},
"IS_BRACKET": {
"title": "Token is a bracket",
"$ref": "#/definitions/boolean_value",
},
"IS_QUOTE": {
"title": "Token is a quotation mark",
"$ref": "#/definitions/boolean_value",
},
"IS_LEFT_PUNCT": {
"title": "Token is a left punctuation mark",
"$ref": "#/definitions/boolean_value",
},
"IS_RIGHT_PUNCT": {
"title": "Token is a right punctuation mark",
"$ref": "#/definitions/boolean_value",
},
"IS_CURRENCY": {
"title": "Token is a currency symbol",
"$ref": "#/definitions/boolean_value",
},
"IS_STOP": {
"title": "Token is stop word",
"$ref": "#/definitions/boolean_value",
},
"IS_SENT_START": {
"title": "Token is the first in a sentence",
"$ref": "#/definitions/boolean_value",
},
"LIKE_NUM": {
"title": "Token resembles a number",
"$ref": "#/definitions/boolean_value",
},
"LIKE_URL": {
"title": "Token resembles a URL",
"$ref": "#/definitions/boolean_value",
},
"LIKE_EMAIL": {
"title": "Token resembles an email address",
"$ref": "#/definitions/boolean_value",
},
"_": {
"title": "Custom extension token attributes (token._.)",
"type": "object",
"patternProperties": {
"^.*$": {"$ref": "#/definitions/underscore_value"}
},
},
"OP": {
"title": "Operators / quantifiers",
"type": "string",
"enum": ["+", "*", "?", "!"],
},
},
"additionalProperties": False,
},
}

View File

@ -39,7 +39,8 @@ cdef class DependencyMatcher:
RETURNS (DependencyMatcher): The newly constructed object. RETURNS (DependencyMatcher): The newly constructed object.
""" """
size = 20 size = 20
self.token_matcher = Matcher(vocab) # TODO: make matcher work with validation
self.token_matcher = Matcher(vocab, validate=False)
self._keys_to_token = {} self._keys_to_token = {}
self._patterns = {} self._patterns = {}
self._root = {} self._root = {}
@ -129,7 +130,7 @@ cdef class DependencyMatcher:
# TODO: Better ways to hash edges in pattern? # TODO: Better ways to hash edges in pattern?
for j in range(len(_patterns[i])): for j in range(len(_patterns[i])):
k = self._normalize_key(unicode(key) + DELIMITER + unicode(i) + DELIMITER + unicode(j)) k = self._normalize_key(unicode(key) + DELIMITER + unicode(i) + DELIMITER + unicode(j))
self.token_matcher.add(k, None, _patterns[i][j]) self.token_matcher.add(k, [_patterns[i][j]])
_keys_to_token[k] = j _keys_to_token[k] = j
_keys_to_token_list.append(_keys_to_token) _keys_to_token_list.append(_keys_to_token)
self._keys_to_token.setdefault(key, []) self._keys_to_token.setdefault(key, [])

View File

@ -63,7 +63,7 @@ cdef class Matcher:
cdef Pool mem cdef Pool mem
cdef vector[TokenPatternC*] patterns cdef vector[TokenPatternC*] patterns
cdef readonly Vocab vocab cdef readonly Vocab vocab
cdef public object validator cdef public object validate
cdef public object _patterns cdef public object _patterns
cdef public object _callbacks cdef public object _callbacks
cdef public object _extensions cdef public object _extensions

View File

@ -15,8 +15,7 @@ from ..tokens.doc cimport Doc, get_token_attr
from ..tokens.token cimport Token from ..tokens.token cimport Token
from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH, POS, TAG, DEP, LEMMA from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH, POS, TAG, DEP, LEMMA
from ._schemas import TOKEN_PATTERN_SCHEMA from ..schemas import validate_token_pattern
from ..util import get_json_validator, validate_json
from ..errors import Errors, MatchPatternError, Warnings, deprecation_warning from ..errors import Errors, MatchPatternError, Warnings, deprecation_warning
from ..strings import get_string_id from ..strings import get_string_id
from ..attrs import IDS from ..attrs import IDS
@ -32,7 +31,7 @@ cdef class Matcher:
USAGE: https://spacy.io/usage/rule-based-matching USAGE: https://spacy.io/usage/rule-based-matching
""" """
def __init__(self, vocab, validate=False): def __init__(self, vocab, validate=True):
"""Create the Matcher. """Create the Matcher.
vocab (Vocab): The vocabulary object, which must be shared with the vocab (Vocab): The vocabulary object, which must be shared with the
@ -46,10 +45,7 @@ cdef class Matcher:
self._seen_attrs = set() self._seen_attrs = set()
self.vocab = vocab self.vocab = vocab
self.mem = Pool() self.mem = Pool()
if validate: self.validate = validate
self.validator = get_json_validator(TOKEN_PATTERN_SCHEMA)
else:
self.validator = None
def __reduce__(self): def __reduce__(self):
data = (self.vocab, self._patterns, self._callbacks) data = (self.vocab, self._patterns, self._callbacks)
@ -119,8 +115,8 @@ cdef class Matcher:
raise ValueError(Errors.E012.format(key=key)) raise ValueError(Errors.E012.format(key=key))
if not isinstance(pattern, list): if not isinstance(pattern, list):
raise ValueError(Errors.E178.format(pat=pattern, key=key)) raise ValueError(Errors.E178.format(pat=pattern, key=key))
if self.validator: if self.validate:
errors[i] = validate_json(pattern, self.validator) errors[i] = validate_token_pattern(pattern)
if any(err for err in errors.values()): if any(err for err in errors.values()):
raise MatchPatternError(key, errors) raise MatchPatternError(key, errors)
key = self._normalize_key(key) key = self._normalize_key(key)
@ -668,8 +664,6 @@ def _get_attr_values(spec, string_store):
continue continue
if attr == "TEXT": if attr == "TEXT":
attr = "ORTH" attr = "ORTH"
if attr not in TOKEN_PATTERN_SCHEMA["items"]["properties"]:
raise ValueError(Errors.E152.format(attr=attr))
attr = IDS.get(attr) attr = IDS.get(attr)
if isinstance(value, basestring): if isinstance(value, basestring):
value = string_store.add(value) value = string_store.add(value)
@ -684,7 +678,7 @@ def _get_attr_values(spec, string_store):
if attr is not None: if attr is not None:
attr_values.append((attr, value)) attr_values.append((attr, value))
else: else:
# should be caught above using TOKEN_PATTERN_SCHEMA # should be caught in validation
raise ValueError(Errors.E152.format(attr=attr)) raise ValueError(Errors.E152.format(attr=attr))
return attr_values return attr_values

View File

@ -9,7 +9,7 @@ from ..structs cimport TokenC
from ..tokens.token cimport Token from ..tokens.token cimport Token
from ..typedefs cimport attr_t from ..typedefs cimport attr_t
from ._schemas import TOKEN_PATTERN_SCHEMA from ..schemas import TokenPattern
from ..errors import Errors, Warnings, deprecation_warning, user_warning from ..errors import Errors, Warnings, deprecation_warning, user_warning
@ -54,7 +54,7 @@ cdef class PhraseMatcher:
attr = attr.upper() attr = attr.upper()
if attr == "TEXT": if attr == "TEXT":
attr = "ORTH" attr = "ORTH"
if attr not in TOKEN_PATTERN_SCHEMA["items"]["properties"]: if attr.lower() not in TokenPattern().dict():
raise ValueError(Errors.E152.format(attr=attr)) raise ValueError(Errors.E152.format(attr=attr))
self.attr = self.vocab.strings[attr] self.attr = self.vocab.strings[attr]

188
spacy/schemas.py Normal file
View File

@ -0,0 +1,188 @@
from typing import Dict, List, Union, Optional
from enum import Enum
from pydantic import BaseModel, Field, ValidationError, validator
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
from collections import defaultdict
from .attrs import NAMES
def validate(schema, obj):
"""Validate data against a given pydantic schema.
obj (dict): JSON-serializable data to validate.
schema (pydantic.BaseModel): The schema to validate against.
RETURNS (list): A list of error messages, if available.
"""
try:
schema(**obj)
return []
except ValidationError as e:
errors = e.errors()
data = defaultdict(list)
for error in errors:
err_loc = " -> ".join([str(p) for p in error.get("loc", [])])
data[err_loc].append(error.get("msg"))
return [f"[{loc}] {', '.join(msg)}" for loc, msg in data.items()]
# Matcher token patterns
def validate_token_pattern(obj):
# Try to convert non-string keys (e.g. {ORTH: "foo"} -> {"ORTH": "foo"})
get_key = lambda k: NAMES[k] if isinstance(k, int) and k < len(NAMES) else k
if isinstance(obj, list):
converted = []
for pattern in obj:
if isinstance(pattern, dict):
pattern = {get_key(k): v for k, v in pattern.items()}
converted.append(pattern)
obj = converted
return validate(TokenPatternSchema, {"pattern": obj})
class TokenPatternString(BaseModel):
REGEX: Optional[StrictStr]
IN: Optional[List[StrictStr]]
NOT_IN: Optional[List[StrictStr]]
class Config:
extra = "forbid"
@validator("*", pre=True, whole=True)
def raise_for_none(cls, v):
if v is None:
raise ValueError("None / null is not allowed")
return v
class TokenPatternNumber(BaseModel):
REGEX: Optional[StrictStr] = None
IN: Optional[List[StrictInt]] = None
NOT_IN: Optional[List[StrictInt]] = None
EQ: Union[StrictInt, StrictFloat] = Field(None, alias="==")
GEQ: Union[StrictInt, StrictFloat] = Field(None, alias=">=")
LEQ: Union[StrictInt, StrictFloat] = Field(None, alias="<=")
GT: Union[StrictInt, StrictFloat] = Field(None, alias=">")
LT: Union[StrictInt, StrictFloat] = Field(None, alias="<")
class Config:
extra = "forbid"
@validator("*", pre=True, whole=True)
def raise_for_none(cls, v):
if v is None:
raise ValueError("None / null is not allowed")
return v
class TokenPatternOperator(str, Enum):
plus: StrictStr = "+"
start: StrictStr = "*"
question: StrictStr = "?"
exclamation: StrictStr = "!"
StringValue = Union[TokenPatternString, StrictStr]
NumberValue = Union[TokenPatternNumber, StrictInt, StrictFloat]
UnderscoreValue = Union[
TokenPatternString, TokenPatternNumber, str, int, float, list, bool,
]
class TokenPattern(BaseModel):
orth: Optional[StringValue] = None
text: Optional[StringValue] = None
lower: Optional[StringValue] = None
pos: Optional[StringValue] = None
tag: Optional[StringValue] = None
dep: Optional[StringValue] = None
lemma: Optional[StringValue] = None
shape: Optional[StringValue] = None
ent_type: Optional[StringValue] = None
norm: Optional[StringValue] = None
length: Optional[NumberValue] = None
is_alpha: Optional[StrictBool] = None
is_ascii: Optional[StrictBool] = None
is_digit: Optional[StrictBool] = None
is_lower: Optional[StrictBool] = None
is_upper: Optional[StrictBool] = None
is_title: Optional[StrictBool] = None
is_punct: Optional[StrictBool] = None
is_space: Optional[StrictBool] = None
is_bracket: Optional[StrictBool] = None
is_quote: Optional[StrictBool] = None
is_left_punct: Optional[StrictBool] = None
is_right_punct: Optional[StrictBool] = None
is_currency: Optional[StrictBool] = None
is_stop: Optional[StrictBool] = None
is_sent_start: Optional[StrictBool] = None
like_num: Optional[StrictBool] = None
like_url: Optional[StrictBool] = None
like_email: Optional[StrictBool] = None
op: Optional[TokenPatternOperator] = None
underscore: Optional[Dict[StrictStr, UnderscoreValue]] = Field(None, alias="_")
class Config:
extra = "forbid"
allow_population_by_field_name = True
alias_generator = lambda value: value.upper()
@validator("*", pre=True)
def raise_for_none(cls, v):
if v is None:
raise ValueError("None / null is not allowed")
return v
class TokenPatternSchema(BaseModel):
pattern: List[TokenPattern] = Field(..., minItems=1)
class Config:
extra = "forbid"
# Model meta
class ModelMetaSchema(BaseModel):
# fmt: off
lang: StrictStr = Field(..., title="Two-letter language code, e.g. 'en'")
name: StrictStr = Field(..., title="Model name")
version: StrictStr = Field(..., title="Model version")
spacy_version: Optional[StrictStr] = Field(None, title="Compatible spaCy version identifier")
parent_package: Optional[StrictStr] = Field("spacy", title="Name of parent spaCy package, e.g. spacy or spacy-nightly")
pipeline: Optional[List[StrictStr]] = Field([], title="Names of pipeline components")
description: Optional[StrictStr] = Field(None, title="Model description")
license: Optional[StrictStr] = Field(None, title="Model license")
author: Optional[StrictStr] = Field(None, title="Model author name")
email: Optional[StrictStr] = Field(None, title="Model author email")
url: Optional[StrictStr] = Field(None, title="Model author URL")
sources: Optional[Union[List[StrictStr], Dict[str, str]]] = Field(None, title="Training data sources")
vectors: Optional[Dict[str, int]] = Field(None, title="Included word vectors")
accuracy: Optional[Dict[str, Union[float, int]]] = Field(None, title="Accuracy numbers")
speed: Optional[Dict[str, Union[float, int]]] = Field(None, title="Speed evaluation numbers")
# fmt: on
# Training data object in "simple training style"
class SimpleTrainingSchema(BaseModel):
# TODO: write
class Config:
title = "Schema for training data dict in passed to nlp.update"
extra = "forbid"
# JSON training format
class TrainingSchema(BaseModel):
# TODO: write
class Config:
title = "Schema for training data in spaCy's JSON format"
extra = "forbid"

View File

@ -1,6 +1,4 @@
import pytest import pytest
from spacy.cli._schemas import TRAINING_SCHEMA
from spacy.util import get_json_validator, validate_json
from spacy.tokens import Doc from spacy.tokens import Doc
from ..util import get_doc from ..util import get_doc
@ -55,10 +53,3 @@ def test_doc_to_json_underscore_error_serialize(doc):
Doc.set_extension("json_test4", method=lambda doc: doc.text) Doc.set_extension("json_test4", method=lambda doc: doc.text)
with pytest.raises(ValueError): with pytest.raises(ValueError):
doc.to_json(underscore=["json_test4"]) doc.to_json(underscore=["json_test4"])
def test_doc_to_json_valid_training(doc):
json_doc = doc.to_json()
validator = get_json_validator(TRAINING_SCHEMA)
errors = validate_json([json_doc], validator)
assert not errors

View File

@ -179,7 +179,7 @@ def test_matcher_match_one_plus(matcher):
doc = Doc(control.vocab, words=["Philippe", "Philippe"]) doc = Doc(control.vocab, words=["Philippe", "Philippe"])
m = control(doc) m = control(doc)
assert len(m) == 2 assert len(m) == 2
pattern = [{"ORTH": "Philippe", "OP": "1"}, {"ORTH": "Philippe", "OP": "+"}] pattern = [{"ORTH": "Philippe"}, {"ORTH": "Philippe", "OP": "+"}]
matcher.add("KleenePhilippe", [pattern]) matcher.add("KleenePhilippe", [pattern])
m = matcher(doc) m = matcher(doc)
assert len(m) == 1 assert len(m) == 1

View File

@ -6,18 +6,18 @@ from spacy.matcher import Matcher
from spacy.tokens import Doc, Span from spacy.tokens import Doc, Span
pattern1 = [{"ORTH": "A", "OP": "1"}, {"ORTH": "A", "OP": "*"}] pattern1 = [{"ORTH": "A"}, {"ORTH": "A", "OP": "*"}]
pattern2 = [{"ORTH": "A", "OP": "*"}, {"ORTH": "A", "OP": "1"}] pattern2 = [{"ORTH": "A"}, {"ORTH": "A"}]
pattern3 = [{"ORTH": "A", "OP": "1"}, {"ORTH": "A", "OP": "1"}] pattern3 = [{"ORTH": "A"}, {"ORTH": "A"}]
pattern4 = [ pattern4 = [
{"ORTH": "B", "OP": "1"}, {"ORTH": "B"},
{"ORTH": "A", "OP": "*"}, {"ORTH": "A", "OP": "*"},
{"ORTH": "B", "OP": "1"}, {"ORTH": "B"},
] ]
pattern5 = [ pattern5 = [
{"ORTH": "B", "OP": "*"}, {"ORTH": "B", "OP": "*"},
{"ORTH": "A", "OP": "*"}, {"ORTH": "A", "OP": "*"},
{"ORTH": "B", "OP": "1"}, {"ORTH": "B"},
] ]
re_pattern1 = "AA*" re_pattern1 = "AA*"

View File

@ -1,8 +1,7 @@
import pytest import pytest
from spacy.matcher import Matcher from spacy.matcher import Matcher
from spacy.matcher._schemas import TOKEN_PATTERN_SCHEMA
from spacy.errors import MatchPatternError from spacy.errors import MatchPatternError
from spacy.util import get_json_validator, validate_json from spacy.schemas import validate_token_pattern
# (pattern, num errors with validation, num errors identified with minimal # (pattern, num errors with validation, num errors identified with minimal
# checks) # checks)
@ -15,12 +14,12 @@ TEST_PATTERNS = [
('[{"TEXT": "foo"}, {"LOWER": "bar"}]', 1, 1), ('[{"TEXT": "foo"}, {"LOWER": "bar"}]', 1, 1),
([1, 2, 3], 3, 1), ([1, 2, 3], 3, 1),
# Bad patterns flagged outside of Matcher # Bad patterns flagged outside of Matcher
([{"_": {"foo": "bar", "baz": {"IN": "foo"}}}], 1, 0), ([{"_": {"foo": "bar", "baz": {"IN": "foo"}}}], 2, 0), # prev: (1, 0)
# Bad patterns not flagged with minimal checks # Bad patterns not flagged with minimal checks
([{"LENGTH": "2", "TEXT": 2}, {"LOWER": "test"}], 2, 0), ([{"LENGTH": "2", "TEXT": 2}, {"LOWER": "test"}], 2, 0),
([{"LENGTH": {"IN": [1, 2, "3"]}}, {"POS": {"IN": "VERB"}}], 2, 0), ([{"LENGTH": {"IN": [1, 2, "3"]}}, {"POS": {"IN": "VERB"}}], 4, 0), # prev: (2, 0)
([{"LENGTH": {"VALUE": 5}}], 1, 0), ([{"LENGTH": {"VALUE": 5}}], 2, 0), # prev: (1, 0)
([{"TEXT": {"VALUE": "foo"}}], 1, 0), ([{"TEXT": {"VALUE": "foo"}}], 2, 0), # prev: (1, 0)
([{"IS_DIGIT": -1}], 1, 0), ([{"IS_DIGIT": -1}], 1, 0),
([{"ORTH": -1}], 1, 0), ([{"ORTH": -1}], 1, 0),
# Good patterns # Good patterns
@ -31,15 +30,9 @@ TEST_PATTERNS = [
([{"LOWER": {"REGEX": "^X", "NOT_IN": ["XXX", "XY"]}}], 0, 0), ([{"LOWER": {"REGEX": "^X", "NOT_IN": ["XXX", "XY"]}}], 0, 0),
([{"NORM": "a"}, {"POS": {"IN": ["NOUN"]}}], 0, 0), ([{"NORM": "a"}, {"POS": {"IN": ["NOUN"]}}], 0, 0),
([{"_": {"foo": {"NOT_IN": ["bar", "baz"]}, "a": 5, "b": {">": 10}}}], 0, 0), ([{"_": {"foo": {"NOT_IN": ["bar", "baz"]}, "a": 5, "b": {">": 10}}}], 0, 0),
([{"orth": "foo"}], 0, 0), # prev: xfail
] ]
XFAIL_TEST_PATTERNS = [([{"orth": "foo"}], 0, 0)]
@pytest.fixture
def validator():
return get_json_validator(TOKEN_PATTERN_SCHEMA)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"pattern", [[{"XX": "y"}, {"LENGTH": "2"}, {"TEXT": {"IN": 5}}]] "pattern", [[{"XX": "y"}, {"LENGTH": "2"}, {"TEXT": {"IN": 5}}]]
@ -51,15 +44,8 @@ def test_matcher_pattern_validation(en_vocab, pattern):
@pytest.mark.parametrize("pattern,n_errors,_", TEST_PATTERNS) @pytest.mark.parametrize("pattern,n_errors,_", TEST_PATTERNS)
def test_pattern_validation(validator, pattern, n_errors, _): def test_pattern_validation(pattern, n_errors, _):
errors = validate_json(pattern, validator) errors = validate_token_pattern(pattern)
assert len(errors) == n_errors
@pytest.mark.xfail
@pytest.mark.parametrize("pattern,n_errors,_", XFAIL_TEST_PATTERNS)
def test_xfail_pattern_validation(validator, pattern, n_errors, _):
errors = validate_json(pattern, validator)
assert len(errors) == n_errors assert len(errors) == n_errors

View File

@ -1,47 +0,0 @@
from spacy.util import get_json_validator, validate_json, validate_schema
from spacy.cli._schemas import META_SCHEMA, TRAINING_SCHEMA
from spacy.matcher._schemas import TOKEN_PATTERN_SCHEMA
import pytest
@pytest.fixture(scope="session")
def training_schema_validator():
return get_json_validator(TRAINING_SCHEMA)
def test_validate_schema():
validate_schema({"type": "object"})
with pytest.raises(Exception):
validate_schema({"type": lambda x: x})
@pytest.mark.parametrize("schema", [TRAINING_SCHEMA, META_SCHEMA, TOKEN_PATTERN_SCHEMA])
def test_schemas(schema):
validate_schema(schema)
@pytest.mark.parametrize(
"data",
[
{"text": "Hello world"},
{"text": "Hello", "ents": [{"start": 0, "end": 5, "label": "TEST"}]},
],
)
def test_json_schema_training_valid(data, training_schema_validator):
errors = validate_json([data], training_schema_validator)
assert not errors
@pytest.mark.parametrize(
"data,n_errors",
[
({"spans": []}, 1),
({"text": "Hello", "ents": [{"start": "0", "end": "5", "label": "TEST"}]}, 2),
({"text": "Hello", "ents": [{"start": 0, "end": 5}]}, 1),
({"text": "Hello", "ents": [{"start": 0, "end": 5, "label": "test"}]}, 1),
({"text": "spaCy", "tokens": [{"pos": "PROPN"}]}, 2),
],
)
def test_json_schema_training_invalid(data, n_errors, training_schema_validator):
errors = validate_json([data], training_schema_validator)
assert len(errors) == n_errors

View File

@ -13,11 +13,6 @@ import srsly
import catalogue import catalogue
import sys import sys
try:
import jsonschema
except ImportError:
jsonschema = None
try: try:
import cupy.random import cupy.random
except ImportError: except ImportError:
@ -705,43 +700,6 @@ def fix_random_seed(seed=0):
cupy.random.seed(seed) cupy.random.seed(seed)
def get_json_validator(schema):
# We're using a helper function here to make it easier to change the
# validator that's used (e.g. different draft implementation), without
# having to change it all across the codebase.
# TODO: replace with (stable) Draft6Validator, if available
if jsonschema is None:
raise ValueError(Errors.E136)
return jsonschema.Draft4Validator(schema)
def validate_schema(schema):
"""Validate a given schema. This just checks if the schema itself is valid."""
validator = get_json_validator(schema)
validator.check_schema(schema)
def validate_json(data, validator):
"""Validate data against a given JSON schema (see https://json-schema.org).
data: JSON-serializable data to validate.
validator (jsonschema.DraftXValidator): The validator.
RETURNS (list): A list of error messages, if available.
"""
errors = []
for err in sorted(validator.iter_errors(data), key=lambda e: e.path):
if err.path:
err_path = "[{}]".format(" -> ".join([str(p) for p in err.path]))
else:
err_path = ""
msg = err.message + " " + err_path
if err.context: # Error has suberrors, e.g. if schema uses anyOf
suberrs = [f" - {suberr.message}" for suberr in err.context]
msg += f":\n{''.join(suberrs)}"
errors.append(msg)
return errors
def get_serialization_exclude(serializers, exclude, kwargs): def get_serialization_exclude(serializers, exclude, kwargs):
"""Helper function to validate serialization args and manage transition from """Helper function to validate serialization args and manage transition from
keyword arguments (pre v2.1) to exclude argument. keyword arguments (pre v2.1) to exclude argument.