mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
368 lines
14 KiB
Cython
368 lines
14 KiB
Cython
# cython: infer_types=True
|
|
from collections import defaultdict
|
|
from typing import List
|
|
|
|
from preshed.maps cimport map_clear, map_get, map_init, map_iter, map_set
|
|
|
|
import warnings
|
|
|
|
from ..attrs cimport DEP, LEMMA, MORPH, POS, TAG
|
|
|
|
from ..attrs import IDS
|
|
|
|
from ..tokens.span cimport Span
|
|
from ..tokens.token cimport Token
|
|
from ..typedefs cimport attr_t
|
|
|
|
from ..errors import Errors, Warnings
|
|
from ..schemas import TokenPattern
|
|
|
|
|
|
cdef class PhraseMatcher:
|
|
"""Efficiently match large terminology lists. While the `Matcher` matches
|
|
sequences based on lists of token descriptions, the `PhraseMatcher` accepts
|
|
match patterns in the form of `Doc` objects.
|
|
|
|
DOCS: https://spacy.io/api/phrasematcher
|
|
USAGE: https://spacy.io/usage/rule-based-matching#phrasematcher
|
|
|
|
Adapted from FlashText: https://github.com/vi3k6i5/flashtext
|
|
MIT License (see `LICENSE`)
|
|
Copyright (c) 2017 Vikash Singh (vikash.duliajan@gmail.com)
|
|
"""
|
|
|
|
def __init__(self, Vocab vocab, attr="ORTH", validate=False):
|
|
"""Initialize the PhraseMatcher.
|
|
|
|
vocab (Vocab): The shared vocabulary.
|
|
attr (int / str): Token attribute to match on.
|
|
validate (bool): Perform additional validation when patterns are added.
|
|
|
|
DOCS: https://spacy.io/api/phrasematcher#init
|
|
"""
|
|
self.vocab = vocab
|
|
self._callbacks = {}
|
|
self._docs = defaultdict(set)
|
|
self._validate = validate
|
|
|
|
self.mem = Pool()
|
|
self.c_map = <MapStruct*>self.mem.alloc(1, sizeof(MapStruct))
|
|
self._terminal_hash = 826361138722620965
|
|
map_init(self.mem, self.c_map, 8)
|
|
|
|
if isinstance(attr, (int, long)):
|
|
self.attr = attr
|
|
else:
|
|
if attr is None:
|
|
attr = "ORTH"
|
|
attr = attr.upper()
|
|
if attr == "TEXT":
|
|
attr = "ORTH"
|
|
if attr == "IS_SENT_START":
|
|
attr = "SENT_START"
|
|
if attr.lower() not in TokenPattern().dict():
|
|
raise ValueError(Errors.E152.format(attr=attr))
|
|
self.attr = IDS.get(attr)
|
|
|
|
def __len__(self):
|
|
"""Get the number of match IDs added to the matcher.
|
|
|
|
RETURNS (int): The number of rules.
|
|
|
|
DOCS: https://spacy.io/api/phrasematcher#len
|
|
"""
|
|
return len(self._callbacks)
|
|
|
|
def __contains__(self, key):
|
|
"""Check whether the matcher contains rules for a match ID.
|
|
|
|
key (str): The match ID.
|
|
RETURNS (bool): Whether the matcher contains rules for this match ID.
|
|
|
|
DOCS: https://spacy.io/api/phrasematcher#contains
|
|
"""
|
|
return key in self._callbacks
|
|
|
|
def __reduce__(self):
|
|
data = (self.vocab, self._docs, self._callbacks, self.attr)
|
|
return (unpickle_matcher, data, None, None)
|
|
|
|
def remove(self, key):
|
|
"""Remove a rule from the matcher by match ID. A KeyError is raised if
|
|
the key does not exist.
|
|
|
|
key (str): The match ID.
|
|
|
|
DOCS: https://spacy.io/api/phrasematcher#remove
|
|
"""
|
|
if key not in self._docs:
|
|
raise KeyError(key)
|
|
cdef MapStruct* current_node
|
|
cdef MapStruct* terminal_map
|
|
cdef MapStruct* node_pointer
|
|
cdef void* result
|
|
cdef key_t terminal_key
|
|
cdef void* value
|
|
cdef int c_i = 0
|
|
cdef vector[MapStruct*] path_nodes
|
|
cdef vector[key_t] path_keys
|
|
cdef key_t key_to_remove
|
|
for keyword in sorted(self._docs[key], key=lambda x: len(x), reverse=True):
|
|
current_node = self.c_map
|
|
path_nodes.clear()
|
|
path_keys.clear()
|
|
for token in keyword:
|
|
result = map_get(current_node, token)
|
|
if result:
|
|
path_nodes.push_back(current_node)
|
|
path_keys.push_back(token)
|
|
current_node = <MapStruct*>result
|
|
else:
|
|
# if token is not found, break out of the loop
|
|
current_node = NULL
|
|
break
|
|
path_nodes.push_back(current_node)
|
|
path_keys.push_back(self._terminal_hash)
|
|
# remove the tokens from trie node if there are no other
|
|
# keywords with them
|
|
result = map_get(current_node, self._terminal_hash)
|
|
if current_node != NULL and result:
|
|
terminal_map = <MapStruct*>result
|
|
terminal_keys = []
|
|
c_i = 0
|
|
while map_iter(terminal_map, &c_i, &terminal_key, &value):
|
|
terminal_keys.append(self.vocab.strings[terminal_key])
|
|
# if this is the only remaining key, remove unnecessary paths
|
|
if terminal_keys == [key]:
|
|
while not path_nodes.empty():
|
|
node_pointer = path_nodes.back()
|
|
path_nodes.pop_back()
|
|
key_to_remove = path_keys.back()
|
|
path_keys.pop_back()
|
|
result = map_get(node_pointer, key_to_remove)
|
|
if node_pointer.filled == 1:
|
|
map_clear(node_pointer, key_to_remove)
|
|
self.mem.free(result)
|
|
else:
|
|
# more than one key means more than 1 path,
|
|
# delete not required path and keep the others
|
|
map_clear(node_pointer, key_to_remove)
|
|
self.mem.free(result)
|
|
break
|
|
# otherwise simply remove the key
|
|
else:
|
|
result = map_get(current_node, self._terminal_hash)
|
|
if result:
|
|
map_clear(<MapStruct*>result, self.vocab.strings[key])
|
|
|
|
del self._callbacks[key]
|
|
del self._docs[key]
|
|
|
|
def _add_from_arrays(self, key, specs, *, on_match=None):
|
|
"""Add a preprocessed list of specs, with an optional callback.
|
|
|
|
key (str): The match ID.
|
|
specs (List[List[int]]): A list of lists of hashes to match.
|
|
on_match (callable): Callback executed on match.
|
|
"""
|
|
cdef MapStruct* current_node
|
|
cdef MapStruct* internal_node
|
|
cdef void* result
|
|
|
|
self._callbacks[key] = on_match
|
|
for spec in specs:
|
|
self._docs[key].add(tuple(spec))
|
|
|
|
current_node = self.c_map
|
|
for token in spec:
|
|
if token == self._terminal_hash:
|
|
warnings.warn(Warnings.W021)
|
|
break
|
|
result = <MapStruct*>map_get(current_node, token)
|
|
if not result:
|
|
internal_node = <MapStruct*>self.mem.alloc(1, sizeof(MapStruct))
|
|
map_init(self.mem, internal_node, 8)
|
|
map_set(self.mem, current_node, token, internal_node)
|
|
result = internal_node
|
|
current_node = <MapStruct*>result
|
|
result = <MapStruct*>map_get(current_node, self._terminal_hash)
|
|
if not result:
|
|
internal_node = <MapStruct*>self.mem.alloc(1, sizeof(MapStruct))
|
|
map_init(self.mem, internal_node, 8)
|
|
map_set(self.mem, current_node, self._terminal_hash, internal_node)
|
|
result = internal_node
|
|
map_set(self.mem, <MapStruct*>result, self.vocab.strings[key], NULL)
|
|
|
|
def add(self, key, docs, *, on_match=None):
|
|
"""Add a match-rule to the phrase-matcher. A match-rule consists of: an ID
|
|
key, a list of one or more patterns, and (optionally) an on_match callback.
|
|
|
|
key (str): The match ID.
|
|
docs (list): List of `Doc` objects representing match patterns.
|
|
on_match (callable): Callback executed on match.
|
|
|
|
If any of the input Docs are invalid, no internal state will be updated.
|
|
|
|
DOCS: https://spacy.io/api/phrasematcher#add
|
|
"""
|
|
if isinstance(docs, Doc):
|
|
raise ValueError(Errors.E179.format(key=key))
|
|
if docs is None or not isinstance(docs, List):
|
|
raise ValueError(Errors.E948.format(name="PhraseMatcher", arg_type=type(docs)))
|
|
if on_match is not None and not hasattr(on_match, "__call__"):
|
|
raise ValueError(Errors.E171.format(name="PhraseMatcher", arg_type=type(on_match)))
|
|
|
|
_ = self.vocab[key]
|
|
specs = []
|
|
|
|
for doc in docs:
|
|
if len(doc) == 0:
|
|
continue
|
|
if not isinstance(doc, Doc):
|
|
raise ValueError(Errors.E4000.format(type=type(doc)))
|
|
|
|
attrs = (TAG, POS, MORPH, LEMMA, DEP)
|
|
has_annotation = {attr: doc.has_annotation(attr) for attr in attrs}
|
|
for attr in attrs:
|
|
if self.attr == attr and not has_annotation[attr]:
|
|
if attr == TAG:
|
|
pipe = "tagger"
|
|
elif attr in (POS, MORPH):
|
|
pipe = "morphologizer or tagger+attribute_ruler"
|
|
elif attr == LEMMA:
|
|
pipe = "lemmatizer"
|
|
elif attr == DEP:
|
|
pipe = "parser"
|
|
error_msg = Errors.E155.format(pipe=pipe, attr=self.vocab.strings.as_string(attr))
|
|
raise ValueError(error_msg)
|
|
if self._validate and any(has_annotation.values()) \
|
|
and self.attr not in attrs:
|
|
string_attr = self.vocab.strings[self.attr]
|
|
warnings.warn(Warnings.W012.format(key=key, attr=string_attr))
|
|
specs.append(self._convert_to_array(doc))
|
|
|
|
self._add_from_arrays(key, specs, on_match=on_match)
|
|
|
|
def __call__(self, object doclike, *, as_spans=False):
|
|
"""Find all sequences matching the supplied patterns on the `Doc`.
|
|
|
|
doclike (Doc or Span): The document to match over.
|
|
as_spans (bool): Return Span objects with labels instead of (match_id,
|
|
start, end) tuples.
|
|
RETURNS (list): A list of `(match_id, start, end)` tuples,
|
|
describing the matches. A match tuple describes a span
|
|
`doc[start:end]`. The `match_id` is an integer. If as_spans is set
|
|
to True, a list of Span objects is returned.
|
|
|
|
DOCS: https://spacy.io/api/phrasematcher#call
|
|
"""
|
|
matches = []
|
|
if doclike is None or len(doclike) == 0:
|
|
# if doc is empty or None just return empty list
|
|
return matches
|
|
if isinstance(doclike, Doc):
|
|
doc = doclike
|
|
start_idx = 0
|
|
end_idx = len(doc)
|
|
elif isinstance(doclike, Span):
|
|
doc = doclike.doc
|
|
start_idx = doclike.start
|
|
end_idx = doclike.end
|
|
else:
|
|
raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doclike).__name__))
|
|
|
|
cdef vector[SpanC] c_matches
|
|
self.find_matches(doc, start_idx, end_idx, &c_matches)
|
|
for i in range(c_matches.size()):
|
|
matches.append((c_matches[i].label, c_matches[i].start, c_matches[i].end))
|
|
for i, (ent_id, start, end) in enumerate(matches):
|
|
on_match = self._callbacks.get(self.vocab.strings[ent_id])
|
|
if on_match is not None:
|
|
on_match(self, doc, i, matches)
|
|
if as_spans:
|
|
return [Span(doc, start, end, label=key) for key, start, end in matches]
|
|
else:
|
|
return matches
|
|
|
|
cdef void find_matches(self, Doc doc, int start_idx, int end_idx, vector[SpanC] *matches) nogil:
|
|
cdef MapStruct* current_node = self.c_map
|
|
cdef int start = 0
|
|
cdef int idx = start_idx
|
|
cdef int idy = start_idx
|
|
cdef key_t key
|
|
cdef void* value
|
|
cdef int i = 0
|
|
cdef SpanC ms
|
|
cdef void* result
|
|
while idx < end_idx:
|
|
start = idx
|
|
token = Token.get_struct_attr(&doc.c[idx], self.attr)
|
|
# look for sequences from this position
|
|
result = map_get(current_node, token)
|
|
if result:
|
|
current_node = <MapStruct*>result
|
|
idy = idx + 1
|
|
while idy < end_idx:
|
|
result = map_get(current_node, self._terminal_hash)
|
|
if result:
|
|
i = 0
|
|
while map_iter(<MapStruct*>result, &i, &key, &value):
|
|
ms = make_spanstruct(key, start, idy)
|
|
matches.push_back(ms)
|
|
inner_token = Token.get_struct_attr(&doc.c[idy], self.attr)
|
|
result = map_get(current_node, inner_token)
|
|
if result:
|
|
current_node = <MapStruct*>result
|
|
idy += 1
|
|
else:
|
|
break
|
|
else:
|
|
# end of doc reached
|
|
result = map_get(current_node, self._terminal_hash)
|
|
if result:
|
|
i = 0
|
|
while map_iter(<MapStruct*>result, &i, &key, &value):
|
|
ms = make_spanstruct(key, start, idy)
|
|
matches.push_back(ms)
|
|
current_node = self.c_map
|
|
idx += 1
|
|
|
|
def pipe(self, stream, batch_size=1000, return_matches=False, as_tuples=False):
|
|
"""Match a stream of documents, yielding them in turn. Deprecated as of
|
|
spaCy v3.0.
|
|
"""
|
|
warnings.warn(Warnings.W105.format(matcher="PhraseMatcher"), DeprecationWarning)
|
|
if as_tuples:
|
|
for doc, context in stream:
|
|
matches = self(doc)
|
|
if return_matches:
|
|
yield ((doc, matches), context)
|
|
else:
|
|
yield (doc, context)
|
|
else:
|
|
for doc in stream:
|
|
matches = self(doc)
|
|
if return_matches:
|
|
yield (doc, matches)
|
|
else:
|
|
yield doc
|
|
|
|
def _convert_to_array(self, Doc doc):
|
|
return [Token.get_struct_attr(&doc.c[i], self.attr) for i in range(len(doc))]
|
|
|
|
|
|
def unpickle_matcher(vocab, docs, callbacks, attr):
|
|
matcher = PhraseMatcher(vocab, attr=attr)
|
|
for key, specs in docs.items():
|
|
callback = callbacks.get(key, None)
|
|
matcher._add_from_arrays(key, specs, on_match=callback)
|
|
return matcher
|
|
|
|
|
|
cdef SpanC make_spanstruct(attr_t label, int start, int end) nogil:
|
|
cdef SpanC spanc
|
|
spanc.label = label
|
|
spanc.start = start
|
|
spanc.end = end
|
|
return spanc
|