Replace PhraseMatcher with trie-based search (#4309)

* Replace PhraseMatcher with Aho-Corasick

Replace PhraseMatcher with the Aho-Corasick algorithm over numpy arrays
of the hash values for the relevant attribute. The implementation is
based on FlashText.

The speed should be similar to the previous PhraseMatcher. It is now
possible to easily remove match IDs and matches don't go missing with
large keyword lists / vocabularies.

Fixes #4308.

* Restore support for pickling

* Fix internal keyword add/remove for numpy arrays

* Add missing loop for match ID set in search loop

* Remove cruft in matching loop for partial matches

There was a bit of unnecessary code left over from FlashText in the
matching loop to handle partial token matches, which we don't have with
PhraseMatcher.

* Replace dict trie with MapStruct trie

* Fix how match ID hash is stored/added

* Update fix for match ID vocab

* Switch from map_get_unless_missing to map_get

* Switch from numpy array to Token.get_struct_attr

Access token attributes directly in Doc instead of making a copy of the
relevant values in a numpy array.

Add unsatisfactory warning for hash collision with reserved terminal
hash key. (Ideally it would change the reserved terminal hash and redo
the whole trie, but for now, I'm hoping there won't be collisions.)

* Restructure imports to export find_matches

* Implement full remove()

Remove unnecessary trie paths and free unused maps.

Parallel to Matcher, raise KeyError when attempting to remove a match ID
that has not been added.

* Store docs internally only as attr lists

* Reduces size for pickle

* Remove duplicate keywords store

Now that docs are stored as lists of attr hashes, there's no need to
have the duplicate _keywords store.
This commit is contained in:
adrianeboyd 2019-09-27 16:22:34 +02:00 committed by Matthew Honnibal
parent d844030fd8
commit c23edf302b
4 changed files with 295 additions and 144 deletions

View File

@ -86,6 +86,8 @@ class Warnings(object):
"previously loaded vectors. See Issue #3853.")
W020 = ("Unnamed vectors. This won't allow multiple vectors models to be "
"loaded. (Shape: {shape})")
W021 = ("Unexpected hash collision in PhraseMatcher. Matches may be "
"incorrect. Modify PhraseMatcher._terminal_hash to fix.")
@add_codes

View File

@ -1,5 +1,27 @@
from libcpp.vector cimport vector
from ..typedefs cimport hash_t
from cymem.cymem cimport Pool
from preshed.maps cimport key_t, MapStruct
ctypedef vector[hash_t] hash_vec
from ..attrs cimport attr_id_t
from ..tokens.doc cimport Doc
from ..vocab cimport Vocab
cdef class PhraseMatcher:
cdef Vocab vocab
cdef attr_id_t attr
cdef object _callbacks
cdef object _docs
cdef bint _validate
cdef MapStruct* c_map
cdef Pool mem
cdef key_t _terminal_hash
cdef void find_matches(self, Doc doc, vector[MatchStruct] *matches) nogil
cdef struct MatchStruct:
key_t match_id
int start
int end

View File

