Matcher ID fixes (#4179)

* allow phrasematcher to link one match to multiple original patterns

* small fix for defining ent_id in the matcher (anti-ghost prevention)

* cleanup

* formatting
This commit is contained in:
Sofie Van Landeghem 2019-08-22 17:17:07 +02:00 committed by Matthew Honnibal
parent f5d3afb1a3
commit c417c380e3
4 changed files with 39 additions and 10 deletions

View File

@ -327,7 +327,7 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
states[q].length += 1 states[q].length += 1
q += 1 q += 1
else: else:
ent_id = get_ent_id(&state.pattern[1]) ent_id = get_ent_id(state.pattern)
if action == MATCH: if action == MATCH:
matches.push_back( matches.push_back(
MatchC(pattern_id=ent_id, start=state.start, MatchC(pattern_id=ent_id, start=state.start,

View File

@ -0,0 +1,5 @@
from libcpp.vector cimport vector
from ..typedefs cimport hash_t
ctypedef vector[hash_t] hash_vec

View File

@ -2,6 +2,7 @@
# cython: profile=True # cython: profile=True
from __future__ import unicode_literals from __future__ import unicode_literals
from libcpp.vector cimport vector
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from murmurhash.mrmr cimport hash64 from murmurhash.mrmr cimport hash64
from preshed.maps cimport PreshMap from preshed.maps cimport PreshMap
@ -37,6 +38,7 @@ cdef class PhraseMatcher:
cdef Vocab vocab cdef Vocab vocab
cdef Matcher matcher cdef Matcher matcher
cdef PreshMap phrase_ids cdef PreshMap phrase_ids
cdef vector[hash_vec] ent_id_matrix
cdef int max_length cdef int max_length
cdef attr_id_t attr cdef attr_id_t attr
cdef public object _callbacks cdef public object _callbacks
@ -145,7 +147,23 @@ cdef class PhraseMatcher:
lexeme.set_flag(tag, True) lexeme.set_flag(tag, True)
phrase_key[i] = lexeme.orth phrase_key[i] = lexeme.orth
phrase_hash = hash64(phrase_key, length * sizeof(attr_t), 0) phrase_hash = hash64(phrase_key, length * sizeof(attr_t), 0)
self.phrase_ids.set(phrase_hash, <void*>ent_id)
if phrase_hash in self.phrase_ids:
phrase_index = self.phrase_ids[phrase_hash]
ent_id_list = self.ent_id_matrix[phrase_index]
ent_id_list.append(ent_id)
self.ent_id_matrix[phrase_index] = ent_id_list
else:
ent_id_list = hash_vec(1)
ent_id_list[0] = ent_id
new_index = self.ent_id_matrix.size()
if new_index == 0:
# PreshMaps can not contain 0 as value, so storing a dummy at 0
self.ent_id_matrix.push_back(hash_vec(0))
new_index = 1
self.ent_id_matrix.push_back(ent_id_list)
self.phrase_ids.set(phrase_hash, <void*>new_index)
def __call__(self, Doc doc): def __call__(self, Doc doc):
"""Find all sequences matching the supplied patterns on the `Doc`. """Find all sequences matching the supplied patterns on the `Doc`.
@ -167,9 +185,10 @@ cdef class PhraseMatcher:
words = [self.get_lex_value(doc, i) for i in range(len(doc))] words = [self.get_lex_value(doc, i) for i in range(len(doc))]
match_doc = Doc(self.vocab, words=words) match_doc = Doc(self.vocab, words=words)
for _, start, end in self.matcher(match_doc): for _, start, end in self.matcher(match_doc):
ent_id = self.accept_match(match_doc, start, end) ent_ids = self.accept_match(match_doc, start, end)
if ent_id is not None: if ent_ids is not None:
matches.append((ent_id, start, end)) for ent_id in ent_ids:
matches.append((ent_id, start, end))
for i, (ent_id, start, end) in enumerate(matches): for i, (ent_id, start, end) in enumerate(matches):
on_match = self._callbacks.get(ent_id) on_match = self._callbacks.get(ent_id)
if on_match is not None: if on_match is not None:
@ -216,11 +235,11 @@ cdef class PhraseMatcher:
for i, j in enumerate(range(start, end)): for i, j in enumerate(range(start, end)):
phrase_key[i] = doc.c[j].lex.orth phrase_key[i] = doc.c[j].lex.orth
cdef hash_t key = hash64(phrase_key, (end-start) * sizeof(attr_t), 0) cdef hash_t key = hash64(phrase_key, (end-start) * sizeof(attr_t), 0)
ent_id = <hash_t>self.phrase_ids.get(key)
if ent_id == 0: ent_index = <hash_t>self.phrase_ids.get(key)
if ent_index == 0:
return None return None
else: return self.ent_id_matrix[ent_index]
return ent_id
def get_lex_value(self, Doc doc, int i): def get_lex_value(self, Doc doc, int i):
if self.attr == ORTH: if self.attr == ORTH:

View File

@ -6,7 +6,6 @@ from spacy.matcher import PhraseMatcher
from spacy.tokens import Doc from spacy.tokens import Doc
@pytest.mark.xfail
def test_issue3972(en_vocab): def test_issue3972(en_vocab):
"""Test that the PhraseMatcher returns duplicates for duplicate match IDs. """Test that the PhraseMatcher returns duplicates for duplicate match IDs.
""" """
@ -15,4 +14,10 @@ def test_issue3972(en_vocab):
matcher.add("B", None, Doc(en_vocab, words=["New", "York"])) matcher.add("B", None, Doc(en_vocab, words=["New", "York"]))
doc = Doc(en_vocab, words=["I", "live", "in", "New", "York"]) doc = Doc(en_vocab, words=["I", "live", "in", "New", "York"])
matches = matcher(doc) matches = matcher(doc)
assert len(matches) == 2 assert len(matches) == 2
# We should have a match for each of the two rules
found_ids = [en_vocab.strings[ent_id] for (ent_id, _, _) in matches]
assert "A" in found_ids
assert "B" in found_ids