Replace PhraseMatcher with Aho-Corasick

Replace PhraseMatcher with the Aho-Corasick algorithm over numpy arrays
of the hash values for the relevant attribute. The implementation is
based on FlashText.

The speed should be similar to the previous PhraseMatcher. It is now
possible to easily remove match IDs and matches don't go missing with
large keyword lists / vocabularies.

Fixes #4308.
This commit is contained in:
Adriane Boyd 2019-09-19 16:36:12 +02:00
parent 31c683d87d
commit 0d9740e826
3 changed files with 189 additions and 135 deletions

View File

@ -1,5 +0,0 @@
from libcpp.vector cimport vector
from ..typedefs cimport hash_t
ctypedef vector[hash_t] hash_vec

View File

@ -2,28 +2,14 @@
# cython: profile=True # cython: profile=True
from __future__ import unicode_literals from __future__ import unicode_literals
from libcpp.vector cimport vector import numpy as np
from cymem.cymem cimport Pool
from murmurhash.mrmr cimport hash64
from preshed.maps cimport PreshMap
from .matcher cimport Matcher
from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA, attr_id_t from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA, attr_id_t
from ..vocab cimport Vocab from ..vocab cimport Vocab
from ..tokens.doc cimport Doc, get_token_attr from ..tokens.doc cimport Doc, get_token_attr
from ..typedefs cimport attr_t, hash_t
from ._schemas import TOKEN_PATTERN_SCHEMA from ._schemas import TOKEN_PATTERN_SCHEMA
from ..errors import Errors, Warnings, deprecation_warning, user_warning from ..errors import Errors, Warnings, deprecation_warning, user_warning
from ..attrs import FLAG61 as U_ENT
from ..attrs import FLAG60 as B2_ENT
from ..attrs import FLAG59 as B3_ENT
from ..attrs import FLAG58 as B4_ENT
from ..attrs import FLAG43 as L2_ENT
from ..attrs import FLAG42 as L3_ENT
from ..attrs import FLAG41 as L4_ENT
from ..attrs import FLAG42 as I3_ENT
from ..attrs import FLAG41 as I4_ENT
cdef class PhraseMatcher: cdef class PhraseMatcher:
@ -33,18 +19,18 @@ cdef class PhraseMatcher:
DOCS: https://spacy.io/api/phrasematcher DOCS: https://spacy.io/api/phrasematcher
USAGE: https://spacy.io/usage/rule-based-matching#phrasematcher USAGE: https://spacy.io/usage/rule-based-matching#phrasematcher
Adapted from FlashText: https://github.com/vi3k6i5/flashtext
MIT License (see `LICENSE`)
Copyright (c) 2017 Vikash Singh (vikash.duliajan@gmail.com)
""" """
cdef Pool mem
cdef Vocab vocab cdef Vocab vocab
cdef Matcher matcher cdef unicode _terminal
cdef PreshMap phrase_ids cdef object keyword_trie_dict
cdef vector[hash_vec] ent_id_matrix
cdef int max_length
cdef attr_id_t attr cdef attr_id_t attr
cdef public object _callbacks cdef object _callbacks
cdef public object _patterns cdef object _keywords
cdef public object _docs cdef bint _validate
cdef public object _validate
def __init__(self, Vocab vocab, max_length=0, attr="ORTH", validate=False): def __init__(self, Vocab vocab, max_length=0, attr="ORTH", validate=False):
"""Initialize the PhraseMatcher. """Initialize the PhraseMatcher.
@ -58,10 +44,13 @@ cdef class PhraseMatcher:
""" """
if max_length != 0: if max_length != 0:
deprecation_warning(Warnings.W010) deprecation_warning(Warnings.W010)
self.mem = Pool()
self.max_length = max_length
self.vocab = vocab self.vocab = vocab
self.matcher = Matcher(self.vocab, validate=False) self._terminal = '_terminal_'
self.keyword_trie_dict = dict()
self._callbacks = {}
self._keywords = {}
self._validate = validate
if isinstance(attr, long): if isinstance(attr, long):
self.attr = attr self.attr = attr
else: else:
@ -71,28 +60,15 @@ cdef class PhraseMatcher:
if attr not in TOKEN_PATTERN_SCHEMA["items"]["properties"]: if attr not in TOKEN_PATTERN_SCHEMA["items"]["properties"]:
raise ValueError(Errors.E152.format(attr=attr)) raise ValueError(Errors.E152.format(attr=attr))
self.attr = self.vocab.strings[attr] self.attr = self.vocab.strings[attr]
self.phrase_ids = PreshMap()
abstract_patterns = [
[{U_ENT: True}],
[{B2_ENT: True}, {L2_ENT: True}],
[{B3_ENT: True}, {I3_ENT: True}, {L3_ENT: True}],
[{B4_ENT: True}, {I4_ENT: True}, {I4_ENT: True, "OP": "+"}, {L4_ENT: True}],
]
self.matcher.add("Candidate", None, *abstract_patterns)
self._callbacks = {}
self._docs = {}
self._validate = validate
def __len__(self): def __len__(self):
"""Get the number of rules added to the matcher. Note that this only """Get the number of match IDs added to the matcher.
returns the number of rules (identical with the number of IDs), not the
number of individual patterns.
RETURNS (int): The number of rules. RETURNS (int): The number of rules.
DOCS: https://spacy.io/api/phrasematcher#len DOCS: https://spacy.io/api/phrasematcher#len
""" """
return len(self._docs) return len(self._callbacks)
def __contains__(self, key): def __contains__(self, key):
"""Check whether the matcher contains rules for a match ID. """Check whether the matcher contains rules for a match ID.
@ -102,12 +78,48 @@ cdef class PhraseMatcher:
DOCS: https://spacy.io/api/phrasematcher#contains DOCS: https://spacy.io/api/phrasematcher#contains
""" """
cdef hash_t ent_id = self.matcher._normalize_key(key) return key in self._callbacks
return ent_id in self._callbacks
def __reduce__(self): def remove(self, key):
data = (self.vocab, self._docs, self._callbacks) """Remove a match-rule from the matcher by match ID.
return (unpickle_matcher, data, None, None)
key (unicode): The match ID.
"""
if key not in self._keywords:
return
for keyword in self._keywords[key]:
current_dict = self.keyword_trie_dict
token_trie_list = []
for tokens in keyword:
if tokens in current_dict:
token_trie_list.append((tokens, current_dict))
current_dict = current_dict[tokens]
else:
# if token is not found, break out of the loop
current_dict = None
break
# remove the tokens from trie dict if there are no other
# keywords with them
if current_dict and self._terminal in current_dict:
# if this is the only remaining key, remove unnecessary paths
if current_dict[self._terminal] == [key]:
# we found a complete match for input keyword
token_trie_list.append((self._terminal, current_dict))
token_trie_list.reverse()
for key_to_remove, dict_pointer in token_trie_list:
if len(dict_pointer.keys()) == 1:
dict_pointer.pop(key_to_remove)
else:
# more than one key means more than 1 path,
# delete not required path and keep the other
dict_pointer.pop(key_to_remove)
break
# otherwise simply remove the key
else:
current_dict[self._terminal].remove(key)
del self._keywords[key]
del self._callbacks[key]
def add(self, key, on_match, *docs): def add(self, key, on_match, *docs):
"""Add a match-rule to the phrase-matcher. A match-rule consists of: an ID """Add a match-rule to the phrase-matcher. A match-rule consists of: an ID
@ -119,17 +131,13 @@ cdef class PhraseMatcher:
DOCS: https://spacy.io/api/phrasematcher#add DOCS: https://spacy.io/api/phrasematcher#add
""" """
cdef Doc doc
cdef hash_t ent_id = self.matcher._normalize_key(key) _ = self.vocab[key]
self._callbacks[ent_id] = on_match self._callbacks[key] = on_match
self._docs[ent_id] = docs self._keywords.setdefault(key, [])
cdef int length
cdef int i
cdef hash_t phrase_hash
cdef Pool mem = Pool()
for doc in docs: for doc in docs:
length = doc.length if len(doc) == 0:
if length == 0:
continue continue
if self.attr in (POS, TAG, LEMMA) and not doc.is_tagged: if self.attr in (POS, TAG, LEMMA) and not doc.is_tagged:
raise ValueError(Errors.E155.format()) raise ValueError(Errors.E155.format())
@ -139,33 +147,18 @@ cdef class PhraseMatcher:
and self.attr not in (DEP, POS, TAG, LEMMA): and self.attr not in (DEP, POS, TAG, LEMMA):
string_attr = self.vocab.strings[self.attr] string_attr = self.vocab.strings[self.attr]
user_warning(Warnings.W012.format(key=key, attr=string_attr)) user_warning(Warnings.W012.format(key=key, attr=string_attr))
tags = get_biluo(length) keyword = self._convert_to_array(doc)
phrase_key = <attr_t*>mem.alloc(length, sizeof(attr_t)) # keep track of keywords per key to make remove easier
for i, tag in enumerate(tags): # (would use a set, but can't hash numpy arrays)
attr_value = self.get_lex_value(doc, i) if keyword not in self._keywords[key]:
lexeme = self.vocab[attr_value] self._keywords[key].append(keyword)
lexeme.set_flag(tag, True) current_dict = self.keyword_trie_dict
phrase_key[i] = lexeme.orth for token in keyword:
phrase_hash = hash64(phrase_key, length * sizeof(attr_t), 0) current_dict = current_dict.setdefault(token, {})
current_dict.setdefault(self._terminal, set())
current_dict[self._terminal].add(key)
if phrase_hash in self.phrase_ids: def __call__(self, doc):
phrase_index = self.phrase_ids[phrase_hash]
ent_id_list = self.ent_id_matrix[phrase_index]
ent_id_list.append(ent_id)
self.ent_id_matrix[phrase_index] = ent_id_list
else:
ent_id_list = hash_vec(1)
ent_id_list[0] = ent_id
new_index = self.ent_id_matrix.size()
if new_index == 0:
# PreshMaps can not contain 0 as value, so storing a dummy at 0
self.ent_id_matrix.push_back(hash_vec(0))
new_index = 1
self.ent_id_matrix.push_back(ent_id_list)
self.phrase_ids.set(phrase_hash, <void*>new_index)
def __call__(self, Doc doc):
"""Find all sequences matching the supplied patterns on the `Doc`. """Find all sequences matching the supplied patterns on the `Doc`.
doc (Doc): The document to match over. doc (Doc): The document to match over.
@ -175,20 +168,62 @@ cdef class PhraseMatcher:
DOCS: https://spacy.io/api/phrasematcher#call DOCS: https://spacy.io/api/phrasematcher#call
""" """
doc_array = self._convert_to_array(doc)
matches = [] matches = []
if self.attr == ORTH: if doc_array is None or len(doc_array) == 0:
match_doc = doc # if doc_array is empty or None just return empty list
else: return matches
# If we're not matching on the ORTH, match_doc will be a Doc whose current_dict = self.keyword_trie_dict
# token.orth values are the attribute values we're matching on, start = 0
# e.g. Doc(nlp.vocab, words=[token.pos_ for token in doc]) reset_current_dict = False
words = [self.get_lex_value(doc, i) for i in range(len(doc))] idx = 0
match_doc = Doc(self.vocab, words=words) doc_array_len = len(doc_array)
for _, start, end in self.matcher(match_doc): while idx < doc_array_len:
ent_ids = self.accept_match(match_doc, start, end) token = doc_array[idx]
if ent_ids is not None: # if end is present in current_dict
for ent_id in ent_ids: if self._terminal in current_dict or token in current_dict:
matches.append((ent_id, start, end)) if self._terminal in current_dict:
ent_id = current_dict[self._terminal]
matches.append((self.vocab.strings[ent_id], start, idx))
# look for longer sequences from this position
if token in current_dict:
current_dict_continued = current_dict[token]
idy = idx + 1
while idy < doc_array_len:
inner_token = doc_array[idy]
if self._terminal in current_dict_continued:
ent_ids = current_dict_continued[self._terminal]
for ent_id in ent_ids:
matches.append((self.vocab.strings[ent_id], start, idy))
if inner_token in current_dict_continued:
current_dict_continued = current_dict_continued[inner_token]
else:
break
idy += 1
else:
# end of doc_array reached
if self._terminal in current_dict_continued:
ent_ids = current_dict_continued[self._terminal]
for ent_id in ent_ids:
matches.append((self.vocab.strings[ent_id], start, idy))
current_dict = self.keyword_trie_dict
reset_current_dict = True
else:
# we reset current_dict
current_dict = self.keyword_trie_dict
reset_current_dict = True
# if we are end of doc_array and have a sequence discovered
if idx + 1 >= doc_array_len:
if self._terminal in current_dict:
ent_ids = current_dict[self._terminal]
for ent_id in ent_ids:
matches.append((self.vocab.strings[ent_id], start, doc_array_len))
idx += 1
if reset_current_dict:
reset_current_dict = False
start = idx
for i, (ent_id, start, end) in enumerate(matches): for i, (ent_id, start, end) in enumerate(matches):
on_match = self._callbacks.get(ent_id) on_match = self._callbacks.get(ent_id)
if on_match is not None: if on_match is not None:
@ -228,19 +263,6 @@ cdef class PhraseMatcher:
else: else:
yield doc yield doc
def accept_match(self, Doc doc, int start, int end):
cdef int i, j
cdef Pool mem = Pool()
phrase_key = <attr_t*>mem.alloc(end-start, sizeof(attr_t))
for i, j in enumerate(range(start, end)):
phrase_key[i] = doc.c[j].lex.orth
cdef hash_t key = hash64(phrase_key, (end-start) * sizeof(attr_t), 0)
ent_index = <hash_t>self.phrase_ids.get(key)
if ent_index == 0:
return None
return self.ent_id_matrix[ent_index]
def get_lex_value(self, Doc doc, int i): def get_lex_value(self, Doc doc, int i):
if self.attr == ORTH: if self.attr == ORTH:
# Return the regular orth value of the lexeme # Return the regular orth value of the lexeme
@ -256,25 +278,10 @@ cdef class PhraseMatcher:
# Concatenate the attr name and value to not pollute lexeme space # Concatenate the attr name and value to not pollute lexeme space
# e.g. 'POS-VERB' instead of just 'VERB', which could otherwise # e.g. 'POS-VERB' instead of just 'VERB', which could otherwise
# create false positive matches # create false positive matches
return "matcher:{}-{}".format(string_attr_name, string_attr_value) matcher_attr_string = "matcher:{}-{}".format(string_attr_name, string_attr_value)
# Add new string to vocab
_ = self.vocab[matcher_attr_string]
return self.vocab.strings[matcher_attr_string]
def _convert_to_array(self, Doc doc):
def get_biluo(length): return np.array([self.get_lex_value(doc, i) for i in range(len(doc))], dtype=np.uint64)
if length == 0:
raise ValueError(Errors.E127)
elif length == 1:
return [U_ENT]
elif length == 2:
return [B2_ENT, L2_ENT]
elif length == 3:
return [B3_ENT, I3_ENT, L3_ENT]
else:
return [B4_ENT, I4_ENT] + [I4_ENT] * (length-3) + [L4_ENT]
def unpickle_matcher(vocab, docs, callbacks):
matcher = PhraseMatcher(vocab)
for key, specs in docs.items():
callback = callbacks.get(key, None)
matcher.add(key, callback, *specs)
return matcher