@ -2,28 +2,16 @@
# cython: profile=True
from __future__ import unicode_literals
from libcpp.vector cimport vector
from cymem.cymem cimport Pool
from murmurhash.mrmr cimport hash64
from preshed.maps cimport PreshMap
from libc.stdint cimport uintptr_t
from .matcher cimport Matcher
from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA, attr_id_t
from ..vocab cimport Vocab
from ..tokens.doc cimport Doc, get_token_attr
from ..typedefs cimport attr_t, hash_t
from preshed.maps cimport map_init, map_set, map_get, map_clear, map_iter
from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA
from ..structs cimport TokenC
from ..tokens.token cimport Token
from ._schemas import TOKEN_PATTERN_SCHEMA
from ..errors import Errors, Warnings, deprecation_warning, user_warning
from ..attrs import FLAG61 as U_ENT
from ..attrs import FLAG60 as B2_ENT
from ..attrs import FLAG59 as B3_ENT
from ..attrs import FLAG58 as B4_ENT
from ..attrs import FLAG43 as L2_ENT
from ..attrs import FLAG42 as L3_ENT
from ..attrs import FLAG41 as L4_ENT
from ..attrs import FLAG42 as I3_ENT
from ..attrs import FLAG41 as I4_ENT
cdef class PhraseMatcher:
@ -33,18 +21,11 @@ cdef class PhraseMatcher:
DOCS: https://spacy.io/api/phrasematcher
USAGE: https://spacy.io/usage/rule-based-matching#phrasematcher
Adapted from FlashText: https://github.com/vi3k6i5/flashtext
MIT License (see `LICENSE`)
Copyright (c) 2017 Vikash Singh (vikash.duliajan@gmail.com)
"""
cdef Pool mem
cdef Vocab vocab
cdef Matcher matcher
cdef PreshMap phrase_ids
cdef vector[hash_vec] ent_id_matrix
cdef int max_length
cdef attr_id_t attr
cdef public object _callbacks
cdef public object _patterns
cdef public object _docs
cdef public object _validate
def __init__(self, Vocab vocab, max_length=0, attr="ORTH", validate=False):
"""Initialize the PhraseMatcher.
@ -58,10 +39,16 @@ cdef class PhraseMatcher:
"""
if max_length != 0:
deprecation_warning(Warnings.W010)
self.mem = Pool()
self.max_length = max_length
self.vocab = vocab
self.matcher = Matcher(self.vocab, validate=False)
self._callbacks = {}
self._docs = {}
self._validate = validate
self.mem = Pool()
self.c_map = <MapStruct*>self.mem.alloc(1, sizeof(MapStruct))
self._terminal_hash = 826361138722620965
map_init(self.mem, self.c_map, 8)
if isinstance(attr, long):
self.attr = attr
else:
@ -71,28 +58,15 @@ cdef class PhraseMatcher:
if attr not in TOKEN_PATTERN_SCHEMA["items"]["properties"]:
raise ValueError(Errors.E152.format(attr=attr))
self.attr = self.vocab.strings[attr]
self.phrase_ids = PreshMap()
abstract_patterns = [
[{U_ENT: True}],
[{B2_ENT: True}, {L2_ENT: True}],
[{B3_ENT: True}, {I3_ENT: True}, {L3_ENT: True}],
[{B4_ENT: True}, {I4_ENT: True}, {I4_ENT: True, "OP": "+"}, {L4_ENT: True}],
]
self.matcher.add("Candidate", None, *abstract_patterns)
self._callbacks = {}
self._docs = {}
self._validate = validate
def __len__(self):
"""Get the number of rules added to the matcher. Note that this only
returns the number of rules (identical with the number of IDs), not the
number of individual patterns.
"""Get the number of match IDs added to the matcher.
RETURNS (int): The number of rules.
DOCS: https://spacy.io/api/phrasematcher#len
"""
return len(self._docs)
return len(self._callbacks)
def __contains__(self, key):
"""Check whether the matcher contains rules for a match ID.
@ -102,13 +76,77 @@ cdef class PhraseMatcher:
DOCS: https://spacy.io/api/phrasematcher#contains
"""
cdef hash_t ent_id = self.matcher._normalize_key(key)
return ent_id in self._callbacks
return key in self._callbacks
def __reduce__(self):
data = (self.vocab, self._docs, self._callbacks)
return (unpickle_matcher, data, None, None)
def remove(self, key):
"""Remove a rule from the matcher by match ID. A KeyError is raised if
the key does not exist.
key (unicode): The match ID.
"""
if key not in self._docs:
raise KeyError(key)
cdef MapStruct* current_node
cdef MapStruct* terminal_map
cdef MapStruct* node_pointer
cdef void* result
cdef key_t terminal_key
cdef void* value
cdef int c_i = 0
cdef vector[MapStruct*] path_nodes
cdef vector[key_t] path_keys
cdef key_t key_to_remove
for keyword in self._docs[key]:
current_node = self.c_map
for token in keyword:
result = map_get(current_node, token)
if result:
path_nodes.push_back(current_node)
path_keys.push_back(token)
current_node = <MapStruct*>result
else:
# if token is not found, break out of the loop
current_node = NULL
break
# remove the tokens from trie node if there are no other
# keywords with them
result = map_get(current_node, self._terminal_hash)
if current_node != NULL and result:
terminal_map = <MapStruct*>result
terminal_keys = []
c_i = 0
while map_iter(terminal_map, &c_i, &terminal_key, &value):
terminal_keys.append(self.vocab.strings[terminal_key])
# if this is the only remaining key, remove unnecessary paths
if terminal_keys == [key]:
while not path_nodes.empty():
node_pointer = path_nodes.back()
path_nodes.pop_back()
key_to_remove = path_keys.back()
path_keys.pop_back()
result = map_get(node_pointer, key_to_remove)
if node_pointer.filled == 1:
map_clear(node_pointer, key_to_remove)
self.mem.free(result)
else:
# more than one key means more than 1 path,
# delete not required path and keep the others
map_clear(node_pointer, key_to_remove)
self.mem.free(result)
break
# otherwise simply remove the key
else:
result = map_get(current_node, self._terminal_hash)
if result:
map_clear(<MapStruct*>result, self.vocab.strings[key])
del self._callbacks[key]
del self._docs[key]
def add(self, key, on_match, *docs):
"""Add a match-rule to the phrase-matcher. A match-rule consists of: an ID
key, an on_match callback, and one or more patterns.
@ -119,17 +157,17 @@ cdef class PhraseMatcher:
DOCS: https://spacy.io/api/phrasematcher#add
"""
cdef Doc doc
cdef hash_t ent_id = self.matcher._normalize_key(key)
self._callbacks[ent_id] = on_match
self._docs[ent_id] = docs
cdef int length
cdef int i
cdef hash_t phrase_hash
cdef Pool mem = Pool()
_ = self.vocab[key]
self._callbacks[key] = on_match
self._docs.setdefault(key, set())
cdef MapStruct* current_node
cdef MapStruct* internal_node
cdef void* result
for doc in docs:
length = doc.length
if length == 0:
if len(doc) == 0:
continue
if self.attr in (POS, TAG, LEMMA) and not doc.is_tagged:
raise ValueError(Errors.E155.format())
@ -139,33 +177,33 @@ cdef class PhraseMatcher:
and self.attr not in (DEP, POS, TAG, LEMMA):
string_attr = self.vocab.strings[self.attr]
user_warning(Warnings.W012.format(key=key, attr=string_attr))
tags = get_biluo(length)
phrase_key = <attr_t*>mem.alloc(length, sizeof(attr_t))
for i, tag in enumerate(tags):
attr_value = self.get_lex_value(doc, i)
lexeme = self.vocab[attr_value]
lexeme.set_flag(tag, True)
phrase_key[i] = lexeme.orth
phrase_hash = hash64(phrase_key, length * sizeof(attr_t), 0)
if phrase_hash in self.phrase_ids:
phrase_index = self.phrase_ids[phrase_hash]
ent_id_list = self.ent_id_matrix[phrase_index]
ent_id_list.append(ent_id)
self.ent_id_matrix[phrase_index] = ent_id_list
if isinstance(doc, Doc):
keyword = self._convert_to_array(doc)
else:
ent_id_list = hash_vec(1)
ent_id_list[0] = ent_id
new_index = self.ent_id_matrix.size()
if new_index == 0:
# PreshMaps can not contain 0 as value, so storing a dummy at 0
self.ent_id_matrix.push_back(hash_vec(0))
new_index = 1
self.ent_id_matrix.push_back(ent_id_list)
self.phrase_ids.set(phrase_hash, <void*>new_index)
keyword = doc
self._docs[key].add(tuple(keyword))
def __call__(self, Doc doc):
current_node = self.c_map
for token in keyword:
if token == self._terminal_hash:
user_warning(Warnings.W021)
break
result = <MapStruct*>map_get(current_node, token)
if not result:
internal_node = <MapStruct*>self.mem.alloc(1, sizeof(MapStruct))
map_init(self.mem, internal_node, 8)
map_set(self.mem, current_node, token, internal_node)
result = internal_node
current_node = <MapStruct*>result
result = <MapStruct*>map_get(current_node, self._terminal_hash)
if not result:
internal_node = <MapStruct*>self.mem.alloc(1, sizeof(MapStruct))
map_init(self.mem, internal_node, 8)
map_set(self.mem, current_node, self._terminal_hash, internal_node)
result = internal_node
map_set(self.mem, <MapStruct*>result, self.vocab.strings[key], NULL)
def __call__(self, doc):
"""Find all sequences matching the supplied patterns on the `Doc`.
doc (Doc): The document to match over.
@ -176,25 +214,63 @@ cdef class PhraseMatcher:
DOCS: https://spacy.io/api/phrasematcher#call
"""
matches = []
if self.attr == ORTH:
match_doc = doc
else:
# If we're not matching on the ORTH, match_doc will be a Doc whose
# token.orth values are the attribute values we're matching on,
# e.g. Doc(nlp.vocab, words=[token.pos_ for token in doc])
words = [self.get_lex_value(doc, i) for i in range(len(doc))]
match_doc = Doc(self.vocab, words=words)
for _, start, end in self.matcher(match_doc):
ent_ids = self.accept_match(match_doc, start, end)
if ent_ids is not None:
for ent_id in ent_ids:
matches.append((ent_id, start, end))
if doc is None or len(doc) == 0:
# if doc is empty or None just return empty list
return matches
cdef vector[MatchStruct] c_matches
self.find_matches(doc, &c_matches)
for i in range(c_matches.size()):
matches.append((c_matches[i].match_id, c_matches[i].start, c_matches[i].end))
for i, (ent_id, start, end) in enumerate(matches):
on_match = self._callbacks.get(ent_id)
if on_match is not None:
on_match(self, doc, i, matches)
return matches
cdef void find_matches(self, Doc doc, vector[MatchStruct] *matches) nogil:
cdef MapStruct* current_node = self.c_map
cdef int start = 0
cdef int idx = 0
cdef int idy = 0
cdef key_t key
cdef void* value
cdef int i = 0
cdef MatchStruct ms
cdef void* result
while idx < doc.length:
start = idx
token = Token.get_struct_attr(&doc.c[idx], self.attr)
# look for sequences from this position
result = map_get(current_node, token)
if result:
current_node = <MapStruct*>result
idy = idx + 1
while idy < doc.length:
result = map_get(current_node, self._terminal_hash)
if result:
i = 0
while map_iter(<MapStruct*>result, &i, &key, &value):
ms = make_matchstruct(key, start, idy)
matches.push_back(ms)
inner_token = Token.get_struct_attr(&doc.c[idy], self.attr)
result = map_get(current_node, inner_token)
if result:
current_node = <MapStruct*>result
idy += 1
else:
break
else:
# end of doc reached
result = map_get(current_node, self._terminal_hash)
if result:
i = 0
while map_iter(<MapStruct*>result, &i, &key, &value):
ms = make_matchstruct(key, start, idy)
matches.push_back(ms)
current_node = self.c_map
idx += 1
def pipe(self, stream, batch_size=1000, n_threads=-1, return_matches=False,
as_tuples=False):
"""Match a stream of documents, yielding them in turn.
@ -228,48 +304,8 @@ cdef class PhraseMatcher:
else:
yield doc
def accept_match(self, Doc doc, int start, int end):
cdef int i, j
cdef Pool mem = Pool()
phrase_key = <attr_t*>mem.alloc(end-start, sizeof(attr_t))
for i, j in enumerate(range(start, end)):
phrase_key[i] = doc.c[j].lex.orth
cdef hash_t key = hash64(phrase_key, (end-start) * sizeof(attr_t), 0)
ent_index = <hash_t>self.phrase_ids.get(key)
if ent_index == 0:
return None
return self.ent_id_matrix[ent_index]
def get_lex_value(self, Doc doc, int i):
if self.attr == ORTH:
# Return the regular orth value of the lexeme
return doc.c[i].lex.orth
# Get the attribute value instead, e.g. token.pos
attr_value = get_token_attr(&doc.c[i], self.attr)
if attr_value in (0, 1):
# Value is boolean, convert to string
string_attr_value = str(attr_value)
else:
string_attr_value = self.vocab.strings[attr_value]
string_attr_name = self.vocab.strings[self.attr]
# Concatenate the attr name and value to not pollute lexeme space
# e.g. 'POS-VERB' instead of just 'VERB', which could otherwise
# create false positive matches
return "matcher:{}-{}".format(string_attr_name, string_attr_value)
def get_biluo(length):
if length == 0:
raise ValueError(Errors.E127)
elif length == 1:
return [U_ENT]
elif length == 2:
return [B2_ENT, L2_ENT]
elif length == 3:
return [B3_ENT, I3_ENT, L3_ENT]
else:
return [B4_ENT, I4_ENT] + [I4_ENT] * (length-3) + [L4_ENT]
def _convert_to_array(self, Doc doc):
return [Token.get_struct_attr(&doc.c[i], self.attr) for i in range(len(doc))]
def unpickle_matcher(vocab, docs, callbacks):
@ -278,3 +314,11 @@ def unpickle_matcher(vocab, docs, callbacks):
callback = callbacks.get(key, None)
matcher.add(key, callback, *specs)
return matcher
cdef MatchStruct make_matchstruct(key_t match_id, int start, int end) nogil:
cdef MatchStruct ms
ms.match_id = match_id
ms.start = start
ms.end = end
return ms

View File

@ -8,10 +8,31 @@ from ..util import get_doc
def test_matcher_phrase_matcher(en_vocab):
doc = Doc(en_vocab, words=["Google", "Now"])
matcher = PhraseMatcher(en_vocab)
matcher.add("COMPANY", None, doc)
doc = Doc(en_vocab, words=["I", "like", "Google", "Now", "best"])
# intermediate phrase
pattern = Doc(en_vocab, words=["Google", "Now"])
matcher = PhraseMatcher(en_vocab)
matcher.add("COMPANY", None, pattern)
assert len(matcher(doc)) == 1
# initial token
pattern = Doc(en_vocab, words=["I"])
matcher = PhraseMatcher(en_vocab)
matcher.add("I", None, pattern)
assert len(matcher(doc)) == 1
# initial phrase
pattern = Doc(en_vocab, words=["I", "like"])
matcher = PhraseMatcher(en_vocab)
matcher.add("ILIKE", None, pattern)
assert len(matcher(doc)) == 1
# final token
pattern = Doc(en_vocab, words=["best"])
matcher = PhraseMatcher(en_vocab)
matcher.add("BEST", None, pattern)
assert len(matcher(doc)) == 1
# final phrase
pattern = Doc(en_vocab, words=["Now", "best"])
matcher = PhraseMatcher(en_vocab)
matcher.add("NOWBEST", None, pattern)
assert len(matcher(doc)) == 1
@ -31,6 +52,68 @@ def test_phrase_matcher_contains(en_vocab):
assert "TEST2" not in matcher
def test_phrase_matcher_repeated_add(en_vocab):
matcher = PhraseMatcher(en_vocab)
# match ID only gets added once
matcher.add("TEST", None, Doc(en_vocab, words=["like"]))
matcher.add("TEST", None, Doc(en_vocab, words=["like"]))
matcher.add("TEST", None, Doc(en_vocab, words=["like"]))
matcher.add("TEST", None, Doc(en_vocab, words=["like"]))
doc = Doc(en_vocab, words=["I", "like", "Google", "Now", "best"])
assert "TEST" in matcher
assert "TEST2" not in matcher
assert len(matcher(doc)) == 1
def test_phrase_matcher_remove(en_vocab):
matcher = PhraseMatcher(en_vocab)
matcher.add("TEST1", None, Doc(en_vocab, words=["like"]))
matcher.add("TEST2", None, Doc(en_vocab, words=["best"]))
doc = Doc(en_vocab, words=["I", "like", "Google", "Now", "best"])
assert "TEST1" in matcher
assert "TEST2" in matcher
assert "TEST3" not in matcher
assert len(matcher(doc)) == 2
matcher.remove("TEST1")
assert "TEST1" not in matcher
assert "TEST2" in matcher
assert "TEST3" not in matcher
assert len(matcher(doc)) == 1
matcher.remove("TEST2")
assert "TEST1" not in matcher
assert "TEST2" not in matcher
assert "TEST3" not in matcher
assert len(matcher(doc)) == 0
with pytest.raises(KeyError):
matcher.remove("TEST3")
assert "TEST1" not in matcher
assert "TEST2" not in matcher
assert "TEST3" not in matcher
assert len(matcher(doc)) == 0
def test_phrase_matcher_overlapping_with_remove(en_vocab):
matcher = PhraseMatcher(en_vocab)
matcher.add("TEST", None, Doc(en_vocab, words=["like"]))
# TEST2 is added alongside TEST
matcher.add("TEST2", None, Doc(en_vocab, words=["like"]))
doc = Doc(en_vocab, words=["I", "like", "Google", "Now", "best"])
assert "TEST" in matcher
assert len(matcher) == 2
assert len(matcher(doc)) == 2
# removing TEST does not remove the entry for TEST2
matcher.remove("TEST")
assert "TEST" not in matcher
assert len(matcher) == 1
assert len(matcher(doc)) == 1
assert matcher(doc)[0][0] == en_vocab.strings["TEST2"]
# removing TEST2 removes all
matcher.remove("TEST2")
assert "TEST2" not in matcher
assert len(matcher) == 0
assert len(matcher(doc)) == 0
def test_phrase_matcher_string_attrs(en_vocab):
words1 = ["I", "like", "cats"]
pos1 = ["PRON", "VERB", "NOUN"]