mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-03 05:04:09 +03:00
Pass alignments to Matcher callbacks (#9001)
* pass alignments to callbacks * refactor for single callback loop * Update spacy/matcher/matcher.pyx Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
parent
8895e3c9ad
commit
ca93504660
|
@ -281,28 +281,19 @@ cdef class Matcher:
|
|||
final_matches.append((key, *match))
|
||||
# Mark tokens that have matched
|
||||
memset(&matched[start], 1, span_len * sizeof(matched[0]))
|
||||
if with_alignments:
|
||||
final_matches_with_alignments = final_matches
|
||||
final_matches = [(key, start, end) for key, start, end, alignments in final_matches]
|
||||
# perform the callbacks on the filtered set of results
|
||||
for i, (key, start, end) in enumerate(final_matches):
|
||||
on_match = self._callbacks.get(key, None)
|
||||
if on_match is not None:
|
||||
on_match(self, doc, i, final_matches)
|
||||
if as_spans:
|
||||
spans = []
|
||||
for key, start, end in final_matches:
|
||||
final_results = []
|
||||
for key, start, end, *_ in final_matches:
|
||||
if isinstance(doclike, Span):
|
||||
start += doclike.start
|
||||
end += doclike.start
|
||||
spans.append(Span(doc, start, end, label=key))
|
||||
return spans
|
||||
final_results.append(Span(doc, start, end, label=key))
|
||||
elif with_alignments:
|
||||
# convert alignments List[Dict[str, int]] --> List[int]
|
||||
final_matches = []
|
||||
# when multiple alignment (belongs to the same length) is found,
|
||||
# keeps the alignment that has largest token_idx
|
||||
for key, start, end, alignments in final_matches_with_alignments:
|
||||
final_results = []
|
||||
for key, start, end, alignments in final_matches:
|
||||
sorted_alignments = sorted(alignments, key=lambda x: (x['length'], x['token_idx']), reverse=False)
|
||||
alignments = [0] * (end-start)
|
||||
for align in sorted_alignments:
|
||||
|
@ -311,10 +302,16 @@ cdef class Matcher:
|
|||
# Since alignments are sorted in order of (length, token_idx)
|
||||
# this overwrites smaller token_idx when they have same length.
|
||||
alignments[align['length']] = align['token_idx']
|
||||
final_matches.append((key, start, end, alignments))
|
||||
return final_matches
|
||||
final_results.append((key, start, end, alignments))
|
||||
final_matches = final_results # for callbacks
|
||||
else:
|
||||
return final_matches
|
||||
final_results = final_matches
|
||||
# perform the callbacks on the filtered set of results
|
||||
for i, (key, *_) in enumerate(final_matches):
|
||||
on_match = self._callbacks.get(key, None)
|
||||
if on_match is not None:
|
||||
on_match(self, doc, i, final_matches)
|
||||
return final_results
|
||||
|
||||
def _normalize_key(self, key):
|
||||
if isinstance(key, basestring):
|
||||
|
|
|
@ -576,6 +576,16 @@ def test_matcher_callback(en_vocab):
|
|||
mock.assert_called_once_with(matcher, doc, 0, matches)
|
||||
|
||||
|
||||
def test_matcher_callback_with_alignments(en_vocab):
|
||||
mock = Mock()
|
||||
matcher = Matcher(en_vocab)
|
||||
pattern = [{"ORTH": "test"}]
|
||||
matcher.add("Rule", [pattern], on_match=mock)
|
||||
doc = Doc(en_vocab, words=["This", "is", "a", "test", "."])
|
||||
matches = matcher(doc, with_alignments=True)
|
||||
mock.assert_called_once_with(matcher, doc, 0, matches)
|
||||
|
||||
|
||||
def test_matcher_span(matcher):
|
||||
text = "JavaScript is good but Java is better"
|
||||
doc = Doc(matcher.vocab, words=text.split())
|
||||
|
|
Loading…
Reference in New Issue
Block a user