mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
Make PhraseMatcher API like Matcher API
This commit is contained in:
parent
43ad250dd5
commit
cc408fc189
|
@ -421,52 +421,67 @@ cdef class PhraseMatcher:
|
||||||
cdef int max_length
|
cdef int max_length
|
||||||
cdef attr_t* _phrase_key
|
cdef attr_t* _phrase_key
|
||||||
|
|
||||||
def __init__(self, Vocab vocab, phrases, max_length=10):
|
cdef public object _callbacks
|
||||||
|
|
||||||
|
def __init__(self, Vocab vocab, max_length=10):
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
self._phrase_key = <attr_t*>self.mem.alloc(max_length, sizeof(attr_t))
|
self._phrase_key = <attr_t*>self.mem.alloc(max_length, sizeof(attr_t))
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.matcher = Matcher(self.vocab)
|
self.matcher = Matcher(self.vocab)
|
||||||
self.phrase_ids = PreshMap()
|
self.phrase_ids = PreshMap()
|
||||||
for phrase in phrases:
|
|
||||||
if len(phrase) < max_length:
|
|
||||||
self.add(phrase)
|
|
||||||
|
|
||||||
abstract_patterns = []
|
abstract_patterns = []
|
||||||
for length in range(1, max_length):
|
for length in range(1, max_length):
|
||||||
abstract_patterns.append([{tag: True} for tag in get_bilou(length)])
|
abstract_patterns.append([{tag: True} for tag in get_bilou(length)])
|
||||||
self.matcher.add('Candidate', None, *abstract_patterns)
|
self.matcher.add('Candidate', None, *abstract_patterns)
|
||||||
|
self._callbacks = {}
|
||||||
|
|
||||||
def add(self, Doc tokens):
|
def add(self, key, on_match, *docs):
|
||||||
cdef int length = tokens.length
|
cdef Doc doc
|
||||||
assert length < self.max_length
|
for doc in docs:
|
||||||
tags = get_bilou(length)
|
if len(doc) >= self.max_length:
|
||||||
assert len(tags) == length, length
|
msg = (
|
||||||
|
"Pattern length (%d) >= phrase_matcher.max_length (%d). "
|
||||||
|
"Length can be set on initialization, up to 10."
|
||||||
|
)
|
||||||
|
raise ValueError(msg % (len(doc), self.max_length))
|
||||||
|
cdef hash_t ent_id = self.matcher._normalize_key(key)
|
||||||
|
self._callbacks[ent_id] = on_match
|
||||||
|
|
||||||
|
cdef int length
|
||||||
cdef int i
|
cdef int i
|
||||||
|
cdef hash_t phrase_hash
|
||||||
|
for doc in docs:
|
||||||
|
length = doc.length
|
||||||
|
tags = get_bilou(length)
|
||||||
for i in range(self.max_length):
|
for i in range(self.max_length):
|
||||||
self._phrase_key[i] = 0
|
self._phrase_key[i] = 0
|
||||||
for i, tag in enumerate(tags):
|
for i, tag in enumerate(tags):
|
||||||
lexeme = self.vocab[tokens.c[i].lex.orth]
|
lexeme = self.vocab[doc.c[i].lex.orth]
|
||||||
lexeme.set_flag(tag, True)
|
lexeme.set_flag(tag, True)
|
||||||
self._phrase_key[i] = lexeme.orth
|
self._phrase_key[i] = lexeme.orth
|
||||||
cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0)
|
phrase_hash = hash64(self._phrase_key,
|
||||||
self.phrase_ids[key] = True
|
self.max_length * sizeof(attr_t), 0)
|
||||||
|
self.phrase_ids[phrase_hash] = ent_id
|
||||||
|
|
||||||
def __call__(self, Doc doc):
|
def __call__(self, Doc doc):
|
||||||
matches = self.matcher(doc)
|
matches = []
|
||||||
accepted = []
|
for _, start, end in self.matcher(doc):
|
||||||
for ent_id, start, end in matches:
|
ent_id = self.accept_match(doc, start, end)
|
||||||
if self.accept_match(doc, ent_id, start, end):
|
if ent_id is not None:
|
||||||
accepted.append((ent_id, start, end))
|
matches.append((ent_id, start, end))
|
||||||
return accepted
|
for i, (ent_id, start, end) in enumerate(matches):
|
||||||
|
on_match = self._callbacks.get(ent_id)
|
||||||
|
if on_match is not None:
|
||||||
|
on_match(self, doc, i, matches)
|
||||||
|
return matches
|
||||||
|
|
||||||
def pipe(self, stream, batch_size=1000, n_threads=2):
|
def pipe(self, stream, batch_size=1000, n_threads=2):
|
||||||
for doc in stream:
|
for doc in stream:
|
||||||
self(doc)
|
self(doc)
|
||||||
yield doc
|
yield doc
|
||||||
|
|
||||||
def accept_match(self, Doc doc, attr_t ent_id, int start, int end):
|
def accept_match(self, Doc doc, int start, int end):
|
||||||
assert (end - start) < self.max_length
|
assert (end - start) < self.max_length
|
||||||
cdef int i, j
|
cdef int i, j
|
||||||
for i in range(self.max_length):
|
for i in range(self.max_length):
|
||||||
|
@ -474,7 +489,8 @@ cdef class PhraseMatcher:
|
||||||
for i, j in enumerate(range(start, end)):
|
for i, j in enumerate(range(start, end)):
|
||||||
self._phrase_key[i] = doc.c[j].lex.orth
|
self._phrase_key[i] = doc.c[j].lex.orth
|
||||||
cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0)
|
cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0)
|
||||||
if self.phrase_ids.get(key):
|
ent_id = <hash_t>self.phrase_ids.get(key)
|
||||||
return True
|
if ent_id == 0:
|
||||||
|
return None
|
||||||
else:
|
else:
|
||||||
return False
|
return ent_id
|
||||||
|
|
|
@ -101,7 +101,8 @@ def test_matcher_match_multi(matcher):
|
||||||
def test_matcher_phrase_matcher(en_vocab):
|
def test_matcher_phrase_matcher(en_vocab):
|
||||||
words = ["Google", "Now"]
|
words = ["Google", "Now"]
|
||||||
doc = get_doc(en_vocab, words)
|
doc = get_doc(en_vocab, words)
|
||||||
matcher = PhraseMatcher(en_vocab, [doc])
|
matcher = PhraseMatcher(en_vocab)
|
||||||
|
matcher.add('COMPANY', None, doc)
|
||||||
words = ["I", "like", "Google", "Now", "best"]
|
words = ["I", "like", "Google", "Now", "best"]
|
||||||
doc = get_doc(en_vocab, words)
|
doc = get_doc(en_vocab, words)
|
||||||
assert len(matcher(doc)) == 1
|
assert len(matcher(doc)) == 1
|
||||||
|
|
Loading…
Reference in New Issue
Block a user