mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Add better schemas and validation using Pydantic (#4831)
* Remove unicode declarations * Remove Python 3.5 and 2.7 from CI * Don't require pathlib * Replace compat helpers * Remove OrderedDict * Use f-strings * Set Cython compiler language level * Fix typo * Re-add OrderedDict for Table * Update setup.cfg * Revert CONTRIBUTING.md * Add better schemas and validation using Pydantic * Revert lookups.md * Remove unused import * Update spacy/schemas.py Co-Authored-By: Sebastián Ramírez <tiangolo@gmail.com> * Various small fixes * Fix docstring Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
This commit is contained in:
parent
db55577c45
commit
33a2682d60
|
@ -12,8 +12,7 @@ numpy>=1.15.0
|
|||
requests>=2.13.0,<3.0.0
|
||||
plac>=0.9.6,<1.2.0
|
||||
tqdm>=4.38.0,<5.0.0
|
||||
# Optional dependencies
|
||||
jsonschema>=2.6.0,<3.1.0
|
||||
pydantic>=1.0.0,<2.0.0
|
||||
# Development dependencies
|
||||
cython>=0.25
|
||||
pytest>=4.6.5
|
||||
|
|
|
@ -51,6 +51,7 @@ install_requires =
|
|||
numpy>=1.15.0
|
||||
plac>=0.9.6,<1.2.0
|
||||
requests>=2.13.0,<3.0.0
|
||||
pydantic>=1.0.0,<2.0.0
|
||||
|
||||
[options.extras_require]
|
||||
lookups =
|
||||
|
|
|
@ -1,217 +0,0 @@
|
|||
|
||||
# NB: This schema describes the new format of the training data, see #2928
|
||||
TRAINING_SCHEMA = {
|
||||
"$schema": "http://json-schema.org/draft-06/schema",
|
||||
"title": "Training data for spaCy models",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"title": "The text of the training example",
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
},
|
||||
"ents": {
|
||||
"title": "Named entity spans in the text",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"start": {
|
||||
"title": "Start character offset of the span",
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
},
|
||||
"end": {
|
||||
"title": "End character offset of the span",
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
},
|
||||
"label": {
|
||||
"title": "Entity label",
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"pattern": "^[A-Z0-9]*$",
|
||||
},
|
||||
},
|
||||
"required": ["start", "end", "label"],
|
||||
},
|
||||
},
|
||||
"sents": {
|
||||
"title": "Sentence spans in the text",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"start": {
|
||||
"title": "Start character offset of the span",
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
},
|
||||
"end": {
|
||||
"title": "End character offset of the span",
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
},
|
||||
},
|
||||
"required": ["start", "end"],
|
||||
},
|
||||
},
|
||||
"cats": {
|
||||
"title": "Text categories for the text classifier",
|
||||
"type": "object",
|
||||
"patternProperties": {
|
||||
"*": {
|
||||
"title": "A text category",
|
||||
"oneOf": [
|
||||
{"type": "boolean"},
|
||||
{"type": "number", "minimum": 0},
|
||||
],
|
||||
}
|
||||
},
|
||||
"propertyNames": {"pattern": "^[A-Z0-9]*$", "minLength": 1},
|
||||
},
|
||||
"tokens": {
|
||||
"title": "The tokens in the text",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"minProperties": 1,
|
||||
"properties": {
|
||||
"id": {
|
||||
"title": "Token ID, usually token index",
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
},
|
||||
"start": {
|
||||
"title": "Start character offset of the token",
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
},
|
||||
"end": {
|
||||
"title": "End character offset of the token",
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
},
|
||||
"pos": {
|
||||
"title": "Coarse-grained part-of-speech tag",
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
},
|
||||
"tag": {
|
||||
"title": "Fine-grained part-of-speech tag",
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
},
|
||||
"dep": {
|
||||
"title": "Dependency label",
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
},
|
||||
"head": {
|
||||
"title": "Index of the token's head",
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
},
|
||||
},
|
||||
"required": ["start", "end"],
|
||||
},
|
||||
},
|
||||
"_": {"title": "Custom user space", "type": "object"},
|
||||
},
|
||||
"required": ["text"],
|
||||
},
|
||||
}
|
||||
|
||||
META_SCHEMA = {
|
||||
"$schema": "http://json-schema.org/draft-06/schema",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"lang": {
|
||||
"title": "Two-letter language code, e.g. 'en'",
|
||||
"type": "string",
|
||||
"minLength": 2,
|
||||
"maxLength": 2,
|
||||
"pattern": "^[a-z]*$",
|
||||
},
|
||||
"name": {
|
||||
"title": "Model name",
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"pattern": "^[a-z_]*$",
|
||||
},
|
||||
"version": {
|
||||
"title": "Model version",
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"pattern": "^[0-9a-z.-]*$",
|
||||
},
|
||||
"spacy_version": {
|
||||
"title": "Compatible spaCy version identifier",
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"pattern": "^[0-9a-z.-><=]*$",
|
||||
},
|
||||
"parent_package": {
|
||||
"title": "Name of parent spaCy package, e.g. spacy or spacy-nightly",
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"default": "spacy",
|
||||
},
|
||||
"pipeline": {
|
||||
"title": "Names of pipeline components",
|
||||
"type": "array",
|
||||
"items": {"type": "string", "minLength": 1},
|
||||
},
|
||||
"description": {"title": "Model description", "type": "string"},
|
||||
"license": {"title": "Model license", "type": "string"},
|
||||
"author": {"title": "Model author name", "type": "string"},
|
||||
"email": {"title": "Model author email", "type": "string", "format": "email"},
|
||||
"url": {"title": "Model author URL", "type": "string", "format": "uri"},
|
||||
"sources": {
|
||||
"title": "Training data sources",
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
"vectors": {
|
||||
"title": "Included word vectors",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"keys": {
|
||||
"title": "Number of unique keys",
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
},
|
||||
"vectors": {
|
||||
"title": "Number of unique vectors",
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
},
|
||||
"width": {
|
||||
"title": "Number of dimensions",
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
"accuracy": {
|
||||
"title": "Accuracy numbers",
|
||||
"type": "object",
|
||||
"patternProperties": {"*": {"type": "number", "minimum": 0.0}},
|
||||
},
|
||||
"speed": {
|
||||
"title": "Speed evaluation numbers",
|
||||
"type": "object",
|
||||
"patternProperties": {
|
||||
"*": {
|
||||
"oneOf": [
|
||||
{"type": "number", "minimum": 0.0},
|
||||
{"type": "integer", "minimum": 0},
|
||||
]
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["lang", "name", "version"],
|
||||
}
|
|
@ -105,7 +105,6 @@ class Warnings(object):
|
|||
"smaller JSON files instead.")
|
||||
|
||||
|
||||
|
||||
@add_codes
|
||||
class Errors(object):
|
||||
E001 = ("No component '{name}' found in pipeline. Available names: {opts}")
|
||||
|
@ -419,8 +418,6 @@ class Errors(object):
|
|||
E134 = ("Entity '{entity}' is not defined in the Knowledge Base.")
|
||||
E135 = ("If you meant to replace a built-in component, use `create_pipe`: "
|
||||
"`nlp.replace_pipe('{name}', nlp.create_pipe('{name}'))`")
|
||||
E136 = ("This additional feature requires the jsonschema library to be "
|
||||
"installed:\npip install jsonschema")
|
||||
E137 = ("Expected 'dict' type, but got '{type}' from '{line}'. Make sure "
|
||||
"to provide a valid JSON object as input with either the `text` "
|
||||
"or `tokens` key. For more info, see the docs:\n"
|
||||
|
|
|
@ -1,197 +0,0 @@
|
|||
|
||||
TOKEN_PATTERN_SCHEMA = {
|
||||
"$schema": "http://json-schema.org/draft-06/schema",
|
||||
"definitions": {
|
||||
"string_value": {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"REGEX": {"type": "string"},
|
||||
"IN": {"type": "array", "items": {"type": "string"}},
|
||||
"NOT_IN": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
"additionalProperties": False,
|
||||
},
|
||||
]
|
||||
},
|
||||
"integer_value": {
|
||||
"anyOf": [
|
||||
{"type": "integer"},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"REGEX": {"type": "string"},
|
||||
"IN": {"type": "array", "items": {"type": "integer"}},
|
||||
"NOT_IN": {"type": "array", "items": {"type": "integer"}},
|
||||
"==": {"type": "integer"},
|
||||
">=": {"type": "integer"},
|
||||
"<=": {"type": "integer"},
|
||||
">": {"type": "integer"},
|
||||
"<": {"type": "integer"},
|
||||
},
|
||||
"additionalProperties": False,
|
||||
},
|
||||
]
|
||||
},
|
||||
"boolean_value": {"type": "boolean"},
|
||||
"underscore_value": {
|
||||
"anyOf": [
|
||||
{"type": ["string", "integer", "number", "array", "boolean", "null"]},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"REGEX": {"type": "string"},
|
||||
"IN": {
|
||||
"type": "array",
|
||||
"items": {"type": ["string", "integer"]},
|
||||
},
|
||||
"NOT_IN": {
|
||||
"type": "array",
|
||||
"items": {"type": ["string", "integer"]},
|
||||
},
|
||||
"==": {"type": "integer"},
|
||||
">=": {"type": "integer"},
|
||||
"<=": {"type": "integer"},
|
||||
">": {"type": "integer"},
|
||||
"<": {"type": "integer"},
|
||||
},
|
||||
"additionalProperties": False,
|
||||
},
|
||||
]
|
||||
},
|
||||
},
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"ORTH": {
|
||||
"title": "Verbatim token text",
|
||||
"$ref": "#/definitions/string_value",
|
||||
},
|
||||
"TEXT": {
|
||||
"title": "Verbatim token text (spaCy v2.1+)",
|
||||
"$ref": "#/definitions/string_value",
|
||||
},
|
||||
"LOWER": {
|
||||
"title": "Lowercase form of token text",
|
||||
"$ref": "#/definitions/string_value",
|
||||
},
|
||||
"POS": {
|
||||
"title": "Coarse-grained part-of-speech tag",
|
||||
"$ref": "#/definitions/string_value",
|
||||
},
|
||||
"TAG": {
|
||||
"title": "Fine-grained part-of-speech tag",
|
||||
"$ref": "#/definitions/string_value",
|
||||
},
|
||||
"DEP": {"title": "Dependency label", "$ref": "#/definitions/string_value"},
|
||||
"LEMMA": {
|
||||
"title": "Lemma (base form)",
|
||||
"$ref": "#/definitions/string_value",
|
||||
},
|
||||
"SHAPE": {
|
||||
"title": "Abstract token shape",
|
||||
"$ref": "#/definitions/string_value",
|
||||
},
|
||||
"ENT_TYPE": {
|
||||
"title": "Entity label of single token",
|
||||
"$ref": "#/definitions/string_value",
|
||||
},
|
||||
"NORM": {
|
||||
"title": "Normalized form of the token text",
|
||||
"$ref": "#/definitions/string_value",
|
||||
},
|
||||
"LENGTH": {
|
||||
"title": "Token character length",
|
||||
"$ref": "#/definitions/integer_value",
|
||||
},
|
||||
"IS_ALPHA": {
|
||||
"title": "Token consists of alphabetic characters",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"IS_ASCII": {
|
||||
"title": "Token consists of ASCII characters",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"IS_DIGIT": {
|
||||
"title": "Token consists of digits",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"IS_LOWER": {
|
||||
"title": "Token is lowercase",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"IS_UPPER": {
|
||||
"title": "Token is uppercase",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"IS_TITLE": {
|
||||
"title": "Token is titlecase",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"IS_PUNCT": {
|
||||
"title": "Token is punctuation",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"IS_SPACE": {
|
||||
"title": "Token is whitespace",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"IS_BRACKET": {
|
||||
"title": "Token is a bracket",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"IS_QUOTE": {
|
||||
"title": "Token is a quotation mark",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"IS_LEFT_PUNCT": {
|
||||
"title": "Token is a left punctuation mark",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"IS_RIGHT_PUNCT": {
|
||||
"title": "Token is a right punctuation mark",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"IS_CURRENCY": {
|
||||
"title": "Token is a currency symbol",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"IS_STOP": {
|
||||
"title": "Token is stop word",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"IS_SENT_START": {
|
||||
"title": "Token is the first in a sentence",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"LIKE_NUM": {
|
||||
"title": "Token resembles a number",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"LIKE_URL": {
|
||||
"title": "Token resembles a URL",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"LIKE_EMAIL": {
|
||||
"title": "Token resembles an email address",
|
||||
"$ref": "#/definitions/boolean_value",
|
||||
},
|
||||
"_": {
|
||||
"title": "Custom extension token attributes (token._.)",
|
||||
"type": "object",
|
||||
"patternProperties": {
|
||||
"^.*$": {"$ref": "#/definitions/underscore_value"}
|
||||
},
|
||||
},
|
||||
"OP": {
|
||||
"title": "Operators / quantifiers",
|
||||
"type": "string",
|
||||
"enum": ["+", "*", "?", "!"],
|
||||
},
|
||||
},
|
||||
"additionalProperties": False,
|
||||
},
|
||||
}
|
|
@ -39,7 +39,8 @@ cdef class DependencyMatcher:
|
|||
RETURNS (DependencyMatcher): The newly constructed object.
|
||||
"""
|
||||
size = 20
|
||||
self.token_matcher = Matcher(vocab)
|
||||
# TODO: make matcher work with validation
|
||||
self.token_matcher = Matcher(vocab, validate=False)
|
||||
self._keys_to_token = {}
|
||||
self._patterns = {}
|
||||
self._root = {}
|
||||
|
@ -129,7 +130,7 @@ cdef class DependencyMatcher:
|
|||
# TODO: Better ways to hash edges in pattern?
|
||||
for j in range(len(_patterns[i])):
|
||||
k = self._normalize_key(unicode(key) + DELIMITER + unicode(i) + DELIMITER + unicode(j))
|
||||
self.token_matcher.add(k, None, _patterns[i][j])
|
||||
self.token_matcher.add(k, [_patterns[i][j]])
|
||||
_keys_to_token[k] = j
|
||||
_keys_to_token_list.append(_keys_to_token)
|
||||
self._keys_to_token.setdefault(key, [])
|
||||
|
|
|
@ -63,7 +63,7 @@ cdef class Matcher:
|
|||
cdef Pool mem
|
||||
cdef vector[TokenPatternC*] patterns
|
||||
cdef readonly Vocab vocab
|
||||
cdef public object validator
|
||||
cdef public object validate
|
||||
cdef public object _patterns
|
||||
cdef public object _callbacks
|
||||
cdef public object _extensions
|
||||
|
|
|
@ -15,8 +15,7 @@ from ..tokens.doc cimport Doc, get_token_attr
|
|||
from ..tokens.token cimport Token
|
||||
from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH, POS, TAG, DEP, LEMMA
|
||||
|
||||
from ._schemas import TOKEN_PATTERN_SCHEMA
|
||||
from ..util import get_json_validator, validate_json
|
||||
from ..schemas import validate_token_pattern
|
||||
from ..errors import Errors, MatchPatternError, Warnings, deprecation_warning
|
||||
from ..strings import get_string_id
|
||||
from ..attrs import IDS
|
||||
|
@ -32,7 +31,7 @@ cdef class Matcher:
|
|||
USAGE: https://spacy.io/usage/rule-based-matching
|
||||
"""
|
||||
|
||||
def __init__(self, vocab, validate=False):
|
||||
def __init__(self, vocab, validate=True):
|
||||
"""Create the Matcher.
|
||||
|
||||
vocab (Vocab): The vocabulary object, which must be shared with the
|
||||
|
@ -46,10 +45,7 @@ cdef class Matcher:
|
|||
self._seen_attrs = set()
|
||||
self.vocab = vocab
|
||||
self.mem = Pool()
|
||||
if validate:
|
||||
self.validator = get_json_validator(TOKEN_PATTERN_SCHEMA)
|
||||
else:
|
||||
self.validator = None
|
||||
self.validate = validate
|
||||
|
||||
def __reduce__(self):
|
||||
data = (self.vocab, self._patterns, self._callbacks)
|
||||
|
@ -119,8 +115,8 @@ cdef class Matcher:
|
|||
raise ValueError(Errors.E012.format(key=key))
|
||||
if not isinstance(pattern, list):
|
||||
raise ValueError(Errors.E178.format(pat=pattern, key=key))
|
||||
if self.validator:
|
||||
errors[i] = validate_json(pattern, self.validator)
|
||||
if self.validate:
|
||||
errors[i] = validate_token_pattern(pattern)
|
||||
if any(err for err in errors.values()):
|
||||
raise MatchPatternError(key, errors)
|
||||
key = self._normalize_key(key)
|
||||
|
@ -668,8 +664,6 @@ def _get_attr_values(spec, string_store):
|
|||
continue
|
||||
if attr == "TEXT":
|
||||
attr = "ORTH"
|
||||
if attr not in TOKEN_PATTERN_SCHEMA["items"]["properties"]:
|
||||
raise ValueError(Errors.E152.format(attr=attr))
|
||||
attr = IDS.get(attr)
|
||||
if isinstance(value, basestring):
|
||||
value = string_store.add(value)
|
||||
|
@ -684,7 +678,7 @@ def _get_attr_values(spec, string_store):
|
|||
if attr is not None:
|
||||
attr_values.append((attr, value))
|
||||
else:
|
||||
# should be caught above using TOKEN_PATTERN_SCHEMA
|
||||
# should be caught in validation
|
||||
raise ValueError(Errors.E152.format(attr=attr))
|
||||
return attr_values
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ from ..structs cimport TokenC
|
|||
from ..tokens.token cimport Token
|
||||
from ..typedefs cimport attr_t
|
||||
|
||||
from ._schemas import TOKEN_PATTERN_SCHEMA
|
||||
from ..schemas import TokenPattern
|
||||
from ..errors import Errors, Warnings, deprecation_warning, user_warning
|
||||
|
||||
|
||||
|
@ -54,7 +54,7 @@ cdef class PhraseMatcher:
|
|||
attr = attr.upper()
|
||||
if attr == "TEXT":
|
||||
attr = "ORTH"
|
||||
if attr not in TOKEN_PATTERN_SCHEMA["items"]["properties"]:
|
||||
if attr.lower() not in TokenPattern().dict():
|
||||
raise ValueError(Errors.E152.format(attr=attr))
|
||||
self.attr = self.vocab.strings[attr]
|
||||
|
||||
|
|
188
spacy/schemas.py
Normal file
188
spacy/schemas.py
Normal file
|
@ -0,0 +1,188 @@
|
|||
from typing import Dict, List, Union, Optional
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field, ValidationError, validator
|
||||
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
|
||||
from collections import defaultdict
|
||||
|
||||
from .attrs import NAMES
|
||||
|
||||
|
||||
def validate(schema, obj):
|
||||
"""Validate data against a given pydantic schema.
|
||||
|
||||
obj (dict): JSON-serializable data to validate.
|
||||
schema (pydantic.BaseModel): The schema to validate against.
|
||||
RETURNS (list): A list of error messages, if available.
|
||||
"""
|
||||
try:
|
||||
schema(**obj)
|
||||
return []
|
||||
except ValidationError as e:
|
||||
errors = e.errors()
|
||||
data = defaultdict(list)
|
||||
for error in errors:
|
||||
err_loc = " -> ".join([str(p) for p in error.get("loc", [])])
|
||||
data[err_loc].append(error.get("msg"))
|
||||
return [f"[{loc}] {', '.join(msg)}" for loc, msg in data.items()]
|
||||
|
||||
|
||||
# Matcher token patterns
|
||||
|
||||
|
||||
def validate_token_pattern(obj):
|
||||
# Try to convert non-string keys (e.g. {ORTH: "foo"} -> {"ORTH": "foo"})
|
||||
get_key = lambda k: NAMES[k] if isinstance(k, int) and k < len(NAMES) else k
|
||||
if isinstance(obj, list):
|
||||
converted = []
|
||||
for pattern in obj:
|
||||
if isinstance(pattern, dict):
|
||||
pattern = {get_key(k): v for k, v in pattern.items()}
|
||||
converted.append(pattern)
|
||||
obj = converted
|
||||
return validate(TokenPatternSchema, {"pattern": obj})
|
||||
|
||||
|
||||
class TokenPatternString(BaseModel):
|
||||
REGEX: Optional[StrictStr]
|
||||
IN: Optional[List[StrictStr]]
|
||||
NOT_IN: Optional[List[StrictStr]]
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
|
||||
@validator("*", pre=True, whole=True)
|
||||
def raise_for_none(cls, v):
|
||||
if v is None:
|
||||
raise ValueError("None / null is not allowed")
|
||||
return v
|
||||
|
||||
|
||||
class TokenPatternNumber(BaseModel):
|
||||
REGEX: Optional[StrictStr] = None
|
||||
IN: Optional[List[StrictInt]] = None
|
||||
NOT_IN: Optional[List[StrictInt]] = None
|
||||
EQ: Union[StrictInt, StrictFloat] = Field(None, alias="==")
|
||||
GEQ: Union[StrictInt, StrictFloat] = Field(None, alias=">=")
|
||||
LEQ: Union[StrictInt, StrictFloat] = Field(None, alias="<=")
|
||||
GT: Union[StrictInt, StrictFloat] = Field(None, alias=">")
|
||||
LT: Union[StrictInt, StrictFloat] = Field(None, alias="<")
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
|
||||
@validator("*", pre=True, whole=True)
|
||||
def raise_for_none(cls, v):
|
||||
if v is None:
|
||||
raise ValueError("None / null is not allowed")
|
||||
return v
|
||||
|
||||
|
||||
class TokenPatternOperator(str, Enum):
|
||||
plus: StrictStr = "+"
|
||||
start: StrictStr = "*"
|
||||
question: StrictStr = "?"
|
||||
exclamation: StrictStr = "!"
|
||||
|
||||
|
||||
StringValue = Union[TokenPatternString, StrictStr]
|
||||
NumberValue = Union[TokenPatternNumber, StrictInt, StrictFloat]
|
||||
UnderscoreValue = Union[
|
||||
TokenPatternString, TokenPatternNumber, str, int, float, list, bool,
|
||||
]
|
||||
|
||||
|
||||
class TokenPattern(BaseModel):
|
||||
orth: Optional[StringValue] = None
|
||||
text: Optional[StringValue] = None
|
||||
lower: Optional[StringValue] = None
|
||||
pos: Optional[StringValue] = None
|
||||
tag: Optional[StringValue] = None
|
||||
dep: Optional[StringValue] = None
|
||||
lemma: Optional[StringValue] = None
|
||||
shape: Optional[StringValue] = None
|
||||
ent_type: Optional[StringValue] = None
|
||||
norm: Optional[StringValue] = None
|
||||
length: Optional[NumberValue] = None
|
||||
is_alpha: Optional[StrictBool] = None
|
||||
is_ascii: Optional[StrictBool] = None
|
||||
is_digit: Optional[StrictBool] = None
|
||||
is_lower: Optional[StrictBool] = None
|
||||
is_upper: Optional[StrictBool] = None
|
||||
is_title: Optional[StrictBool] = None
|
||||
is_punct: Optional[StrictBool] = None
|
||||
is_space: Optional[StrictBool] = None
|
||||
is_bracket: Optional[StrictBool] = None
|
||||
is_quote: Optional[StrictBool] = None
|
||||
is_left_punct: Optional[StrictBool] = None
|
||||
is_right_punct: Optional[StrictBool] = None
|
||||
is_currency: Optional[StrictBool] = None
|
||||
is_stop: Optional[StrictBool] = None
|
||||
is_sent_start: Optional[StrictBool] = None
|
||||
like_num: Optional[StrictBool] = None
|
||||
like_url: Optional[StrictBool] = None
|
||||
like_email: Optional[StrictBool] = None
|
||||
op: Optional[TokenPatternOperator] = None
|
||||
underscore: Optional[Dict[StrictStr, UnderscoreValue]] = Field(None, alias="_")
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
allow_population_by_field_name = True
|
||||
alias_generator = lambda value: value.upper()
|
||||
|
||||
@validator("*", pre=True)
|
||||
def raise_for_none(cls, v):
|
||||
if v is None:
|
||||
raise ValueError("None / null is not allowed")
|
||||
return v
|
||||
|
||||
|
||||
class TokenPatternSchema(BaseModel):
|
||||
pattern: List[TokenPattern] = Field(..., minItems=1)
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
|
||||
|
||||
# Model meta
|
||||
|
||||
|
||||
class ModelMetaSchema(BaseModel):
|
||||
# fmt: off
|
||||
lang: StrictStr = Field(..., title="Two-letter language code, e.g. 'en'")
|
||||
name: StrictStr = Field(..., title="Model name")
|
||||
version: StrictStr = Field(..., title="Model version")
|
||||
spacy_version: Optional[StrictStr] = Field(None, title="Compatible spaCy version identifier")
|
||||
parent_package: Optional[StrictStr] = Field("spacy", title="Name of parent spaCy package, e.g. spacy or spacy-nightly")
|
||||
pipeline: Optional[List[StrictStr]] = Field([], title="Names of pipeline components")
|
||||
description: Optional[StrictStr] = Field(None, title="Model description")
|
||||
license: Optional[StrictStr] = Field(None, title="Model license")
|
||||
author: Optional[StrictStr] = Field(None, title="Model author name")
|
||||
email: Optional[StrictStr] = Field(None, title="Model author email")
|
||||
url: Optional[StrictStr] = Field(None, title="Model author URL")
|
||||
sources: Optional[Union[List[StrictStr], Dict[str, str]]] = Field(None, title="Training data sources")
|
||||
vectors: Optional[Dict[str, int]] = Field(None, title="Included word vectors")
|
||||
accuracy: Optional[Dict[str, Union[float, int]]] = Field(None, title="Accuracy numbers")
|
||||
speed: Optional[Dict[str, Union[float, int]]] = Field(None, title="Speed evaluation numbers")
|
||||
# fmt: on
|
||||
|
||||
|
||||
# Training data object in "simple training style"
|
||||
|
||||
|
||||
class SimpleTrainingSchema(BaseModel):
|
||||
# TODO: write
|
||||
|
||||
class Config:
|
||||
title = "Schema for training data dict in passed to nlp.update"
|
||||
extra = "forbid"
|
||||
|
||||
|
||||
# JSON training format
|
||||
|
||||
|
||||
class TrainingSchema(BaseModel):
|
||||
# TODO: write
|
||||
|
||||
class Config:
|
||||
title = "Schema for training data in spaCy's JSON format"
|
||||
extra = "forbid"
|
|
@ -1,6 +1,4 @@
|
|||
import pytest
|
||||
from spacy.cli._schemas import TRAINING_SCHEMA
|
||||
from spacy.util import get_json_validator, validate_json
|
||||
from spacy.tokens import Doc
|
||||
from ..util import get_doc
|
||||
|
||||
|
@ -55,10 +53,3 @@ def test_doc_to_json_underscore_error_serialize(doc):
|
|||
Doc.set_extension("json_test4", method=lambda doc: doc.text)
|
||||
with pytest.raises(ValueError):
|
||||
doc.to_json(underscore=["json_test4"])
|
||||
|
||||
|
||||
def test_doc_to_json_valid_training(doc):
|
||||
json_doc = doc.to_json()
|
||||
validator = get_json_validator(TRAINING_SCHEMA)
|
||||
errors = validate_json([json_doc], validator)
|
||||
assert not errors
|
||||
|
|
|
@ -179,7 +179,7 @@ def test_matcher_match_one_plus(matcher):
|
|||
doc = Doc(control.vocab, words=["Philippe", "Philippe"])
|
||||
m = control(doc)
|
||||
assert len(m) == 2
|
||||
pattern = [{"ORTH": "Philippe", "OP": "1"}, {"ORTH": "Philippe", "OP": "+"}]
|
||||
pattern = [{"ORTH": "Philippe"}, {"ORTH": "Philippe", "OP": "+"}]
|
||||
matcher.add("KleenePhilippe", [pattern])
|
||||
m = matcher(doc)
|
||||
assert len(m) == 1
|
||||
|
|
|
@ -6,18 +6,18 @@ from spacy.matcher import Matcher
|
|||
from spacy.tokens import Doc, Span
|
||||
|
||||
|
||||
pattern1 = [{"ORTH": "A", "OP": "1"}, {"ORTH": "A", "OP": "*"}]
|
||||
pattern2 = [{"ORTH": "A", "OP": "*"}, {"ORTH": "A", "OP": "1"}]
|
||||
pattern3 = [{"ORTH": "A", "OP": "1"}, {"ORTH": "A", "OP": "1"}]
|
||||
pattern1 = [{"ORTH": "A"}, {"ORTH": "A", "OP": "*"}]
|
||||
pattern2 = [{"ORTH": "A"}, {"ORTH": "A"}]
|
||||
pattern3 = [{"ORTH": "A"}, {"ORTH": "A"}]
|
||||
pattern4 = [
|
||||
{"ORTH": "B", "OP": "1"},
|
||||
{"ORTH": "B"},
|
||||
{"ORTH": "A", "OP": "*"},
|
||||
{"ORTH": "B", "OP": "1"},
|
||||
{"ORTH": "B"},
|
||||
]
|
||||
pattern5 = [
|
||||
{"ORTH": "B", "OP": "*"},
|
||||
{"ORTH": "A", "OP": "*"},
|
||||
{"ORTH": "B", "OP": "1"},
|
||||
{"ORTH": "B"},
|
||||
]
|
||||
|
||||
re_pattern1 = "AA*"
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
import pytest
|
||||
from spacy.matcher import Matcher
|
||||
from spacy.matcher._schemas import TOKEN_PATTERN_SCHEMA
|
||||
from spacy.errors import MatchPatternError
|
||||
from spacy.util import get_json_validator, validate_json
|
||||
from spacy.schemas import validate_token_pattern
|
||||
|
||||
# (pattern, num errors with validation, num errors identified with minimal
|
||||
# checks)
|
||||
|
@ -15,12 +14,12 @@ TEST_PATTERNS = [
|
|||
('[{"TEXT": "foo"}, {"LOWER": "bar"}]', 1, 1),
|
||||
([1, 2, 3], 3, 1),
|
||||
# Bad patterns flagged outside of Matcher
|
||||
([{"_": {"foo": "bar", "baz": {"IN": "foo"}}}], 1, 0),
|
||||
([{"_": {"foo": "bar", "baz": {"IN": "foo"}}}], 2, 0), # prev: (1, 0)
|
||||
# Bad patterns not flagged with minimal checks
|
||||
([{"LENGTH": "2", "TEXT": 2}, {"LOWER": "test"}], 2, 0),
|
||||
([{"LENGTH": {"IN": [1, 2, "3"]}}, {"POS": {"IN": "VERB"}}], 2, 0),
|
||||
([{"LENGTH": {"VALUE": 5}}], 1, 0),
|
||||
([{"TEXT": {"VALUE": "foo"}}], 1, 0),
|
||||
([{"LENGTH": {"IN": [1, 2, "3"]}}, {"POS": {"IN": "VERB"}}], 4, 0), # prev: (2, 0)
|
||||
([{"LENGTH": {"VALUE": 5}}], 2, 0), # prev: (1, 0)
|
||||
([{"TEXT": {"VALUE": "foo"}}], 2, 0), # prev: (1, 0)
|
||||
([{"IS_DIGIT": -1}], 1, 0),
|
||||
([{"ORTH": -1}], 1, 0),
|
||||
# Good patterns
|
||||
|
@ -31,15 +30,9 @@ TEST_PATTERNS = [
|
|||
([{"LOWER": {"REGEX": "^X", "NOT_IN": ["XXX", "XY"]}}], 0, 0),
|
||||
([{"NORM": "a"}, {"POS": {"IN": ["NOUN"]}}], 0, 0),
|
||||
([{"_": {"foo": {"NOT_IN": ["bar", "baz"]}, "a": 5, "b": {">": 10}}}], 0, 0),
|
||||
([{"orth": "foo"}], 0, 0), # prev: xfail
|
||||
]
|
||||
|
||||
XFAIL_TEST_PATTERNS = [([{"orth": "foo"}], 0, 0)]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def validator():
|
||||
return get_json_validator(TOKEN_PATTERN_SCHEMA)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"pattern", [[{"XX": "y"}, {"LENGTH": "2"}, {"TEXT": {"IN": 5}}]]
|
||||
|
@ -51,15 +44,8 @@ def test_matcher_pattern_validation(en_vocab, pattern):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("pattern,n_errors,_", TEST_PATTERNS)
|
||||
def test_pattern_validation(validator, pattern, n_errors, _):
|
||||
errors = validate_json(pattern, validator)
|
||||
assert len(errors) == n_errors
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.parametrize("pattern,n_errors,_", XFAIL_TEST_PATTERNS)
|
||||
def test_xfail_pattern_validation(validator, pattern, n_errors, _):
|
||||
errors = validate_json(pattern, validator)
|
||||
def test_pattern_validation(pattern, n_errors, _):
|
||||
errors = validate_token_pattern(pattern)
|
||||
assert len(errors) == n_errors
|
||||
|
||||
|
||||
|
|
|
@ -1,47 +0,0 @@
|
|||
from spacy.util import get_json_validator, validate_json, validate_schema
|
||||
from spacy.cli._schemas import META_SCHEMA, TRAINING_SCHEMA
|
||||
from spacy.matcher._schemas import TOKEN_PATTERN_SCHEMA
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def training_schema_validator():
|
||||
return get_json_validator(TRAINING_SCHEMA)
|
||||
|
||||
|
||||
def test_validate_schema():
|
||||
validate_schema({"type": "object"})
|
||||
with pytest.raises(Exception):
|
||||
validate_schema({"type": lambda x: x})
|
||||
|
||||
|
||||
@pytest.mark.parametrize("schema", [TRAINING_SCHEMA, META_SCHEMA, TOKEN_PATTERN_SCHEMA])
|
||||
def test_schemas(schema):
|
||||
validate_schema(schema)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data",
|
||||
[
|
||||
{"text": "Hello world"},
|
||||
{"text": "Hello", "ents": [{"start": 0, "end": 5, "label": "TEST"}]},
|
||||
],
|
||||
)
|
||||
def test_json_schema_training_valid(data, training_schema_validator):
|
||||
errors = validate_json([data], training_schema_validator)
|
||||
assert not errors
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data,n_errors",
|
||||
[
|
||||
({"spans": []}, 1),
|
||||
({"text": "Hello", "ents": [{"start": "0", "end": "5", "label": "TEST"}]}, 2),
|
||||
({"text": "Hello", "ents": [{"start": 0, "end": 5}]}, 1),
|
||||
({"text": "Hello", "ents": [{"start": 0, "end": 5, "label": "test"}]}, 1),
|
||||
({"text": "spaCy", "tokens": [{"pos": "PROPN"}]}, 2),
|
||||
],
|
||||
)
|
||||
def test_json_schema_training_invalid(data, n_errors, training_schema_validator):
|
||||
errors = validate_json([data], training_schema_validator)
|
||||
assert len(errors) == n_errors
|
|
@ -13,11 +13,6 @@ import srsly
|
|||
import catalogue
|
||||
import sys
|
||||
|
||||
try:
|
||||
import jsonschema
|
||||
except ImportError:
|
||||
jsonschema = None
|
||||
|
||||
try:
|
||||
import cupy.random
|
||||
except ImportError:
|
||||
|
@ -705,43 +700,6 @@ def fix_random_seed(seed=0):
|
|||
cupy.random.seed(seed)
|
||||
|
||||
|
||||
def get_json_validator(schema):
|
||||
# We're using a helper function here to make it easier to change the
|
||||
# validator that's used (e.g. different draft implementation), without
|
||||
# having to change it all across the codebase.
|
||||
# TODO: replace with (stable) Draft6Validator, if available
|
||||
if jsonschema is None:
|
||||
raise ValueError(Errors.E136)
|
||||
return jsonschema.Draft4Validator(schema)
|
||||
|
||||
|
||||
def validate_schema(schema):
|
||||
"""Validate a given schema. This just checks if the schema itself is valid."""
|
||||
validator = get_json_validator(schema)
|
||||
validator.check_schema(schema)
|
||||
|
||||
|
||||
def validate_json(data, validator):
|
||||
"""Validate data against a given JSON schema (see https://json-schema.org).
|
||||
|
||||
data: JSON-serializable data to validate.
|
||||
validator (jsonschema.DraftXValidator): The validator.
|
||||
RETURNS (list): A list of error messages, if available.
|
||||
"""
|
||||
errors = []
|
||||
for err in sorted(validator.iter_errors(data), key=lambda e: e.path):
|
||||
if err.path:
|
||||
err_path = "[{}]".format(" -> ".join([str(p) for p in err.path]))
|
||||
else:
|
||||
err_path = ""
|
||||
msg = err.message + " " + err_path
|
||||
if err.context: # Error has suberrors, e.g. if schema uses anyOf
|
||||
suberrs = [f" - {suberr.message}" for suberr in err.context]
|
||||
msg += f":\n{''.join(suberrs)}"
|
||||
errors.append(msg)
|
||||
return errors
|
||||
|
||||
|
||||
def get_serialization_exclude(serializers, exclude, kwargs):
|
||||
"""Helper function to validate serialization args and manage transition from
|
||||
keyword arguments (pre v2.1) to exclude argument.
|
||||
|
|
Loading…
Reference in New Issue
Block a user