From ca93504660156ce21b6f795831fd200a530ea832 Mon Sep 17 00:00:00 2001 From: Kevin Humphreys Date: Thu, 2 Sep 2021 06:58:05 -0400 Subject: [PATCH] Pass alignments to Matcher callbacks (#9001) * pass alignments to callbacks * refactor for single callback loop * Update spacy/matcher/matcher.pyx Co-authored-by: Sofie Van Landeghem --- spacy/matcher/matcher.pyx | 31 +++++++++++-------------- spacy/tests/matcher/test_matcher_api.py | 10 ++++++++ 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/spacy/matcher/matcher.pyx b/spacy/matcher/matcher.pyx index be45dcaad..05c55c9a7 100644 --- a/spacy/matcher/matcher.pyx +++ b/spacy/matcher/matcher.pyx @@ -281,28 +281,19 @@ cdef class Matcher: final_matches.append((key, *match)) # Mark tokens that have matched memset(&matched[start], 1, span_len * sizeof(matched[0])) - if with_alignments: - final_matches_with_alignments = final_matches - final_matches = [(key, start, end) for key, start, end, alignments in final_matches] - # perform the callbacks on the filtered set of results - for i, (key, start, end) in enumerate(final_matches): - on_match = self._callbacks.get(key, None) - if on_match is not None: - on_match(self, doc, i, final_matches) if as_spans: - spans = [] - for key, start, end in final_matches: + final_results = [] + for key, start, end, *_ in final_matches: if isinstance(doclike, Span): start += doclike.start end += doclike.start - spans.append(Span(doc, start, end, label=key)) - return spans + final_results.append(Span(doc, start, end, label=key)) elif with_alignments: # convert alignments List[Dict[str, int]] --> List[int] - final_matches = [] # when multiple alignment (belongs to the same length) is found, # keeps the alignment that has largest token_idx - for key, start, end, alignments in final_matches_with_alignments: + final_results = [] + for key, start, end, alignments in final_matches: sorted_alignments = sorted(alignments, key=lambda x: (x['length'], x['token_idx']), reverse=False) alignments = [0] * (end-start) for align in sorted_alignments: @@ -311,10 +302,16 @@ cdef class Matcher: # Since alignments are sorted in order of (length, token_idx) # this overwrites smaller token_idx when they have same length. alignments[align['length']] = align['token_idx'] - final_matches.append((key, start, end, alignments)) - return final_matches + final_results.append((key, start, end, alignments)) + final_matches = final_results # for callbacks else: - return final_matches + final_results = final_matches + # perform the callbacks on the filtered set of results + for i, (key, *_) in enumerate(final_matches): + on_match = self._callbacks.get(key, None) + if on_match is not None: + on_match(self, doc, i, final_matches) + return final_results def _normalize_key(self, key): if isinstance(key, basestring): diff --git a/spacy/tests/matcher/test_matcher_api.py b/spacy/tests/matcher/test_matcher_api.py index a42735eae..c02d65cdf 100644 --- a/spacy/tests/matcher/test_matcher_api.py +++ b/spacy/tests/matcher/test_matcher_api.py @@ -576,6 +576,16 @@ def test_matcher_callback(en_vocab): mock.assert_called_once_with(matcher, doc, 0, matches) +def test_matcher_callback_with_alignments(en_vocab): + mock = Mock() + matcher = Matcher(en_vocab) + pattern = [{"ORTH": "test"}] + matcher.add("Rule", [pattern], on_match=mock) + doc = Doc(en_vocab, words=["This", "is", "a", "test", "."]) + matches = matcher(doc, with_alignments=True) + mock.assert_called_once_with(matcher, doc, 0, matches) + + def test_matcher_span(matcher): text = "JavaScript is good but Java is better" doc = Doc(matcher.vocab, words=text.split())