💫 Add token match pattern validation via JSON schemas (#3244)

* Add custom MatchPatternError

* Improve validators and add validation option to Matcher

* Adjust formatting

* Never validate in Matcher within PhraseMatcher

If we do decide to make validate default to True, the PhraseMatcher's Matcher shouldn't ever validate. Here, we create the patterns automatically anyways (and it's currently unclear whether the validation has performance impacts at a very large scale).
This commit is contained in:
Ines Montani 2019-02-12 15:47:26 +01:00 committed by Matthew Honnibal
parent ad2a514cdf
commit 483dddc9bc
10 changed files with 282 additions and 21 deletions

View File

@ -74,8 +74,8 @@ def debug_data(
# Validate data format using the JSON schema # Validate data format using the JSON schema
# TODO: update once the new format is ready # TODO: update once the new format is ready
train_data_errors = [] # TODO: validate_json(train_data, schema) train_data_errors = [] # TODO: validate_json
dev_data_errors = [] # TODO: validate_json(dev_data, schema) dev_data_errors = [] # TODO: validate_json
if not train_data_errors: if not train_data_errors:
msg.good("Training data JSON format is valid") msg.good("Training data JSON format is valid")
if not dev_data_errors: if not dev_data_errors:

View File

@ -325,6 +325,21 @@ class TempErrors(object):
# fmt: on # fmt: on
class MatchPatternError(ValueError):
def __init__(self, key, errors):
"""Custom error for validating match patterns.
key (unicode): The name of the matcher rule.
errors (dict): Validation errors (sequence of strings) mapped to pattern
ID, i.e. the index of the added pattern.
"""
msg = "Invalid token patterns for matcher rule '{}'\n".format(key)
for pattern_idx, error_msgs in errors.items():
pattern_errors = "\n".join(["- {}".format(e) for e in error_msgs])
msg += "\nPattern {}:\n{}\n".format(pattern_idx, pattern_errors)
ValueError.__init__(self, msg)
class ModelsWarning(UserWarning): class ModelsWarning(UserWarning):
pass pass

172
spacy/matcher/_schemas.py Normal file
View File

@ -0,0 +1,172 @@
# coding: utf8
from __future__ import unicode_literals
TOKEN_PATTERN_SCHEMA = {
"$schema": "http://json-schema.org/draft-06/schema",
"definitions": {
"string_value": {
"anyOf": [
{"type": "string"},
{
"type": "object",
"properties": {
"REGEX": {"type": "string"},
"IN": {"type": "array", "items": {"type": "string"}},
"NOT_IN": {"type": "array", "items": {"type": "string"}},
},
"additionalProperties": False,
},
]
},
"integer_value": {
"anyOf": [
{"type": "integer"},
{
"type": "object",
"properties": {
"REGEX": {"type": "string"},
"IN": {"type": "array", "items": {"type": "integer"}},
"NOT_IN": {"type": "array", "items": {"type": "integer"}},
"==": {"type": "integer"},
">=": {"type": "integer"},
"<=": {"type": "integer"},
">": {"type": "integer"},
"<": {"type": "integer"},
},
"additionalProperties": False,
},
]
},
"boolean_value": {"type": "boolean"},
"underscore_value": {
"anyOf": [
{"type": ["string", "integer", "number", "array", "boolean", "null"]},
{
"type": "object",
"properties": {
"REGEX": {"type": "string"},
"IN": {
"type": "array",
"items": {"type": ["string", "integer"]},
},
"NOT_IN": {
"type": "array",
"items": {"type": ["string", "integer"]},
},
"==": {"type": "integer"},
">=": {"type": "integer"},
"<=": {"type": "integer"},
">": {"type": "integer"},
"<": {"type": "integer"},
},
"additionalProperties": False,
},
]
},
},
"type": "array",
"items": {
"type": "object",
"properties": {
"ORTH": {
"title": "Verbatim token text",
"$ref": "#/definitions/string_value",
},
"TEXT": {
"title": "Verbatim token text (spaCy v2.1+)",
"$ref": "#/definitions/string_value",
},
"LOWER": {
"title": "Lowercase form of token text",
"$ref": "#/definitions/string_value",
},
"POS": {
"title": "Coarse-grained part-of-speech tag",
"$ref": "#/definitions/string_value",
},
"TAG": {
"title": "Fine-grained part-of-speech tag",
"$ref": "#/definitions/string_value",
},
"DEP": {"title": "Dependency label", "$ref": "#/definitions/string_value"},
"LEMMA": {
"title": "Lemma (base form)",
"$ref": "#/definitions/string_value",
},
"SHAPE": {
"title": "Abstract token shape",
"$ref": "#/definitions/string_value",
},
"ENT_TYPE": {
"title": "Entity label of single token",
"$ref": "#/definitions/string_value",
},
"LENGTH": {
"title": "Token character length",
"$ref": "#/definitions/integer_value",
},
"IS_ALPHA": {
"title": "Token consists of alphanumeric characters",
"$ref": "#/definitions/boolean_value",
},
"IS_ASCII": {
"title": "Token consists of ASCII characters",
"$ref": "#/definitions/boolean_value",
},
"IS_DIGIT": {
"title": "Token consists of digits",
"$ref": "#/definitions/boolean_value",
},
"IS_LOWER": {
"title": "Token is lowercase",
"$ref": "#/definitions/boolean_value",
},
"IS_UPPER": {
"title": "Token is uppercase",
"$ref": "#/definitions/boolean_value",
},
"IS_TITLE": {
"title": "Token is titlecase",
"$ref": "#/definitions/boolean_value",
},
"IS_PUNCT": {
"title": "Token is punctuation",
"$ref": "#/definitions/boolean_value",
},
"IS_SPACE": {
"title": "Token is whitespace",
"$ref": "#/definitions/boolean_value",
},
"IS_STOP": {
"title": "Token is stop word",
"$ref": "#/definitions/boolean_value",
},
"LIKE_NUM": {
"title": "Token resembles a number",
"$ref": "#/definitions/boolean_value",
},
"LIKE_URL": {
"title": "Token resembles a URL",
"$ref": "#/definitions/boolean_value",
},
"LIKE_EMAIL": {
"title": "Token resembles an email address",
"$ref": "#/definitions/boolean_value",
},
"_": {
"title": "Custom extension token attributes (token._.)",
"type": "object",
"patternProperties": {
"^.*$": {"$ref": "#/definitions/underscore_value"}
},
},
"OP": {
"title": "Operators / quantifiers",
"type": "string",
"enum": ["+", "*", "?", "!"],
},
},
"additionalProperties": False,
},
}

View File

@ -62,6 +62,7 @@ cdef class Matcher:
cdef Pool mem cdef Pool mem
cdef vector[TokenPatternC*] patterns cdef vector[TokenPatternC*] patterns
cdef readonly Vocab vocab cdef readonly Vocab vocab
cdef public object validator
cdef public object _patterns cdef public object _patterns
cdef public object _callbacks cdef public object _callbacks
cdef public object _extensions cdef public object _extensions

View File

@ -17,7 +17,9 @@ from ..tokens.doc cimport Doc, get_token_attr
from ..tokens.token cimport Token from ..tokens.token cimport Token
from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH
from ..errors import Errors from ._schemas import TOKEN_PATTERN_SCHEMA
from ..util import get_json_validator, validate_json
from ..errors import Errors, MatchPatternError
from ..strings import get_string_id from ..strings import get_string_id
from ..attrs import IDS from ..attrs import IDS
@ -579,7 +581,7 @@ def _get_extensions(spec, string_store, name2index):
cdef class Matcher: cdef class Matcher:
"""Match sequences of tokens, based on pattern rules.""" """Match sequences of tokens, based on pattern rules."""
def __init__(self, vocab): def __init__(self, vocab, validate=False):
"""Create the Matcher. """Create the Matcher.
vocab (Vocab): The vocabulary object, which must be shared with the vocab (Vocab): The vocabulary object, which must be shared with the
@ -593,6 +595,7 @@ cdef class Matcher:
self._extra_predicates = [] self._extra_predicates = []
self.vocab = vocab self.vocab = vocab
self.mem = Pool() self.mem = Pool()
self.validator = get_json_validator(TOKEN_PATTERN_SCHEMA) if validate else None
def __reduce__(self): def __reduce__(self):
data = (self.vocab, self._patterns, self._callbacks) data = (self.vocab, self._patterns, self._callbacks)
@ -643,9 +646,14 @@ cdef class Matcher:
on_match (callable): Callback executed on match. on_match (callable): Callback executed on match.
*patterns (list): List of token descriptions. *patterns (list): List of token descriptions.
""" """
for pattern in patterns: errors = {}
for i, pattern in enumerate(patterns):
if len(pattern) == 0: if len(pattern) == 0:
raise ValueError(Errors.E012.format(key=key)) raise ValueError(Errors.E012.format(key=key))
if self.validator:
errors[i] = validate_json(pattern, self.validator)
if errors:
raise MatchPatternError(key, errors)
key = self._normalize_key(key) key = self._normalize_key(key)
for pattern in patterns: for pattern in patterns:
specs = _preprocess_pattern(pattern, self.vocab.strings, specs = _preprocess_pattern(pattern, self.vocab.strings,

View File

@ -41,7 +41,7 @@ cdef class PhraseMatcher:
self.mem = Pool() self.mem = Pool()
self.max_length = max_length self.max_length = max_length
self.vocab = vocab self.vocab = vocab
self.matcher = Matcher(self.vocab) self.matcher = Matcher(self.vocab, validate=False)
if isinstance(attr, long): if isinstance(attr, long):
self.attr = attr self.attr = attr
else: else:

View File

@ -3,7 +3,7 @@ from __future__ import unicode_literals
import pytest import pytest
from spacy.cli._schemas import TRAINING_SCHEMA from spacy.cli._schemas import TRAINING_SCHEMA
from spacy.util import validate_json from spacy.util import get_json_validator, validate_json
from spacy.tokens import Doc from spacy.tokens import Doc
from ..util import get_doc from ..util import get_doc
@ -62,5 +62,6 @@ def test_doc_to_json_underscore_error_serialize(doc):
def test_doc_to_json_valid_training(doc): def test_doc_to_json_valid_training(doc):
json_doc = doc.to_json() json_doc = doc.to_json()
errors = validate_json([json_doc], TRAINING_SCHEMA) validator = get_json_validator(TRAINING_SCHEMA)
errors = validate_json([json_doc], validator)
assert not errors assert not errors

View File

@ -0,0 +1,48 @@
# coding: utf-8
from __future__ import unicode_literals
import pytest
from spacy.matcher import Matcher
from spacy.matcher._schemas import TOKEN_PATTERN_SCHEMA
from spacy.errors import MatchPatternError
from spacy.util import get_json_validator, validate_json
@pytest.fixture
def validator():
return get_json_validator(TOKEN_PATTERN_SCHEMA)
@pytest.mark.parametrize(
"pattern", [[{"XX": "y"}, {"LENGTH": "2"}, {"TEXT": {"IN": 5}}]]
)
def test_matcher_pattern_validation(en_vocab, pattern):
matcher = Matcher(en_vocab, validate=True)
with pytest.raises(MatchPatternError):
matcher.add("TEST", None, pattern)
@pytest.mark.parametrize(
"pattern,n_errors",
[
# Bad patterns
([{"XX": "foo"}], 1),
([{"LENGTH": "2", "TEXT": 2}, {"LOWER": "test"}], 2),
([{"LENGTH": {"IN": [1, 2, "3"]}}, {"POS": {"IN": "VERB"}}], 2),
([{"IS_ALPHA": {"==": True}}, {"LIKE_NUM": None}], 2),
([{"TEXT": {"VALUE": "foo"}}], 1),
([{"LENGTH": {"VALUE": 5}}], 1),
([{"_": "foo"}], 1),
([{"_": {"foo": "bar", "baz": {"IN": "foo"}}}], 1),
([{"IS_PUNCT": True, "OP": "$"}], 1),
# Good patterns
([{"TEXT": "foo"}, {"LOWER": "bar"}], 0),
([{"LEMMA": {"IN": ["love", "like"]}}, {"POS": "DET", "OP": "?"}], 0),
([{"LIKE_NUM": True, "LENGTH": {">=": 5}}], 0),
([{"LOWER": {"REGEX": "^X", "NOT_IN": ["XXX", "XY"]}}], 0),
([{"_": {"foo": {"NOT_IN": ["bar", "baz"]}, "a": 5, "b": {">": 10}}}], 0),
],
)
def test_pattern_validation(validator, pattern, n_errors):
errors = validate_json(pattern, validator)
assert len(errors) == n_errors

View File

@ -1,18 +1,24 @@
# coding: utf-8 # coding: utf-8
from __future__ import unicode_literals from __future__ import unicode_literals
from spacy.util import validate_json, validate_schema from spacy.util import get_json_validator, validate_json, validate_schema
from spacy.cli._schemas import META_SCHEMA, TRAINING_SCHEMA from spacy.cli._schemas import META_SCHEMA, TRAINING_SCHEMA
from spacy.matcher._schemas import TOKEN_PATTERN_SCHEMA
import pytest import pytest
@pytest.fixture(scope="session")
def training_schema_validator():
return get_json_validator(TRAINING_SCHEMA)
def test_validate_schema(): def test_validate_schema():
validate_schema({"type": "object"}) validate_schema({"type": "object"})
with pytest.raises(Exception): with pytest.raises(Exception):
validate_schema({"type": lambda x: x}) validate_schema({"type": lambda x: x})
@pytest.mark.parametrize("schema", [TRAINING_SCHEMA, META_SCHEMA]) @pytest.mark.parametrize("schema", [TRAINING_SCHEMA, META_SCHEMA, TOKEN_PATTERN_SCHEMA])
def test_schemas(schema): def test_schemas(schema):
validate_schema(schema) validate_schema(schema)
@ -24,8 +30,8 @@ def test_schemas(schema):
{"text": "Hello", "ents": [{"start": 0, "end": 5, "label": "TEST"}]}, {"text": "Hello", "ents": [{"start": 0, "end": 5, "label": "TEST"}]},
], ],
) )
def test_json_schema_training_valid(data): def test_json_schema_training_valid(data, training_schema_validator):
errors = validate_json([data], TRAINING_SCHEMA) errors = validate_json([data], training_schema_validator)
assert not errors assert not errors
@ -39,6 +45,6 @@ def test_json_schema_training_valid(data):
({"text": "spaCy", "tokens": [{"pos": "PROPN"}]}, 2), ({"text": "spaCy", "tokens": [{"pos": "PROPN"}]}, 2),
], ],
) )
def test_json_schema_training_invalid(data, n_errors): def test_json_schema_training_invalid(data, n_errors, training_schema_validator):
errors = validate_json([data], TRAINING_SCHEMA) errors = validate_json([data], training_schema_validator)
assert len(errors) == n_errors assert len(errors) == n_errors

