mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-16 03:20:34 +03:00
Replace dict trie with MapStruct trie
This commit is contained in:
parent
a7e9c0fd3e
commit
39540ed1ce
|
@ -0,0 +1,6 @@
|
||||||
|
from preshed.maps cimport key_t
|
||||||
|
|
||||||
|
cdef struct MatchStruct:
|
||||||
|
key_t match_id
|
||||||
|
int start
|
||||||
|
int end
|
|
@ -2,10 +2,20 @@
|
||||||
# cython: profile=True
|
# cython: profile=True
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
from libc.stdint cimport uintptr_t
|
||||||
|
from libc.stdio cimport printf
|
||||||
|
from libcpp.vector cimport vector
|
||||||
|
|
||||||
|
from cymem.cymem cimport Pool
|
||||||
|
|
||||||
|
from preshed.maps cimport MapStruct, map_init, map_set, map_get_unless_missing
|
||||||
|
from preshed.maps cimport map_clear, map_iter, key_t, Result
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA, attr_id_t
|
from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA, attr_id_t
|
||||||
from ..vocab cimport Vocab
|
from ..vocab cimport Vocab
|
||||||
|
from ..strings cimport hash_string
|
||||||
from ..tokens.doc cimport Doc, get_token_attr
|
from ..tokens.doc cimport Doc, get_token_attr
|
||||||
|
|
||||||
from ._schemas import TOKEN_PATTERN_SCHEMA
|
from ._schemas import TOKEN_PATTERN_SCHEMA
|
||||||
|
@ -25,14 +35,18 @@ cdef class PhraseMatcher:
|
||||||
Copyright (c) 2017 Vikash Singh (vikash.duliajan@gmail.com)
|
Copyright (c) 2017 Vikash Singh (vikash.duliajan@gmail.com)
|
||||||
"""
|
"""
|
||||||
cdef Vocab vocab
|
cdef Vocab vocab
|
||||||
cdef unicode _terminal
|
|
||||||
cdef object keyword_trie_dict
|
|
||||||
cdef attr_id_t attr
|
cdef attr_id_t attr
|
||||||
cdef object _callbacks
|
cdef object _callbacks
|
||||||
cdef object _keywords
|
cdef object _keywords
|
||||||
cdef object _docs
|
cdef object _docs
|
||||||
cdef bint _validate
|
cdef bint _validate
|
||||||
|
|
||||||
|
cdef MapStruct* c_map
|
||||||
|
cdef Pool mem
|
||||||
|
cdef key_t _terminal_node
|
||||||
|
|
||||||
|
cdef void find_matches(self, key_t* hash_array, int hash_array_len, vector[MatchStruct] *matches) nogil
|
||||||
|
|
||||||
def __init__(self, Vocab vocab, max_length=0, attr="ORTH", validate=False):
|
def __init__(self, Vocab vocab, max_length=0, attr="ORTH", validate=False):
|
||||||
"""Initialize the PhraseMatcher.
|
"""Initialize the PhraseMatcher.
|
||||||
|
|
||||||
|
@ -46,13 +60,16 @@ cdef class PhraseMatcher:
|
||||||
if max_length != 0:
|
if max_length != 0:
|
||||||
deprecation_warning(Warnings.W010)
|
deprecation_warning(Warnings.W010)
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self._terminal = '_terminal_'
|
|
||||||
self.keyword_trie_dict = dict()
|
|
||||||
self._callbacks = {}
|
self._callbacks = {}
|
||||||
self._keywords = {}
|
self._keywords = {}
|
||||||
self._docs = {}
|
self._docs = {}
|
||||||
self._validate = validate
|
self._validate = validate
|
||||||
|
|
||||||
|
self.mem = Pool()
|
||||||
|
self.c_map = <MapStruct*>self.mem.alloc(1, sizeof(MapStruct))
|
||||||
|
self._terminal_node = 1 # or random: np.random.randint(0, high=np.iinfo(np.uint64).max, dtype=np.uint64)
|
||||||
|
map_init(self.mem, self.c_map, 8)
|
||||||
|
|
||||||
if isinstance(attr, long):
|
if isinstance(attr, long):
|
||||||
self.attr = attr
|
self.attr = attr
|
||||||
else:
|
else:
|
||||||
|
@ -93,37 +110,58 @@ cdef class PhraseMatcher:
|
||||||
"""
|
"""
|
||||||
if key not in self._keywords:
|
if key not in self._keywords:
|
||||||
return
|
return
|
||||||
|
cdef MapStruct* current_node
|
||||||
|
cdef MapStruct* terminal_map
|
||||||
|
cdef MapStruct* node_pointer
|
||||||
|
cdef Result result
|
||||||
|
cdef key_t terminal_key
|
||||||
|
cdef void* value
|
||||||
|
cdef int c_i = 0
|
||||||
for keyword in self._keywords[key]:
|
for keyword in self._keywords[key]:
|
||||||
current_dict = self.keyword_trie_dict
|
current_node = self.c_map
|
||||||
token_trie_list = []
|
token_trie_list = []
|
||||||
for tokens in keyword:
|
for token in keyword:
|
||||||
if tokens in current_dict:
|
result = map_get_unless_missing(current_node, token)
|
||||||
token_trie_list.append((tokens, current_dict))
|
if result.found:
|
||||||
current_dict = current_dict[tokens]
|
token_trie_list.append((token, <uintptr_t>current_node))
|
||||||
|
current_node = <MapStruct*>result.value
|
||||||
else:
|
else:
|
||||||
# if token is not found, break out of the loop
|
# if token is not found, break out of the loop
|
||||||
current_dict = None
|
current_node = NULL
|
||||||
break
|
break
|
||||||
# remove the tokens from trie dict if there are no other
|
# remove the tokens from trie node if there are no other
|
||||||
# keywords with them
|
# keywords with them
|
||||||
if current_dict and self._terminal in current_dict:
|
result = map_get_unless_missing(current_node, self._terminal_node)
|
||||||
|
if current_node != NULL and result.found:
|
||||||
# if this is the only remaining key, remove unnecessary paths
|
# if this is the only remaining key, remove unnecessary paths
|
||||||
if current_dict[self._terminal] == [key]:
|
terminal_map = <MapStruct*>result.value
|
||||||
|
terminal_keys = []
|
||||||
|
c_i = 0
|
||||||
|
while map_iter(terminal_map, &c_i, &terminal_key, &value):
|
||||||
|
terminal_keys.append(self.vocab.strings[terminal_key])
|
||||||
|
# TODO: not working, fix remove for unused paths/maps
|
||||||
|
if False and terminal_keys == [key]:
|
||||||
# we found a complete match for input keyword
|
# we found a complete match for input keyword
|
||||||
token_trie_list.append((self._terminal, current_dict))
|
token_trie_list.append((self.vocab.strings[key], <uintptr_t>terminal_map))
|
||||||
token_trie_list.reverse()
|
token_trie_list.reverse()
|
||||||
for key_to_remove, dict_pointer in token_trie_list:
|
for key_to_remove, py_node_pointer in token_trie_list:
|
||||||
if len(dict_pointer.keys()) == 1:
|
node_pointer = <MapStruct*>py_node_pointer
|
||||||
dict_pointer.pop(key_to_remove)
|
result = map_get_unless_missing(node_pointer, key_to_remove)
|
||||||
|
if node_pointer.filled == 1:
|
||||||
|
map_clear(node_pointer, key_to_remove)
|
||||||
|
self.mem.free(result.value)
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
# more than one key means more than 1 path,
|
# more than one key means more than 1 path,
|
||||||
# delete not required path and keep the other
|
# delete not required path and keep the other
|
||||||
dict_pointer.pop(key_to_remove)
|
map_clear(node_pointer, key_to_remove)
|
||||||
|
self.mem.free(result.value)
|
||||||
break
|
break
|
||||||
# otherwise simply remove the key
|
# otherwise simply remove the key
|
||||||
else:
|
else:
|
||||||
if key in current_dict[self._terminal]:
|
result = map_get_unless_missing(current_node, self._terminal_node)
|
||||||
current_dict[self._terminal].remove(key)
|
if result.found:
|
||||||
|
map_clear(<MapStruct*>result.value, self.vocab.strings[key])
|
||||||
|
|
||||||
del self._keywords[key]
|
del self._keywords[key]
|
||||||
del self._callbacks[key]
|
del self._callbacks[key]
|
||||||
|
@ -146,6 +184,10 @@ cdef class PhraseMatcher:
|
||||||
self._docs.setdefault(key, set())
|
self._docs.setdefault(key, set())
|
||||||
self._docs[key].update(docs)
|
self._docs[key].update(docs)
|
||||||
|
|
||||||
|
cdef MapStruct* current_node
|
||||||
|
cdef MapStruct* internal_node
|
||||||
|
cdef Result result
|
||||||
|
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
if len(doc) == 0:
|
if len(doc) == 0:
|
||||||
continue
|
continue
|
||||||
|
@ -161,11 +203,23 @@ cdef class PhraseMatcher:
|
||||||
# keep track of keywords per key to make remove easier
|
# keep track of keywords per key to make remove easier
|
||||||
# (would use a set, but can't hash numpy arrays)
|
# (would use a set, but can't hash numpy arrays)
|
||||||
self._keywords[key].append(keyword)
|
self._keywords[key].append(keyword)
|
||||||
current_dict = self.keyword_trie_dict
|
|
||||||
|
current_node = self.c_map
|
||||||
for token in keyword:
|
for token in keyword:
|
||||||
current_dict = current_dict.setdefault(token, {})
|
result = map_get_unless_missing(current_node, token)
|
||||||
current_dict.setdefault(self._terminal, set())
|
if not result.found:
|
||||||
current_dict[self._terminal].add(key)
|
internal_node = <MapStruct*>self.mem.alloc(1, sizeof(MapStruct))
|
||||||
|
map_init(self.mem, internal_node, 8)
|
||||||
|
map_set(self.mem, current_node, token, internal_node)
|
||||||
|
result.value = internal_node
|
||||||
|
current_node = <MapStruct*>result.value
|
||||||
|
result = map_get_unless_missing(current_node, self._terminal_node)
|
||||||
|
if not result.found:
|
||||||
|
internal_node = <MapStruct*>self.mem.alloc(1, sizeof(MapStruct))
|
||||||
|
map_init(self.mem, internal_node, 8)
|
||||||
|
map_set(self.mem, current_node, self._terminal_node, internal_node)
|
||||||
|
result.value = internal_node
|
||||||
|
map_set(self.mem, <MapStruct*>result.value, hash_string(key), NULL)
|
||||||
|
|
||||||
def __call__(self, doc):
|
def __call__(self, doc):
|
||||||
"""Find all sequences matching the supplied patterns on the `Doc`.
|
"""Find all sequences matching the supplied patterns on the `Doc`.
|
||||||
|
@ -182,42 +236,62 @@ cdef class PhraseMatcher:
|
||||||
if doc_array is None or len(doc_array) == 0:
|
if doc_array is None or len(doc_array) == 0:
|
||||||
# if doc_array is empty or None just return empty list
|
# if doc_array is empty or None just return empty list
|
||||||
return matches
|
return matches
|
||||||
current_dict = self.keyword_trie_dict
|
|
||||||
start = 0
|
if not doc_array.flags['C_CONTIGUOUS']:
|
||||||
idx = 0
|
doc_array = np.ascontiguousarray(doc_array)
|
||||||
doc_array_len = len(doc_array)
|
cdef key_t[::1] doc_array_memview = doc_array
|
||||||
while idx < doc_array_len:
|
cdef vector[MatchStruct] c_matches
|
||||||
start = idx
|
self.find_matches(&doc_array_memview[0], doc_array_memview.shape[0], &c_matches)
|
||||||
token = doc_array[idx]
|
for i in range(c_matches.size()):
|
||||||
# look for sequences from this position
|
matches.append((c_matches[i].match_id, c_matches[i].start, c_matches[i].end))
|
||||||
if token in current_dict:
|
|
||||||
current_dict_continued = current_dict[token]
|
|
||||||
idy = idx + 1
|
|
||||||
while idy < doc_array_len:
|
|
||||||
if self._terminal in current_dict_continued:
|
|
||||||
ent_ids = current_dict_continued[self._terminal]
|
|
||||||
for ent_id in ent_ids:
|
|
||||||
matches.append((self.vocab.strings[ent_id], start, idy))
|
|
||||||
inner_token = doc_array[idy]
|
|
||||||
if inner_token in current_dict_continued:
|
|
||||||
current_dict_continued = current_dict_continued[inner_token]
|
|
||||||
idy += 1
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# end of doc_array reached
|
|
||||||
if self._terminal in current_dict_continued:
|
|
||||||
ent_ids = current_dict_continued[self._terminal]
|
|
||||||
for ent_id in ent_ids:
|
|
||||||
matches.append((self.vocab.strings[ent_id], start, idy))
|
|
||||||
current_dict = self.keyword_trie_dict
|
|
||||||
idx += 1
|
|
||||||
for i, (ent_id, start, end) in enumerate(matches):
|
for i, (ent_id, start, end) in enumerate(matches):
|
||||||
on_match = self._callbacks.get(ent_id)
|
on_match = self._callbacks.get(ent_id)
|
||||||
if on_match is not None:
|
if on_match is not None:
|
||||||
on_match(self, doc, i, matches)
|
on_match(self, doc, i, matches)
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
|
cdef void find_matches(self, key_t* hash_array, int hash_array_len, vector[MatchStruct] *matches) nogil:
|
||||||
|
cdef MapStruct* current_node = self.c_map
|
||||||
|
cdef int start = 0
|
||||||
|
cdef int idx = 0
|
||||||
|
cdef int idy = 0
|
||||||
|
cdef key_t key
|
||||||
|
cdef void* value
|
||||||
|
cdef int i = 0
|
||||||
|
cdef MatchStruct ms
|
||||||
|
while idx < hash_array_len:
|
||||||
|
start = idx
|
||||||
|
token = hash_array[idx]
|
||||||
|
# look for sequences from this position
|
||||||
|
result = map_get_unless_missing(current_node, token)
|
||||||
|
if result.found:
|
||||||
|
current_node = <MapStruct*>result.value
|
||||||
|
idy = idx + 1
|
||||||
|
while idy < hash_array_len:
|
||||||
|
result = map_get_unless_missing(current_node, self._terminal_node)
|
||||||
|
if result.found:
|
||||||
|
i = 0
|
||||||
|
while map_iter(<MapStruct*>result.value, &i, &key, &value):
|
||||||
|
ms = make_matchstruct(key, start, idy)
|
||||||
|
matches.push_back(ms)
|
||||||
|
inner_token = hash_array[idy]
|
||||||
|
result = map_get_unless_missing(current_node, inner_token)
|
||||||
|
if result.found:
|
||||||
|
current_node = <MapStruct*>result.value
|
||||||
|
idy += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# end of hash_array reached
|
||||||
|
result = map_get_unless_missing(current_node, self._terminal_node)
|
||||||
|
if result.found:
|
||||||
|
i = 0
|
||||||
|
while map_iter(<MapStruct*>result.value, &i, &key, &value):
|
||||||
|
ms = make_matchstruct(key, start, idy)
|
||||||
|
matches.push_back(ms)
|
||||||
|
current_node = self.c_map
|
||||||
|
idx += 1
|
||||||
|
|
||||||
def pipe(self, stream, batch_size=1000, n_threads=-1, return_matches=False,
|
def pipe(self, stream, batch_size=1000, n_threads=-1, return_matches=False,
|
||||||
as_tuples=False):
|
as_tuples=False):
|
||||||
"""Match a stream of documents, yielding them in turn.
|
"""Match a stream of documents, yielding them in turn.
|
||||||
|
@ -281,3 +355,11 @@ def unpickle_matcher(vocab, docs, callbacks):
|
||||||
callback = callbacks.get(key, None)
|
callback = callbacks.get(key, None)
|
||||||
matcher.add(key, callback, *specs)
|
matcher.add(key, callback, *specs)
|
||||||
return matcher
|
return matcher
|
||||||
|
|
||||||
|
|
||||||
|
cdef MatchStruct make_matchstruct(key_t match_id, int start, int end) nogil:
|
||||||
|
cdef MatchStruct ms
|
||||||
|
ms.match_id = match_id
|
||||||
|
ms.start = start
|
||||||
|
ms.end = end
|
||||||
|
return ms
|
||||||
|
|
|
@ -67,18 +67,27 @@ def test_phrase_matcher_repeated_add(en_vocab):
|
||||||
|
|
||||||
def test_phrase_matcher_remove(en_vocab):
|
def test_phrase_matcher_remove(en_vocab):
|
||||||
matcher = PhraseMatcher(en_vocab)
|
matcher = PhraseMatcher(en_vocab)
|
||||||
matcher.add("TEST", None, Doc(en_vocab, words=["like"]))
|
matcher.add("TEST1", None, Doc(en_vocab, words=["like"]))
|
||||||
|
matcher.add("TEST2", None, Doc(en_vocab, words=["best"]))
|
||||||
doc = Doc(en_vocab, words=["I", "like", "Google", "Now", "best"])
|
doc = Doc(en_vocab, words=["I", "like", "Google", "Now", "best"])
|
||||||
assert "TEST" in matcher
|
assert "TEST1" in matcher
|
||||||
assert "TEST2" not in matcher
|
assert "TEST2" in matcher
|
||||||
|
assert "TEST3" not in matcher
|
||||||
|
assert len(matcher(doc)) == 2
|
||||||
|
matcher.remove("TEST1")
|
||||||
|
assert "TEST1" not in matcher
|
||||||
|
assert "TEST2" in matcher
|
||||||
|
assert "TEST3" not in matcher
|
||||||
assert len(matcher(doc)) == 1
|
assert len(matcher(doc)) == 1
|
||||||
matcher.remove("TEST")
|
|
||||||
assert "TEST" not in matcher
|
|
||||||
assert "TEST2" not in matcher
|
|
||||||
assert len(matcher(doc)) == 0
|
|
||||||
matcher.remove("TEST2")
|
matcher.remove("TEST2")
|
||||||
assert "TEST" not in matcher
|
assert "TEST1" not in matcher
|
||||||
assert "TEST2" not in matcher
|
assert "TEST2" not in matcher
|
||||||
|
assert "TEST3" not in matcher
|
||||||
|
assert len(matcher(doc)) == 0
|
||||||
|
matcher.remove("TEST3")
|
||||||
|
assert "TEST1" not in matcher
|
||||||
|
assert "TEST2" not in matcher
|
||||||
|
assert "TEST3" not in matcher
|
||||||
assert len(matcher(doc)) == 0
|
assert len(matcher(doc)) == 0
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user