mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-11 17:56:30 +03:00
Check for is_tagged/is_parsed for Matcher attrs (#4163)
Check for relevant components in the pipeline when Matcher is called, similar to the checks for PhraseMatcher in #4105. * keep track of attributes seen in patterns * when Matcher is called on a Doc, check for is_tagged for LEMMA, TAG, POS and for is_parsed for DEP
This commit is contained in:
parent
4fe9329bfb
commit
2d17b047e2
|
@ -441,11 +441,11 @@ class Errors(object):
|
|||
"patterns. Please use the option validate=True with Matcher, "
|
||||
"PhraseMatcher, or EntityRuler for more details.")
|
||||
E155 = ("The pipeline needs to include a tagger in order to use "
|
||||
"PhraseMatcher with the attributes POS, TAG, or LEMMA. Try using "
|
||||
"nlp() instead of nlp.make_doc() or list(nlp.pipe()) instead of "
|
||||
"list(nlp.tokenizer.pipe()).")
|
||||
"Matcher or PhraseMatcher with the attributes POS, TAG, or LEMMA. "
|
||||
"Try using nlp() instead of nlp.make_doc() or list(nlp.pipe()) "
|
||||
"instead of list(nlp.tokenizer.pipe()).")
|
||||
E156 = ("The pipeline needs to include a parser in order to use "
|
||||
"PhraseMatcher with the attribute DEP. Try using "
|
||||
"Matcher or PhraseMatcher with the attribute DEP. Try using "
|
||||
"nlp() instead of nlp.make_doc() or list(nlp.pipe()) instead of "
|
||||
"list(nlp.tokenizer.pipe()).")
|
||||
E157 = ("Can't render negative values for dependency arc start or end. "
|
||||
|
|
|
@ -67,3 +67,4 @@ cdef class Matcher:
|
|||
cdef public object _callbacks
|
||||
cdef public object _extensions
|
||||
cdef public object _extra_predicates
|
||||
cdef public object _seen_attrs
|
||||
|
|
|
@ -15,7 +15,7 @@ from ..structs cimport TokenC
|
|||
from ..vocab cimport Vocab
|
||||
from ..tokens.doc cimport Doc, get_token_attr
|
||||
from ..tokens.token cimport Token
|
||||
from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH
|
||||
from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH, POS, TAG, DEP, LEMMA
|
||||
|
||||
from ._schemas import TOKEN_PATTERN_SCHEMA
|
||||
from ..util import get_json_validator, validate_json
|
||||
|
@ -45,7 +45,7 @@ cdef class Matcher:
|
|||
self._patterns = {}
|
||||
self._callbacks = {}
|
||||
self._extensions = {}
|
||||
self._extra_predicates = []
|
||||
self._seen_attrs = set()
|
||||
self.vocab = vocab
|
||||
self.mem = Pool()
|
||||
if validate:
|
||||
|
@ -116,6 +116,9 @@ cdef class Matcher:
|
|||
specs = _preprocess_pattern(pattern, self.vocab.strings,
|
||||
self._extensions, self._extra_predicates)
|
||||
self.patterns.push_back(init_pattern(self.mem, key, specs))
|
||||
for spec in specs:
|
||||
for attr, _ in spec[1]:
|
||||
self._seen_attrs.add(attr)
|
||||
except OverflowError, AttributeError:
|
||||
raise ValueError(Errors.E154.format())
|
||||
self._patterns.setdefault(key, [])
|
||||
|
@ -180,6 +183,11 @@ cdef class Matcher:
|
|||
describing the matches. A match tuple describes a span
|
||||
`doc[start:end]`. The `label_id` and `key` are both integers.
|
||||
"""
|
||||
if len(set([LEMMA, POS, TAG]) & self._seen_attrs) > 0 \
|
||||
and not doc.is_tagged:
|
||||
raise ValueError(Errors.E155.format())
|
||||
if DEP in self._seen_attrs and not doc.is_parsed:
|
||||
raise ValueError(Errors.E156.format())
|
||||
matches = find_matches(&self.patterns[0], self.patterns.size(), doc,
|
||||
extensions=self._extensions,
|
||||
predicates=self._extra_predicates)
|
||||
|
|
|
@ -344,3 +344,39 @@ def test_dependency_matcher_compile(dependency_matcher):
|
|||
# assert matches[0][1] == [[3, 1, 2]]
|
||||
# assert matches[1][1] == [[4, 3, 3]]
|
||||
# assert matches[2][1] == [[4, 3, 2]]
|
||||
|
||||
|
||||
def test_attr_pipeline_checks(en_vocab):
|
||||
doc1 = Doc(en_vocab, words=["Test"])
|
||||
doc1.is_parsed = True
|
||||
doc2 = Doc(en_vocab, words=["Test"])
|
||||
doc2.is_tagged = True
|
||||
doc3 = Doc(en_vocab, words=["Test"])
|
||||
# DEP requires is_parsed
|
||||
matcher = Matcher(en_vocab)
|
||||
matcher.add("TEST", None, [{"DEP": "a"}])
|
||||
matcher(doc1)
|
||||
with pytest.raises(ValueError):
|
||||
matcher(doc2)
|
||||
with pytest.raises(ValueError):
|
||||
matcher(doc3)
|
||||
# TAG, POS, LEMMA require is_tagged
|
||||
for attr in ("TAG", "POS", "LEMMA"):
|
||||
matcher = Matcher(en_vocab)
|
||||
matcher.add("TEST", None, [{attr: "a"}])
|
||||
matcher(doc2)
|
||||
with pytest.raises(ValueError):
|
||||
matcher(doc1)
|
||||
with pytest.raises(ValueError):
|
||||
matcher(doc3)
|
||||
# TEXT/ORTH only require tokens
|
||||
matcher = Matcher(en_vocab)
|
||||
matcher.add("TEST", None, [{"ORTH": "a"}])
|
||||
matcher(doc1)
|
||||
matcher(doc2)
|
||||
matcher(doc3)
|
||||
matcher = Matcher(en_vocab)
|
||||
matcher.add("TEST", None, [{"TEXT": "a"}])
|
||||
matcher(doc1)
|
||||
matcher(doc2)
|
||||
matcher(doc3)
|
||||
|
|
Loading…
Reference in New Issue
Block a user