Fix PhraseMatcher for spaCy 2

This commit is contained in:
Matthew Honnibal 2017-09-20 21:54:31 +02:00
parent 78301b2d29
commit 828cc91545

View File

@ -426,7 +426,7 @@ cdef class PhraseMatcher:
self._phrase_key = <attr_t*>self.mem.alloc(max_length, sizeof(attr_t)) self._phrase_key = <attr_t*>self.mem.alloc(max_length, sizeof(attr_t))
self.max_length = max_length self.max_length = max_length
self.vocab = vocab self.vocab = vocab
self.matcher = Matcher(self.vocab, {}) self.matcher = Matcher(self.vocab)
self.phrase_ids = PreshMap() self.phrase_ids = PreshMap()
for phrase in phrases: for phrase in phrases:
if len(phrase) < max_length: if len(phrase) < max_length:
@ -435,7 +435,7 @@ cdef class PhraseMatcher:
abstract_patterns = [] abstract_patterns = []
for length in range(1, max_length): for length in range(1, max_length):
abstract_patterns.append([{tag: True} for tag in get_bilou(length)]) abstract_patterns.append([{tag: True} for tag in get_bilou(length)])
self.matcher.add('Candidate', 'MWE', {}, abstract_patterns, acceptor=self.accept_match) self.matcher.add('Candidate', None, *abstract_patterns)
def add(self, Doc tokens): def add(self, Doc tokens):
cdef int length = tokens.length cdef int length = tokens.length
@ -454,22 +454,19 @@ cdef class PhraseMatcher:
self.phrase_ids[key] = True self.phrase_ids[key] = True
def __call__(self, Doc doc): def __call__(self, Doc doc):
matches = [] matches = self.matcher(doc)
for ent_id, label, start, end in self.matcher(doc): accepted = []
cand = doc[start : end] for ent_id, start, end in matches:
start = cand[0].idx if self.accept_match(doc, ent_id, start, end):
end = cand[-1].idx + len(cand[-1]) accepted.append((ent_id, start, end))
matches.append((start, end, cand.root.tag_, cand.text, 'MWE')) return accepted
for match in matches:
doc.merge(*match)
return matches
def pipe(self, stream, batch_size=1000, n_threads=2): def pipe(self, stream, batch_size=1000, n_threads=2):
for doc in stream: for doc in stream:
self(doc) self(doc)
yield doc yield doc
def accept_match(self, Doc doc, attr_t ent_id, attr_t label, int start, int end): def accept_match(self, Doc doc, attr_t ent_id, int start, int end):
assert (end - start) < self.max_length assert (end - start) < self.max_length
cdef int i, j cdef int i, j
for i in range(self.max_length): for i in range(self.max_length):
@ -478,6 +475,6 @@ cdef class PhraseMatcher:
self._phrase_key[i] = doc.c[j].lex.orth self._phrase_key[i] = doc.c[j].lex.orth
cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0) cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0)
if self.phrase_ids.get(key): if self.phrase_ids.get(key):
return (ent_id, label, start, end) return True
else: else:
return False return False