Matcher support for Span as well as Doc (#5113)

* Matcher support for Span, as well as Doc #5056

* Removes an import unused

* Signed contributors agreement

* Code optimization and better test

* Add error message for bad Matcher call argument

* Fix merging
This commit is contained in:
Paolo Arduin 2020-04-15 13:51:33 +02:00 committed by GitHub
parent 1eef60c658
commit 1ca32d8f9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 33 additions and 17 deletions

View File

@ -5,7 +5,7 @@ This spaCy Contributor Agreement (**"SCA"**) is based on the
The SCA applies to any contribution that you make to any product or project The SCA applies to any contribution that you make to any product or project
managed by us (the **"project"**), and sets out the intellectual property rights managed by us (the **"project"**), and sets out the intellectual property rights
you grant to us in the contributed materials. The term **"us"** shall mean you grant to us in the contributed materials. The term **"us"** shall mean
[ExplosionAI UG (haftungsbeschränkt)](https://explosion.ai/legal). The term [ExplosionAI GmbH](https://explosion.ai/legal). The term
**"you"** shall mean the person or entity identified below. **"you"** shall mean the person or entity identified below.
If you agree to be bound by these terms, fill in the information requested If you agree to be bound by these terms, fill in the information requested

View File

@ -556,6 +556,7 @@ class Errors(object):
"({new_dim}) is not the same as the current vector dimension " "({new_dim}) is not the same as the current vector dimension "
"({curr_dim}).") "({curr_dim}).")
E194 = ("Unable to aligned mismatched text '{text}' and words '{words}'.") E194 = ("Unable to aligned mismatched text '{text}' and words '{words}'.")
E195 = ("Matcher can be called on {good} only, got {got}.")
@add_codes @add_codes

View File

