diff --git a/spacy/errors.py b/spacy/errors.py index 02656e0e7..30c7a5f48 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -86,6 +86,8 @@ class Warnings(object): "previously loaded vectors. See Issue #3853.") W020 = ("Unnamed vectors. This won't allow multiple vectors models to be " "loaded. (Shape: {shape})") + W021 = ("Unexpected hash collision in PhraseMatcher. Matches may be " + "incorrect. Modify PhraseMatcher._terminal_hash to fix.") @add_codes diff --git a/spacy/matcher/phrasematcher.pxd b/spacy/matcher/phrasematcher.pxd index 3aba1686f..753b2da74 100644 --- a/spacy/matcher/phrasematcher.pxd +++ b/spacy/matcher/phrasematcher.pxd @@ -1,5 +1,27 @@ from libcpp.vector cimport vector -from ..typedefs cimport hash_t +from cymem.cymem cimport Pool +from preshed.maps cimport key_t, MapStruct -ctypedef vector[hash_t] hash_vec +from ..attrs cimport attr_id_t +from ..tokens.doc cimport Doc +from ..vocab cimport Vocab + + +cdef class PhraseMatcher: + cdef Vocab vocab + cdef attr_id_t attr + cdef object _callbacks + cdef object _docs + cdef bint _validate + cdef MapStruct* c_map + cdef Pool mem + cdef key_t _terminal_hash + + cdef void find_matches(self, Doc doc, vector[MatchStruct] *matches) nogil + + +cdef struct MatchStruct: + key_t match_id + int start + int end diff --git a/spacy/matcher/phrasematcher.pyx b/spacy/matcher/phrasematcher.pyx index 9e8801cc1..a93c34288 100644 --- a/spacy/matcher/phrasematcher.pyx +++ b/spacy/matcher/phrasematcher.pyx @@ -2,28 +2,16 @@ # cython: profile=True from __future__ import unicode_literals -from libcpp.vector cimport vector -from cymem.cymem cimport Pool -from murmurhash.mrmr cimport hash64 -from preshed.maps cimport PreshMap +from libc.stdint cimport uintptr_t -from .matcher cimport Matcher -from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA, attr_id_t -from ..vocab cimport Vocab -from ..tokens.doc cimport Doc, get_token_attr -from ..typedefs cimport attr_t, hash_t +from preshed.maps cimport map_init, map_set, map_get, map_clear, map_iter + +from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA +from ..structs cimport TokenC +from ..tokens.token cimport Token from ._schemas import TOKEN_PATTERN_SCHEMA from ..errors import Errors, Warnings, deprecation_warning, user_warning -from ..attrs import FLAG61 as U_ENT -from ..attrs import FLAG60 as B2_ENT -from ..attrs import FLAG59 as B3_ENT -from ..attrs import FLAG58 as B4_ENT -from ..attrs import FLAG43 as L2_ENT -from ..attrs import FLAG42 as L3_ENT -from ..attrs import FLAG41 as L4_ENT -from ..attrs import FLAG42 as I3_ENT -from ..attrs import FLAG41 as I4_ENT cdef class PhraseMatcher: @@ -33,18 +21,11 @@ cdef class PhraseMatcher: DOCS: https://spacy.io/api/phrasematcher USAGE: https://spacy.io/usage/rule-based-matching#phrasematcher + + Adapted from FlashText: https://github.com/vi3k6i5/flashtext + MIT License (see `LICENSE`) + Copyright (c) 2017 Vikash Singh (vikash.duliajan@gmail.com) """ - cdef Pool mem - cdef Vocab vocab - cdef Matcher matcher - cdef PreshMap phrase_ids - cdef vector[hash_vec] ent_id_matrix - cdef int max_length - cdef attr_id_t attr - cdef public object _callbacks - cdef public object _patterns - cdef public object _docs - cdef public object _validate def __init__(self, Vocab vocab, max_length=0, attr="ORTH", validate=False): """Initialize the PhraseMatcher. @@ -58,10 +39,16 @@ cdef class PhraseMatcher: """ if max_length != 0: deprecation_warning(Warnings.W010) - self.mem = Pool() - self.max_length = max_length self.vocab = vocab - self.matcher = Matcher(self.vocab, validate=False) + self._callbacks = {} + self._docs = {} + self._validate = validate + + self.mem = Pool() + self.c_map = self.mem.alloc(1, sizeof(MapStruct)) + self._terminal_hash = 826361138722620965 + map_init(self.mem, self.c_map, 8) + if isinstance(attr, long): self.attr = attr else: @@ -71,28 +58,15 @@ cdef class PhraseMatcher: if attr not in TOKEN_PATTERN_SCHEMA["items"]["properties"]: raise ValueError(Errors.E152.format(attr=attr)) self.attr = self.vocab.strings[attr] - self.phrase_ids = PreshMap() - abstract_patterns = [ - [{U_ENT: True}], - [{B2_ENT: True}, {L2_ENT: True}], - [{B3_ENT: True}, {I3_ENT: True}, {L3_ENT: True}], - [{B4_ENT: True}, {I4_ENT: True}, {I4_ENT: True, "OP": "+"}, {L4_ENT: True}], - ] - self.matcher.add("Candidate", None, *abstract_patterns) - self._callbacks = {} - self._docs = {} - self._validate = validate def __len__(self): - """Get the number of rules added to the matcher. Note that this only - returns the number of rules (identical with the number of IDs), not the - number of individual patterns. + """Get the number of match IDs added to the matcher. RETURNS (int): The number of rules. DOCS: https://spacy.io/api/phrasematcher#len """ - return len(self._docs) + return len(self._callbacks) def __contains__(self, key): """Check whether the matcher contains rules for a match ID. @@ -102,13 +76,77 @@ cdef class PhraseMatcher: DOCS: https://spacy.io/api/phrasematcher#contains """ - cdef hash_t ent_id = self.matcher._normalize_key(key) - return ent_id in self._callbacks + return key in self._callbacks def __reduce__(self): data = (self.vocab, self._docs, self._callbacks) return (unpickle_matcher, data, None, None) + def remove(self, key): + """Remove a rule from the matcher by match ID. A KeyError is raised if + the key does not exist. + + key (unicode): The match ID. + """ + if key not in self._docs: + raise KeyError(key) + cdef MapStruct* current_node + cdef MapStruct* terminal_map + cdef MapStruct* node_pointer + cdef void* result + cdef key_t terminal_key + cdef void* value + cdef int c_i = 0 + cdef vector[MapStruct*] path_nodes + cdef vector[key_t] path_keys + cdef key_t key_to_remove + for keyword in self._docs[key]: + current_node = self.c_map + for token in keyword: + result = map_get(current_node, token) + if result: + path_nodes.push_back(current_node) + path_keys.push_back(token) + current_node = result + else: + # if token is not found, break out of the loop + current_node = NULL + break + # remove the tokens from trie node if there are no other + # keywords with them + result = map_get(current_node, self._terminal_hash) + if current_node != NULL and result: + terminal_map = result + terminal_keys = [] + c_i = 0 + while map_iter(terminal_map, &c_i, &terminal_key, &value): + terminal_keys.append(self.vocab.strings[terminal_key]) + # if this is the only remaining key, remove unnecessary paths + if terminal_keys == [key]: + while not path_nodes.empty(): + node_pointer = path_nodes.back() + path_nodes.pop_back() + key_to_remove = path_keys.back() + path_keys.pop_back() + result = map_get(node_pointer, key_to_remove) + if node_pointer.filled == 1: + map_clear(node_pointer, key_to_remove) + self.mem.free(result) + else: + # more than one key means more than 1 path, + # delete not required path and keep the others + map_clear(node_pointer, key_to_remove) + self.mem.free(result) + break + # otherwise simply remove the key + else: + result = map_get(current_node, self._terminal_hash) + if result: + map_clear(result, self.vocab.strings[key]) + + del self._callbacks[key] + del self._docs[key] + def add(self, key, on_match, *docs): """Add a match-rule to the phrase-matcher. A match-rule consists of: an ID key, an on_match callback, and one or more patterns. @@ -119,17 +157,17 @@ cdef class PhraseMatcher: DOCS: https://spacy.io/api/phrasematcher#add """ - cdef Doc doc - cdef hash_t ent_id = self.matcher._normalize_key(key) - self._callbacks[ent_id] = on_match - self._docs[ent_id] = docs - cdef int length - cdef int i - cdef hash_t phrase_hash - cdef Pool mem = Pool() + + _ = self.vocab[key] + self._callbacks[key] = on_match + self._docs.setdefault(key, set()) + + cdef MapStruct* current_node + cdef MapStruct* internal_node + cdef void* result + for doc in docs: - length = doc.length - if length == 0: + if len(doc) == 0: continue if self.attr in (POS, TAG, LEMMA) and not doc.is_tagged: raise ValueError(Errors.E155.format()) @@ -139,33 +177,33 @@ cdef class PhraseMatcher: and self.attr not in (DEP, POS, TAG, LEMMA): string_attr = self.vocab.strings[self.attr] user_warning(Warnings.W012.format(key=key, attr=string_attr)) - tags = get_biluo(length) - phrase_key = mem.alloc(length, sizeof(attr_t)) - for i, tag in enumerate(tags): - attr_value = self.get_lex_value(doc, i) - lexeme = self.vocab[attr_value] - lexeme.set_flag(tag, True) - phrase_key[i] = lexeme.orth - phrase_hash = hash64(phrase_key, length * sizeof(attr_t), 0) - - if phrase_hash in self.phrase_ids: - phrase_index = self.phrase_ids[phrase_hash] - ent_id_list = self.ent_id_matrix[phrase_index] - ent_id_list.append(ent_id) - self.ent_id_matrix[phrase_index] = ent_id_list - + if isinstance(doc, Doc): + keyword = self._convert_to_array(doc) else: - ent_id_list = hash_vec(1) - ent_id_list[0] = ent_id - new_index = self.ent_id_matrix.size() - if new_index == 0: - # PreshMaps can not contain 0 as value, so storing a dummy at 0 - self.ent_id_matrix.push_back(hash_vec(0)) - new_index = 1 - self.ent_id_matrix.push_back(ent_id_list) - self.phrase_ids.set(phrase_hash, new_index) + keyword = doc + self._docs[key].add(tuple(keyword)) - def __call__(self, Doc doc): + current_node = self.c_map + for token in keyword: + if token == self._terminal_hash: + user_warning(Warnings.W021) + break + result = map_get(current_node, token) + if not result: + internal_node = self.mem.alloc(1, sizeof(MapStruct)) + map_init(self.mem, internal_node, 8) + map_set(self.mem, current_node, token, internal_node) + result = internal_node + current_node = result + result = map_get(current_node, self._terminal_hash) + if not result: + internal_node = self.mem.alloc(1, sizeof(MapStruct)) + map_init(self.mem, internal_node, 8) + map_set(self.mem, current_node, self._terminal_hash, internal_node) + result = internal_node + map_set(self.mem, result, self.vocab.strings[key], NULL) + + def __call__(self, doc): """Find all sequences matching the supplied patterns on the `Doc`. doc (Doc): The document to match over. @@ -176,25 +214,63 @@ cdef class PhraseMatcher: DOCS: https://spacy.io/api/phrasematcher#call """ matches = [] - if self.attr == ORTH: - match_doc = doc - else: - # If we're not matching on the ORTH, match_doc will be a Doc whose - # token.orth values are the attribute values we're matching on, - # e.g. Doc(nlp.vocab, words=[token.pos_ for token in doc]) - words = [self.get_lex_value(doc, i) for i in range(len(doc))] - match_doc = Doc(self.vocab, words=words) - for _, start, end in self.matcher(match_doc): - ent_ids = self.accept_match(match_doc, start, end) - if ent_ids is not None: - for ent_id in ent_ids: - matches.append((ent_id, start, end)) + if doc is None or len(doc) == 0: + # if doc is empty or None just return empty list + return matches + + cdef vector[MatchStruct] c_matches + self.find_matches(doc, &c_matches) + for i in range(c_matches.size()): + matches.append((c_matches[i].match_id, c_matches[i].start, c_matches[i].end)) for i, (ent_id, start, end) in enumerate(matches): on_match = self._callbacks.get(ent_id) if on_match is not None: on_match(self, doc, i, matches) return matches + cdef void find_matches(self, Doc doc, vector[MatchStruct] *matches) nogil: + cdef MapStruct* current_node = self.c_map + cdef int start = 0 + cdef int idx = 0 + cdef int idy = 0 + cdef key_t key + cdef void* value + cdef int i = 0 + cdef MatchStruct ms + cdef void* result + while idx < doc.length: + start = idx + token = Token.get_struct_attr(&doc.c[idx], self.attr) + # look for sequences from this position + result = map_get(current_node, token) + if result: + current_node = result + idy = idx + 1 + while idy < doc.length: + result = map_get(current_node, self._terminal_hash) + if result: + i = 0 + while map_iter(result, &i, &key, &value): + ms = make_matchstruct(key, start, idy) + matches.push_back(ms) + inner_token = Token.get_struct_attr(&doc.c[idy], self.attr) + result = map_get(current_node, inner_token) + if result: + current_node = result + idy += 1 + else: + break + else: + # end of doc reached + result = map_get(current_node, self._terminal_hash) + if result: + i = 0 + while map_iter(result, &i, &key, &value): + ms = make_matchstruct(key, start, idy) + matches.push_back(ms) + current_node = self.c_map + idx += 1 + def pipe(self, stream, batch_size=1000, n_threads=-1, return_matches=False, as_tuples=False): """Match a stream of documents, yielding them in turn. @@ -228,48 +304,8 @@ cdef class PhraseMatcher: else: yield doc - def accept_match(self, Doc doc, int start, int end): - cdef int i, j - cdef Pool mem = Pool() - phrase_key = mem.alloc(end-start, sizeof(attr_t)) - for i, j in enumerate(range(start, end)): - phrase_key[i] = doc.c[j].lex.orth - cdef hash_t key = hash64(phrase_key, (end-start) * sizeof(attr_t), 0) - - ent_index = self.phrase_ids.get(key) - if ent_index == 0: - return None - return self.ent_id_matrix[ent_index] - - def get_lex_value(self, Doc doc, int i): - if self.attr == ORTH: - # Return the regular orth value of the lexeme - return doc.c[i].lex.orth - # Get the attribute value instead, e.g. token.pos - attr_value = get_token_attr(&doc.c[i], self.attr) - if attr_value in (0, 1): - # Value is boolean, convert to string - string_attr_value = str(attr_value) - else: - string_attr_value = self.vocab.strings[attr_value] - string_attr_name = self.vocab.strings[self.attr] - # Concatenate the attr name and value to not pollute lexeme space - # e.g. 'POS-VERB' instead of just 'VERB', which could otherwise - # create false positive matches - return "matcher:{}-{}".format(string_attr_name, string_attr_value) - - -def get_biluo(length): - if length == 0: - raise ValueError(Errors.E127) - elif length == 1: - return [U_ENT] - elif length == 2: - return [B2_ENT, L2_ENT] - elif length == 3: - return [B3_ENT, I3_ENT, L3_ENT] - else: - return [B4_ENT, I4_ENT] + [I4_ENT] * (length-3) + [L4_ENT] + def _convert_to_array(self, Doc doc): + return [Token.get_struct_attr(&doc.c[i], self.attr) for i in range(len(doc))] def unpickle_matcher(vocab, docs, callbacks): @@ -278,3 +314,11 @@ def unpickle_matcher(vocab, docs, callbacks): callback = callbacks.get(key, None) matcher.add(key, callback, *specs) return matcher + + +cdef MatchStruct make_matchstruct(key_t match_id, int start, int end) nogil: + cdef MatchStruct ms + ms.match_id = match_id + ms.start = start + ms.end = end + return ms diff --git a/spacy/tests/matcher/test_phrase_matcher.py b/spacy/tests/matcher/test_phrase_matcher.py index b82f9a058..486cbb984 100644 --- a/spacy/tests/matcher/test_phrase_matcher.py +++ b/spacy/tests/matcher/test_phrase_matcher.py @@ -8,10 +8,31 @@ from ..util import get_doc def test_matcher_phrase_matcher(en_vocab): - doc = Doc(en_vocab, words=["Google", "Now"]) - matcher = PhraseMatcher(en_vocab) - matcher.add("COMPANY", None, doc) doc = Doc(en_vocab, words=["I", "like", "Google", "Now", "best"]) + # intermediate phrase + pattern = Doc(en_vocab, words=["Google", "Now"]) + matcher = PhraseMatcher(en_vocab) + matcher.add("COMPANY", None, pattern) + assert len(matcher(doc)) == 1 + # initial token + pattern = Doc(en_vocab, words=["I"]) + matcher = PhraseMatcher(en_vocab) + matcher.add("I", None, pattern) + assert len(matcher(doc)) == 1 + # initial phrase + pattern = Doc(en_vocab, words=["I", "like"]) + matcher = PhraseMatcher(en_vocab) + matcher.add("ILIKE", None, pattern) + assert len(matcher(doc)) == 1 + # final token + pattern = Doc(en_vocab, words=["best"]) + matcher = PhraseMatcher(en_vocab) + matcher.add("BEST", None, pattern) + assert len(matcher(doc)) == 1 + # final phrase + pattern = Doc(en_vocab, words=["Now", "best"]) + matcher = PhraseMatcher(en_vocab) + matcher.add("NOWBEST", None, pattern) assert len(matcher(doc)) == 1 @@ -31,6 +52,68 @@ def test_phrase_matcher_contains(en_vocab): assert "TEST2" not in matcher +def test_phrase_matcher_repeated_add(en_vocab): + matcher = PhraseMatcher(en_vocab) + # match ID only gets added once + matcher.add("TEST", None, Doc(en_vocab, words=["like"])) + matcher.add("TEST", None, Doc(en_vocab, words=["like"])) + matcher.add("TEST", None, Doc(en_vocab, words=["like"])) + matcher.add("TEST", None, Doc(en_vocab, words=["like"])) + doc = Doc(en_vocab, words=["I", "like", "Google", "Now", "best"]) + assert "TEST" in matcher + assert "TEST2" not in matcher + assert len(matcher(doc)) == 1 + + +def test_phrase_matcher_remove(en_vocab): + matcher = PhraseMatcher(en_vocab) + matcher.add("TEST1", None, Doc(en_vocab, words=["like"])) + matcher.add("TEST2", None, Doc(en_vocab, words=["best"])) + doc = Doc(en_vocab, words=["I", "like", "Google", "Now", "best"]) + assert "TEST1" in matcher + assert "TEST2" in matcher + assert "TEST3" not in matcher + assert len(matcher(doc)) == 2 + matcher.remove("TEST1") + assert "TEST1" not in matcher + assert "TEST2" in matcher + assert "TEST3" not in matcher + assert len(matcher(doc)) == 1 + matcher.remove("TEST2") + assert "TEST1" not in matcher + assert "TEST2" not in matcher + assert "TEST3" not in matcher + assert len(matcher(doc)) == 0 + with pytest.raises(KeyError): + matcher.remove("TEST3") + assert "TEST1" not in matcher + assert "TEST2" not in matcher + assert "TEST3" not in matcher + assert len(matcher(doc)) == 0 + + +def test_phrase_matcher_overlapping_with_remove(en_vocab): + matcher = PhraseMatcher(en_vocab) + matcher.add("TEST", None, Doc(en_vocab, words=["like"])) + # TEST2 is added alongside TEST + matcher.add("TEST2", None, Doc(en_vocab, words=["like"])) + doc = Doc(en_vocab, words=["I", "like", "Google", "Now", "best"]) + assert "TEST" in matcher + assert len(matcher) == 2 + assert len(matcher(doc)) == 2 + # removing TEST does not remove the entry for TEST2 + matcher.remove("TEST") + assert "TEST" not in matcher + assert len(matcher) == 1 + assert len(matcher(doc)) == 1 + assert matcher(doc)[0][0] == en_vocab.strings["TEST2"] + # removing TEST2 removes all + matcher.remove("TEST2") + assert "TEST2" not in matcher + assert len(matcher) == 0 + assert len(matcher(doc)) == 0 + + def test_phrase_matcher_string_attrs(en_vocab): words1 = ["I", "like", "cats"] pos1 = ["PRON", "VERB", "NOUN"]