mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-20 21:40:35 +03:00
Merge branch 'feature/hashmatcher' into bugfix/tokenizer-special-cases-matcher
This commit is contained in:
commit
63b014d09f
|
@ -86,6 +86,8 @@ class Warnings(object):
|
||||||
"previously loaded vectors. See Issue #3853.")
|
"previously loaded vectors. See Issue #3853.")
|
||||||
W020 = ("Unnamed vectors. This won't allow multiple vectors models to be "
|
W020 = ("Unnamed vectors. This won't allow multiple vectors models to be "
|
||||||
"loaded. (Shape: {shape})")
|
"loaded. (Shape: {shape})")
|
||||||
|
W021 = ("Unexpected hash collision in PhraseMatcher. Matches may be "
|
||||||
|
"incorrect. Modify PhraseMatcher._terminal_hash to fix.")
|
||||||
|
|
||||||
|
|
||||||
@add_codes
|
@add_codes
|
||||||
|
|
|
@ -1,5 +1,28 @@
|
||||||
from libcpp.vector cimport vector
|
from libcpp.vector cimport vector
|
||||||
|
|
||||||
from ..typedefs cimport hash_t
|
from cymem.cymem cimport Pool
|
||||||
|
from preshed.maps cimport key_t, MapStruct
|
||||||
|
|
||||||
ctypedef vector[hash_t] hash_vec
|
from ..attrs cimport attr_id_t
|
||||||
|
from ..tokens.doc cimport Doc
|
||||||
|
from ..vocab cimport Vocab
|
||||||
|
|
||||||
|
|
||||||
|
cdef class PhraseMatcher:
|
||||||
|
cdef Vocab vocab
|
||||||
|
cdef attr_id_t attr
|
||||||
|
cdef object _callbacks
|
||||||
|
cdef object _keywords
|
||||||
|
cdef object _docs
|
||||||
|
cdef bint _validate
|
||||||
|
cdef MapStruct* c_map
|
||||||
|
cdef Pool mem
|
||||||
|
cdef key_t _terminal_hash
|
||||||
|
|
||||||
|
cdef void find_matches(self, Doc doc, vector[MatchStruct] *matches) nogil
|
||||||
|
|
||||||
|
|
||||||
|
cdef struct MatchStruct:
|
||||||
|
key_t match_id
|
||||||
|
int start
|
||||||
|
int end
|
||||||
|
|
|
@ -2,28 +2,16 @@
|
||||||
# cython: profile=True
|
# cython: profile=True
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
from libcpp.vector cimport vector
|
from libc.stdint cimport uintptr_t
|
||||||
from cymem.cymem cimport Pool
|
|
||||||
from murmurhash.mrmr cimport hash64
|
|
||||||
from preshed.maps cimport PreshMap
|
|
||||||
|
|
||||||
from .matcher cimport Matcher
|
from preshed.maps cimport map_init, map_set, map_get, map_clear, map_iter
|
||||||
from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA, attr_id_t
|
|
||||||
from ..vocab cimport Vocab
|
from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA
|
||||||
from ..tokens.doc cimport Doc, get_token_attr
|
from ..structs cimport TokenC
|
||||||
from ..typedefs cimport attr_t, hash_t
|
from ..tokens.token cimport Token
|
||||||
|
|
||||||
from ._schemas import TOKEN_PATTERN_SCHEMA
|
from ._schemas import TOKEN_PATTERN_SCHEMA
|
||||||
from ..errors import Errors, Warnings, deprecation_warning, user_warning
|
from ..errors import Errors, Warnings, deprecation_warning, user_warning
|
||||||
from ..attrs import FLAG61 as U_ENT
|
|
||||||
from ..attrs import FLAG60 as B2_ENT
|
|
||||||
from ..attrs import FLAG59 as B3_ENT
|
|
||||||
from ..attrs import FLAG58 as B4_ENT
|
|
||||||
from ..attrs import FLAG43 as L2_ENT
|
|
||||||
from ..attrs import FLAG42 as L3_ENT
|
|
||||||
from ..attrs import FLAG41 as L4_ENT
|
|
||||||
from ..attrs import FLAG42 as I3_ENT
|
|
||||||
from ..attrs import FLAG41 as I4_ENT
|
|
||||||
|
|
||||||
|
|
||||||
cdef class PhraseMatcher:
|
cdef class PhraseMatcher:
|
||||||
|
@ -33,18 +21,11 @@ cdef class PhraseMatcher:
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/phrasematcher
|
DOCS: https://spacy.io/api/phrasematcher
|
||||||
USAGE: https://spacy.io/usage/rule-based-matching#phrasematcher
|
USAGE: https://spacy.io/usage/rule-based-matching#phrasematcher
|
||||||
|
|
||||||
|
Adapted from FlashText: https://github.com/vi3k6i5/flashtext
|
||||||
|
MIT License (see `LICENSE`)
|
||||||
|
Copyright (c) 2017 Vikash Singh (vikash.duliajan@gmail.com)
|
||||||
"""
|
"""
|
||||||
cdef Pool mem
|
|
||||||
cdef Vocab vocab
|
|
||||||
cdef Matcher matcher
|
|
||||||
cdef PreshMap phrase_ids
|
|
||||||
cdef vector[hash_vec] ent_id_matrix
|
|
||||||
cdef int max_length
|
|
||||||
cdef attr_id_t attr
|
|
||||||
cdef public object _callbacks
|
|
||||||
cdef public object _patterns
|
|
||||||
cdef public object _docs
|
|
||||||
cdef public object _validate
|
|
||||||
|
|
||||||
def __init__(self, Vocab vocab, max_length=0, attr="ORTH", validate=False):
|
def __init__(self, Vocab vocab, max_length=0, attr="ORTH", validate=False):
|
||||||
"""Initialize the PhraseMatcher.
|
"""Initialize the PhraseMatcher.
|
||||||
|
@ -58,10 +39,17 @@ cdef class PhraseMatcher:
|
||||||
"""
|
"""
|
||||||
if max_length != 0:
|
if max_length != 0:
|
||||||
deprecation_warning(Warnings.W010)
|
deprecation_warning(Warnings.W010)
|
||||||
self.mem = Pool()
|
|
||||||
self.max_length = max_length
|
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.matcher = Matcher(self.vocab, validate=False)
|
self._callbacks = {}
|
||||||
|
self._keywords = {}
|
||||||
|
self._docs = {}
|
||||||
|
self._validate = validate
|
||||||
|
|
||||||
|
self.mem = Pool()
|
||||||
|
self.c_map = <MapStruct*>self.mem.alloc(1, sizeof(MapStruct))
|
||||||
|
self._terminal_hash = 826361138722620965
|
||||||
|
map_init(self.mem, self.c_map, 8)
|
||||||
|
|
||||||
if isinstance(attr, long):
|
if isinstance(attr, long):
|
||||||
self.attr = attr
|
self.attr = attr
|
||||||
else:
|
else:
|
||||||
|
@ -71,28 +59,15 @@ cdef class PhraseMatcher:
|
||||||
if attr not in TOKEN_PATTERN_SCHEMA["items"]["properties"]:
|
if attr not in TOKEN_PATTERN_SCHEMA["items"]["properties"]:
|
||||||
raise ValueError(Errors.E152.format(attr=attr))
|
raise ValueError(Errors.E152.format(attr=attr))
|
||||||
self.attr = self.vocab.strings[attr]
|
self.attr = self.vocab.strings[attr]
|
||||||
self.phrase_ids = PreshMap()
|
|
||||||
abstract_patterns = [
|
|
||||||
[{U_ENT: True}],
|
|
||||||
[{B2_ENT: True}, {L2_ENT: True}],
|
|
||||||
[{B3_ENT: True}, {I3_ENT: True}, {L3_ENT: True}],
|
|
||||||
[{B4_ENT: True}, {I4_ENT: True}, {I4_ENT: True, "OP": "+"}, {L4_ENT: True}],
|
|
||||||
]
|
|
||||||
self.matcher.add("Candidate", None, *abstract_patterns)
|
|
||||||
self._callbacks = {}
|
|
||||||
self._docs = {}
|
|
||||||
self._validate = validate
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
"""Get the number of rules added to the matcher. Note that this only
|
"""Get the number of match IDs added to the matcher.
|
||||||
returns the number of rules (identical with the number of IDs), not the
|
|
||||||
number of individual patterns.
|
|
||||||
|
|
||||||
RETURNS (int): The number of rules.
|
RETURNS (int): The number of rules.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/phrasematcher#len
|
DOCS: https://spacy.io/api/phrasematcher#len
|
||||||
"""
|
"""
|
||||||
return len(self._docs)
|
return len(self._callbacks)
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
"""Check whether the matcher contains rules for a match ID.
|
"""Check whether the matcher contains rules for a match ID.
|
||||||
|
@ -102,13 +77,78 @@ cdef class PhraseMatcher:
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/phrasematcher#contains
|
DOCS: https://spacy.io/api/phrasematcher#contains
|
||||||
"""
|
"""
|
||||||
cdef hash_t ent_id = self.matcher._normalize_key(key)
|
return key in self._callbacks
|
||||||
return ent_id in self._callbacks
|
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
data = (self.vocab, self._docs, self._callbacks)
|
data = (self.vocab, self._docs, self._callbacks)
|
||||||
return (unpickle_matcher, data, None, None)
|
return (unpickle_matcher, data, None, None)
|
||||||
|
|
||||||
|
def remove(self, key):
|
||||||
|
"""Remove a rule from the matcher by match ID. A KeyError is raised if
|
||||||
|
the key does not exist.
|
||||||
|
|
||||||
|
key (unicode): The match ID.
|
||||||
|
"""
|
||||||
|
if key not in self._keywords:
|
||||||
|
raise KeyError(key)
|
||||||
|
cdef MapStruct* current_node
|
||||||
|
cdef MapStruct* terminal_map
|
||||||
|
cdef MapStruct* node_pointer
|
||||||
|
cdef void* result
|
||||||
|
cdef key_t terminal_key
|
||||||
|
cdef void* value
|
||||||
|
cdef int c_i = 0
|
||||||
|
cdef vector[MapStruct*] path_nodes
|
||||||
|
cdef vector[key_t] path_keys
|
||||||
|
cdef key_t key_to_remove
|
||||||
|
for keyword in self._keywords[key]:
|
||||||
|
current_node = self.c_map
|
||||||
|
for token in keyword:
|
||||||
|
result = map_get(current_node, token)
|
||||||
|
if result:
|
||||||
|
path_nodes.push_back(current_node)
|
||||||
|
path_keys.push_back(token)
|
||||||
|
current_node = <MapStruct*>result
|
||||||
|
else:
|
||||||
|
# if token is not found, break out of the loop
|
||||||
|
current_node = NULL
|
||||||
|
break
|
||||||
|
# remove the tokens from trie node if there are no other
|
||||||
|
# keywords with them
|
||||||
|
result = map_get(current_node, self._terminal_hash)
|
||||||
|
if current_node != NULL and result:
|
||||||
|
terminal_map = <MapStruct*>result
|
||||||
|
terminal_keys = []
|
||||||
|
c_i = 0
|
||||||
|
while map_iter(terminal_map, &c_i, &terminal_key, &value):
|
||||||
|
terminal_keys.append(self.vocab.strings[terminal_key])
|
||||||
|
# if this is the only remaining key, remove unnecessary paths
|
||||||
|
if terminal_keys == [key]:
|
||||||
|
while not path_nodes.empty():
|
||||||
|
node_pointer = path_nodes.back()
|
||||||
|
path_nodes.pop_back()
|
||||||
|
key_to_remove = path_keys.back()
|
||||||
|
path_keys.pop_back()
|
||||||
|
result = map_get(node_pointer, key_to_remove)
|
||||||
|
if node_pointer.filled == 1:
|
||||||
|
map_clear(node_pointer, key_to_remove)
|
||||||
|
self.mem.free(result)
|
||||||
|
else:
|
||||||
|
# more than one key means more than 1 path,
|
||||||
|
# delete not required path and keep the others
|
||||||
|
map_clear(node_pointer, key_to_remove)
|
||||||
|
self.mem.free(result)
|
||||||
|
break
|
||||||
|
# otherwise simply remove the key
|
||||||
|
else:
|
||||||
|
result = map_get(current_node, self._terminal_hash)
|
||||||
|
if result:
|
||||||
|
map_clear(<MapStruct*>result, self.vocab.strings[key])
|
||||||
|
|
||||||
|
del self._keywords[key]
|
||||||
|
del self._callbacks[key]
|
||||||
|
del self._docs[key]
|
||||||
|
|
||||||
def add(self, key, on_match, *docs):
|
def add(self, key, on_match, *docs):
|
||||||
"""Add a match-rule to the phrase-matcher. A match-rule consists of: an ID
|
"""Add a match-rule to the phrase-matcher. A match-rule consists of: an ID
|
||||||
key, an on_match callback, and one or more patterns.
|
key, an on_match callback, and one or more patterns.
|
||||||
|
@ -119,17 +159,19 @@ cdef class PhraseMatcher:
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/phrasematcher#add
|
DOCS: https://spacy.io/api/phrasematcher#add
|
||||||
"""
|
"""
|
||||||
cdef Doc doc
|
|
||||||
cdef hash_t ent_id = self.matcher._normalize_key(key)
|
_ = self.vocab[key]
|
||||||
self._callbacks[ent_id] = on_match
|
self._callbacks[key] = on_match
|
||||||
self._docs[ent_id] = docs
|
self._keywords.setdefault(key, [])
|
||||||
cdef int length
|
self._docs.setdefault(key, set())
|
||||||
cdef int i
|
self._docs[key].update(docs)
|
||||||
cdef hash_t phrase_hash
|
|
||||||
cdef Pool mem = Pool()
|
cdef MapStruct* current_node
|
||||||
|
cdef MapStruct* internal_node
|
||||||
|
cdef void* result
|
||||||
|
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
length = doc.length
|
if len(doc) == 0:
|
||||||
if length == 0:
|
|
||||||
continue
|
continue
|
||||||
if self.attr in (POS, TAG, LEMMA) and not doc.is_tagged:
|
if self.attr in (POS, TAG, LEMMA) and not doc.is_tagged:
|
||||||
raise ValueError(Errors.E155.format())
|
raise ValueError(Errors.E155.format())
|
||||||
|
@ -139,33 +181,32 @@ cdef class PhraseMatcher:
|
||||||
and self.attr not in (DEP, POS, TAG, LEMMA):
|
and self.attr not in (DEP, POS, TAG, LEMMA):
|
||||||
string_attr = self.vocab.strings[self.attr]
|
string_attr = self.vocab.strings[self.attr]
|
||||||
user_warning(Warnings.W012.format(key=key, attr=string_attr))
|
user_warning(Warnings.W012.format(key=key, attr=string_attr))
|
||||||
tags = get_biluo(length)
|
keyword = self._convert_to_array(doc)
|
||||||
phrase_key = <attr_t*>mem.alloc(length, sizeof(attr_t))
|
# keep track of keywords per key to make remove easier
|
||||||
for i, tag in enumerate(tags):
|
# (would use a set, but can't hash numpy arrays)
|
||||||
attr_value = self.get_lex_value(doc, i)
|
self._keywords[key].append(keyword)
|
||||||
lexeme = self.vocab[attr_value]
|
|
||||||
lexeme.set_flag(tag, True)
|
|
||||||
phrase_key[i] = lexeme.orth
|
|
||||||
phrase_hash = hash64(phrase_key, length * sizeof(attr_t), 0)
|
|
||||||
|
|
||||||
if phrase_hash in self.phrase_ids:
|
current_node = self.c_map
|
||||||
phrase_index = self.phrase_ids[phrase_hash]
|
for token in keyword:
|
||||||
ent_id_list = self.ent_id_matrix[phrase_index]
|
if token == self._terminal_hash:
|
||||||
ent_id_list.append(ent_id)
|
user_warning(Warnings.W021)
|
||||||
self.ent_id_matrix[phrase_index] = ent_id_list
|
break
|
||||||
|
result = <MapStruct*>map_get(current_node, token)
|
||||||
|
if not result:
|
||||||
|
internal_node = <MapStruct*>self.mem.alloc(1, sizeof(MapStruct))
|
||||||
|
map_init(self.mem, internal_node, 8)
|
||||||
|
map_set(self.mem, current_node, token, internal_node)
|
||||||
|
result = internal_node
|
||||||
|
current_node = <MapStruct*>result
|
||||||
|
result = <MapStruct*>map_get(current_node, self._terminal_hash)
|
||||||
|
if not result:
|
||||||
|
internal_node = <MapStruct*>self.mem.alloc(1, sizeof(MapStruct))
|
||||||
|
map_init(self.mem, internal_node, 8)
|
||||||
|
map_set(self.mem, current_node, self._terminal_hash, internal_node)
|
||||||
|
result = internal_node
|
||||||
|
map_set(self.mem, <MapStruct*>result, self.vocab.strings[key], NULL)
|
||||||
|
|
||||||
else:
|
def __call__(self, doc):
|
||||||
ent_id_list = hash_vec(1)
|
|
||||||
ent_id_list[0] = ent_id
|
|
||||||
new_index = self.ent_id_matrix.size()
|
|
||||||
if new_index == 0:
|
|
||||||
# PreshMaps can not contain 0 as value, so storing a dummy at 0
|
|
||||||
self.ent_id_matrix.push_back(hash_vec(0))
|
|
||||||
new_index = 1
|
|
||||||
self.ent_id_matrix.push_back(ent_id_list)
|
|
||||||
self.phrase_ids.set(phrase_hash, <void*>new_index)
|
|
||||||
|
|
||||||
def __call__(self, Doc doc):
|
|
||||||
"""Find all sequences matching the supplied patterns on the `Doc`.
|
"""Find all sequences matching the supplied patterns on the `Doc`.
|
||||||
|
|
||||||
doc (Doc): The document to match over.
|
doc (Doc): The document to match over.
|
||||||
|
@ -176,25 +217,63 @@ cdef class PhraseMatcher:
|
||||||
DOCS: https://spacy.io/api/phrasematcher#call
|
DOCS: https://spacy.io/api/phrasematcher#call
|
||||||
"""
|
"""
|
||||||
matches = []
|
matches = []
|
||||||
if self.attr == ORTH:
|
if doc is None or len(doc) == 0:
|
||||||
match_doc = doc
|
# if doc is empty or None just return empty list
|
||||||
else:
|
return matches
|
||||||
# If we're not matching on the ORTH, match_doc will be a Doc whose
|
|
||||||
# token.orth values are the attribute values we're matching on,
|
cdef vector[MatchStruct] c_matches
|
||||||
# e.g. Doc(nlp.vocab, words=[token.pos_ for token in doc])
|
self.find_matches(doc, &c_matches)
|
||||||
words = [self.get_lex_value(doc, i) for i in range(len(doc))]
|
for i in range(c_matches.size()):
|
||||||
match_doc = Doc(self.vocab, words=words)
|
matches.append((c_matches[i].match_id, c_matches[i].start, c_matches[i].end))
|
||||||
for _, start, end in self.matcher(match_doc):
|
|
||||||
ent_ids = self.accept_match(match_doc, start, end)
|
|
||||||
if ent_ids is not None:
|
|
||||||
for ent_id in ent_ids:
|
|
||||||
matches.append((ent_id, start, end))
|
|
||||||
for i, (ent_id, start, end) in enumerate(matches):
|
for i, (ent_id, start, end) in enumerate(matches):
|
||||||
on_match = self._callbacks.get(ent_id)
|
on_match = self._callbacks.get(ent_id)
|
||||||
if on_match is not None:
|
if on_match is not None:
|
||||||
on_match(self, doc, i, matches)
|
on_match(self, doc, i, matches)
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
|
cdef void find_matches(self, Doc doc, vector[MatchStruct] *matches) nogil:
|
||||||
|
cdef MapStruct* current_node = self.c_map
|
||||||
|
cdef int start = 0
|
||||||
|
cdef int idx = 0
|
||||||
|
cdef int idy = 0
|
||||||
|
cdef key_t key
|
||||||
|
cdef void* value
|
||||||
|
cdef int i = 0
|
||||||
|
cdef MatchStruct ms
|
||||||
|
cdef void* result
|
||||||
|
while idx < doc.length:
|
||||||
|
start = idx
|
||||||
|
token = Token.get_struct_attr(&doc.c[idx], self.attr)
|
||||||
|
# look for sequences from this position
|
||||||
|
result = map_get(current_node, token)
|
||||||
|
if result:
|
||||||
|
current_node = <MapStruct*>result
|
||||||
|
idy = idx + 1
|
||||||
|
while idy < doc.length:
|
||||||
|
result = map_get(current_node, self._terminal_hash)
|
||||||
|
if result:
|
||||||
|
i = 0
|
||||||
|
while map_iter(<MapStruct*>result, &i, &key, &value):
|
||||||
|
ms = make_matchstruct(key, start, idy)
|
||||||
|
matches.push_back(ms)
|
||||||
|
inner_token = Token.get_struct_attr(&doc.c[idy], self.attr)
|
||||||
|
result = map_get(current_node, inner_token)
|
||||||
|
if result:
|
||||||
|
current_node = <MapStruct*>result
|
||||||
|
idy += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# end of doc reached
|
||||||
|
result = map_get(current_node, self._terminal_hash)
|
||||||
|
if result:
|
||||||
|
i = 0
|
||||||
|
while map_iter(<MapStruct*>result, &i, &key, &value):
|
||||||
|
ms = make_matchstruct(key, start, idy)
|
||||||
|
matches.push_back(ms)
|
||||||
|
current_node = self.c_map
|
||||||
|
idx += 1
|
||||||
|
|
||||||
def pipe(self, stream, batch_size=1000, n_threads=-1, return_matches=False,
|
def pipe(self, stream, batch_size=1000, n_threads=-1, return_matches=False,
|
||||||
as_tuples=False):
|
as_tuples=False):
|
||||||
"""Match a stream of documents, yielding them in turn.
|
"""Match a stream of documents, yielding them in turn.
|
||||||
|
@ -228,48 +307,8 @@ cdef class PhraseMatcher:
|
||||||
else:
|
else:
|
||||||
yield doc
|
yield doc
|
||||||
|
|
||||||
def accept_match(self, Doc doc, int start, int end):
|
def _convert_to_array(self, Doc doc):
|
||||||
cdef int i, j
|
return [Token.get_struct_attr(&doc.c[i], self.attr) for i in range(len(doc))]
|
||||||
cdef Pool mem = Pool()
|
|
||||||
phrase_key = <attr_t*>mem.alloc(end-start, sizeof(attr_t))
|
|
||||||
for i, j in enumerate(range(start, end)):
|
|
||||||
phrase_key[i] = doc.c[j].lex.orth
|
|
||||||
cdef hash_t key = hash64(phrase_key, (end-start) * sizeof(attr_t), 0)
|
|
||||||
|
|
||||||
ent_index = <hash_t>self.phrase_ids.get(key)
|
|
||||||
if ent_index == 0:
|
|
||||||
return None
|
|
||||||
return self.ent_id_matrix[ent_index]
|
|
||||||
|
|
||||||
def get_lex_value(self, Doc doc, int i):
|
|
||||||
if self.attr == ORTH:
|
|
||||||
# Return the regular orth value of the lexeme
|
|
||||||
return doc.c[i].lex.orth
|
|
||||||
# Get the attribute value instead, e.g. token.pos
|
|
||||||
attr_value = get_token_attr(&doc.c[i], self.attr)
|
|
||||||
if attr_value in (0, 1):
|
|
||||||
# Value is boolean, convert to string
|
|
||||||
string_attr_value = str(attr_value)
|
|
||||||
else:
|
|
||||||
string_attr_value = self.vocab.strings[attr_value]
|
|
||||||
string_attr_name = self.vocab.strings[self.attr]
|
|
||||||
# Concatenate the attr name and value to not pollute lexeme space
|
|
||||||
# e.g. 'POS-VERB' instead of just 'VERB', which could otherwise
|
|
||||||
# create false positive matches
|
|
||||||
return "matcher:{}-{}".format(string_attr_name, string_attr_value)
|
|
||||||
|
|
||||||
|
|
||||||
def get_biluo(length):
|
|
||||||
if length == 0:
|
|
||||||
raise ValueError(Errors.E127)
|
|
||||||
elif length == 1:
|
|
||||||
return [U_ENT]
|
|
||||||
elif length == 2:
|
|
||||||
return [B2_ENT, L2_ENT]
|
|
||||||
elif length == 3:
|
|
||||||
return [B3_ENT, I3_ENT, L3_ENT]
|
|
||||||
else:
|
|
||||||
return [B4_ENT, I4_ENT] + [I4_ENT] * (length-3) + [L4_ENT]
|
|
||||||
|
|
||||||
|
|
||||||
def unpickle_matcher(vocab, docs, callbacks):
|
def unpickle_matcher(vocab, docs, callbacks):
|
||||||
|
@ -278,3 +317,11 @@ def unpickle_matcher(vocab, docs, callbacks):
|
||||||
callback = callbacks.get(key, None)
|
callback = callbacks.get(key, None)
|
||||||
matcher.add(key, callback, *specs)
|
matcher.add(key, callback, *specs)
|
||||||
return matcher
|
return matcher
|
||||||
|
|
||||||
|
|
||||||
|
cdef MatchStruct make_matchstruct(key_t match_id, int start, int end) nogil:
|
||||||
|
cdef MatchStruct ms
|
||||||
|
ms.match_id = match_id
|
||||||
|
ms.start = start
|
||||||
|
ms.end = end
|
||||||
|
return ms
|
||||||
|
|
|
@ -8,10 +8,31 @@ from ..util import get_doc
|
||||||
|
|
||||||
|
|
||||||
def test_matcher_phrase_matcher(en_vocab):
|
def test_matcher_phrase_matcher(en_vocab):
|
||||||
doc = Doc(en_vocab, words=["Google", "Now"])
|
|
||||||
matcher = PhraseMatcher(en_vocab)
|
|
||||||
matcher.add("COMPANY", None, doc)
|
|
||||||
doc = Doc(en_vocab, words=["I", "like", "Google", "Now", "best"])
|
doc = Doc(en_vocab, words=["I", "like", "Google", "Now", "best"])
|
||||||
|
# intermediate phrase
|
||||||
|
pattern = Doc(en_vocab, words=["Google", "Now"])
|
||||||
|
matcher = PhraseMatcher(en_vocab)
|
||||||
|
matcher.add("COMPANY", None, pattern)
|
||||||
|
assert len(matcher(doc)) == 1
|
||||||
|
# initial token
|
||||||
|
pattern = Doc(en_vocab, words=["I"])
|
||||||
|
matcher = PhraseMatcher(en_vocab)
|
||||||
|
matcher.add("I", None, pattern)
|
||||||
|
assert len(matcher(doc)) == 1
|
||||||
|
# initial phrase
|
||||||
|
pattern = Doc(en_vocab, words=["I", "like"])
|
||||||
|
matcher = PhraseMatcher(en_vocab)
|
||||||
|
matcher.add("ILIKE", None, pattern)
|
||||||
|
assert len(matcher(doc)) == 1
|
||||||
|
# final token
|
||||||
|
pattern = Doc(en_vocab, words=["best"])
|
||||||
|
matcher = PhraseMatcher(en_vocab)
|
||||||
|
matcher.add("BEST", None, pattern)
|
||||||
|
assert len(matcher(doc)) == 1
|
||||||
|
# final phrase
|
||||||
|
pattern = Doc(en_vocab, words=["Now", "best"])
|
||||||
|
matcher = PhraseMatcher(en_vocab)
|
||||||
|
matcher.add("NOWBEST", None, pattern)
|
||||||
assert len(matcher(doc)) == 1
|
assert len(matcher(doc)) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,6 +52,68 @@ def test_phrase_matcher_contains(en_vocab):
|
||||||
assert "TEST2" not in matcher
|
assert "TEST2" not in matcher
|
||||||
|
|
||||||
|
|
||||||
|
def test_phrase_matcher_repeated_add(en_vocab):
|
||||||
|
matcher = PhraseMatcher(en_vocab)
|
||||||
|
# match ID only gets added once
|
||||||
|
matcher.add("TEST", None, Doc(en_vocab, words=["like"]))
|
||||||
|
matcher.add("TEST", None, Doc(en_vocab, words=["like"]))
|
||||||
|
matcher.add("TEST", None, Doc(en_vocab, words=["like"]))
|
||||||
|
matcher.add("TEST", None, Doc(en_vocab, words=["like"]))
|
||||||
|
doc = Doc(en_vocab, words=["I", "like", "Google", "Now", "best"])
|
||||||
|
assert "TEST" in matcher
|
||||||
|
assert "TEST2" not in matcher
|
||||||
|
assert len(matcher(doc)) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_phrase_matcher_remove(en_vocab):
|
||||||
|
matcher = PhraseMatcher(en_vocab)
|
||||||
|
matcher.add("TEST1", None, Doc(en_vocab, words=["like"]))
|
||||||
|
matcher.add("TEST2", None, Doc(en_vocab, words=["best"]))
|
||||||
|
doc = Doc(en_vocab, words=["I", "like", "Google", "Now", "best"])
|
||||||
|
assert "TEST1" in matcher
|
||||||
|
assert "TEST2" in matcher
|
||||||
|
assert "TEST3" not in matcher
|
||||||
|
assert len(matcher(doc)) == 2
|
||||||
|
matcher.remove("TEST1")
|
||||||
|
assert "TEST1" not in matcher
|
||||||
|
assert "TEST2" in matcher
|
||||||
|
assert "TEST3" not in matcher
|
||||||
|
assert len(matcher(doc)) == 1
|
||||||
|
matcher.remove("TEST2")
|
||||||
|
assert "TEST1" not in matcher
|
||||||
|
assert "TEST2" not in matcher
|
||||||
|
assert "TEST3" not in matcher
|
||||||
|
assert len(matcher(doc)) == 0
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
matcher.remove("TEST3")
|
||||||
|
assert "TEST1" not in matcher
|
||||||
|
assert "TEST2" not in matcher
|
||||||
|
assert "TEST3" not in matcher
|
||||||
|
assert len(matcher(doc)) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_phrase_matcher_overlapping_with_remove(en_vocab):
|
||||||
|
matcher = PhraseMatcher(en_vocab)
|
||||||
|
matcher.add("TEST", None, Doc(en_vocab, words=["like"]))
|
||||||
|
# TEST2 is added alongside TEST
|
||||||
|
matcher.add("TEST2", None, Doc(en_vocab, words=["like"]))
|
||||||
|
doc = Doc(en_vocab, words=["I", "like", "Google", "Now", "best"])
|
||||||
|
assert "TEST" in matcher
|
||||||
|
assert len(matcher) == 2
|
||||||
|
assert len(matcher(doc)) == 2
|
||||||
|
# removing TEST does not remove the entry for TEST2
|
||||||
|
matcher.remove("TEST")
|
||||||
|
assert "TEST" not in matcher
|
||||||
|
assert len(matcher) == 1
|
||||||
|
assert len(matcher(doc)) == 1
|
||||||
|
assert matcher(doc)[0][0] == en_vocab.strings["TEST2"]
|
||||||
|
# removing TEST2 removes all
|
||||||
|
matcher.remove("TEST2")
|
||||||
|
assert "TEST2" not in matcher
|
||||||
|
assert len(matcher) == 0
|
||||||
|
assert len(matcher(doc)) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_phrase_matcher_string_attrs(en_vocab):
|
def test_phrase_matcher_string_attrs(en_vocab):
|
||||||
words1 = ["I", "like", "cats"]
|
words1 = ["I", "like", "cats"]
|
||||||
pos1 = ["PRON", "VERB", "NOUN"]
|
pos1 = ["PRON", "VERB", "NOUN"]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user