@ -14,6 +14,7 @@ from ..typedefs cimport attr_t
from ..structs cimport TokenC from ..structs cimport TokenC
from ..vocab cimport Vocab from ..vocab cimport Vocab
from ..tokens.doc cimport Doc, get_token_attr from ..tokens.doc cimport Doc, get_token_attr
from ..tokens.span cimport Span
from ..tokens.token cimport Token from ..tokens.token cimport Token
from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH, POS, TAG, DEP, LEMMA from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH, POS, TAG, DEP, LEMMA
@ -211,22 +212,29 @@ cdef class Matcher:
else: else:
yield doc yield doc
def __call__(self, Doc doc): def __call__(self, object doc_or_span):
"""Find all token sequences matching the supplied pattern. """Find all token sequences matching the supplied pattern.
doc (Doc): The document to match over. doc_or_span (Doc or Span): The document to match over.
RETURNS (list): A list of `(key, start, end)` tuples, RETURNS (list): A list of `(key, start, end)` tuples,
describing the matches. A match tuple describes a span describing the matches. A match tuple describes a span
`doc[start:end]`. The `label_id` and `key` are both integers. `doc[start:end]`. The `label_id` and `key` are both integers.
""" """
if isinstance(doc_or_span, Doc):
doc = doc_or_span
length = len(doc)
elif isinstance(doc_or_span, Span):
doc = doc_or_span.doc
length = doc_or_span.end - doc_or_span.start
else:
raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doc_or_span).__name__))
if len(set([LEMMA, POS, TAG]) & self._seen_attrs) > 0 \ if len(set([LEMMA, POS, TAG]) & self._seen_attrs) > 0 \
and not doc.is_tagged: and not doc.is_tagged:
raise ValueError(Errors.E155.format()) raise ValueError(Errors.E155.format())
if DEP in self._seen_attrs and not doc.is_parsed: if DEP in self._seen_attrs and not doc.is_parsed:
raise ValueError(Errors.E156.format()) raise ValueError(Errors.E156.format())
matches = find_matches(&self.patterns[0], self.patterns.size(), doc, matches = find_matches(&self.patterns[0], self.patterns.size(), doc_or_span, length,
extensions=self._extensions, extensions=self._extensions, predicates=self._extra_predicates)
predicates=self._extra_predicates)
for i, (key, start, end) in enumerate(matches): for i, (key, start, end) in enumerate(matches):
on_match = self._callbacks.get(key, None) on_match = self._callbacks.get(key, None)
if on_match is not None: if on_match is not None:
@ -248,9 +256,7 @@ def unpickle_matcher(vocab, patterns, callbacks):
return matcher return matcher
cdef find_matches(TokenPatternC** patterns, int n, object doc_or_span, int length, extensions=None, predicates=tuple()):
cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None,
predicates=tuple()):
"""Find matches in a doc, with a compiled array of patterns. Matches are """Find matches in a doc, with a compiled array of patterns. Matches are
returned as a list of (id, start, end) tuples. returned as a list of (id, start, end) tuples.
@ -268,18 +274,18 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None,
cdef int i, j, nr_extra_attr cdef int i, j, nr_extra_attr
cdef Pool mem = Pool() cdef Pool mem = Pool()
output = [] output = []
if doc.length == 0: if length == 0:
# avoid any processing or mem alloc if the document is empty # avoid any processing or mem alloc if the document is empty
return output return output
if len(predicates) > 0: if len(predicates) > 0:
predicate_cache = <char*>mem.alloc(doc.length * len(predicates), sizeof(char)) predicate_cache = <char*>mem.alloc(length * len(predicates), sizeof(char))
if extensions is not None and len(extensions) >= 1: if extensions is not None and len(extensions) >= 1:
nr_extra_attr = max(extensions.values()) + 1 nr_extra_attr = max(extensions.values()) + 1
extra_attr_values = <attr_t*>mem.alloc(doc.length * nr_extra_attr, sizeof(attr_t)) extra_attr_values = <attr_t*>mem.alloc(length * nr_extra_attr, sizeof(attr_t))
else: else:
nr_extra_attr = 0 nr_extra_attr = 0
extra_attr_values = <attr_t*>mem.alloc(doc.length, sizeof(attr_t)) extra_attr_values = <attr_t*>mem.alloc(length, sizeof(attr_t))
for i, token in enumerate(doc): for i, token in enumerate(doc_or_span):
for name, index in extensions.items(): for name, index in extensions.items():
value = token._.get(name) value = token._.get(name)
if isinstance(value, basestring): if isinstance(value, basestring):
@ -287,11 +293,11 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None,
extra_attr_values[i * nr_extra_attr + index] = value extra_attr_values[i * nr_extra_attr + index] = value
# Main loop # Main loop
cdef int nr_predicate = len(predicates) cdef int nr_predicate = len(predicates)
for i in range(doc.length): for i in range(length):
for j in range(n): for j in range(n):
states.push_back(PatternStateC(patterns[j], i, 0)) states.push_back(PatternStateC(patterns[j], i, 0))
transition_states(states, matches, predicate_cache, transition_states(states, matches, predicate_cache,
doc[i], extra_attr_values, predicates) doc_or_span[i], extra_attr_values, predicates)
extra_attr_values += nr_extra_attr extra_attr_values += nr_extra_attr
predicate_cache += len(predicates) predicate_cache += len(predicates)
# Handle matches that end in 0-width patterns # Handle matches that end in 0-width patterns

View File

@ -6,7 +6,6 @@ import re
from mock import Mock from mock import Mock
from spacy.matcher import Matcher, DependencyMatcher from spacy.matcher import Matcher, DependencyMatcher
from spacy.tokens import Doc, Token from spacy.tokens import Doc, Token
from ..doc.test_underscore import clean_underscore # noqa: F401 from ..doc.test_underscore import clean_underscore # noqa: F401
@ -470,3 +469,13 @@ def test_matcher_callback(en_vocab):
doc = Doc(en_vocab, words=["This", "is", "a", "test", "."]) doc = Doc(en_vocab, words=["This", "is", "a", "test", "."])
matches = matcher(doc) matches = matcher(doc)
mock.assert_called_once_with(matcher, doc, 0, matches) mock.assert_called_once_with(matcher, doc, 0, matches)
def test_matcher_span(matcher):
text = "JavaScript is good but Java is better"
doc = Doc(matcher.vocab, words=text.split())
span_js = doc[:3]
span_java = doc[4:]
assert len(matcher(doc)) == 2
assert len(matcher(span_js)) == 1
assert len(matcher(span_java)) == 1