Preserve user data for DependencyMatcher on spans (#7528)

* Preserve user data for DependencyMatcher on spans

* Clean underscore in test

* Modify test to use extensions stored in user data
This commit is contained in:
Adriane Boyd 2021-03-30 12:26:22 +02:00 committed by GitHub
parent 921feee092
commit 348d1829c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 2 deletions

View File

@ -299,7 +299,7 @@ cdef class DependencyMatcher:
if isinstance(doclike, Doc): if isinstance(doclike, Doc):
doc = doclike doc = doclike
elif isinstance(doclike, Span): elif isinstance(doclike, Span):
doc = doclike.as_doc() doc = doclike.as_doc(copy_user_data=True)
else: else:
raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doclike).__name__)) raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doclike).__name__))

View File

@ -4,7 +4,9 @@ import re
import copy import copy
from mock import Mock from mock import Mock
from spacy.matcher import DependencyMatcher from spacy.matcher import DependencyMatcher
from spacy.tokens import Doc from spacy.tokens import Doc, Token
from ..doc.test_underscore import clean_underscore # noqa: F401
@pytest.fixture @pytest.fixture
@ -344,3 +346,26 @@ def test_dependency_matcher_long_matches(en_vocab, doc):
matcher = DependencyMatcher(en_vocab) matcher = DependencyMatcher(en_vocab)
with pytest.raises(ValueError): with pytest.raises(ValueError):
matcher.add("pattern", [pattern]) matcher.add("pattern", [pattern])
@pytest.mark.usefixtures("clean_underscore")
def test_dependency_matcher_span_user_data(en_tokenizer):
doc = en_tokenizer("a b c d e")
for token in doc:
token.head = doc[0]
token.dep_ = "a"
get_is_c = lambda token: token.text in ("c",)
Token.set_extension("is_c", default=False)
doc[2]._.is_c = True
pattern = [
{"RIGHT_ID": "c", "RIGHT_ATTRS": {"_": {"is_c": True}}},
]
matcher = DependencyMatcher(en_tokenizer.vocab)
matcher.add("C", [pattern])
doc_matches = matcher(doc)
offset = 1
span_matches = matcher(doc[offset:])
for doc_match, span_match in zip(sorted(doc_matches), sorted(span_matches)):
assert doc_match[0] == span_match[0]
for doc_t_i, span_t_i in zip(doc_match[1], span_match[1]):
assert doc_t_i == span_t_i + offset