Check for is_tagged/is_parsed for Matcher attrs (#4163)

Check for relevant components in the pipeline when Matcher is called,
similar to the checks for PhraseMatcher in #4105.

* keep track of attributes seen in patterns

* when Matcher is called on a Doc, check for is_tagged for LEMMA, TAG,
POS and for is_parsed for DEP
This commit is contained in:
adrianeboyd 2019-08-21 20:52:36 +02:00 committed by Matthew Honnibal
parent 4fe9329bfb
commit 2d17b047e2
4 changed files with 51 additions and 6 deletions

View File

@ -441,11 +441,11 @@ class Errors(object):
"patterns. Please use the option validate=True with Matcher, "
"PhraseMatcher, or EntityRuler for more details.")
E155 = ("The pipeline needs to include a tagger in order to use "
"PhraseMatcher with the attributes POS, TAG, or LEMMA. Try using "
"nlp() instead of nlp.make_doc() or list(nlp.pipe()) instead of "
"list(nlp.tokenizer.pipe()).")
"Matcher or PhraseMatcher with the attributes POS, TAG, or LEMMA. "
"Try using nlp() instead of nlp.make_doc() or list(nlp.pipe()) "
"instead of list(nlp.tokenizer.pipe()).")
E156 = ("The pipeline needs to include a parser in order to use "
"PhraseMatcher with the attribute DEP. Try using "
"Matcher or PhraseMatcher with the attribute DEP. Try using "
"nlp() instead of nlp.make_doc() or list(nlp.pipe()) instead of "
"list(nlp.tokenizer.pipe()).")
E157 = ("Can't render negative values for dependency arc start or end. "

View File

@ -67,3 +67,4 @@ cdef class Matcher:
cdef public object _callbacks
cdef public object _extensions
cdef public object _extra_predicates
cdef public object _seen_attrs

View File

@ -15,7 +15,7 @@ from ..structs cimport TokenC
from ..vocab cimport Vocab
from ..tokens.doc cimport Doc, get_token_attr
from ..tokens.token cimport Token
from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH
from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH, POS, TAG, DEP, LEMMA
from ._schemas import TOKEN_PATTERN_SCHEMA
from ..util import get_json_validator, validate_json
@ -45,7 +45,7 @@ cdef class Matcher:
self._patterns = {}
self._callbacks = {}
self._extensions = {}
self._extra_predicates = []
self._seen_attrs = set()
self.vocab = vocab
self.mem = Pool()
if validate:
@ -116,6 +116,9 @@ cdef class Matcher:
specs = _preprocess_pattern(pattern, self.vocab.strings,
self._extensions, self._extra_predicates)
self.patterns.push_back(init_pattern(self.mem, key, specs))
for spec in specs:
for attr, _ in spec[1]:
self._seen_attrs.add(attr)
except OverflowError, AttributeError:
raise ValueError(Errors.E154.format())
self._patterns.setdefault(key, [])
@ -180,6 +183,11 @@ cdef class Matcher:
describing the matches. A match tuple describes a span
`doc[start:end]`. The `label_id` and `key` are both integers.
"""
if len(set([LEMMA, POS, TAG]) & self._seen_attrs) > 0 \
and not doc.is_tagged:
raise ValueError(Errors.E155.format())
if DEP in self._seen_attrs and not doc.is_parsed:
raise ValueError(Errors.E156.format())
matches = find_matches(&self.patterns[0], self.patterns.size(), doc,
extensions=self._extensions,
predicates=self._extra_predicates)

View File

@ -344,3 +344,39 @@ def test_dependency_matcher_compile(dependency_matcher):
# assert matches[0][1] == [[3, 1, 2]]
# assert matches[1][1] == [[4, 3, 3]]
# assert matches[2][1] == [[4, 3, 2]]
def test_attr_pipeline_checks(en_vocab):
doc1 = Doc(en_vocab, words=["Test"])
doc1.is_parsed = True
doc2 = Doc(en_vocab, words=["Test"])
doc2.is_tagged = True
doc3 = Doc(en_vocab, words=["Test"])
# DEP requires is_parsed
matcher = Matcher(en_vocab)
matcher.add("TEST", None, [{"DEP": "a"}])
matcher(doc1)
with pytest.raises(ValueError):
matcher(doc2)
with pytest.raises(ValueError):
matcher(doc3)
# TAG, POS, LEMMA require is_tagged
for attr in ("TAG", "POS", "LEMMA"):
matcher = Matcher(en_vocab)
matcher.add("TEST", None, [{attr: "a"}])
matcher(doc2)
with pytest.raises(ValueError):
matcher(doc1)
with pytest.raises(ValueError):
matcher(doc3)
# TEXT/ORTH only require tokens
matcher = Matcher(en_vocab)
matcher.add("TEST", None, [{"ORTH": "a"}])
matcher(doc1)
matcher(doc2)
matcher(doc3)
matcher = Matcher(en_vocab)
matcher.add("TEST", None, [{"TEXT": "a"}])
matcher(doc1)
matcher(doc2)
matcher(doc3)