mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 09:44:36 +03:00
Add better schemas and validation using Pydantic (#4831)
* Remove unicode declarations * Remove Python 3.5 and 2.7 from CI * Don't require pathlib * Replace compat helpers * Remove OrderedDict * Use f-strings * Set Cython compiler language level * Fix typo * Re-add OrderedDict for Table * Update setup.cfg * Revert CONTRIBUTING.md * Add better schemas and validation using Pydantic * Revert lookups.md * Remove unused import * Update spacy/schemas.py Co-Authored-By: Sebastián Ramírez <tiangolo@gmail.com> * Various small fixes * Fix docstring Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
This commit is contained in:
parent
db55577c45
commit
33a2682d60
|
@ -12,8 +12,7 @@ numpy>=1.15.0
|
||||||
requests>=2.13.0,<3.0.0
|
requests>=2.13.0,<3.0.0
|
||||||
plac>=0.9.6,<1.2.0
|
plac>=0.9.6,<1.2.0
|
||||||
tqdm>=4.38.0,<5.0.0
|
tqdm>=4.38.0,<5.0.0
|
||||||
# Optional dependencies
|
pydantic>=1.0.0,<2.0.0
|
||||||
jsonschema>=2.6.0,<3.1.0
|
|
||||||
# Development dependencies
|
# Development dependencies
|
||||||
cython>=0.25
|
cython>=0.25
|
||||||
pytest>=4.6.5
|
pytest>=4.6.5
|
||||||
|
|
|
@ -51,6 +51,7 @@ install_requires =
|
||||||
numpy>=1.15.0
|
numpy>=1.15.0
|
||||||
plac>=0.9.6,<1.2.0
|
plac>=0.9.6,<1.2.0
|
||||||
requests>=2.13.0,<3.0.0
|
requests>=2.13.0,<3.0.0
|
||||||
|
pydantic>=1.0.0,<2.0.0
|
||||||
|
|
||||||
[options.extras_require]
|
[options.extras_require]
|
||||||
lookups =
|
lookups =
|
||||||
|
|
|
@ -1,217 +0,0 @@
|
||||||
|
|
||||||
# NB: This schema describes the new format of the training data, see #2928
|
|
||||||
TRAINING_SCHEMA = {
|
|
||||||
"$schema": "http://json-schema.org/draft-06/schema",
|
|
||||||
"title": "Training data for spaCy models",
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"text": {
|
|
||||||
"title": "The text of the training example",
|
|
||||||
"type": "string",
|
|
||||||
"minLength": 1,
|
|
||||||
},
|
|
||||||
"ents": {
|
|
||||||
"title": "Named entity spans in the text",
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"start": {
|
|
||||||
"title": "Start character offset of the span",
|
|
||||||
"type": "integer",
|
|
||||||
"minimum": 0,
|
|
||||||
},
|
|
||||||
"end": {
|
|
||||||
"title": "End character offset of the span",
|
|
||||||
"type": "integer",
|
|
||||||
"minimum": 0,
|
|
||||||
},
|
|
||||||
"label": {
|
|
||||||
"title": "Entity label",
|
|
||||||
"type": "string",
|
|
||||||
"minLength": 1,
|
|
||||||
"pattern": "^[A-Z0-9]*$",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["start", "end", "label"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"sents": {
|
|
||||||
"title": "Sentence spans in the text",
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"start": {
|
|
||||||
"title": "Start character offset of the span",
|
|
||||||
"type": "integer",
|
|
||||||
"minimum": 0,
|
|
||||||
},
|
|
||||||
"end": {
|
|
||||||
"title": "End character offset of the span",
|
|
||||||
"type": "integer",
|
|
||||||
"minimum": 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["start", "end"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"cats": {
|
|
||||||
"title": "Text categories for the text classifier",
|
|
||||||
"type": "object",
|
|
||||||
"patternProperties": {
|
|
||||||
"*": {
|
|
||||||
"title": "A text category",
|
|
||||||
"oneOf": [
|
|
||||||
{"type": "boolean"},
|
|
||||||
{"type": "number", "minimum": 0},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"propertyNames": {"pattern": "^[A-Z0-9]*$", "minLength": 1},
|
|
||||||
},
|
|
||||||
"tokens": {
|
|
||||||
"title": "The tokens in the text",
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"type": "object",
|
|
||||||
"minProperties": 1,
|
|
||||||
"properties": {
|
|
||||||
"id": {
|
|
||||||
"title": "Token ID, usually token index",
|
|
||||||
"type": "integer",
|
|
||||||
"minimum": 0,
|
|
||||||
},
|
|
||||||
"start": {
|
|
||||||
"title": "Start character offset of the token",
|
|
||||||
"type": "integer",
|
|
||||||
"minimum": 0,
|
|
||||||
},
|
|
||||||
"end": {
|
|
||||||
"title": "End character offset of the token",
|
|
||||||
"type": "integer",
|
|
||||||
"minimum": 0,
|
|
||||||
},
|
|
||||||
"pos": {
|
|
||||||
"title": "Coarse-grained part-of-speech tag",
|
|
||||||
"type": "string",
|
|
||||||
"minLength": 1,
|
|
||||||
},
|
|
||||||
"tag": {
|
|
||||||
"title": "Fine-grained part-of-speech tag",
|
|
||||||
"type": "string",
|
|
||||||
"minLength": 1,
|
|
||||||
},
|
|
||||||
"dep": {
|
|
||||||
"title": "Dependency label",
|
|
||||||
"type": "string",
|
|
||||||
"minLength": 1,
|
|
||||||
},
|
|
||||||
"head": {
|
|
||||||
"title": "Index of the token's head",
|
|
||||||
"type": "integer",
|
|
||||||
"minimum": 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["start", "end"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"_": {"title": "Custom user space", "type": "object"},
|
|
||||||
},
|
|
||||||
"required": ["text"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
META_SCHEMA = {
|
|
||||||
"$schema": "http://json-schema.org/draft-06/schema",
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"lang": {
|
|
||||||
"title": "Two-letter language code, e.g. 'en'",
|
|
||||||
"type": "string",
|
|
||||||
"minLength": 2,
|
|
||||||
"maxLength": 2,
|
|
||||||
"pattern": "^[a-z]*$",
|
|
||||||
},
|
|
||||||
"name": {
|
|
||||||
"title": "Model name",
|
|
||||||
"type": "string",
|
|
||||||
"minLength": 1,
|
|
||||||
"pattern": "^[a-z_]*$",
|
|
||||||
},
|
|
||||||
"version": {
|
|
||||||
"title": "Model version",
|
|
||||||
"type": "string",
|
|
||||||
"minLength": 1,
|
|
||||||
"pattern": "^[0-9a-z.-]*$",
|
|
||||||
},
|
|
||||||
"spacy_version": {
|
|
||||||
"title": "Compatible spaCy version identifier",
|
|
||||||
"type": "string",
|
|
||||||
"minLength": 1,
|
|
||||||
"pattern": "^[0-9a-z.-><=]*$",
|
|
||||||
},
|
|
||||||
"parent_package": {
|
|
||||||
"title": "Name of parent spaCy package, e.g. spacy or spacy-nightly",
|
|
||||||
"type": "string",
|
|
||||||
"minLength": 1,
|
|
||||||
"default": "spacy",
|
|
||||||
},
|
|
||||||
"pipeline": {
|
|
||||||
"title": "Names of pipeline components",
|
|
||||||
"type": "array",
|
|
||||||
"items": {"type": "string", "minLength": 1},
|
|
||||||
},
|
|
||||||
"description": {"title": "Model description", "type": "string"},
|
|
||||||
"license": {"title": "Model license", "type": "string"},
|
|
||||||
"author": {"title": "Model author name", "type": "string"},
|
|
||||||
"email": {"title": "Model author email", "type": "string", "format": "email"},
|
|
||||||
"url": {"title": "Model author URL", "type": "string", "format": "uri"},
|
|
||||||
"sources": {
|
|
||||||
"title": "Training data sources",
|
|
||||||
"type": "array",
|
|
||||||
"items": {"type": "string"},
|
|
||||||
},
|
|
||||||
"vectors": {
|
|
||||||
"title": "Included word vectors",
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"keys": {
|
|
||||||
"title": "Number of unique keys",
|
|
||||||
"type": "integer",
|
|
||||||
"minimum": 0,
|
|
||||||
},
|
|
||||||
"vectors": {
|
|
||||||
"title": "Number of unique vectors",
|
|
||||||
"type": "integer",
|
|
||||||
"minimum": 0,
|
|
||||||
},
|
|
||||||
"width": {
|
|
||||||
"title": "Number of dimensions",
|
|
||||||
"type": "integer",
|
|
||||||
"minimum": 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"accuracy": {
|
|
||||||
"title": "Accuracy numbers",
|
|
||||||
"type": "object",
|
|
||||||
"patternProperties": {"*": {"type": "number", "minimum": 0.0}},
|
|
||||||
},
|
|
||||||
"speed": {
|
|
||||||
"title": "Speed evaluation numbers",
|
|
||||||
"type": "object",
|
|
||||||
"patternProperties": {
|
|
||||||
"*": {
|
|
||||||
"oneOf": [
|
|
||||||
{"type": "number", "minimum": 0.0},
|
|
||||||
{"type": "integer", "minimum": 0},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["lang", "name", "version"],
|
|
||||||
}
|
|
|
@ -105,7 +105,6 @@ class Warnings(object):
|
||||||
"smaller JSON files instead.")
|
"smaller JSON files instead.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@add_codes
|
@add_codes
|
||||||
class Errors(object):
|
class Errors(object):
|
||||||
E001 = ("No component '{name}' found in pipeline. Available names: {opts}")
|
E001 = ("No component '{name}' found in pipeline. Available names: {opts}")
|
||||||
|
@ -419,8 +418,6 @@ class Errors(object):
|
||||||
E134 = ("Entity '{entity}' is not defined in the Knowledge Base.")
|
E134 = ("Entity '{entity}' is not defined in the Knowledge Base.")
|
||||||
E135 = ("If you meant to replace a built-in component, use `create_pipe`: "
|
E135 = ("If you meant to replace a built-in component, use `create_pipe`: "
|
||||||
"`nlp.replace_pipe('{name}', nlp.create_pipe('{name}'))`")
|
"`nlp.replace_pipe('{name}', nlp.create_pipe('{name}'))`")
|
||||||
E136 = ("This additional feature requires the jsonschema library to be "
|
|
||||||
"installed:\npip install jsonschema")
|
|
||||||
E137 = ("Expected 'dict' type, but got '{type}' from '{line}'. Make sure "
|
E137 = ("Expected 'dict' type, but got '{type}' from '{line}'. Make sure "
|
||||||
"to provide a valid JSON object as input with either the `text` "
|
"to provide a valid JSON object as input with either the `text` "
|
||||||
"or `tokens` key. For more info, see the docs:\n"
|
"or `tokens` key. For more info, see the docs:\n"
|
||||||
|
|
|
@ -1,197 +0,0 @@
|
||||||
|
|
||||||
TOKEN_PATTERN_SCHEMA = {
|
|
||||||
"$schema": "http://json-schema.org/draft-06/schema",
|
|
||||||
"definitions": {
|
|
||||||
"string_value": {
|
|
||||||
"anyOf": [
|
|
||||||
{"type": "string"},
|
|
||||||
{
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"REGEX": {"type": "string"},
|
|
||||||
"IN": {"type": "array", "items": {"type": "string"}},
|
|
||||||
"NOT_IN": {"type": "array", "items": {"type": "string"}},
|
|
||||||
},
|
|
||||||
"additionalProperties": False,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"integer_value": {
|
|
||||||
"anyOf": [
|
|
||||||
{"type": "integer"},
|
|
||||||
{
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"REGEX": {"type": "string"},
|
|
||||||
"IN": {"type": "array", "items": {"type": "integer"}},
|
|
||||||
"NOT_IN": {"type": "array", "items": {"type": "integer"}},
|
|
||||||
"==": {"type": "integer"},
|
|
||||||
">=": {"type": "integer"},
|
|
||||||
"<=": {"type": "integer"},
|
|
||||||
">": {"type": "integer"},
|
|
||||||
"<": {"type": "integer"},
|
|
||||||
},
|
|
||||||
"additionalProperties": False,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"boolean_value": {"type": "boolean"},
|
|
||||||
"underscore_value": {
|
|
||||||
"anyOf": [
|
|
||||||
{"type": ["string", "integer", "number", "array", "boolean", "null"]},
|
|
||||||
{
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"REGEX": {"type": "string"},
|
|
||||||
"IN": {
|
|
||||||
"type": "array",
|
|
||||||
"items": {"type": ["string", "integer"]},
|
|
||||||
},
|
|
||||||
"NOT_IN": {
|
|
||||||
"type": "array",
|
|
||||||
"items": {"type": ["string", "integer"]},
|
|
||||||
},
|
|
||||||
"==": {"type": "integer"},
|
|
||||||
">=": {"type": "integer"},
|
|
||||||
"<=": {"type": "integer"},
|
|
||||||
">": {"type": "integer"},
|
|
||||||
"<": {"type": "integer"},
|
|
||||||
},
|
|
||||||
"additionalProperties": False,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"ORTH": {
|
|
||||||
"title": "Verbatim token text",
|
|
||||||
"$ref": "#/definitions/string_value",
|
|
||||||
},
|
|
||||||
"TEXT": {
|
|
||||||
"title": "Verbatim token text (spaCy v2.1+)",
|
|
||||||
"$ref": "#/definitions/string_value",
|
|
||||||
},
|
|
||||||
"LOWER": {
|
|
||||||
"title": "Lowercase form of token text",
|
|
||||||
"$ref": "#/definitions/string_value",
|
|
||||||
},
|
|
||||||
"POS": {
|
|
||||||
"title": "Coarse-grained part-of-speech tag",
|
|
||||||
"$ref": "#/definitions/string_value",
|
|
||||||
},
|
|
||||||
"TAG": {
|
|
||||||
"title": "Fine-grained part-of-speech tag",
|
|
||||||
"$ref": "#/definitions/string_value",
|
|
||||||
},
|
|
||||||
"DEP": {"title": "Dependency label", "$ref": "#/definitions/string_value"},
|
|
||||||
"LEMMA": {
|
|
||||||
"title": "Lemma (base form)",
|
|
||||||
"$ref": "#/definitions/string_value",
|
|
||||||
},
|
|
||||||
"SHAPE": {
|
|
||||||
"title": "Abstract token shape",
|
|
||||||
"$ref": "#/definitions/string_value",
|
|
||||||
},
|
|
||||||
"ENT_TYPE": {
|
|
||||||
"title": "Entity label of single token",
|
|
||||||
"$ref": "#/definitions/string_value",
|
|
||||||
},
|
|
||||||
"NORM": {
|
|
||||||
"title": "Normalized form of the token text",
|
|
||||||
"$ref": "#/definitions/string_value",
|
|
||||||
},
|
|
||||||
"LENGTH": {
|
|
||||||
"title": "Token character length",
|
|
||||||
"$ref": "#/definitions/integer_value",
|
|
||||||
},
|
|
||||||
"IS_ALPHA": {
|
|
||||||
"title": "Token consists of alphabetic characters",
|
|
||||||
"$ref": "#/definitions/boolean_value",
|
|
||||||
},
|
|
||||||
"IS_ASCII": {
|
|
||||||
"title": "Token consists of ASCII characters",
|
|
||||||
"$ref": "#/definitions/boolean_value",
|
|
||||||
},
|
|
||||||
"IS_DIGIT": {
|
|
||||||
"title": "Token consists of digits",
|
|
||||||
"$ref": "#/definitions/boolean_value",
|
|
||||||
},
|
|
||||||
"IS_LOWER": {
|
|
||||||
"title": "Token is lowercase",
|
|
||||||
"$ref": "#/definitions/boolean_value",
|
|
||||||
},
|
|
||||||
"IS_UPPER": {
|
|
||||||
"title": "Token is uppercase",
|
|
||||||
"$ref": "#/definitions/boolean_value",
|
|
||||||
},
|
|
||||||
"IS_TITLE": {
|
|
||||||
"title": "Token is titlecase",
|
|
||||||
"$ref": "#/definitions/boolean_value",
|
|
||||||
},
|
|
||||||
"IS_PUNCT": {
|
|
||||||
"title": "Token is punctuation",
|
|
||||||
"$ref": "#/definitions/boolean_value",
|
|
||||||
},
|
|
||||||
"IS_SPACE": {
|
|
||||||
"title": "Token is whitespace",
|
|
||||||
"$ref": "#/definitions/boolean_value",
|
|
||||||
},
|
|
||||||
"IS_BRACKET": {
|
|
||||||
"title": "Token is a bracket",
|
|
||||||
"$ref": "#/definitions/boolean_value",
|
|
||||||
},
|
|
||||||
"IS_QUOTE": {
|
|
||||||
"title": "Token is a quotation mark",
|
|
||||||
"$ref": "#/definitions/boolean_value",
|
|
||||||
},
|
|
||||||
"IS_LEFT_PUNCT": {
|
|
||||||
"title": "Token is a left punctuation mark",
|
|
||||||
"$ref": "#/definitions/boolean_value",
|
|
||||||
},
|
|
||||||
"IS_RIGHT_PUNCT": {
|
|
||||||
"title": "Token is a right punctuation mark",
|
|
||||||
"$ref": "#/definitions/boolean_value",
|
|
||||||
},
|
|
||||||
"IS_CURRENCY": {
|
|
||||||
"title": "Token is a currency symbol",
|
|
||||||
"$ref": "#/definitions/boolean_value",
|
|
||||||
},
|
|
||||||
"IS_STOP": {
|
|
||||||
"title": "Token is stop word",
|
|
||||||
"$ref": "#/definitions/boolean_value",
|
|
||||||
},
|
|
||||||
"IS_SENT_START": {
|
|
||||||
"title": "Token is the first in a sentence",
|
|
||||||
"$ref": "#/definitions/boolean_value",
|
|
||||||
},
|
|
||||||
"LIKE_NUM": {
|
|
||||||
"title": "Token resembles a number",
|
|
||||||
"$ref": "#/definitions/boolean_value",
|
|
||||||
},
|
|
||||||
"LIKE_URL": {
|
|
||||||
"title": "Token resembles a URL",
|
|
||||||
"$ref": "#/definitions/boolean_value",
|
|
||||||
},
|
|
||||||
"LIKE_EMAIL": {
|
|
||||||
"title": "Token resembles an email address",
|
|
||||||
"$ref": "#/definitions/boolean_value",
|
|
||||||
},
|
|
||||||
"_": {
|
|
||||||
"title": "Custom extension token attributes (token._.)",
|
|
||||||
"type": "object",
|
|
||||||
"patternProperties": {
|
|
||||||
"^.*$": {"$ref": "#/definitions/underscore_value"}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"OP": {
|
|
||||||
"title": "Operators / quantifiers",
|
|
||||||
"type": "string",
|
|
||||||
"enum": ["+", "*", "?", "!"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"additionalProperties": False,
|
|
||||||
},
|
|
||||||
}
|
|
|
@ -39,7 +39,8 @@ cdef class DependencyMatcher:
|
||||||
RETURNS (DependencyMatcher): The newly constructed object.
|
RETURNS (DependencyMatcher): The newly constructed object.
|
||||||
"""
|
"""
|
||||||
size = 20
|
size = 20
|
||||||
self.token_matcher = Matcher(vocab)
|
# TODO: make matcher work with validation
|
||||||
|
self.token_matcher = Matcher(vocab, validate=False)
|
||||||
self._keys_to_token = {}
|
self._keys_to_token = {}
|
||||||
self._patterns = {}
|
self._patterns = {}
|
||||||
self._root = {}
|
self._root = {}
|
||||||
|
@ -129,7 +130,7 @@ cdef class DependencyMatcher:
|
||||||
# TODO: Better ways to hash edges in pattern?
|
# TODO: Better ways to hash edges in pattern?
|
||||||
for j in range(len(_patterns[i])):
|
for j in range(len(_patterns[i])):
|
||||||
k = self._normalize_key(unicode(key) + DELIMITER + unicode(i) + DELIMITER + unicode(j))
|
k = self._normalize_key(unicode(key) + DELIMITER + unicode(i) + DELIMITER + unicode(j))
|
||||||
self.token_matcher.add(k, None, _patterns[i][j])
|
self.token_matcher.add(k, [_patterns[i][j]])
|
||||||
_keys_to_token[k] = j
|
_keys_to_token[k] = j
|
||||||
_keys_to_token_list.append(_keys_to_token)
|
_keys_to_token_list.append(_keys_to_token)
|
||||||
self._keys_to_token.setdefault(key, [])
|
self._keys_to_token.setdefault(key, [])
|
||||||
|
|
|
@ -63,7 +63,7 @@ cdef class Matcher:
|
||||||
cdef Pool mem
|
cdef Pool mem
|
||||||
cdef vector[TokenPatternC*] patterns
|
cdef vector[TokenPatternC*] patterns
|
||||||
cdef readonly Vocab vocab
|
cdef readonly Vocab vocab
|
||||||
cdef public object validator
|
cdef public object validate
|
||||||
cdef public object _patterns
|
cdef public object _patterns
|
||||||
cdef public object _callbacks
|
cdef public object _callbacks
|
||||||
cdef public object _extensions
|
cdef public object _extensions
|
||||||
|
|
|
@ -15,8 +15,7 @@ from ..tokens.doc cimport Doc, get_token_attr
|
||||||
from ..tokens.token cimport Token
|
from ..tokens.token cimport Token
|
||||||
from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH, POS, TAG, DEP, LEMMA
|
from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH, POS, TAG, DEP, LEMMA
|
||||||
|
|
||||||
from ._schemas import TOKEN_PATTERN_SCHEMA
|
from ..schemas import validate_token_pattern
|
||||||
from ..util import get_json_validator, validate_json
|
|
||||||
from ..errors import Errors, MatchPatternError, Warnings, deprecation_warning
|
from ..errors import Errors, MatchPatternError, Warnings, deprecation_warning
|
||||||
from ..strings import get_string_id
|
from ..strings import get_string_id
|
||||||
from ..attrs import IDS
|
from ..attrs import IDS
|
||||||
|
@ -32,7 +31,7 @@ cdef class Matcher:
|
||||||
USAGE: https://spacy.io/usage/rule-based-matching
|
USAGE: https://spacy.io/usage/rule-based-matching
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, vocab, validate=False):
|
def __init__(self, vocab, validate=True):
|
||||||
"""Create the Matcher.
|
"""Create the Matcher.
|
||||||
|
|
||||||
vocab (Vocab): The vocabulary object, which must be shared with the
|
vocab (Vocab): The vocabulary object, which must be shared with the
|
||||||
|
@ -46,10 +45,7 @@ cdef class Matcher:
|
||||||
self._seen_attrs = set()
|
self._seen_attrs = set()
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
if validate:
|
self.validate = validate
|
||||||
self.validator = get_json_validator(TOKEN_PATTERN_SCHEMA)
|
|
||||||
else:
|
|
||||||
self.validator = None
|
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
data = (self.vocab, self._patterns, self._callbacks)
|
data = (self.vocab, self._patterns, self._callbacks)
|
||||||
|
@ -119,8 +115,8 @@ cdef class Matcher:
|
||||||
raise ValueError(Errors.E012.format(key=key))
|
raise ValueError(Errors.E012.format(key=key))
|
||||||
if not isinstance(pattern, list):
|
if not isinstance(pattern, list):
|
||||||
raise ValueError(Errors.E178.format(pat=pattern, key=key))
|
raise ValueError(Errors.E178.format(pat=pattern, key=key))
|
||||||
if self.validator:
|
if self.validate:
|
||||||
errors[i] = validate_json(pattern, self.validator)
|
errors[i] = validate_token_pattern(pattern)
|
||||||
if any(err for err in errors.values()):
|
if any(err for err in errors.values()):
|
||||||
raise MatchPatternError(key, errors)
|
raise MatchPatternError(key, errors)
|
||||||
key = self._normalize_key(key)
|
key = self._normalize_key(key)
|
||||||
|
@ -668,8 +664,6 @@ def _get_attr_values(spec, string_store):
|
||||||
continue
|
continue
|
||||||
if attr == "TEXT":
|
if attr == "TEXT":
|
||||||
attr = "ORTH"
|
attr = "ORTH"
|
||||||
if attr not in TOKEN_PATTERN_SCHEMA["items"]["properties"]:
|
|
||||||
raise ValueError(Errors.E152.format(attr=attr))
|
|
||||||
attr = IDS.get(attr)
|
attr = IDS.get(attr)
|
||||||
if isinstance(value, basestring):
|
if isinstance(value, basestring):
|
||||||
value = string_store.add(value)
|
value = string_store.add(value)
|
||||||
|
@ -684,7 +678,7 @@ def _get_attr_values(spec, string_store):
|
||||||
if attr is not None:
|
if attr is not None:
|
||||||
attr_values.append((attr, value))
|
attr_values.append((attr, value))
|
||||||
else:
|
else:
|
||||||
# should be caught above using TOKEN_PATTERN_SCHEMA
|
# should be caught in validation
|
||||||
raise ValueError(Errors.E152.format(attr=attr))
|
raise ValueError(Errors.E152.format(attr=attr))
|
||||||
return attr_values
|
return attr_values
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,7 @@ from ..structs cimport TokenC
|
||||||
from ..tokens.token cimport Token
|
from ..tokens.token cimport Token
|
||||||
from ..typedefs cimport attr_t
|
from ..typedefs cimport attr_t
|
||||||
|
|
||||||
from ._schemas import TOKEN_PATTERN_SCHEMA
|
from ..schemas import TokenPattern
|
||||||
from ..errors import Errors, Warnings, deprecation_warning, user_warning
|
from ..errors import Errors, Warnings, deprecation_warning, user_warning
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ cdef class PhraseMatcher:
|
||||||
attr = attr.upper()
|
attr = attr.upper()
|
||||||
if attr == "TEXT":
|
if attr == "TEXT":
|
||||||
attr = "ORTH"
|
attr = "ORTH"
|
||||||
if attr not in TOKEN_PATTERN_SCHEMA["items"]["properties"]:
|
if attr.lower() not in TokenPattern().dict():
|
||||||
raise ValueError(Errors.E152.format(attr=attr))
|
raise ValueError(Errors.E152.format(attr=attr))
|
||||||
self.attr = self.vocab.strings[attr]
|
self.attr = self.vocab.strings[attr]
|
||||||
|
|
||||||
|
|
188
spacy/schemas.py
Normal file
188
spacy/schemas.py
Normal file
|
@ -0,0 +1,188 @@
|
||||||
|
from typing import Dict, List, Union, Optional
|
||||||
|
from enum import Enum
|
||||||
|
from pydantic import BaseModel, Field, ValidationError, validator
|
||||||
|
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from .attrs import NAMES
|
||||||
|
|
||||||
|
|
||||||
|
def validate(schema, obj):
|
||||||
|
"""Validate data against a given pydantic schema.
|
||||||
|
|
||||||
|
obj (dict): JSON-serializable data to validate.
|
||||||
|
schema (pydantic.BaseModel): The schema to validate against.
|
||||||
|
RETURNS (list): A list of error messages, if available.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
schema(**obj)
|
||||||
|
return []
|
||||||
|
except ValidationError as e:
|
||||||
|
errors = e.errors()
|
||||||
|
data = defaultdict(list)
|
||||||
|
for error in errors:
|
||||||
|
err_loc = " -> ".join([str(p) for p in error.get("loc", [])])
|
||||||
|
data[err_loc].append(error.get("msg"))
|
||||||
|
return [f"[{loc}] {', '.join(msg)}" for loc, msg in data.items()]
|
||||||
|
|
||||||
|
|
||||||
|
# Matcher token patterns
|
||||||
|
|
||||||
|
|
||||||
|
def validate_token_pattern(obj):
|
||||||
|
# Try to convert non-string keys (e.g. {ORTH: "foo"} -> {"ORTH": "foo"})
|
||||||
|
get_key = lambda k: NAMES[k] if isinstance(k, int) and k < len(NAMES) else k
|
||||||
|
if isinstance(obj, list):
|
||||||
|
converted = []
|
||||||
|
for pattern in obj:
|
||||||
|
if isinstance(pattern, dict):
|
||||||
|
pattern = {get_key(k): v for k, v in pattern.items()}
|
||||||
|
converted.append(pattern)
|
||||||
|
obj = converted
|
||||||
|
return validate(TokenPatternSchema, {"pattern": obj})
|
||||||
|
|
||||||
|
|
||||||
|
class TokenPatternString(BaseModel):
|
||||||
|
REGEX: Optional[StrictStr]
|
||||||
|
IN: Optional[List[StrictStr]]
|
||||||
|
NOT_IN: Optional[List[StrictStr]]
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "forbid"
|
||||||
|
|
||||||
|
@validator("*", pre=True, whole=True)
|
||||||
|
def raise_for_none(cls, v):
|
||||||
|
if v is None:
|
||||||
|
raise ValueError("None / null is not allowed")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class TokenPatternNumber(BaseModel):
|
||||||
|
REGEX: Optional[StrictStr] = None
|
||||||
|
IN: Optional[List[StrictInt]] = None
|
||||||
|
NOT_IN: Optional[List[StrictInt]] = None
|
||||||
|
EQ: Union[StrictInt, StrictFloat] = Field(None, alias="==")
|
||||||
|
GEQ: Union[StrictInt, StrictFloat] = Field(None, alias=">=")
|
||||||
|
LEQ: Union[StrictInt, StrictFloat] = Field(None, alias="<=")
|
||||||
|
GT: Union[StrictInt, StrictFloat] = Field(None, alias=">")
|
||||||
|
LT: Union[StrictInt, StrictFloat] = Field(None, alias="<")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "forbid"
|
||||||
|
|
||||||
|
@validator("*", pre=True, whole=True)
|
||||||
|
def raise_for_none(cls, v):
|
||||||
|
if v is None:
|
||||||
|
raise ValueError("None / null is not allowed")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class TokenPatternOperator(str, Enum):
|
||||||
|
plus: StrictStr = "+"
|
||||||
|
start: StrictStr = "*"
|
||||||
|
question: StrictStr = "?"
|
||||||
|
exclamation: StrictStr = "!"
|
||||||
|
|
||||||
|
|
||||||
|
StringValue = Union[TokenPatternString, StrictStr]
|
||||||
|
NumberValue = Union[TokenPatternNumber, StrictInt, StrictFloat]
|
||||||
|
UnderscoreValue = Union[
|
||||||
|
TokenPatternString, TokenPatternNumber, str, int, float, list, bool,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TokenPattern(BaseModel):
|
||||||
|
orth: Optional[StringValue] = None
|
||||||
|
text: Optional[StringValue] = None
|
||||||
|
lower: Optional[StringValue] = None
|
||||||
|
pos: Optional[StringValue] = None
|
||||||
|
tag: Optional[StringValue] = None
|
||||||
|
dep: Optional[StringValue] = None
|
||||||
|
lemma: Optional[StringValue] = None
|
||||||
|
shape: Optional[StringValue] = None
|
||||||
|
ent_type: Optional[StringValue] = None
|
||||||
|
norm: Optional[StringValue] = None
|
||||||
|
length: Optional[NumberValue] = None
|
||||||
|
is_alpha: Optional[StrictBool] = None
|
||||||
|
is_ascii: Optional[StrictBool] = None
|
||||||
|
is_digit: Optional[StrictBool] = None
|
||||||
|
is_lower: Optional[StrictBool] = None
|
||||||
|
is_upper: Optional[StrictBool] = None
|
||||||
|
is_title: Optional[StrictBool] = None
|
||||||
|
is_punct: Optional[StrictBool] = None
|
||||||
|
is_space: Optional[StrictBool] = None
|
||||||
|
is_bracket: Optional[StrictBool] = None
|
||||||
|
is_quote: Optional[StrictBool] = None
|
||||||
|
is_left_punct: Optional[StrictBool] = None
|
||||||
|
is_right_punct: Optional[StrictBool] = None
|
||||||
|
is_currency: Optional[StrictBool] = None
|
||||||
|
is_stop: Optional[StrictBool] = None
|
||||||
|
is_sent_start: Optional[StrictBool] = None
|
||||||
|
like_num: Optional[StrictBool] = None
|
||||||
|
like_url: Optional[StrictBool] = None
|
||||||
|
like_email: Optional[StrictBool] = None
|
||||||
|
op: Optional[TokenPatternOperator] = None
|
||||||
|
underscore: Optional[Dict[StrictStr, UnderscoreValue]] = Field(None, alias="_")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "forbid"
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
alias_generator = lambda value: value.upper()
|
||||||
|
|
||||||
|
@validator("*", pre=True)
|
||||||
|
def raise_for_none(cls, v):
|
||||||
|
if v is None:
|
||||||
|
raise ValueError("None / null is not allowed")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class TokenPatternSchema(BaseModel):
|
||||||
|
pattern: List[TokenPattern] = Field(..., minItems=1)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "forbid"
|
||||||
|
|
||||||
|
|
||||||
|
# Model meta
|
||||||
|
|
||||||
|
|
||||||
|
class ModelMetaSchema(BaseModel):
|
||||||
|
# fmt: off
|
||||||
|
lang: StrictStr = Field(..., title="Two-letter language code, e.g. 'en'")
|
||||||
|
name: StrictStr = Field(..., title="Model name")
|
||||||
|
version: StrictStr = Field(..., title="Model version")
|
||||||
|
spacy_version: Optional[StrictStr] = Field(None, title="Compatible spaCy version identifier")
|
||||||
|
parent_package: Optional[StrictStr] = Field("spacy", title="Name of parent spaCy package, e.g. spacy or spacy-nightly")
|
||||||
|
pipeline: Optional[List[StrictStr]] = Field([], title="Names of pipeline components")
|
||||||
|
description: Optional[StrictStr] = Field(None, title="Model description")
|
||||||
|
license: Optional[StrictStr] = Field(None, title="Model license")
|
||||||
|
author: Optional[StrictStr] = Field(None, title="Model author name")
|
||||||
|
email: Optional[StrictStr] = Field(None, title="Model author email")
|
||||||
|
url: Optional[StrictStr] = Field(None, title="Model author URL")
|
||||||
|
sources: Optional[Union[List[StrictStr], Dict[str, str]]] = Field(None, title="Training data sources")
|
||||||
|
vectors: Optional[Dict[str, int]] = Field(None, title="Included word vectors")
|
||||||
|
accuracy: Optional[Dict[str, Union[float, int]]] = Field(None, title="Accuracy numbers")
|
||||||
|
speed: Optional[Dict[str, Union[float, int]]] = Field(None, title="Speed evaluation numbers")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
# Training data object in "simple training style"
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleTrainingSchema(BaseModel):
|
||||||
|
# TODO: write
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
title = "Schema for training data dict in passed to nlp.update"
|
||||||
|
extra = "forbid"
|
||||||
|
|
||||||
|
|
||||||
|
# JSON training format
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingSchema(BaseModel):
|
||||||
|
# TODO: write
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
title = "Schema for training data in spaCy's JSON format"
|
||||||
|
extra = "forbid"
|
|
@ -1,6 +1,4 @@
|
||||||
import pytest
|
import pytest
|
||||||
from spacy.cli._schemas import TRAINING_SCHEMA
|
|
||||||
from spacy.util import get_json_validator, validate_json
|
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
from ..util import get_doc
|
from ..util import get_doc
|
||||||
|
|
||||||
|
@ -55,10 +53,3 @@ def test_doc_to_json_underscore_error_serialize(doc):
|
||||||
Doc.set_extension("json_test4", method=lambda doc: doc.text)
|
Doc.set_extension("json_test4", method=lambda doc: doc.text)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
doc.to_json(underscore=["json_test4"])
|
doc.to_json(underscore=["json_test4"])
|
||||||
|
|
||||||
|
|
||||||
def test_doc_to_json_valid_training(doc):
|
|
||||||
json_doc = doc.to_json()
|
|
||||||
validator = get_json_validator(TRAINING_SCHEMA)
|
|
||||||
errors = validate_json([json_doc], validator)
|
|
||||||
assert not errors
|
|
||||||
|
|
|
@ -179,7 +179,7 @@ def test_matcher_match_one_plus(matcher):
|
||||||
doc = Doc(control.vocab, words=["Philippe", "Philippe"])
|
doc = Doc(control.vocab, words=["Philippe", "Philippe"])
|
||||||
m = control(doc)
|
m = control(doc)
|
||||||
assert len(m) == 2
|
assert len(m) == 2
|
||||||
pattern = [{"ORTH": "Philippe", "OP": "1"}, {"ORTH": "Philippe", "OP": "+"}]
|
pattern = [{"ORTH": "Philippe"}, {"ORTH": "Philippe", "OP": "+"}]
|
||||||
matcher.add("KleenePhilippe", [pattern])
|
matcher.add("KleenePhilippe", [pattern])
|
||||||
m = matcher(doc)
|
m = matcher(doc)
|
||||||
assert len(m) == 1
|
assert len(m) == 1
|
||||||
|
|
|
@ -6,18 +6,18 @@ from spacy.matcher import Matcher
|
||||||
from spacy.tokens import Doc, Span
|
from spacy.tokens import Doc, Span
|
||||||
|
|
||||||
|
|
||||||
pattern1 = [{"ORTH": "A", "OP": "1"}, {"ORTH": "A", "OP": "*"}]
|
pattern1 = [{"ORTH": "A"}, {"ORTH": "A", "OP": "*"}]
|
||||||
pattern2 = [{"ORTH": "A", "OP": "*"}, {"ORTH": "A", "OP": "1"}]
|
pattern2 = [{"ORTH": "A"}, {"ORTH": "A"}]
|
||||||
pattern3 = [{"ORTH": "A", "OP": "1"}, {"ORTH": "A", "OP": "1"}]
|
pattern3 = [{"ORTH": "A"}, {"ORTH": "A"}]
|
||||||
pattern4 = [
|
pattern4 = [
|
||||||
{"ORTH": "B", "OP": "1"},
|
{"ORTH": "B"},
|
||||||
{"ORTH": "A", "OP": "*"},
|
{"ORTH": "A", "OP": "*"},
|
||||||
{"ORTH": "B", "OP": "1"},
|
{"ORTH": "B"},
|
||||||
]
|
]
|
||||||
pattern5 = [
|
pattern5 = [
|
||||||
{"ORTH": "B", "OP": "*"},
|
{"ORTH": "B", "OP": "*"},
|
||||||
{"ORTH": "A", "OP": "*"},
|
{"ORTH": "A", "OP": "*"},
|
||||||
{"ORTH": "B", "OP": "1"},
|
{"ORTH": "B"},
|
||||||
]
|
]
|
||||||
|
|
||||||
re_pattern1 = "AA*"
|
re_pattern1 = "AA*"
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
from spacy.matcher import Matcher
|
from spacy.matcher import Matcher
|
||||||
from spacy.matcher._schemas import TOKEN_PATTERN_SCHEMA
|
|
||||||
from spacy.errors import MatchPatternError
|
from spacy.errors import MatchPatternError
|
||||||
from spacy.util import get_json_validator, validate_json
|
from spacy.schemas import validate_token_pattern
|
||||||
|
|
||||||
# (pattern, num errors with validation, num errors identified with minimal
|
# (pattern, num errors with validation, num errors identified with minimal
|
||||||
# checks)
|
# checks)
|
||||||
|
@ -15,12 +14,12 @@ TEST_PATTERNS = [
|
||||||
('[{"TEXT": "foo"}, {"LOWER": "bar"}]', 1, 1),
|
('[{"TEXT": "foo"}, {"LOWER": "bar"}]', 1, 1),
|
||||||
([1, 2, 3], 3, 1),
|
([1, 2, 3], 3, 1),
|
||||||
# Bad patterns flagged outside of Matcher
|
# Bad patterns flagged outside of Matcher
|
||||||
([{"_": {"foo": "bar", "baz": {"IN": "foo"}}}], 1, 0),
|
([{"_": {"foo": "bar", "baz": {"IN": "foo"}}}], 2, 0), # prev: (1, 0)
|
||||||
# Bad patterns not flagged with minimal checks
|
# Bad patterns not flagged with minimal checks
|
||||||
([{"LENGTH": "2", "TEXT": 2}, {"LOWER": "test"}], 2, 0),
|
([{"LENGTH": "2", "TEXT": 2}, {"LOWER": "test"}], 2, 0),
|
||||||
([{"LENGTH": {"IN": [1, 2, "3"]}}, {"POS": {"IN": "VERB"}}], 2, 0),
|
([{"LENGTH": {"IN": [1, 2, "3"]}}, {"POS": {"IN": "VERB"}}], 4, 0), # prev: (2, 0)
|
||||||
([{"LENGTH": {"VALUE": 5}}], 1, 0),
|
([{"LENGTH": {"VALUE": 5}}], 2, 0), # prev: (1, 0)
|
||||||
([{"TEXT": {"VALUE": "foo"}}], 1, 0),
|
([{"TEXT": {"VALUE": "foo"}}], 2, 0), # prev: (1, 0)
|
||||||
([{"IS_DIGIT": -1}], 1, 0),
|
([{"IS_DIGIT": -1}], 1, 0),
|
||||||
([{"ORTH": -1}], 1, 0),
|
([{"ORTH": -1}], 1, 0),
|
||||||
# Good patterns
|
# Good patterns
|
||||||
|
@ -31,15 +30,9 @@ TEST_PATTERNS = [
|
||||||
([{"LOWER": {"REGEX": "^X", "NOT_IN": ["XXX", "XY"]}}], 0, 0),
|
([{"LOWER": {"REGEX": "^X", "NOT_IN": ["XXX", "XY"]}}], 0, 0),
|
||||||
([{"NORM": "a"}, {"POS": {"IN": ["NOUN"]}}], 0, 0),
|
([{"NORM": "a"}, {"POS": {"IN": ["NOUN"]}}], 0, 0),
|
||||||
([{"_": {"foo": {"NOT_IN": ["bar", "baz"]}, "a": 5, "b": {">": 10}}}], 0, 0),
|
([{"_": {"foo": {"NOT_IN": ["bar", "baz"]}, "a": 5, "b": {">": 10}}}], 0, 0),
|
||||||
|
([{"orth": "foo"}], 0, 0), # prev: xfail
|
||||||
]
|
]
|
||||||
|
|
||||||
XFAIL_TEST_PATTERNS = [([{"orth": "foo"}], 0, 0)]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def validator():
|
|
||||||
return get_json_validator(TOKEN_PATTERN_SCHEMA)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"pattern", [[{"XX": "y"}, {"LENGTH": "2"}, {"TEXT": {"IN": 5}}]]
|
"pattern", [[{"XX": "y"}, {"LENGTH": "2"}, {"TEXT": {"IN": 5}}]]
|
||||||
|
@ -51,15 +44,8 @@ def test_matcher_pattern_validation(en_vocab, pattern):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("pattern,n_errors,_", TEST_PATTERNS)
|
@pytest.mark.parametrize("pattern,n_errors,_", TEST_PATTERNS)
|
||||||
def test_pattern_validation(validator, pattern, n_errors, _):
|
def test_pattern_validation(pattern, n_errors, _):
|
||||||
errors = validate_json(pattern, validator)
|
errors = validate_token_pattern(pattern)
|
||||||
assert len(errors) == n_errors
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail
|
|
||||||
@pytest.mark.parametrize("pattern,n_errors,_", XFAIL_TEST_PATTERNS)
|
|
||||||
def test_xfail_pattern_validation(validator, pattern, n_errors, _):
|
|
||||||
errors = validate_json(pattern, validator)
|
|
||||||
assert len(errors) == n_errors
|
assert len(errors) == n_errors
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,47 +0,0 @@
|
||||||
from spacy.util import get_json_validator, validate_json, validate_schema
|
|
||||||
from spacy.cli._schemas import META_SCHEMA, TRAINING_SCHEMA
|
|
||||||
from spacy.matcher._schemas import TOKEN_PATTERN_SCHEMA
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def training_schema_validator():
|
|
||||||
return get_json_validator(TRAINING_SCHEMA)
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_schema():
|
|
||||||
validate_schema({"type": "object"})
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
validate_schema({"type": lambda x: x})
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("schema", [TRAINING_SCHEMA, META_SCHEMA, TOKEN_PATTERN_SCHEMA])
|
|
||||||
def test_schemas(schema):
|
|
||||||
validate_schema(schema)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"data",
|
|
||||||
[
|
|
||||||
{"text": "Hello world"},
|
|
||||||
{"text": "Hello", "ents": [{"start": 0, "end": 5, "label": "TEST"}]},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_json_schema_training_valid(data, training_schema_validator):
|
|
||||||
errors = validate_json([data], training_schema_validator)
|
|
||||||
assert not errors
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"data,n_errors",
|
|
||||||
[
|
|
||||||
({"spans": []}, 1),
|
|
||||||
({"text": "Hello", "ents": [{"start": "0", "end": "5", "label": "TEST"}]}, 2),
|
|
||||||
({"text": "Hello", "ents": [{"start": 0, "end": 5}]}, 1),
|
|
||||||
({"text": "Hello", "ents": [{"start": 0, "end": 5, "label": "test"}]}, 1),
|
|
||||||
({"text": "spaCy", "tokens": [{"pos": "PROPN"}]}, 2),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_json_schema_training_invalid(data, n_errors, training_schema_validator):
|
|
||||||
errors = validate_json([data], training_schema_validator)
|
|
||||||
assert len(errors) == n_errors
|
|
|
@ -13,11 +13,6 @@ import srsly
|
||||||
import catalogue
|
import catalogue
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
try:
|
|
||||||
import jsonschema
|
|
||||||
except ImportError:
|
|
||||||
jsonschema = None
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import cupy.random
|
import cupy.random
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -705,43 +700,6 @@ def fix_random_seed(seed=0):
|
||||||
cupy.random.seed(seed)
|
cupy.random.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
def get_json_validator(schema):
|
|
||||||
# We're using a helper function here to make it easier to change the
|
|
||||||
# validator that's used (e.g. different draft implementation), without
|
|
||||||
# having to change it all across the codebase.
|
|
||||||
# TODO: replace with (stable) Draft6Validator, if available
|
|
||||||
if jsonschema is None:
|
|
||||||
raise ValueError(Errors.E136)
|
|
||||||
return jsonschema.Draft4Validator(schema)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_schema(schema):
|
|
||||||
"""Validate a given schema. This just checks if the schema itself is valid."""
|
|
||||||
validator = get_json_validator(schema)
|
|
||||||
validator.check_schema(schema)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_json(data, validator):
|
|
||||||
"""Validate data against a given JSON schema (see https://json-schema.org).
|
|
||||||
|
|
||||||
data: JSON-serializable data to validate.
|
|
||||||
validator (jsonschema.DraftXValidator): The validator.
|
|
||||||
RETURNS (list): A list of error messages, if available.
|
|
||||||
"""
|
|
||||||
errors = []
|
|
||||||
for err in sorted(validator.iter_errors(data), key=lambda e: e.path):
|
|
||||||
if err.path:
|
|
||||||
err_path = "[{}]".format(" -> ".join([str(p) for p in err.path]))
|
|
||||||
else:
|
|
||||||
err_path = ""
|
|
||||||
msg = err.message + " " + err_path
|
|
||||||
if err.context: # Error has suberrors, e.g. if schema uses anyOf
|
|
||||||
suberrs = [f" - {suberr.message}" for suberr in err.context]
|
|
||||||
msg += f":\n{''.join(suberrs)}"
|
|
||||||
errors.append(msg)
|
|
||||||
return errors
|
|
||||||
|
|
||||||
|
|
||||||
def get_serialization_exclude(serializers, exclude, kwargs):
|
def get_serialization_exclude(serializers, exclude, kwargs):
|
||||||
"""Helper function to validate serialization args and manage transition from
|
"""Helper function to validate serialization args and manage transition from
|
||||||
keyword arguments (pre v2.1) to exclude argument.
|
keyword arguments (pre v2.1) to exclude argument.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user