Make PhraseMatcher API like Matcher API

This commit is contained in:
Matthew Honnibal 2017-09-20 22:20:35 +02:00
parent 43ad250dd5
commit cc408fc189
2 changed files with 46 additions and 29 deletions

View File

@ -421,52 +421,67 @@ cdef class PhraseMatcher:
cdef int max_length cdef int max_length
cdef attr_t* _phrase_key cdef attr_t* _phrase_key
def __init__(self, Vocab vocab, phrases, max_length=10): cdef public object _callbacks
def __init__(self, Vocab vocab, max_length=10):
self.mem = Pool() self.mem = Pool()
self._phrase_key = <attr_t*>self.mem.alloc(max_length, sizeof(attr_t)) self._phrase_key = <attr_t*>self.mem.alloc(max_length, sizeof(attr_t))
self.max_length = max_length self.max_length = max_length
self.vocab = vocab self.vocab = vocab
self.matcher = Matcher(self.vocab) self.matcher = Matcher(self.vocab)
self.phrase_ids = PreshMap() self.phrase_ids = PreshMap()
for phrase in phrases:
if len(phrase) < max_length:
self.add(phrase)
abstract_patterns = [] abstract_patterns = []
for length in range(1, max_length): for length in range(1, max_length):
abstract_patterns.append([{tag: True} for tag in get_bilou(length)]) abstract_patterns.append([{tag: True} for tag in get_bilou(length)])
self.matcher.add('Candidate', None, *abstract_patterns) self.matcher.add('Candidate', None, *abstract_patterns)
self._callbacks = {}
def add(self, Doc tokens): def add(self, key, on_match, *docs):
cdef int length = tokens.length cdef Doc doc
assert length < self.max_length for doc in docs:
tags = get_bilou(length) if len(doc) >= self.max_length:
assert len(tags) == length, length msg = (
"Pattern length (%d) >= phrase_matcher.max_length (%d). "
"Length can be set on initialization, up to 10."
)
raise ValueError(msg % (len(doc), self.max_length))
cdef hash_t ent_id = self.matcher._normalize_key(key)
self._callbacks[ent_id] = on_match
cdef int length
cdef int i cdef int i
cdef hash_t phrase_hash
for doc in docs:
length = doc.length
tags = get_bilou(length)
for i in range(self.max_length): for i in range(self.max_length):
self._phrase_key[i] = 0 self._phrase_key[i] = 0
for i, tag in enumerate(tags): for i, tag in enumerate(tags):
lexeme = self.vocab[tokens.c[i].lex.orth] lexeme = self.vocab[doc.c[i].lex.orth]
lexeme.set_flag(tag, True) lexeme.set_flag(tag, True)
self._phrase_key[i] = lexeme.orth self._phrase_key[i] = lexeme.orth
cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0) phrase_hash = hash64(self._phrase_key,
self.phrase_ids[key] = True self.max_length * sizeof(attr_t), 0)
self.phrase_ids[phrase_hash] = ent_id
def __call__(self, Doc doc): def __call__(self, Doc doc):
matches = self.matcher(doc) matches = []
accepted = [] for _, start, end in self.matcher(doc):
for ent_id, start, end in matches: ent_id = self.accept_match(doc, start, end)
if self.accept_match(doc, ent_id, start, end): if ent_id is not None:
accepted.append((ent_id, start, end)) matches.append((ent_id, start, end))
return accepted for i, (ent_id, start, end) in enumerate(matches):
on_match = self._callbacks.get(ent_id)
if on_match is not None:
on_match(self, doc, i, matches)
return matches
def pipe(self, stream, batch_size=1000, n_threads=2): def pipe(self, stream, batch_size=1000, n_threads=2):
for doc in stream: for doc in stream:
self(doc) self(doc)
yield doc yield doc
def accept_match(self, Doc doc, attr_t ent_id, int start, int end): def accept_match(self, Doc doc, int start, int end):
assert (end - start) < self.max_length assert (end - start) < self.max_length
cdef int i, j cdef int i, j
for i in range(self.max_length): for i in range(self.max_length):
@ -474,7 +489,8 @@ cdef class PhraseMatcher:
for i, j in enumerate(range(start, end)): for i, j in enumerate(range(start, end)):
self._phrase_key[i] = doc.c[j].lex.orth self._phrase_key[i] = doc.c[j].lex.orth
cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0) cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0)
if self.phrase_ids.get(key): ent_id = <hash_t>self.phrase_ids.get(key)
return True if ent_id == 0:
return None
else: else:
return False return ent_id

View File

@ -101,7 +101,8 @@ def test_matcher_match_multi(matcher):
def test_matcher_phrase_matcher(en_vocab): def test_matcher_phrase_matcher(en_vocab):
words = ["Google", "Now"] words = ["Google", "Now"]
doc = get_doc(en_vocab, words) doc = get_doc(en_vocab, words)
matcher = PhraseMatcher(en_vocab, [doc]) matcher = PhraseMatcher(en_vocab)
matcher.add('COMPANY', None, doc)
words = ["I", "like", "Google", "Now", "best"] words = ["I", "like", "Google", "Now", "best"]
doc = get_doc(en_vocab, words) doc = get_doc(en_vocab, words)
assert len(matcher(doc)) == 1 assert len(matcher(doc)) == 1