mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Check for is_tagged/is_parsed for Matcher attrs (#4163)
Check for relevant components in the pipeline when Matcher is called, similar to the checks for PhraseMatcher in #4105. * keep track of attributes seen in patterns * when Matcher is called on a Doc, check for is_tagged for LEMMA, TAG, POS and for is_parsed for DEP
This commit is contained in:
parent
4fe9329bfb
commit
2d17b047e2
|
@ -441,11 +441,11 @@ class Errors(object):
|
||||||
"patterns. Please use the option validate=True with Matcher, "
|
"patterns. Please use the option validate=True with Matcher, "
|
||||||
"PhraseMatcher, or EntityRuler for more details.")
|
"PhraseMatcher, or EntityRuler for more details.")
|
||||||
E155 = ("The pipeline needs to include a tagger in order to use "
|
E155 = ("The pipeline needs to include a tagger in order to use "
|
||||||
"PhraseMatcher with the attributes POS, TAG, or LEMMA. Try using "
|
"Matcher or PhraseMatcher with the attributes POS, TAG, or LEMMA. "
|
||||||
"nlp() instead of nlp.make_doc() or list(nlp.pipe()) instead of "
|
"Try using nlp() instead of nlp.make_doc() or list(nlp.pipe()) "
|
||||||
"list(nlp.tokenizer.pipe()).")
|
"instead of list(nlp.tokenizer.pipe()).")
|
||||||
E156 = ("The pipeline needs to include a parser in order to use "
|
E156 = ("The pipeline needs to include a parser in order to use "
|
||||||
"PhraseMatcher with the attribute DEP. Try using "
|
"Matcher or PhraseMatcher with the attribute DEP. Try using "
|
||||||
"nlp() instead of nlp.make_doc() or list(nlp.pipe()) instead of "
|
"nlp() instead of nlp.make_doc() or list(nlp.pipe()) instead of "
|
||||||
"list(nlp.tokenizer.pipe()).")
|
"list(nlp.tokenizer.pipe()).")
|
||||||
E157 = ("Can't render negative values for dependency arc start or end. "
|
E157 = ("Can't render negative values for dependency arc start or end. "
|
||||||
|
|
|
@ -67,3 +67,4 @@ cdef class Matcher:
|
||||||
cdef public object _callbacks
|
cdef public object _callbacks
|
||||||
cdef public object _extensions
|
cdef public object _extensions
|
||||||
cdef public object _extra_predicates
|
cdef public object _extra_predicates
|
||||||
|
cdef public object _seen_attrs
|
||||||
|
|
|
@ -15,7 +15,7 @@ from ..structs cimport TokenC
|
||||||
from ..vocab cimport Vocab
|
from ..vocab cimport Vocab
|
||||||
from ..tokens.doc cimport Doc, get_token_attr
|
from ..tokens.doc cimport Doc, get_token_attr
|
||||||
from ..tokens.token cimport Token
|
from ..tokens.token cimport Token
|
||||||
from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH
|
from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH, POS, TAG, DEP, LEMMA
|
||||||
|
|
||||||
from ._schemas import TOKEN_PATTERN_SCHEMA
|
from ._schemas import TOKEN_PATTERN_SCHEMA
|
||||||
from ..util import get_json_validator, validate_json
|
from ..util import get_json_validator, validate_json
|
||||||
|
@ -45,7 +45,7 @@ cdef class Matcher:
|
||||||
self._patterns = {}
|
self._patterns = {}
|
||||||
self._callbacks = {}
|
self._callbacks = {}
|
||||||
self._extensions = {}
|
self._extensions = {}
|
||||||
self._extra_predicates = []
|
self._seen_attrs = set()
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
if validate:
|
if validate:
|
||||||
|
@ -116,6 +116,9 @@ cdef class Matcher:
|
||||||
specs = _preprocess_pattern(pattern, self.vocab.strings,
|
specs = _preprocess_pattern(pattern, self.vocab.strings,
|
||||||
self._extensions, self._extra_predicates)
|
self._extensions, self._extra_predicates)
|
||||||
self.patterns.push_back(init_pattern(self.mem, key, specs))
|
self.patterns.push_back(init_pattern(self.mem, key, specs))
|
||||||
|
for spec in specs:
|
||||||
|
for attr, _ in spec[1]:
|
||||||
|
self._seen_attrs.add(attr)
|
||||||
except OverflowError, AttributeError:
|
except OverflowError, AttributeError:
|
||||||
raise ValueError(Errors.E154.format())
|
raise ValueError(Errors.E154.format())
|
||||||
self._patterns.setdefault(key, [])
|
self._patterns.setdefault(key, [])
|
||||||
|
@ -180,6 +183,11 @@ cdef class Matcher:
|
||||||
describing the matches. A match tuple describes a span
|
describing the matches. A match tuple describes a span
|
||||||
`doc[start:end]`. The `label_id` and `key` are both integers.
|
`doc[start:end]`. The `label_id` and `key` are both integers.
|
||||||
"""
|
"""
|
||||||
|
if len(set([LEMMA, POS, TAG]) & self._seen_attrs) > 0 \
|
||||||
|
and not doc.is_tagged:
|
||||||
|
raise ValueError(Errors.E155.format())
|
||||||
|
if DEP in self._seen_attrs and not doc.is_parsed:
|
||||||
|
raise ValueError(Errors.E156.format())
|
||||||
matches = find_matches(&self.patterns[0], self.patterns.size(), doc,
|
matches = find_matches(&self.patterns[0], self.patterns.size(), doc,
|
||||||
extensions=self._extensions,
|
extensions=self._extensions,
|
||||||
predicates=self._extra_predicates)
|
predicates=self._extra_predicates)
|
||||||
|
|
|
@ -344,3 +344,39 @@ def test_dependency_matcher_compile(dependency_matcher):
|
||||||
# assert matches[0][1] == [[3, 1, 2]]
|
# assert matches[0][1] == [[3, 1, 2]]
|
||||||
# assert matches[1][1] == [[4, 3, 3]]
|
# assert matches[1][1] == [[4, 3, 3]]
|
||||||
# assert matches[2][1] == [[4, 3, 2]]
|
# assert matches[2][1] == [[4, 3, 2]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_attr_pipeline_checks(en_vocab):
|
||||||
|
doc1 = Doc(en_vocab, words=["Test"])
|
||||||
|
doc1.is_parsed = True
|
||||||
|
doc2 = Doc(en_vocab, words=["Test"])
|
||||||
|
doc2.is_tagged = True
|
||||||
|
doc3 = Doc(en_vocab, words=["Test"])
|
||||||
|
# DEP requires is_parsed
|
||||||
|
matcher = Matcher(en_vocab)
|
||||||
|
matcher.add("TEST", None, [{"DEP": "a"}])
|
||||||
|
matcher(doc1)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
matcher(doc2)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
matcher(doc3)
|
||||||
|
# TAG, POS, LEMMA require is_tagged
|
||||||
|
for attr in ("TAG", "POS", "LEMMA"):
|
||||||
|
matcher = Matcher(en_vocab)
|
||||||
|
matcher.add("TEST", None, [{attr: "a"}])
|
||||||
|
matcher(doc2)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
matcher(doc1)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
matcher(doc3)
|
||||||
|
# TEXT/ORTH only require tokens
|
||||||
|
matcher = Matcher(en_vocab)
|
||||||
|
matcher.add("TEST", None, [{"ORTH": "a"}])
|
||||||
|
matcher(doc1)
|
||||||
|
matcher(doc2)
|
||||||
|
matcher(doc3)
|
||||||
|
matcher = Matcher(en_vocab)
|
||||||
|
matcher.add("TEST", None, [{"TEXT": "a"}])
|
||||||
|
matcher(doc1)
|
||||||
|
matcher(doc2)
|
||||||
|
matcher(doc3)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user