mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Replace PhraseMatcher with trie-based search (#4309)
* Replace PhraseMatcher with Aho-Corasick Replace PhraseMatcher with the Aho-Corasick algorithm over numpy arrays of the hash values for the relevant attribute. The implementation is based on FlashText. The speed should be similar to the previous PhraseMatcher. It is now possible to easily remove match IDs and matches don't go missing with large keyword lists / vocabularies. Fixes #4308. * Restore support for pickling * Fix internal keyword add/remove for numpy arrays * Add missing loop for match ID set in search loop * Remove cruft in matching loop for partial matches There was a bit of unnecessary code left over from FlashText in the matching loop to handle partial token matches, which we don't have with PhraseMatcher. * Replace dict trie with MapStruct trie * Fix how match ID hash is stored/added * Update fix for match ID vocab * Switch from map_get_unless_missing to map_get * Switch from numpy array to Token.get_struct_attr Access token attributes directly in Doc instead of making a copy of the relevant values in a numpy array. Add unsatisfactory warning for hash collision with reserved terminal hash key. (Ideally it would change the reserved terminal hash and redo the whole trie, but for now, I'm hoping there won't be collisions.) * Restructure imports to export find_matches * Implement full remove() Remove unnecessary trie paths and free unused maps. Parallel to Matcher, raise KeyError when attempting to remove a match ID that has not been added. * Store docs internally only as attr lists * Reduces size for pickle * Remove duplicate keywords store Now that docs are stored as lists of attr hashes, there's no need to have the duplicate _keywords store.
This commit is contained in:
parent
d844030fd8
commit
c23edf302b
|
@ -86,6 +86,8 @@ class Warnings(object):
|
|||
"previously loaded vectors. See Issue #3853.")
|
||||
W020 = ("Unnamed vectors. This won't allow multiple vectors models to be "
|
||||
"loaded. (Shape: {shape})")
|
||||
W021 = ("Unexpected hash collision in PhraseMatcher. Matches may be "
|
||||
"incorrect. Modify PhraseMatcher._terminal_hash to fix.")
|
||||
|
||||
|
||||
@add_codes
|
||||
|
|
|
@ -1,5 +1,27 @@
|
|||
from libcpp.vector cimport vector
|
||||
|
||||
from ..typedefs cimport hash_t
|
||||
from cymem.cymem cimport Pool
|
||||
from preshed.maps cimport key_t, MapStruct
|
||||
|
||||
ctypedef vector[hash_t] hash_vec
|
||||
from ..attrs cimport attr_id_t
|
||||
from ..tokens.doc cimport Doc
|
||||
from ..vocab cimport Vocab
|
||||
|
||||
|
||||
cdef class PhraseMatcher:
|
||||
cdef Vocab vocab
|
||||
cdef attr_id_t attr
|
||||
cdef object _callbacks
|
||||
cdef object _docs
|
||||
cdef bint _validate
|
||||
cdef MapStruct* c_map
|
||||
cdef Pool mem
|
||||
cdef key_t _terminal_hash
|
||||
|
||||
cdef void find_matches(self, Doc doc, vector[MatchStruct] *matches) nogil
|
||||
|
||||
|
||||
cdef struct MatchStruct:
|
||||
key_t match_id
|
||||
int start
|
||||
int end
|
||||
|
|
|
@ -2,28 +2,16 @@
|
|||
# cython: profile=True
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from libcpp.vector cimport vector
|
||||
from cymem.cymem cimport Pool
|
||||
from murmurhash.mrmr cimport hash64
|
||||
from preshed.maps cimport PreshMap
|
||||
from libc.stdint cimport uintptr_t
|
||||
|
||||
from .matcher cimport Matcher
|
||||
from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA, attr_id_t
|
||||
from ..vocab cimport Vocab
|
||||
from ..tokens.doc cimport Doc, get_token_attr
|
||||
from ..typedefs cimport attr_t, hash_t
|
||||
from preshed.maps cimport map_init, map_set, map_get, map_clear, map_iter
|
||||
|
||||
from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA
|
||||
from ..structs cimport TokenC
|
||||
from ..tokens.token cimport Token
|
||||
|
||||
from ._schemas import TOKEN_PATTERN_SCHEMA
|
||||
from ..errors import Errors, Warnings, deprecation_warning, user_warning
|
||||
from ..attrs import FLAG61 as U_ENT
|
||||
from ..attrs import FLAG60 as B2_ENT
|
||||
from ..attrs import FLAG59 as B3_ENT
|
||||
from ..attrs import FLAG58 as B4_ENT
|
||||
from ..attrs import FLAG43 as L2_ENT
|
||||
from ..attrs import FLAG42 as L3_ENT
|
||||
from ..attrs import FLAG41 as L4_ENT
|
||||
from ..attrs import FLAG42 as I3_ENT
|
||||
from ..attrs import FLAG41 as I4_ENT
|
||||
|
||||
|
||||
cdef class PhraseMatcher:
|
||||
|
@ -33,18 +21,11 @@ cdef class PhraseMatcher:
|
|||
|
||||
DOCS: https://spacy.io/api/phrasematcher
|
||||
USAGE: https://spacy.io/usage/rule-based-matching#phrasematcher
|
||||
|
||||
Adapted from FlashText: https://github.com/vi3k6i5/flashtext
|
||||
MIT License (see `LICENSE`)
|
||||
Copyright (c) 2017 Vikash Singh (vikash.duliajan@gmail.com)
|
||||
"""
|
||||
cdef Pool mem
|
||||
cdef Vocab vocab
|
||||
cdef Matcher matcher
|
||||
cdef PreshMap phrase_ids
|
||||
cdef vector[hash_vec] ent_id_matrix
|
||||
cdef int max_length
|
||||
cdef attr_id_t attr
|
||||
cdef public object _callbacks
|
||||
cdef public object _patterns
|
||||
cdef public object _docs
|
||||
cdef public object _validate
|
||||
|
||||
def __init__(self, Vocab vocab, max_length=0, attr="ORTH", validate=False):
|
||||
"""Initialize the PhraseMatcher.
|
||||
|
@ -58,10 +39,16 @@ cdef class PhraseMatcher:
|
|||
"""
|
||||
if max_length != 0:
|
||||
deprecation_warning(Warnings.W010)
|
||||
self.mem = Pool()
|
||||
self.max_length = max_length
|
||||
self.vocab = vocab
|
||||
self.matcher = Matcher(self.vocab, validate=False)
|
||||
self._callbacks = {}
|
||||
self._docs = {}
|
||||
self._validate = validate
|
||||
|
||||
self.mem = Pool()
|
||||
self.c_map = <MapStruct*>self.mem.alloc(1, sizeof(MapStruct))
|
||||
self._terminal_hash = 826361138722620965
|
||||
map_init(self.mem, self.c_map, 8)
|
||||
|
||||
if isinstance(attr, long):
|
||||
self.attr = attr
|
||||
else:
|
||||
|
@ -71,28 +58,15 @@ cdef class PhraseMatcher:
|
|||
if attr not in TOKEN_PATTERN_SCHEMA["items"]["properties"]:
|
||||
raise ValueError(Errors.E152.format(attr=attr))
|
||||
self.attr = self.vocab.strings[attr]
|
||||
self.phrase_ids = PreshMap()
|
||||
abstract_patterns = [
|
||||
[{U_ENT: True}],
|
||||
[{B2_ENT: True}, {L2_ENT: True}],
|
||||
[{B3_ENT: True}, {I3_ENT: True}, {L3_ENT: True}],
|
||||
[{B4_ENT: True}, {I4_ENT: True}, {I4_ENT: True, "OP": "+"}, {L4_ENT: True}],
|
||||
]
|
||||
self.matcher.add("Candidate", None, *abstract_patterns)
|
||||
self._callbacks = {}
|
||||
self._docs = {}
|
||||
self._validate = validate
|
||||
|
||||
def __len__(self):
|
||||
"""Get the number of rules added to the matcher. Note that this only
|
||||
returns the number of rules (identical with the number of IDs), not the
|
||||
number of individual patterns.
|
||||
"""Get the number of match IDs added to the matcher.
|
||||
|
||||
RETURNS (int): The number of rules.
|
||||
|
||||
DOCS: https://spacy.io/api/phrasematcher#len
|
||||
"""
|
||||
return len(self._docs)
|
||||
return len(self._callbacks)
|
||||
|
||||
def __contains__(self, key):
|
||||
"""Check whether the matcher contains rules for a match ID.
|
||||
|
@ -102,13 +76,77 @@ cdef class PhraseMatcher:
|
|||
|
||||
DOCS: https://spacy.io/api/phrasematcher#contains
|
||||
"""
|
||||
cdef hash_t ent_id = self.matcher._normalize_key(key)
|
||||
return ent_id in self._callbacks
|
||||
return key in self._callbacks
|
||||
|
||||
def __reduce__(self):
|
||||
data = (self.vocab, self._docs, self._callbacks)
|
||||
return (unpickle_matcher, data, None, None)
|
||||
|
||||
def remove(self, key):
|
||||
"""Remove a rule from the matcher by match ID. A KeyError is raised if
|
||||
the key does not exist.
|
||||
|
||||
key (unicode): The match ID.
|
||||
"""
|
||||
if key not in self._docs:
|
||||
raise KeyError(key)
|
||||
cdef MapStruct* current_node
|
||||
cdef MapStruct* terminal_map
|
||||
cdef MapStruct* node_pointer
|
||||
cdef void* result
|
||||
cdef key_t terminal_key
|
||||
cdef void* value
|
||||
cdef int c_i = 0
|
||||
cdef vector[MapStruct*] path_nodes
|
||||
cdef vector[key_t] path_keys
|
||||
cdef key_t key_to_remove
|
||||
for keyword in self._docs[key]:
|
||||
current_node = self.c_map
|
||||
for token in keyword:
|
||||
result = map_get(current_node, token)
|
||||
if result:
|
||||
path_nodes.push_back(current_node)
|
||||
path_keys.push_back(token)
|
||||
current_node = <MapStruct*>result
|
||||
else:
|
||||
# if token is not found, break out of the loop
|
||||
current_node = NULL
|
||||
break
|
||||
# remove the tokens from trie node if there are no other
|
||||
# keywords with them
|
||||
result = map_get(current_node, self._terminal_hash)
|
||||
if current_node != NULL and result:
|
||||
terminal_map = <MapStruct*>result
|
||||
terminal_keys = []
|
||||
c_i = 0
|
||||
while map_iter(terminal_map, &c_i, &terminal_key, &value):
|
||||
terminal_keys.append(self.vocab.strings[terminal_key])
|
||||
# if this is the only remaining key, remove unnecessary paths
|
||||
if terminal_keys == [key]:
|
||||
while not path_nodes.empty():
|
||||
node_pointer = path_nodes.back()
|
||||
path_nodes.pop_back()
|
||||
key_to_remove = path_keys.back()
|
||||
path_keys.pop_back()
|
||||
result = map_get(node_pointer, key_to_remove)
|
||||
if node_pointer.filled == 1:
|
||||
map_clear(node_pointer, key_to_remove)
|
||||
self.mem.free(result)
|
||||
else:
|
||||
# more than one key means more than 1 path,
|
||||
# delete not required path and keep the others
|
||||
map_clear(node_pointer, key_to_remove)
|
||||
self.mem.free(result)
|
||||
break
|
||||
# otherwise simply remove the key
|
||||
else:
|
||||
result = map_get(current_node, self._terminal_hash)
|
||||
if result:
|
||||
map_clear(<MapStruct*>result, self.vocab.strings[key])
|
||||
|
||||
del self._callbacks[key]
|
||||
del self._docs[key]
|
||||
|
||||
def add(self, key, on_match, *docs):
|
||||
"""Add a match-rule to the phrase-matcher. A match-rule consists of: an ID
|
||||
key, an on_match callback, and one or more patterns.
|
||||
|
@ -119,17 +157,17 @@ cdef class PhraseMatcher:
|
|||
|
||||
DOCS: https://spacy.io/api/phrasematcher#add
|
||||
"""
|
||||
cdef Doc doc
|
||||
cdef hash_t ent_id = self.matcher._normalize_key(key)
|
||||
self._callbacks[ent_id] = on_match
|
||||
self._docs[ent_id] = docs
|
||||
cdef int length
|
||||
cdef int i
|
||||
cdef hash_t phrase_hash
|
||||
cdef Pool mem = Pool()
|
||||
|
||||
_ = self.vocab[key]
|
||||
self._callbacks[key] = on_match
|
||||
self._docs.setdefault(key, set())
|
||||
|
||||
cdef MapStruct* current_node
|
||||
cdef MapStruct* internal_node
|
||||
cdef void* result
|
||||
|
||||
for doc in docs:
|
||||
length = doc.length
|
||||
if length == 0:
|
||||
if len(doc) == 0:
|
||||
continue
|
||||
if self.attr in (POS, TAG, LEMMA) and not doc.is_tagged:
|
||||
raise ValueError(Errors.E155.format())
|
||||
|
@ -139,33 +177,33 @@ cdef class PhraseMatcher:
|
|||
and self.attr not in (DEP, POS, TAG, LEMMA):
|
||||
string_attr = self.vocab.strings[self.attr]
|
||||
user_warning(Warnings.W012.format(key=key, attr=string_attr))
|
||||
tags = get_biluo(length)
|
||||
phrase_key = <attr_t*>mem.alloc(length, sizeof(attr_t))
|
||||
for i, tag in enumerate(tags):
|
||||
attr_value = self.get_lex_value(doc, i)
|
||||
lexeme = self.vocab[attr_value]
|
||||
lexeme.set_flag(tag, True)
|
||||
phrase_key[i] = lexeme.orth
|
||||
phrase_hash = hash64(phrase_key, length * sizeof(attr_t), 0)
|
||||
|
||||
if phrase_hash in self.phrase_ids:
|
||||
phrase_index = self.phrase_ids[phrase_hash]
|
||||
ent_id_list = self.ent_id_matrix[phrase_index]
|
||||
ent_id_list.append(ent_id)
|
||||
self.ent_id_matrix[phrase_index] = ent_id_list
|
||||
|
||||
if isinstance(doc, Doc):
|
||||
keyword = self._convert_to_array(doc)
|
||||
else:
|
||||
ent_id_list = hash_vec(1)
|
||||
ent_id_list[0] = ent_id
|
||||
new_index = self.ent_id_matrix.size()
|
||||
if new_index == 0:
|
||||
# PreshMaps can not contain 0 as value, so storing a dummy at 0
|
||||
self.ent_id_matrix.push_back(hash_vec(0))
|
||||
new_index = 1
|
||||
self.ent_id_matrix.push_back(ent_id_list)
|
||||
self.phrase_ids.set(phrase_hash, <void*>new_index)
|
||||
keyword = doc
|
||||
self._docs[key].add(tuple(keyword))
|
||||
|
||||
def __call__(self, Doc doc):
|
||||
current_node = self.c_map
|
||||
for token in keyword:
|
||||
if token == self._terminal_hash:
|
||||
user_warning(Warnings.W021)
|
||||
break
|
||||
result = <MapStruct*>map_get(current_node, token)
|
||||
if not result:
|
||||
internal_node = <MapStruct*>self.mem.alloc(1, sizeof(MapStruct))
|
||||
map_init(self.mem, internal_node, 8)
|
||||
map_set(self.mem, current_node, token, internal_node)
|
||||
result = internal_node
|
||||
current_node = <MapStruct*>result
|
||||
result = <MapStruct*>map_get(current_node, self._terminal_hash)
|
||||
if not result:
|
||||
internal_node = <MapStruct*>self.mem.alloc(1, sizeof(MapStruct))
|
||||
map_init(self.mem, internal_node, 8)
|
||||
map_set(self.mem, current_node, self._terminal_hash, internal_node)
|
||||
result = internal_node
|
||||
map_set(self.mem, <MapStruct*>result, self.vocab.strings[key], NULL)
|
||||
|
||||
def __call__(self, doc):
|
||||
"""Find all sequences matching the supplied patterns on the `Doc`.
|
||||
|
||||
doc (Doc): The document to match over.
|
||||
|
@ -176,25 +214,63 @@ cdef class PhraseMatcher:
|
|||
DOCS: https://spacy.io/api/phrasematcher#call
|
||||
"""
|
||||
matches = []
|
||||
if self.attr == ORTH:
|
||||
match_doc = doc
|
||||
else:
|
||||
# If we're not matching on the ORTH, match_doc will be a Doc whose
|
||||
# token.orth values are the attribute values we're matching on,
|
||||
# e.g. Doc(nlp.vocab, words=[token.pos_ for token in doc])
|
||||
words = [self.get_lex_value(doc, i) for i in range(len(doc))]
|
||||
match_doc = Doc(self.vocab, words=words)
|
||||
for _, start, end in self.matcher(match_doc):
|
||||
ent_ids = self.accept_match(match_doc, start, end)
|
||||
if ent_ids is not None:
|
||||
for ent_id in ent_ids:
|
||||
matches.append((ent_id, start, end))
|
||||
if doc is None or len(doc) == 0:
|
||||
# if doc is empty or None just return empty list
|
||||
return matches
|
||||
|
||||
cdef vector[MatchStruct] c_matches
|
||||
self.find_matches(doc, &c_matches)
|
||||
for i in range(c_matches.size()):
|
||||
matches.append((c_matches[i].match_id, c_matches[i].start, c_matches[i].end))
|
||||
for i, (ent_id, start, end) in enumerate(matches):
|
||||
on_match = self._callbacks.get(ent_id)
|
||||
if on_match is not None:
|
||||
on_match(self, doc, i, matches)
|
||||
return matches
|
||||
|
||||
cdef void find_matches(self, Doc doc, vector[MatchStruct] *matches) nogil:
|
||||
cdef MapStruct* current_node = self.c_map
|
||||
cdef int start = 0
|
||||
cdef int idx = 0
|
||||
cdef int idy = 0
|
||||
cdef key_t key
|
||||
cdef void* value
|
||||
cdef int i = 0
|
||||
cdef MatchStruct ms
|
||||
cdef void* result
|
||||
while idx < doc.length:
|
||||
start = idx
|
||||
token = Token.get_struct_attr(&doc.c[idx], self.attr)
|
||||
# look for sequences from this position
|
||||
result = map_get(current_node, token)
|
||||
if result:
|
||||
current_node = <MapStruct*>result
|
||||
idy = idx + 1
|
||||
while idy < doc.length:
|
||||
result = map_get(current_node, self._terminal_hash)
|
||||
if result:
|
||||
i = 0
|
||||
while map_iter(<MapStruct*>result, &i, &key, &value):
|
||||
ms = make_matchstruct(key, start, idy)
|
||||
matches.push_back(ms)
|
||||
inner_token = Token.get_struct_attr(&doc.c[idy], self.attr)
|
||||
result = map_get(current_node, inner_token)
|
||||
if result:
|
||||
current_node = <MapStruct*>result
|
||||
idy += 1
|
||||
else:
|
||||
break
|
||||
else:
|
||||
# end of doc reached
|
||||
result = map_get(current_node, self._terminal_hash)
|
||||
if result:
|
||||
i = 0
|
||||
while map_iter(<MapStruct*>result, &i, &key, &value):
|
||||
ms = make_matchstruct(key, start, idy)
|
||||
matches.push_back(ms)
|
||||
current_node = self.c_map
|
||||
idx += 1
|
||||
|
||||
def pipe(self, stream, batch_size=1000, n_threads=-1, return_matches=False,
|
||||
as_tuples=False):
|
||||
"""Match a stream of documents, yielding them in turn.
|
||||
|
@ -228,48 +304,8 @@ cdef class PhraseMatcher:
|
|||
else:
|
||||
yield doc
|
||||
|
||||
def accept_match(self, Doc doc, int start, int end):
|
||||
cdef int i, j
|
||||
cdef Pool mem = Pool()
|
||||
phrase_key = <attr_t*>mem.alloc(end-start, sizeof(attr_t))
|
||||
for i, j in enumerate(range(start, end)):
|
||||
phrase_key[i] = doc.c[j].lex.orth
|
||||
cdef hash_t key = hash64(phrase_key, (end-start) * sizeof(attr_t), 0)
|
||||
|
||||
ent_index = <hash_t>self.phrase_ids.get(key)
|
||||
if ent_index == 0:
|
||||
return None
|
||||
return self.ent_id_matrix[ent_index]
|
||||
|
||||
def get_lex_value(self, Doc doc, int i):
|
||||
if self.attr == ORTH:
|
||||
# Return the regular orth value of the lexeme
|
||||
return doc.c[i].lex.orth
|
||||
# Get the attribute value instead, e.g. token.pos
|
||||
attr_value = get_token_attr(&doc.c[i], self.attr)
|
||||
if attr_value in (0, 1):
|
||||
# Value is boolean, convert to string
|
||||
string_attr_value = str(attr_value)
|
||||
else:
|
||||
string_attr_value = self.vocab.strings[attr_value]
|
||||
string_attr_name = self.vocab.strings[self.attr]
|
||||
# Concatenate the attr name and value to not pollute lexeme space
|
||||
# e.g. 'POS-VERB' instead of just 'VERB', which could otherwise
|
||||
# create false positive matches
|
||||
return "matcher:{}-{}".format(string_attr_name, string_attr_value)
|
||||
|
||||
|
||||
def get_biluo(length):
|
||||
if length == 0:
|
||||
raise ValueError(Errors.E127)
|
||||
elif length == 1:
|
||||
return [U_ENT]
|
||||
elif length == 2:
|
||||
return [B2_ENT, L2_ENT]
|
||||
elif length == 3:
|
||||
return [B3_ENT, I3_ENT, L3_ENT]
|
||||
else:
|
||||
return [B4_ENT, I4_ENT] + [I4_ENT] * (length-3) + [L4_ENT]
|
||||
def _convert_to_array(self, Doc doc):
|
||||
return [Token.get_struct_attr(&doc.c[i], self.attr) for i in range(len(doc))]
|
||||
|
||||
|
||||
def unpickle_matcher(vocab, docs, callbacks):
|
||||
|
@ -278,3 +314,11 @@ def unpickle_matcher(vocab, docs, callbacks):
|
|||
callback = callbacks.get(key, None)
|
||||
matcher.add(key, callback, *specs)
|
||||
return matcher
|
||||
|
||||
|
||||
cdef MatchStruct make_matchstruct(key_t match_id, int start, int end) nogil:
|
||||
cdef MatchStruct ms
|
||||
ms.match_id = match_id
|
||||
ms.start = start
|
||||
ms.end = end
|
||||
return ms
|
||||
|
|
|
@ -8,10 +8,31 @@ from ..util import get_doc
|
|||
|
||||
|
||||
def test_matcher_phrase_matcher(en_vocab):
|
||||
doc = Doc(en_vocab, words=["Google", "Now"])
|
||||
matcher = PhraseMatcher(en_vocab)
|
||||
matcher.add("COMPANY", None, doc)
|
||||
doc = Doc(en_vocab, words=["I", "like", "Google", "Now", "best"])
|
||||
# intermediate phrase
|
||||
pattern = Doc(en_vocab, words=["Google", "Now"])
|
||||
matcher = PhraseMatcher(en_vocab)
|
||||
matcher.add("COMPANY", None, pattern)
|
||||
assert len(matcher(doc)) == 1
|
||||
# initial token
|
||||
pattern = Doc(en_vocab, words=["I"])
|
||||
matcher = PhraseMatcher(en_vocab)
|
||||
matcher.add("I", None, pattern)
|
||||
assert len(matcher(doc)) == 1
|
||||
# initial phrase
|
||||
pattern = Doc(en_vocab, words=["I", "like"])
|
||||
matcher = PhraseMatcher(en_vocab)
|
||||
matcher.add("ILIKE", None, pattern)
|
||||
assert len(matcher(doc)) == 1
|
||||
# final token
|
||||
pattern = Doc(en_vocab, words=["best"])
|
||||
matcher = PhraseMatcher(en_vocab)
|
||||
matcher.add("BEST", None, pattern)
|
||||
assert len(matcher(doc)) == 1
|
||||
# final phrase
|
||||
pattern = Doc(en_vocab, words=["Now", "best"])
|
||||
matcher = PhraseMatcher(en_vocab)
|
||||
matcher.add("NOWBEST", None, pattern)
|
||||
assert len(matcher(doc)) == 1
|
||||
|
||||
|
||||
|
@ -31,6 +52,68 @@ def test_phrase_matcher_contains(en_vocab):
|
|||
assert "TEST2" not in matcher
|
||||
|
||||
|
||||
def test_phrase_matcher_repeated_add(en_vocab):
|
||||
matcher = PhraseMatcher(en_vocab)
|
||||
# match ID only gets added once
|
||||
matcher.add("TEST", None, Doc(en_vocab, words=["like"]))
|
||||
matcher.add("TEST", None, Doc(en_vocab, words=["like"]))
|
||||
matcher.add("TEST", None, Doc(en_vocab, words=["like"]))
|
||||
matcher.add("TEST", None, Doc(en_vocab, words=["like"]))
|
||||
doc = Doc(en_vocab, words=["I", "like", "Google", "Now", "best"])
|
||||
assert "TEST" in matcher
|
||||
assert "TEST2" not in matcher
|
||||
assert len(matcher(doc)) == 1
|
||||
|
||||
|
||||
def test_phrase_matcher_remove(en_vocab):
|
||||
matcher = PhraseMatcher(en_vocab)
|
||||
matcher.add("TEST1", None, Doc(en_vocab, words=["like"]))
|
||||
matcher.add("TEST2", None, Doc(en_vocab, words=["best"]))
|
||||
doc = Doc(en_vocab, words=["I", "like", "Google", "Now", "best"])
|
||||
assert "TEST1" in matcher
|
||||
assert "TEST2" in matcher
|
||||
assert "TEST3" not in matcher
|
||||
assert len(matcher(doc)) == 2
|
||||
matcher.remove("TEST1")
|
||||
assert "TEST1" not in matcher
|
||||
assert "TEST2" in matcher
|
||||
assert "TEST3" not in matcher
|
||||
assert len(matcher(doc)) == 1
|
||||
matcher.remove("TEST2")
|
||||
assert "TEST1" not in matcher
|
||||
assert "TEST2" not in matcher
|
||||
assert "TEST3" not in matcher
|
||||
assert len(matcher(doc)) == 0
|
||||
with pytest.raises(KeyError):
|
||||
matcher.remove("TEST3")
|
||||
assert "TEST1" not in matcher
|
||||
assert "TEST2" not in matcher
|
||||
assert "TEST3" not in matcher
|
||||
assert len(matcher(doc)) == 0
|
||||
|
||||
|
||||
def test_phrase_matcher_overlapping_with_remove(en_vocab):
|
||||
matcher = PhraseMatcher(en_vocab)
|
||||
matcher.add("TEST", None, Doc(en_vocab, words=["like"]))
|
||||
# TEST2 is added alongside TEST
|
||||
matcher.add("TEST2", None, Doc(en_vocab, words=["like"]))
|
||||
doc = Doc(en_vocab, words=["I", "like", "Google", "Now", "best"])
|
||||
assert "TEST" in matcher
|
||||
assert len(matcher) == 2
|
||||
assert len(matcher(doc)) == 2
|
||||
# removing TEST does not remove the entry for TEST2
|
||||
matcher.remove("TEST")
|
||||
assert "TEST" not in matcher
|
||||
assert len(matcher) == 1
|
||||
assert len(matcher(doc)) == 1
|
||||
assert matcher(doc)[0][0] == en_vocab.strings["TEST2"]
|
||||
# removing TEST2 removes all
|
||||
matcher.remove("TEST2")
|
||||
assert "TEST2" not in matcher
|
||||
assert len(matcher) == 0
|
||||
assert len(matcher(doc)) == 0
|
||||
|
||||
|
||||
def test_phrase_matcher_string_attrs(en_vocab):
|
||||
words1 = ["I", "like", "cats"]
|
||||
pos1 = ["PRON", "VERB", "NOUN"]
|
||||
|
|
Loading…
Reference in New Issue
Block a user