Forbid OP matching 2+ tokens in DependencyMatcher (#6824)

Instead of silently using only the first token in each matched span:

* Forbid `OP: ?/*/+` through `DependencyMatcher` validation
* As a fail-safe, add warning if a token match that's not exactly one
token long is found by a token pattern.
This commit is contained in:
Adriane Boyd 2021-01-29 01:52:01 +01:00 committed by GitHub
parent 24a697abb8
commit fcce3600ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 2 deletions

View File

@ -132,6 +132,11 @@ class Warnings:
"'morphologizer'.") "'morphologizer'.")
W109 = ("Unable to save user hooks while serializing the doc. Re-add any " W109 = ("Unable to save user hooks while serializing the doc. Re-add any "
"required user hooks to the doc after processing.") "required user hooks to the doc after processing.")
W110 = ("The DependencyMatcher token pattern {pattern} matched a span "
"{tokens} that is 2+ tokens long. Only the first token in the span "
"will be included in the results. For better results, token "
"patterns should return matches that are each exactly one token "
"long.")
@add_codes @add_codes
@ -751,6 +756,10 @@ class Errors:
"file.json .`.") "file.json .`.")
E1015 = ("Can't initialize model from config: no {value} found. For more " E1015 = ("Can't initialize model from config: no {value} found. For more "
"information, run: python -m spacy debug config config.cfg") "information, run: python -m spacy debug config config.cfg")
E1016 = ("The operators 'OP': '?', '*', and '+' are not supported in "
"DependencyMatcher token patterns. The token pattern in "
"RIGHT_ATTR should return matches that are each exactly one token "
"long. Invalid pattern:\n{node}")
# Deprecated model shortcuts, only used in errors and warnings # Deprecated model shortcuts, only used in errors and warnings

View File

@ -9,8 +9,9 @@ from .matcher cimport Matcher
from ..vocab cimport Vocab from ..vocab cimport Vocab
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
from ..errors import Errors from ..errors import Errors, Warnings
from ..tokens import Span from ..tokens import Span
from ..util import logger
DELIMITER = "||" DELIMITER = "||"
@ -137,6 +138,8 @@ cdef class DependencyMatcher:
raise ValueError(Errors.E1007.format(op=relation["REL_OP"])) raise ValueError(Errors.E1007.format(op=relation["REL_OP"]))
visited_nodes[relation["RIGHT_ID"]] = True visited_nodes[relation["RIGHT_ID"]] = True
visited_nodes[relation["LEFT_ID"]] = True visited_nodes[relation["LEFT_ID"]] = True
if relation["RIGHT_ATTRS"].get("OP", "") in ("?", "*", "+"):
raise ValueError(Errors.E1016.format(node=relation))
idx = idx + 1 idx = idx + 1
def _get_matcher_key(self, key, pattern_idx, token_idx): def _get_matcher_key(self, key, pattern_idx, token_idx):
@ -277,7 +280,9 @@ cdef class DependencyMatcher:
e.g. keys_to_position_maps[root_index][match_id] = [...] e.g. keys_to_position_maps[root_index][match_id] = [...]
""" """
keys_to_position_maps = defaultdict(lambda: defaultdict(list)) keys_to_position_maps = defaultdict(lambda: defaultdict(list))
for match_id, start, _ in self._matcher(doc): for match_id, start, end in self._matcher(doc):
if start + 1 != end:
logger.warning(Warnings.W110.format(tokens=[t.text for t in doc[start:end]], pattern=self._matcher.get(match_id)[1][0][0]))
token = doc[start] token = doc[start]
root = ([token] + list(token.ancestors))[-1] root = ([token] + list(token.ancestors))[-1]
keys_to_position_maps[root.i][match_id].append(start) keys_to_position_maps[root.i][match_id].append(start)

View File

@ -2,6 +2,7 @@ import pytest
import pickle import pickle
import re import re
import copy import copy
import logging
from mock import Mock from mock import Mock
from spacy.matcher import DependencyMatcher from spacy.matcher import DependencyMatcher
from spacy.tokens import Doc from spacy.tokens import Doc
@ -334,3 +335,14 @@ def test_dependency_matcher_ops(en_vocab, doc, left, right, op, num_matches):
matcher.add("pattern", [pattern]) matcher.add("pattern", [pattern])
matches = matcher(doc) matches = matcher(doc)
assert len(matches) == num_matches assert len(matches) == num_matches
def test_dependency_matcher_long_matches(en_vocab, doc):
pattern = [
{"RIGHT_ID": "quick", "RIGHT_ATTRS": {"DEP": "amod", "OP": "+"}},
]
matcher = DependencyMatcher(en_vocab)
logger = logging.getLogger("spacy")
with pytest.raises(ValueError):
matcher.add("pattern", [pattern])