spaCy/spacy/matcher/phrasematcher.pyx

281 lines
11 KiB
Cython
Raw Normal View History

# cython: infer_types=True
# cython: profile=True
from __future__ import unicode_literals
from libcpp.vector cimport vector
from cymem.cymem cimport Pool
from murmurhash.mrmr cimport hash64
from preshed.maps cimport PreshMap
from .matcher cimport Matcher
from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA, attr_id_t
from ..vocab cimport Vocab
from ..tokens.doc cimport Doc, get_token_attr
from ..typedefs cimport attr_t, hash_t
2019-08-21 15:00:37 +03:00
from ._schemas import TOKEN_PATTERN_SCHEMA
from ..errors import Errors, Warnings, deprecation_warning, user_warning
from ..attrs import FLAG61 as U_ENT
from ..attrs import FLAG60 as B2_ENT
from ..attrs import FLAG59 as B3_ENT
from ..attrs import FLAG58 as B4_ENT
from ..attrs import FLAG43 as L2_ENT
from ..attrs import FLAG42 as L3_ENT
from ..attrs import FLAG41 as L4_ENT
from ..attrs import FLAG42 as I3_ENT
from ..attrs import FLAG41 as I4_ENT
cdef class PhraseMatcher:
"""Efficiently match large terminology lists. While the `Matcher` matches
sequences based on lists of token descriptions, the `PhraseMatcher` accepts
match patterns in the form of `Doc` objects.
DOCS: https://spacy.io/api/phrasematcher
USAGE: https://spacy.io/usage/rule-based-matching#phrasematcher
"""
cdef Pool mem
cdef Vocab vocab
cdef Matcher matcher
cdef PreshMap phrase_ids
cdef vector[hash_vec] ent_id_matrix
cdef int max_length
cdef attr_id_t attr
cdef public object _callbacks
cdef public object _patterns
cdef public object _docs
cdef public object _validate
def __init__(self, Vocab vocab, max_length=0, attr="ORTH", validate=False):
"""Initialize the PhraseMatcher.
vocab (Vocab): The shared vocabulary.
attr (int / unicode): Token attribute to match on.
validate (bool): Perform additional validation when patterns are added.
RETURNS (PhraseMatcher): The newly constructed object.
DOCS: https://spacy.io/api/phrasematcher#init
"""
if max_length != 0:
deprecation_warning(Warnings.W010)
Replace PhraseMatcher with trie-based search (#4309) * Replace PhraseMatcher with Aho-Corasick Replace PhraseMatcher with the Aho-Corasick algorithm over numpy arrays of the hash values for the relevant attribute. The implementation is based on FlashText. The speed should be similar to the previous PhraseMatcher. It is now possible to easily remove match IDs and matches don't go missing with large keyword lists / vocabularies. Fixes #4308. * Restore support for pickling * Fix internal keyword add/remove for numpy arrays * Add missing loop for match ID set in search loop * Remove cruft in matching loop for partial matches There was a bit of unnecessary code left over from FlashText in the matching loop to handle partial token matches, which we don't have with PhraseMatcher. * Replace dict trie with MapStruct trie * Fix how match ID hash is stored/added * Update fix for match ID vocab * Switch from map_get_unless_missing to map_get * Switch from numpy array to Token.get_struct_attr Access token attributes directly in Doc instead of making a copy of the relevant values in a numpy array. Add unsatisfactory warning for hash collision with reserved terminal hash key. (Ideally it would change the reserved terminal hash and redo the whole trie, but for now, I'm hoping there won't be collisions.) * Restructure imports to export find_matches * Implement full remove() Remove unnecessary trie paths and free unused maps. Parallel to Matcher, raise KeyError when attempting to remove a match ID that has not been added. * Store docs internally only as attr lists * Reduces size for pickle * Remove duplicate keywords store Now that docs are stored as lists of attr hashes, there's no need to have the duplicate _keywords store.
2019-09-27 17:22:34 +03:00
self.mem = Pool()
self.max_length = max_length
self.vocab = vocab
self.matcher = Matcher(self.vocab, validate=False)
if isinstance(attr, long):
self.attr = attr
else:
2019-08-21 15:00:37 +03:00
attr = attr.upper()
if attr == "TEXT":
attr = "ORTH"
if attr not in TOKEN_PATTERN_SCHEMA["items"]["properties"]:
raise ValueError(Errors.E152.format(attr=attr))
self.attr = self.vocab.strings[attr]
self.phrase_ids = PreshMap()
abstract_patterns = [
[{U_ENT: True}],
[{B2_ENT: True}, {L2_ENT: True}],
[{B3_ENT: True}, {I3_ENT: True}, {L3_ENT: True}],
[{B4_ENT: True}, {I4_ENT: True}, {I4_ENT: True, "OP": "+"}, {L4_ENT: True}],
]
self.matcher.add("Candidate", None, *abstract_patterns)
self._callbacks = {}
self._docs = {}
self._validate = validate
def __len__(self):
"""Get the number of rules added to the matcher. Note that this only
returns the number of rules (identical with the number of IDs), not the
number of individual patterns.
RETURNS (int): The number of rules.
DOCS: https://spacy.io/api/phrasematcher#len
"""
return len(self._docs)
def __contains__(self, key):
"""Check whether the matcher contains rules for a match ID.
key (unicode): The match ID.
RETURNS (bool): Whether the matcher contains rules for this match ID.
DOCS: https://spacy.io/api/phrasematcher#contains
"""
cdef hash_t ent_id = self.matcher._normalize_key(key)
return ent_id in self._callbacks
def __reduce__(self):
data = (self.vocab, self._docs, self._callbacks)
return (unpickle_matcher, data, None, None)
def add(self, key, on_match, *docs):
"""Add a match-rule to the phrase-matcher. A match-rule consists of: an ID
key, an on_match callback, and one or more patterns.
key (unicode): The match ID.
on_match (callable): Callback executed on match.
*docs (Doc): `Doc` objects representing match patterns.
DOCS: https://spacy.io/api/phrasematcher#add
"""
cdef Doc doc
cdef hash_t ent_id = self.matcher._normalize_key(key)
self._callbacks[ent_id] = on_match
self._docs[ent_id] = docs
cdef int length
cdef int i
cdef hash_t phrase_hash
cdef Pool mem = Pool()
for doc in docs:
length = doc.length
if length == 0:
continue
if self.attr in (POS, TAG, LEMMA) and not doc.is_tagged:
raise ValueError(Errors.E155.format())
if self.attr == DEP and not doc.is_parsed:
raise ValueError(Errors.E156.format())
if self._validate and (doc.is_tagged or doc.is_parsed) \
and self.attr not in (DEP, POS, TAG, LEMMA):
string_attr = self.vocab.strings[self.attr]
user_warning(Warnings.W012.format(key=key, attr=string_attr))
tags = get_biluo(length)
phrase_key = <attr_t*>mem.alloc(length, sizeof(attr_t))
for i, tag in enumerate(tags):
attr_value = self.get_lex_value(doc, i)
lexeme = self.vocab[attr_value]
lexeme.set_flag(tag, True)
phrase_key[i] = lexeme.orth
phrase_hash = hash64(phrase_key, length * sizeof(attr_t), 0)
if phrase_hash in self.phrase_ids:
phrase_index = self.phrase_ids[phrase_hash]
ent_id_list = self.ent_id_matrix[phrase_index]
ent_id_list.append(ent_id)
self.ent_id_matrix[phrase_index] = ent_id_list
else:
ent_id_list = hash_vec(1)
ent_id_list[0] = ent_id
new_index = self.ent_id_matrix.size()
if new_index == 0:
# PreshMaps can not contain 0 as value, so storing a dummy at 0
self.ent_id_matrix.push_back(hash_vec(0))
new_index = 1
self.ent_id_matrix.push_back(ent_id_list)
self.phrase_ids.set(phrase_hash, <void*>new_index)
def __call__(self, Doc doc):
"""Find all sequences matching the supplied patterns on the `Doc`.
doc (Doc): The document to match over.
RETURNS (list): A list of `(key, start, end)` tuples,
describing the matches. A match tuple describes a span
`doc[start:end]`. The `label_id` and `key` are both integers.
DOCS: https://spacy.io/api/phrasematcher#call
"""
matches = []
if self.attr == ORTH:
match_doc = doc
else:
# If we're not matching on the ORTH, match_doc will be a Doc whose
# token.orth values are the attribute values we're matching on,
# e.g. Doc(nlp.vocab, words=[token.pos_ for token in doc])
words = [self.get_lex_value(doc, i) for i in range(len(doc))]
match_doc = Doc(self.vocab, words=words)
for _, start, end in self.matcher(match_doc):
ent_ids = self.accept_match(match_doc, start, end)
if ent_ids is not None:
for ent_id in ent_ids:
matches.append((ent_id, start, end))
for i, (ent_id, start, end) in enumerate(matches):
on_match = self._callbacks.get(ent_id)
if on_match is not None:
on_match(self, doc, i, matches)
return matches
def pipe(self, stream, batch_size=1000, n_threads=-1, return_matches=False,
as_tuples=False):
"""Match a stream of documents, yielding them in turn.
docs (iterable): A stream of documents.
batch_size (int): Number of documents to accumulate into a working set.
return_matches (bool): Yield the match lists along with the docs, making
results (doc, matches) tuples.
as_tuples (bool): Interpret the input stream as (doc, context) tuples,
and yield (result, context) tuples out.
If both return_matches and as_tuples are True, the output will
be a sequence of ((doc, matches), context) tuples.
YIELDS (Doc): Documents, in order.
DOCS: https://spacy.io/api/phrasematcher#pipe
"""
if n_threads != -1:
deprecation_warning(Warnings.W016)
if as_tuples:
for doc, context in stream:
matches = self(doc)
if return_matches:
yield ((doc, matches), context)
else:
yield (doc, context)
else:
for doc in stream:
matches = self(doc)
if return_matches:
yield (doc, matches)
else:
yield doc
def accept_match(self, Doc doc, int start, int end):
cdef int i, j
cdef Pool mem = Pool()
phrase_key = <attr_t*>mem.alloc(end-start, sizeof(attr_t))
for i, j in enumerate(range(start, end)):
phrase_key[i] = doc.c[j].lex.orth
cdef hash_t key = hash64(phrase_key, (end-start) * sizeof(attr_t), 0)
ent_index = <hash_t>self.phrase_ids.get(key)
if ent_index == 0:
return None
return self.ent_id_matrix[ent_index]
def get_lex_value(self, Doc doc, int i):
if self.attr == ORTH:
# Return the regular orth value of the lexeme
return doc.c[i].lex.orth
# Get the attribute value instead, e.g. token.pos
attr_value = get_token_attr(&doc.c[i], self.attr)
if attr_value in (0, 1):
# Value is boolean, convert to string
string_attr_value = str(attr_value)
else:
string_attr_value = self.vocab.strings[attr_value]
string_attr_name = self.vocab.strings[self.attr]
# Concatenate the attr name and value to not pollute lexeme space
# e.g. 'POS-VERB' instead of just 'VERB', which could otherwise
# create false positive matches
return "matcher:{}-{}".format(string_attr_name, string_attr_value)
def get_biluo(length):
if length == 0:
raise ValueError(Errors.E127)
elif length == 1:
return [U_ENT]
elif length == 2:
return [B2_ENT, L2_ENT]
elif length == 3:
return [B3_ENT, I3_ENT, L3_ENT]
else:
return [B4_ENT, I4_ENT] + [I4_ENT] * (length-3) + [L4_ENT]
def unpickle_matcher(vocab, docs, callbacks):
matcher = PhraseMatcher(vocab)
for key, specs in docs.items():
callback = callbacks.get(key, None)
matcher.add(key, callback, *specs)
return matcher