View File

@ -627,28 +627,38 @@ def fix_random_seed(seed=0):
cupy.random.seed(seed) cupy.random.seed(seed)
def validate_schema(schema): def get_json_validator(schema):
# We're using a helper function here to make it easier to change the
# validator that's used (e.g. different draft implementation), without
# having to change it all across the codebase.
# TODO: replace with (stable) Draft6Validator, if available # TODO: replace with (stable) Draft6Validator, if available
validator = Draft4Validator(schema) return Draft4Validator(schema)
def validate_schema(schema):
"""Validate a given schema. This just checks if the schema itself is valid."""
validator = get_json_validator(schema)
validator.check_schema(schema) validator.check_schema(schema)
def validate_json(data, schema): def validate_json(data, validator):
"""Validate data against a given JSON schema (see https://json-schema.org). """Validate data against a given JSON schema (see https://json-schema.org).
data: JSON-serializable data to validate. data: JSON-serializable data to validate.
schema (dict): The JSON schema. validator (jsonschema.DraftXValidator): The validator.
RETURNS (list): A list of error messages, if available. RETURNS (list): A list of error messages, if available.
""" """
# TODO: replace with (stable) Draft6Validator, if available
validator = Draft4Validator(schema)
errors = [] errors = []
for err in sorted(validator.iter_errors(data), key=lambda e: e.path): for err in sorted(validator.iter_errors(data), key=lambda e: e.path):
if err.path: if err.path:
err_path = "[{}]".format(" -> ".join([str(p) for p in err.path])) err_path = "[{}]".format(" -> ".join([str(p) for p in err.path]))
else: else:
err_path = "" err_path = ""
errors.append(err.message + " " + err_path) msg = err.message + " " + err_path
if err.context: # Error has suberrors, e.g. if schema uses anyOf
suberrs = [" - {}".format(suberr.message) for suberr in err.context]
msg += ":\n{}".format("".join(suberrs))
errors.append(msg)
return errors return errors