mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-03 05:04:09 +03:00
Pass alignments to Matcher callbacks (#9001)
* pass alignments to callbacks * refactor for single callback loop * Update spacy/matcher/matcher.pyx Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
parent
8895e3c9ad
commit
ca93504660
|
@ -281,28 +281,19 @@ cdef class Matcher:
|
||||||
final_matches.append((key, *match))
|
final_matches.append((key, *match))
|
||||||
# Mark tokens that have matched
|
# Mark tokens that have matched
|
||||||
memset(&matched[start], 1, span_len * sizeof(matched[0]))
|
memset(&matched[start], 1, span_len * sizeof(matched[0]))
|
||||||
if with_alignments:
|
|
||||||
final_matches_with_alignments = final_matches
|
|
||||||
final_matches = [(key, start, end) for key, start, end, alignments in final_matches]
|
|
||||||
# perform the callbacks on the filtered set of results
|
|
||||||
for i, (key, start, end) in enumerate(final_matches):
|
|
||||||
on_match = self._callbacks.get(key, None)
|
|
||||||
if on_match is not None:
|
|
||||||
on_match(self, doc, i, final_matches)
|
|
||||||
if as_spans:
|
if as_spans:
|
||||||
spans = []
|
final_results = []
|
||||||
for key, start, end in final_matches:
|
for key, start, end, *_ in final_matches:
|
||||||
if isinstance(doclike, Span):
|
if isinstance(doclike, Span):
|
||||||
start += doclike.start
|
start += doclike.start
|
||||||
end += doclike.start
|
end += doclike.start
|
||||||
spans.append(Span(doc, start, end, label=key))
|
final_results.append(Span(doc, start, end, label=key))
|
||||||
return spans
|
|
||||||
elif with_alignments:
|
elif with_alignments:
|
||||||
# convert alignments List[Dict[str, int]] --> List[int]
|
# convert alignments List[Dict[str, int]] --> List[int]
|
||||||
final_matches = []
|
|
||||||
# when multiple alignment (belongs to the same length) is found,
|
# when multiple alignment (belongs to the same length) is found,
|
||||||
# keeps the alignment that has largest token_idx
|
# keeps the alignment that has largest token_idx
|
||||||
for key, start, end, alignments in final_matches_with_alignments:
|
final_results = []
|
||||||
|
for key, start, end, alignments in final_matches:
|
||||||
sorted_alignments = sorted(alignments, key=lambda x: (x['length'], x['token_idx']), reverse=False)
|
sorted_alignments = sorted(alignments, key=lambda x: (x['length'], x['token_idx']), reverse=False)
|
||||||
alignments = [0] * (end-start)
|
alignments = [0] * (end-start)
|
||||||
for align in sorted_alignments:
|
for align in sorted_alignments:
|
||||||
|
@ -311,10 +302,16 @@ cdef class Matcher:
|
||||||
# Since alignments are sorted in order of (length, token_idx)
|
# Since alignments are sorted in order of (length, token_idx)
|
||||||
# this overwrites smaller token_idx when they have same length.
|
# this overwrites smaller token_idx when they have same length.
|
||||||
alignments[align['length']] = align['token_idx']
|
alignments[align['length']] = align['token_idx']
|
||||||
final_matches.append((key, start, end, alignments))
|
final_results.append((key, start, end, alignments))
|
||||||
return final_matches
|
final_matches = final_results # for callbacks
|
||||||
else:
|
else:
|
||||||
return final_matches
|
final_results = final_matches
|
||||||
|
# perform the callbacks on the filtered set of results
|
||||||
|
for i, (key, *_) in enumerate(final_matches):
|
||||||
|
on_match = self._callbacks.get(key, None)
|
||||||
|
if on_match is not None:
|
||||||
|
on_match(self, doc, i, final_matches)
|
||||||
|
return final_results
|
||||||
|
|
||||||
def _normalize_key(self, key):
|
def _normalize_key(self, key):
|
||||||
if isinstance(key, basestring):
|
if isinstance(key, basestring):
|
||||||
|
|
|
@ -576,6 +576,16 @@ def test_matcher_callback(en_vocab):
|
||||||
mock.assert_called_once_with(matcher, doc, 0, matches)
|
mock.assert_called_once_with(matcher, doc, 0, matches)
|
||||||
|
|
||||||
|
|
||||||
|
def test_matcher_callback_with_alignments(en_vocab):
|
||||||
|
mock = Mock()
|
||||||
|
matcher = Matcher(en_vocab)
|
||||||
|
pattern = [{"ORTH": "test"}]
|
||||||
|
matcher.add("Rule", [pattern], on_match=mock)
|
||||||
|
doc = Doc(en_vocab, words=["This", "is", "a", "test", "."])
|
||||||
|
matches = matcher(doc, with_alignments=True)
|
||||||
|
mock.assert_called_once_with(matcher, doc, 0, matches)
|
||||||
|
|
||||||
|
|
||||||
def test_matcher_span(matcher):
|
def test_matcher_span(matcher):
|
||||||
text = "JavaScript is good but Java is better"
|
text = "JavaScript is good but Java is better"
|
||||||
doc = Doc(matcher.vocab, words=text.split())
|
doc = Doc(matcher.vocab, words=text.split())
|
||||||
|
|
Loading…
Reference in New Issue
Block a user