View File

@ -31,6 +31,58 @@ def test_phrase_matcher_contains(en_vocab):
assert "TEST2" not in matcher assert "TEST2" not in matcher
def test_phrase_matcher_repeated_add(en_vocab):
matcher = PhraseMatcher(en_vocab)
# match ID only gets added once
matcher.add("TEST", None, Doc(en_vocab, words=["like"]))
matcher.add("TEST", None, Doc(en_vocab, words=["like"]))
matcher.add("TEST", None, Doc(en_vocab, words=["like"]))
matcher.add("TEST", None, Doc(en_vocab, words=["like"]))
doc = Doc(en_vocab, words=["I", "like", "Google", "Now", "best"])
assert "TEST" in matcher
assert "TEST2" not in matcher
assert len(matcher(doc)) == 1
def test_phrase_matcher_remove(en_vocab):
matcher = PhraseMatcher(en_vocab)
matcher.add("TEST", None, Doc(en_vocab, words=["like"]))
doc = Doc(en_vocab, words=["I", "like", "Google", "Now", "best"])
assert "TEST" in matcher
assert "TEST2" not in matcher
assert len(matcher(doc)) == 1
matcher.remove("TEST")
assert "TEST" not in matcher
assert "TEST2" not in matcher
assert len(matcher(doc)) == 0
matcher.remove("TEST2")
assert "TEST" not in matcher
assert "TEST2" not in matcher
assert len(matcher(doc)) == 0
def test_phrase_matcher_overlapping_with_remove(en_vocab):
matcher = PhraseMatcher(en_vocab)
matcher.add("TEST", None, Doc(en_vocab, words=["like"]))
# TEST2 is added alongside TEST
matcher.add("TEST2", None, Doc(en_vocab, words=["like"]))
doc = Doc(en_vocab, words=["I", "like", "Google", "Now", "best"])
assert "TEST" in matcher
assert len(matcher) == 2
assert len(matcher(doc)) == 2
# removing TEST does not remove the entry for TEST2
matcher.remove("TEST")
assert "TEST" not in matcher
assert len(matcher) == 1
assert len(matcher(doc)) == 1
assert matcher(doc)[0][0] == en_vocab.strings["TEST2"]
# removing TEST2 removes all
matcher.remove("TEST2")
assert "TEST2" not in matcher
assert len(matcher) == 0
assert len(matcher(doc)) == 0
def test_phrase_matcher_string_attrs(en_vocab): def test_phrase_matcher_string_attrs(en_vocab):
words1 = ["I", "like", "cats"] words1 = ["I", "like", "cats"]
pos1 = ["PRON", "VERB", "NOUN"] pos1 = ["PRON", "VERB", "NOUN"]