Forbid OP matching 2+ tokens in DependencyMatcher (#6824)

Instead of silently using only the first token in each matched span:

* Forbid `OP: ?/*/+` through `DependencyMatcher` validation
* As a fail-safe, add warning if a token match that's not exactly one
token long is found by a token pattern.
This commit is contained in:
Adriane Boyd 2021-01-29 01:52:01 +01:00 committed by GitHub
parent 24a697abb8
commit fcce3600ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 2 deletions

View File

@ -132,6 +132,11 @@ class Warnings:
"'morphologizer'.")
W109 = ("Unable to save user hooks while serializing the doc. Re-add any "
"required user hooks to the doc after processing.")
W110 = ("The DependencyMatcher token pattern {pattern} matched a span "
"{tokens} that is 2+ tokens long. Only the first token in the span "
"will be included in the results. For better results, token "
"patterns should return matches that are each exactly one token "
"long.")
@add_codes
@ -751,6 +756,10 @@ class Errors:
"file.json .`.")
E1015 = ("Can't initialize model from config: no {value} found. For more "
"information, run: python -m spacy debug config config.cfg")
E1016 = ("The operators 'OP': '?', '*', and '+' are not supported in "
"DependencyMatcher token patterns. The token pattern in "
"RIGHT_ATTR should return matches that are each exactly one token "
"long. Invalid pattern:\n{node}")
# Deprecated model shortcuts, only used in errors and warnings

View File

@ -9,8 +9,9 @@ from .matcher cimport Matcher
from ..vocab cimport Vocab
from ..tokens.doc cimport Doc
from ..errors import Errors
from ..errors import Errors, Warnings
from ..tokens import Span
from ..util import logger
DELIMITER = "||"
@ -137,6 +138,8 @@ cdef class DependencyMatcher:
raise ValueError(Errors.E1007.format(op=relation["REL_OP"]))
visited_nodes[relation["RIGHT_ID"]] = True
visited_nodes[relation["LEFT_ID"]] = True
if relation["RIGHT_ATTRS"].get("OP", "") in ("?", "*", "+"):
raise ValueError(Errors.E1016.format(node=relation))
idx = idx + 1
def _get_matcher_key(self, key, pattern_idx, token_idx):
@ -277,7 +280,9 @@ cdef class DependencyMatcher:
e.g. keys_to_position_maps[root_index][match_id] = [...]
"""
keys_to_position_maps = defaultdict(lambda: defaultdict(list))
for match_id, start, _ in self._matcher(doc):
for match_id, start, end in self._matcher(doc):
if start + 1 != end:
logger.warning(Warnings.W110.format(tokens=[t.text for t in doc[start:end]], pattern=self._matcher.get(match_id)[1][0][0]))
token = doc[start]
root = ([token] + list(token.ancestors))[-1]
keys_to_position_maps[root.i][match_id].append(start)

View File

@ -2,6 +2,7 @@ import pytest
import pickle
import re
import copy
import logging
from mock import Mock
from spacy.matcher import DependencyMatcher
from spacy.tokens import Doc
@ -334,3 +335,14 @@ def test_dependency_matcher_ops(en_vocab, doc, left, right, op, num_matches):
matcher.add("pattern", [pattern])
matches = matcher(doc)
assert len(matches) == num_matches
def test_dependency_matcher_long_matches(en_vocab, doc):
pattern = [
{"RIGHT_ID": "quick", "RIGHT_ATTRS": {"DEP": "amod", "OP": "+"}},
]
matcher = DependencyMatcher(en_vocab)
logger = logging.getLogger("spacy")
with pytest.raises(ValueError):
matcher.add("pattern", [pattern])