From 2d17b047e22be4500bbd2f83898fbca2168917f4 Mon Sep 17 00:00:00 2001 From: adrianeboyd Date: Wed, 21 Aug 2019 20:52:36 +0200 Subject: [PATCH] Check for is_tagged/is_parsed for Matcher attrs (#4163) Check for relevant components in the pipeline when Matcher is called, similar to the checks for PhraseMatcher in #4105. * keep track of attributes seen in patterns * when Matcher is called on a Doc, check for is_tagged for LEMMA, TAG, POS and for is_parsed for DEP --- spacy/errors.py | 8 +++--- spacy/matcher/matcher.pxd | 1 + spacy/matcher/matcher.pyx | 12 +++++++-- spacy/tests/matcher/test_matcher_api.py | 36 +++++++++++++++++++++++++ 4 files changed, 51 insertions(+), 6 deletions(-) diff --git a/spacy/errors.py b/spacy/errors.py index 7811374a7..489f70ca7 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -441,11 +441,11 @@ class Errors(object): "patterns. Please use the option validate=True with Matcher, " "PhraseMatcher, or EntityRuler for more details.") E155 = ("The pipeline needs to include a tagger in order to use " - "PhraseMatcher with the attributes POS, TAG, or LEMMA. Try using " - "nlp() instead of nlp.make_doc() or list(nlp.pipe()) instead of " - "list(nlp.tokenizer.pipe()).") + "Matcher or PhraseMatcher with the attributes POS, TAG, or LEMMA. " + "Try using nlp() instead of nlp.make_doc() or list(nlp.pipe()) " + "instead of list(nlp.tokenizer.pipe()).") E156 = ("The pipeline needs to include a parser in order to use " - "PhraseMatcher with the attribute DEP. Try using " + "Matcher or PhraseMatcher with the attribute DEP. Try using " "nlp() instead of nlp.make_doc() or list(nlp.pipe()) instead of " "list(nlp.tokenizer.pipe()).") E157 = ("Can't render negative values for dependency arc start or end. " diff --git a/spacy/matcher/matcher.pxd b/spacy/matcher/matcher.pxd index 2e46cb376..a146cd107 100644 --- a/spacy/matcher/matcher.pxd +++ b/spacy/matcher/matcher.pxd @@ -67,3 +67,4 @@ cdef class Matcher: cdef public object _callbacks cdef public object _extensions cdef public object _extra_predicates + cdef public object _seen_attrs diff --git a/spacy/matcher/matcher.pyx b/spacy/matcher/matcher.pyx index 3226dd124..a79542929 100644 --- a/spacy/matcher/matcher.pyx +++ b/spacy/matcher/matcher.pyx @@ -15,7 +15,7 @@ from ..structs cimport TokenC from ..vocab cimport Vocab from ..tokens.doc cimport Doc, get_token_attr from ..tokens.token cimport Token -from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH +from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH, POS, TAG, DEP, LEMMA from ._schemas import TOKEN_PATTERN_SCHEMA from ..util import get_json_validator, validate_json @@ -45,7 +45,7 @@ cdef class Matcher: self._patterns = {} self._callbacks = {} self._extensions = {} - self._extra_predicates = [] + self._seen_attrs = set() self.vocab = vocab self.mem = Pool() if validate: @@ -116,6 +116,9 @@ cdef class Matcher: specs = _preprocess_pattern(pattern, self.vocab.strings, self._extensions, self._extra_predicates) self.patterns.push_back(init_pattern(self.mem, key, specs)) + for spec in specs: + for attr, _ in spec[1]: + self._seen_attrs.add(attr) except OverflowError, AttributeError: raise ValueError(Errors.E154.format()) self._patterns.setdefault(key, []) @@ -180,6 +183,11 @@ cdef class Matcher: describing the matches. A match tuple describes a span `doc[start:end]`. The `label_id` and `key` are both integers. """ + if len(set([LEMMA, POS, TAG]) & self._seen_attrs) > 0 \ + and not doc.is_tagged: + raise ValueError(Errors.E155.format()) + if DEP in self._seen_attrs and not doc.is_parsed: + raise ValueError(Errors.E156.format()) matches = find_matches(&self.patterns[0], self.patterns.size(), doc, extensions=self._extensions, predicates=self._extra_predicates) diff --git a/spacy/tests/matcher/test_matcher_api.py b/spacy/tests/matcher/test_matcher_api.py index 013700d52..401b7c928 100644 --- a/spacy/tests/matcher/test_matcher_api.py +++ b/spacy/tests/matcher/test_matcher_api.py @@ -344,3 +344,39 @@ def test_dependency_matcher_compile(dependency_matcher): # assert matches[0][1] == [[3, 1, 2]] # assert matches[1][1] == [[4, 3, 3]] # assert matches[2][1] == [[4, 3, 2]] + + +def test_attr_pipeline_checks(en_vocab): + doc1 = Doc(en_vocab, words=["Test"]) + doc1.is_parsed = True + doc2 = Doc(en_vocab, words=["Test"]) + doc2.is_tagged = True + doc3 = Doc(en_vocab, words=["Test"]) + # DEP requires is_parsed + matcher = Matcher(en_vocab) + matcher.add("TEST", None, [{"DEP": "a"}]) + matcher(doc1) + with pytest.raises(ValueError): + matcher(doc2) + with pytest.raises(ValueError): + matcher(doc3) + # TAG, POS, LEMMA require is_tagged + for attr in ("TAG", "POS", "LEMMA"): + matcher = Matcher(en_vocab) + matcher.add("TEST", None, [{attr: "a"}]) + matcher(doc2) + with pytest.raises(ValueError): + matcher(doc1) + with pytest.raises(ValueError): + matcher(doc3) + # TEXT/ORTH only require tokens + matcher = Matcher(en_vocab) + matcher.add("TEST", None, [{"ORTH": "a"}]) + matcher(doc1) + matcher(doc2) + matcher(doc3) + matcher = Matcher(en_vocab) + matcher.add("TEST", None, [{"TEXT": "a"}]) + matcher(doc1) + matcher(doc2) + matcher(doc3)