diff --git a/spacy/matcher/phrasematcher.pxd b/spacy/matcher/phrasematcher.pxd index e69de29bb..1a550989d 100644 --- a/spacy/matcher/phrasematcher.pxd +++ b/spacy/matcher/phrasematcher.pxd @@ -0,0 +1,6 @@ +from preshed.maps cimport key_t + +cdef struct MatchStruct: + key_t match_id + int start + int end diff --git a/spacy/matcher/phrasematcher.pyx b/spacy/matcher/phrasematcher.pyx index 048b5bd68..172a271b0 100644 --- a/spacy/matcher/phrasematcher.pyx +++ b/spacy/matcher/phrasematcher.pyx @@ -2,10 +2,20 @@ # cython: profile=True from __future__ import unicode_literals +from libc.stdint cimport uintptr_t +from libc.stdio cimport printf +from libcpp.vector cimport vector + +from cymem.cymem cimport Pool + +from preshed.maps cimport MapStruct, map_init, map_set, map_get_unless_missing +from preshed.maps cimport map_clear, map_iter, key_t, Result + import numpy as np from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA, attr_id_t from ..vocab cimport Vocab +from ..strings cimport hash_string from ..tokens.doc cimport Doc, get_token_attr from ._schemas import TOKEN_PATTERN_SCHEMA @@ -25,14 +35,18 @@ cdef class PhraseMatcher: Copyright (c) 2017 Vikash Singh (vikash.duliajan@gmail.com) """ cdef Vocab vocab - cdef unicode _terminal - cdef object keyword_trie_dict cdef attr_id_t attr cdef object _callbacks cdef object _keywords cdef object _docs cdef bint _validate + cdef MapStruct* c_map + cdef Pool mem + cdef key_t _terminal_node + + cdef void find_matches(self, key_t* hash_array, int hash_array_len, vector[MatchStruct] *matches) nogil + def __init__(self, Vocab vocab, max_length=0, attr="ORTH", validate=False): """Initialize the PhraseMatcher. @@ -46,13 +60,16 @@ cdef class PhraseMatcher: if max_length != 0: deprecation_warning(Warnings.W010) self.vocab = vocab - self._terminal = '_terminal_' - self.keyword_trie_dict = dict() self._callbacks = {} self._keywords = {} self._docs = {} self._validate = validate + self.mem = Pool() + self.c_map = self.mem.alloc(1, sizeof(MapStruct)) + self._terminal_node = 1 # or random: np.random.randint(0, high=np.iinfo(np.uint64).max, dtype=np.uint64) + map_init(self.mem, self.c_map, 8) + if isinstance(attr, long): self.attr = attr else: @@ -93,37 +110,58 @@ cdef class PhraseMatcher: """ if key not in self._keywords: return + cdef MapStruct* current_node + cdef MapStruct* terminal_map + cdef MapStruct* node_pointer + cdef Result result + cdef key_t terminal_key + cdef void* value + cdef int c_i = 0 for keyword in self._keywords[key]: - current_dict = self.keyword_trie_dict + current_node = self.c_map token_trie_list = [] - for tokens in keyword: - if tokens in current_dict: - token_trie_list.append((tokens, current_dict)) - current_dict = current_dict[tokens] + for token in keyword: + result = map_get_unless_missing(current_node, token) + if result.found: + token_trie_list.append((token, current_node)) + current_node = result.value else: # if token is not found, break out of the loop - current_dict = None + current_node = NULL break - # remove the tokens from trie dict if there are no other + # remove the tokens from trie node if there are no other # keywords with them - if current_dict and self._terminal in current_dict: + result = map_get_unless_missing(current_node, self._terminal_node) + if current_node != NULL and result.found: # if this is the only remaining key, remove unnecessary paths - if current_dict[self._terminal] == [key]: + terminal_map = result.value + terminal_keys = [] + c_i = 0 + while map_iter(terminal_map, &c_i, &terminal_key, &value): + terminal_keys.append(self.vocab.strings[terminal_key]) + # TODO: not working, fix remove for unused paths/maps + if False and terminal_keys == [key]: # we found a complete match for input keyword - token_trie_list.append((self._terminal, current_dict)) + token_trie_list.append((self.vocab.strings[key], terminal_map)) token_trie_list.reverse() - for key_to_remove, dict_pointer in token_trie_list: - if len(dict_pointer.keys()) == 1: - dict_pointer.pop(key_to_remove) + for key_to_remove, py_node_pointer in token_trie_list: + node_pointer = py_node_pointer + result = map_get_unless_missing(node_pointer, key_to_remove) + if node_pointer.filled == 1: + map_clear(node_pointer, key_to_remove) + self.mem.free(result.value) + pass else: # more than one key means more than 1 path, # delete not required path and keep the other - dict_pointer.pop(key_to_remove) + map_clear(node_pointer, key_to_remove) + self.mem.free(result.value) break # otherwise simply remove the key else: - if key in current_dict[self._terminal]: - current_dict[self._terminal].remove(key) + result = map_get_unless_missing(current_node, self._terminal_node) + if result.found: + map_clear(result.value, self.vocab.strings[key]) del self._keywords[key] del self._callbacks[key] @@ -146,6 +184,10 @@ cdef class PhraseMatcher: self._docs.setdefault(key, set()) self._docs[key].update(docs) + cdef MapStruct* current_node + cdef MapStruct* internal_node + cdef Result result + for doc in docs: if len(doc) == 0: continue @@ -161,11 +203,23 @@ cdef class PhraseMatcher: # keep track of keywords per key to make remove easier # (would use a set, but can't hash numpy arrays) self._keywords[key].append(keyword) - current_dict = self.keyword_trie_dict + + current_node = self.c_map for token in keyword: - current_dict = current_dict.setdefault(token, {}) - current_dict.setdefault(self._terminal, set()) - current_dict[self._terminal].add(key) + result = map_get_unless_missing(current_node, token) + if not result.found: + internal_node = self.mem.alloc(1, sizeof(MapStruct)) + map_init(self.mem, internal_node, 8) + map_set(self.mem, current_node, token, internal_node) + result.value = internal_node + current_node = result.value + result = map_get_unless_missing(current_node, self._terminal_node) + if not result.found: + internal_node = self.mem.alloc(1, sizeof(MapStruct)) + map_init(self.mem, internal_node, 8) + map_set(self.mem, current_node, self._terminal_node, internal_node) + result.value = internal_node + map_set(self.mem, result.value, hash_string(key), NULL) def __call__(self, doc): """Find all sequences matching the supplied patterns on the `Doc`. @@ -182,42 +236,62 @@ cdef class PhraseMatcher: if doc_array is None or len(doc_array) == 0: # if doc_array is empty or None just return empty list return matches - current_dict = self.keyword_trie_dict - start = 0 - idx = 0 - doc_array_len = len(doc_array) - while idx < doc_array_len: - start = idx - token = doc_array[idx] - # look for sequences from this position - if token in current_dict: - current_dict_continued = current_dict[token] - idy = idx + 1 - while idy < doc_array_len: - if self._terminal in current_dict_continued: - ent_ids = current_dict_continued[self._terminal] - for ent_id in ent_ids: - matches.append((self.vocab.strings[ent_id], start, idy)) - inner_token = doc_array[idy] - if inner_token in current_dict_continued: - current_dict_continued = current_dict_continued[inner_token] - idy += 1 - else: - break - else: - # end of doc_array reached - if self._terminal in current_dict_continued: - ent_ids = current_dict_continued[self._terminal] - for ent_id in ent_ids: - matches.append((self.vocab.strings[ent_id], start, idy)) - current_dict = self.keyword_trie_dict - idx += 1 + + if not doc_array.flags['C_CONTIGUOUS']: + doc_array = np.ascontiguousarray(doc_array) + cdef key_t[::1] doc_array_memview = doc_array + cdef vector[MatchStruct] c_matches + self.find_matches(&doc_array_memview[0], doc_array_memview.shape[0], &c_matches) + for i in range(c_matches.size()): + matches.append((c_matches[i].match_id, c_matches[i].start, c_matches[i].end)) for i, (ent_id, start, end) in enumerate(matches): on_match = self._callbacks.get(ent_id) if on_match is not None: on_match(self, doc, i, matches) return matches + cdef void find_matches(self, key_t* hash_array, int hash_array_len, vector[MatchStruct] *matches) nogil: + cdef MapStruct* current_node = self.c_map + cdef int start = 0 + cdef int idx = 0 + cdef int idy = 0 + cdef key_t key + cdef void* value + cdef int i = 0 + cdef MatchStruct ms + while idx < hash_array_len: + start = idx + token = hash_array[idx] + # look for sequences from this position + result = map_get_unless_missing(current_node, token) + if result.found: + current_node = result.value + idy = idx + 1 + while idy < hash_array_len: + result = map_get_unless_missing(current_node, self._terminal_node) + if result.found: + i = 0 + while map_iter(result.value, &i, &key, &value): + ms = make_matchstruct(key, start, idy) + matches.push_back(ms) + inner_token = hash_array[idy] + result = map_get_unless_missing(current_node, inner_token) + if result.found: + current_node = result.value + idy += 1 + else: + break + else: + # end of hash_array reached + result = map_get_unless_missing(current_node, self._terminal_node) + if result.found: + i = 0 + while map_iter(result.value, &i, &key, &value): + ms = make_matchstruct(key, start, idy) + matches.push_back(ms) + current_node = self.c_map + idx += 1 + def pipe(self, stream, batch_size=1000, n_threads=-1, return_matches=False, as_tuples=False): """Match a stream of documents, yielding them in turn. @@ -281,3 +355,11 @@ def unpickle_matcher(vocab, docs, callbacks): callback = callbacks.get(key, None) matcher.add(key, callback, *specs) return matcher + + +cdef MatchStruct make_matchstruct(key_t match_id, int start, int end) nogil: + cdef MatchStruct ms + ms.match_id = match_id + ms.start = start + ms.end = end + return ms diff --git a/spacy/tests/matcher/test_phrase_matcher.py b/spacy/tests/matcher/test_phrase_matcher.py index 3eb454c39..7d65d0007 100644 --- a/spacy/tests/matcher/test_phrase_matcher.py +++ b/spacy/tests/matcher/test_phrase_matcher.py @@ -67,18 +67,27 @@ def test_phrase_matcher_repeated_add(en_vocab): def test_phrase_matcher_remove(en_vocab): matcher = PhraseMatcher(en_vocab) - matcher.add("TEST", None, Doc(en_vocab, words=["like"])) + matcher.add("TEST1", None, Doc(en_vocab, words=["like"])) + matcher.add("TEST2", None, Doc(en_vocab, words=["best"])) doc = Doc(en_vocab, words=["I", "like", "Google", "Now", "best"]) - assert "TEST" in matcher - assert "TEST2" not in matcher + assert "TEST1" in matcher + assert "TEST2" in matcher + assert "TEST3" not in matcher + assert len(matcher(doc)) == 2 + matcher.remove("TEST1") + assert "TEST1" not in matcher + assert "TEST2" in matcher + assert "TEST3" not in matcher assert len(matcher(doc)) == 1 - matcher.remove("TEST") - assert "TEST" not in matcher - assert "TEST2" not in matcher - assert len(matcher(doc)) == 0 matcher.remove("TEST2") - assert "TEST" not in matcher + assert "TEST1" not in matcher assert "TEST2" not in matcher + assert "TEST3" not in matcher + assert len(matcher(doc)) == 0 + matcher.remove("TEST3") + assert "TEST1" not in matcher + assert "TEST2" not in matcher + assert "TEST3" not in matcher assert len(matcher(doc)) == 0