Pass alignments to Matcher callbacks (#9001)

* pass alignments to callbacks

* refactor for single callback loop

* Update spacy/matcher/matcher.pyx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
Kevin Humphreys 2021-09-02 06:58:05 -04:00 committed by GitHub
parent 8895e3c9ad
commit ca93504660
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 17 deletions

View File

@ -281,28 +281,19 @@ cdef class Matcher:
final_matches.append((key, *match)) final_matches.append((key, *match))
# Mark tokens that have matched # Mark tokens that have matched
memset(&matched[start], 1, span_len * sizeof(matched[0])) memset(&matched[start], 1, span_len * sizeof(matched[0]))
if with_alignments:
final_matches_with_alignments = final_matches
final_matches = [(key, start, end) for key, start, end, alignments in final_matches]
# perform the callbacks on the filtered set of results
for i, (key, start, end) in enumerate(final_matches):
on_match = self._callbacks.get(key, None)
if on_match is not None:
on_match(self, doc, i, final_matches)
if as_spans: if as_spans:
spans = [] final_results = []
for key, start, end in final_matches: for key, start, end, *_ in final_matches:
if isinstance(doclike, Span): if isinstance(doclike, Span):
start += doclike.start start += doclike.start
end += doclike.start end += doclike.start
spans.append(Span(doc, start, end, label=key)) final_results.append(Span(doc, start, end, label=key))
return spans
elif with_alignments: elif with_alignments:
# convert alignments List[Dict[str, int]] --> List[int] # convert alignments List[Dict[str, int]] --> List[int]
final_matches = []
# when multiple alignment (belongs to the same length) is found, # when multiple alignment (belongs to the same length) is found,
# keeps the alignment that has largest token_idx # keeps the alignment that has largest token_idx
for key, start, end, alignments in final_matches_with_alignments: final_results = []
for key, start, end, alignments in final_matches:
sorted_alignments = sorted(alignments, key=lambda x: (x['length'], x['token_idx']), reverse=False) sorted_alignments = sorted(alignments, key=lambda x: (x['length'], x['token_idx']), reverse=False)
alignments = [0] * (end-start) alignments = [0] * (end-start)
for align in sorted_alignments: for align in sorted_alignments:
@ -311,10 +302,16 @@ cdef class Matcher:
# Since alignments are sorted in order of (length, token_idx) # Since alignments are sorted in order of (length, token_idx)
# this overwrites smaller token_idx when they have same length. # this overwrites smaller token_idx when they have same length.
alignments[align['length']] = align['token_idx'] alignments[align['length']] = align['token_idx']
final_matches.append((key, start, end, alignments)) final_results.append((key, start, end, alignments))
return final_matches final_matches = final_results # for callbacks
else: else:
return final_matches final_results = final_matches
# perform the callbacks on the filtered set of results
for i, (key, *_) in enumerate(final_matches):
on_match = self._callbacks.get(key, None)
if on_match is not None:
on_match(self, doc, i, final_matches)
return final_results
def _normalize_key(self, key): def _normalize_key(self, key):
if isinstance(key, basestring): if isinstance(key, basestring):

View File

@ -576,6 +576,16 @@ def test_matcher_callback(en_vocab):
mock.assert_called_once_with(matcher, doc, 0, matches) mock.assert_called_once_with(matcher, doc, 0, matches)
def test_matcher_callback_with_alignments(en_vocab):
mock = Mock()
matcher = Matcher(en_vocab)
pattern = [{"ORTH": "test"}]
matcher.add("Rule", [pattern], on_match=mock)
doc = Doc(en_vocab, words=["This", "is", "a", "test", "."])
matches = matcher(doc, with_alignments=True)
mock.assert_called_once_with(matcher, doc, 0, matches)
def test_matcher_span(matcher): def test_matcher_span(matcher):
text = "JavaScript is good but Java is better" text = "JavaScript is good but Java is better"
doc = Doc(matcher.vocab, words=text.split()) doc = Doc(matcher.vocab, words=text.split())