mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 18:06:29 +03:00
Fix PhraseMatcher for spaCy 2
This commit is contained in:
parent
78301b2d29
commit
828cc91545
|
@ -426,7 +426,7 @@ cdef class PhraseMatcher:
|
||||||
self._phrase_key = <attr_t*>self.mem.alloc(max_length, sizeof(attr_t))
|
self._phrase_key = <attr_t*>self.mem.alloc(max_length, sizeof(attr_t))
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.matcher = Matcher(self.vocab, {})
|
self.matcher = Matcher(self.vocab)
|
||||||
self.phrase_ids = PreshMap()
|
self.phrase_ids = PreshMap()
|
||||||
for phrase in phrases:
|
for phrase in phrases:
|
||||||
if len(phrase) < max_length:
|
if len(phrase) < max_length:
|
||||||
|
@ -435,7 +435,7 @@ cdef class PhraseMatcher:
|
||||||
abstract_patterns = []
|
abstract_patterns = []
|
||||||
for length in range(1, max_length):
|
for length in range(1, max_length):
|
||||||
abstract_patterns.append([{tag: True} for tag in get_bilou(length)])
|
abstract_patterns.append([{tag: True} for tag in get_bilou(length)])
|
||||||
self.matcher.add('Candidate', 'MWE', {}, abstract_patterns, acceptor=self.accept_match)
|
self.matcher.add('Candidate', None, *abstract_patterns)
|
||||||
|
|
||||||
def add(self, Doc tokens):
|
def add(self, Doc tokens):
|
||||||
cdef int length = tokens.length
|
cdef int length = tokens.length
|
||||||
|
@ -454,22 +454,19 @@ cdef class PhraseMatcher:
|
||||||
self.phrase_ids[key] = True
|
self.phrase_ids[key] = True
|
||||||
|
|
||||||
def __call__(self, Doc doc):
|
def __call__(self, Doc doc):
|
||||||
matches = []
|
matches = self.matcher(doc)
|
||||||
for ent_id, label, start, end in self.matcher(doc):
|
accepted = []
|
||||||
cand = doc[start : end]
|
for ent_id, start, end in matches:
|
||||||
start = cand[0].idx
|
if self.accept_match(doc, ent_id, start, end):
|
||||||
end = cand[-1].idx + len(cand[-1])
|
accepted.append((ent_id, start, end))
|
||||||
matches.append((start, end, cand.root.tag_, cand.text, 'MWE'))
|
return accepted
|
||||||
for match in matches:
|
|
||||||
doc.merge(*match)
|
|
||||||
return matches
|
|
||||||
|
|
||||||
def pipe(self, stream, batch_size=1000, n_threads=2):
|
def pipe(self, stream, batch_size=1000, n_threads=2):
|
||||||
for doc in stream:
|
for doc in stream:
|
||||||
self(doc)
|
self(doc)
|
||||||
yield doc
|
yield doc
|
||||||
|
|
||||||
def accept_match(self, Doc doc, attr_t ent_id, attr_t label, int start, int end):
|
def accept_match(self, Doc doc, attr_t ent_id, int start, int end):
|
||||||
assert (end - start) < self.max_length
|
assert (end - start) < self.max_length
|
||||||
cdef int i, j
|
cdef int i, j
|
||||||
for i in range(self.max_length):
|
for i in range(self.max_length):
|
||||||
|
@ -478,6 +475,6 @@ cdef class PhraseMatcher:
|
||||||
self._phrase_key[i] = doc.c[j].lex.orth
|
self._phrase_key[i] = doc.c[j].lex.orth
|
||||||
cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0)
|
cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0)
|
||||||
if self.phrase_ids.get(key):
|
if self.phrase_ids.get(key):
|
||||||
return (ent_id, label, start, end)
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
Loading…
Reference in New Issue
Block a user