mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-19 14:04:12 +03:00
0d9740e826
Replace PhraseMatcher with the Aho-Corasick algorithm over numpy arrays of the hash values for the relevant attribute. The implementation is based on FlashText. The speed should be similar to the previous PhraseMatcher. It is now possible to easily remove match IDs and matches don't go missing with large keyword lists / vocabularies. Fixes #4308.
288 lines
12 KiB
Cython
288 lines
12 KiB
Cython
# cython: infer_types=True
|
|
# cython: profile=True
|
|
from __future__ import unicode_literals
|
|
|
|
import numpy as np
|
|
|
|
from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA, attr_id_t
|
|
from ..vocab cimport Vocab
|
|
from ..tokens.doc cimport Doc, get_token_attr
|
|
|
|
from ._schemas import TOKEN_PATTERN_SCHEMA
|
|
from ..errors import Errors, Warnings, deprecation_warning, user_warning
|
|
|
|
|
|
cdef class PhraseMatcher:
|
|
"""Efficiently match large terminology lists. While the `Matcher` matches
|
|
sequences based on lists of token descriptions, the `PhraseMatcher` accepts
|
|
match patterns in the form of `Doc` objects.
|
|
|
|
DOCS: https://spacy.io/api/phrasematcher
|
|
USAGE: https://spacy.io/usage/rule-based-matching#phrasematcher
|
|
|
|
Adapted from FlashText: https://github.com/vi3k6i5/flashtext
|
|
MIT License (see `LICENSE`)
|
|
Copyright (c) 2017 Vikash Singh (vikash.duliajan@gmail.com)
|
|
"""
|
|
cdef Vocab vocab
|
|
cdef unicode _terminal
|
|
cdef object keyword_trie_dict
|
|
cdef attr_id_t attr
|
|
cdef object _callbacks
|
|
cdef object _keywords
|
|
cdef bint _validate
|
|
|
|
def __init__(self, Vocab vocab, max_length=0, attr="ORTH", validate=False):
|
|
"""Initialize the PhraseMatcher.
|
|
|
|
vocab (Vocab): The shared vocabulary.
|
|
attr (int / unicode): Token attribute to match on.
|
|
validate (bool): Perform additional validation when patterns are added.
|
|
RETURNS (PhraseMatcher): The newly constructed object.
|
|
|
|
DOCS: https://spacy.io/api/phrasematcher#init
|
|
"""
|
|
if max_length != 0:
|
|
deprecation_warning(Warnings.W010)
|
|
self.vocab = vocab
|
|
self._terminal = '_terminal_'
|
|
self.keyword_trie_dict = dict()
|
|
self._callbacks = {}
|
|
self._keywords = {}
|
|
self._validate = validate
|
|
|
|
if isinstance(attr, long):
|
|
self.attr = attr
|
|
else:
|
|
attr = attr.upper()
|
|
if attr == "TEXT":
|
|
attr = "ORTH"
|
|
if attr not in TOKEN_PATTERN_SCHEMA["items"]["properties"]:
|
|
raise ValueError(Errors.E152.format(attr=attr))
|
|
self.attr = self.vocab.strings[attr]
|
|
|
|
def __len__(self):
|
|
"""Get the number of match IDs added to the matcher.
|
|
|
|
RETURNS (int): The number of rules.
|
|
|
|
DOCS: https://spacy.io/api/phrasematcher#len
|
|
"""
|
|
return len(self._callbacks)
|
|
|
|
def __contains__(self, key):
|
|
"""Check whether the matcher contains rules for a match ID.
|
|
|
|
key (unicode): The match ID.
|
|
RETURNS (bool): Whether the matcher contains rules for this match ID.
|
|
|
|
DOCS: https://spacy.io/api/phrasematcher#contains
|
|
"""
|
|
return key in self._callbacks
|
|
|
|
def remove(self, key):
|
|
"""Remove a match-rule from the matcher by match ID.
|
|
|
|
key (unicode): The match ID.
|
|
"""
|
|
if key not in self._keywords:
|
|
return
|
|
for keyword in self._keywords[key]:
|
|
current_dict = self.keyword_trie_dict
|
|
token_trie_list = []
|
|
for tokens in keyword:
|
|
if tokens in current_dict:
|
|
token_trie_list.append((tokens, current_dict))
|
|
current_dict = current_dict[tokens]
|
|
else:
|
|
# if token is not found, break out of the loop
|
|
current_dict = None
|
|
break
|
|
# remove the tokens from trie dict if there are no other
|
|
# keywords with them
|
|
if current_dict and self._terminal in current_dict:
|
|
# if this is the only remaining key, remove unnecessary paths
|
|
if current_dict[self._terminal] == [key]:
|
|
# we found a complete match for input keyword
|
|
token_trie_list.append((self._terminal, current_dict))
|
|
token_trie_list.reverse()
|
|
for key_to_remove, dict_pointer in token_trie_list:
|
|
if len(dict_pointer.keys()) == 1:
|
|
dict_pointer.pop(key_to_remove)
|
|
else:
|
|
# more than one key means more than 1 path,
|
|
# delete not required path and keep the other
|
|
dict_pointer.pop(key_to_remove)
|
|
break
|
|
# otherwise simply remove the key
|
|
else:
|
|
current_dict[self._terminal].remove(key)
|
|
|
|
del self._keywords[key]
|
|
del self._callbacks[key]
|
|
|
|
def add(self, key, on_match, *docs):
|
|
"""Add a match-rule to the phrase-matcher. A match-rule consists of: an ID
|
|
key, an on_match callback, and one or more patterns.
|
|
|
|
key (unicode): The match ID.
|
|
on_match (callable): Callback executed on match.
|
|
*docs (Doc): `Doc` objects representing match patterns.
|
|
|
|
DOCS: https://spacy.io/api/phrasematcher#add
|
|
"""
|
|
|
|
_ = self.vocab[key]
|
|
self._callbacks[key] = on_match
|
|
self._keywords.setdefault(key, [])
|
|
|
|
for doc in docs:
|
|
if len(doc) == 0:
|
|
continue
|
|
if self.attr in (POS, TAG, LEMMA) and not doc.is_tagged:
|
|
raise ValueError(Errors.E155.format())
|
|
if self.attr == DEP and not doc.is_parsed:
|
|
raise ValueError(Errors.E156.format())
|
|
if self._validate and (doc.is_tagged or doc.is_parsed) \
|
|
and self.attr not in (DEP, POS, TAG, LEMMA):
|
|
string_attr = self.vocab.strings[self.attr]
|
|
user_warning(Warnings.W012.format(key=key, attr=string_attr))
|
|
keyword = self._convert_to_array(doc)
|
|
# keep track of keywords per key to make remove easier
|
|
# (would use a set, but can't hash numpy arrays)
|
|
if keyword not in self._keywords[key]:
|
|
self._keywords[key].append(keyword)
|
|
current_dict = self.keyword_trie_dict
|
|
for token in keyword:
|
|
current_dict = current_dict.setdefault(token, {})
|
|
current_dict.setdefault(self._terminal, set())
|
|
current_dict[self._terminal].add(key)
|
|
|
|
def __call__(self, doc):
|
|
"""Find all sequences matching the supplied patterns on the `Doc`.
|
|
|
|
doc (Doc): The document to match over.
|
|
RETURNS (list): A list of `(key, start, end)` tuples,
|
|
describing the matches. A match tuple describes a span
|
|
`doc[start:end]`. The `label_id` and `key` are both integers.
|
|
|
|
DOCS: https://spacy.io/api/phrasematcher#call
|
|
"""
|
|
doc_array = self._convert_to_array(doc)
|
|
matches = []
|
|
if doc_array is None or len(doc_array) == 0:
|
|
# if doc_array is empty or None just return empty list
|
|
return matches
|
|
current_dict = self.keyword_trie_dict
|
|
start = 0
|
|
reset_current_dict = False
|
|
idx = 0
|
|
doc_array_len = len(doc_array)
|
|
while idx < doc_array_len:
|
|
token = doc_array[idx]
|
|
# if end is present in current_dict
|
|
if self._terminal in current_dict or token in current_dict:
|
|
if self._terminal in current_dict:
|
|
ent_id = current_dict[self._terminal]
|
|
matches.append((self.vocab.strings[ent_id], start, idx))
|
|
|
|
# look for longer sequences from this position
|
|
if token in current_dict:
|
|
current_dict_continued = current_dict[token]
|
|
|
|
idy = idx + 1
|
|
while idy < doc_array_len:
|
|
inner_token = doc_array[idy]
|
|
if self._terminal in current_dict_continued:
|
|
ent_ids = current_dict_continued[self._terminal]
|
|
for ent_id in ent_ids:
|
|
matches.append((self.vocab.strings[ent_id], start, idy))
|
|
if inner_token in current_dict_continued:
|
|
current_dict_continued = current_dict_continued[inner_token]
|
|
else:
|
|
break
|
|
idy += 1
|
|
else:
|
|
# end of doc_array reached
|
|
if self._terminal in current_dict_continued:
|
|
ent_ids = current_dict_continued[self._terminal]
|
|
for ent_id in ent_ids:
|
|
matches.append((self.vocab.strings[ent_id], start, idy))
|
|
current_dict = self.keyword_trie_dict
|
|
reset_current_dict = True
|
|
else:
|
|
# we reset current_dict
|
|
current_dict = self.keyword_trie_dict
|
|
reset_current_dict = True
|
|
# if we are end of doc_array and have a sequence discovered
|
|
if idx + 1 >= doc_array_len:
|
|
if self._terminal in current_dict:
|
|
ent_ids = current_dict[self._terminal]
|
|
for ent_id in ent_ids:
|
|
matches.append((self.vocab.strings[ent_id], start, doc_array_len))
|
|
idx += 1
|
|
if reset_current_dict:
|
|
reset_current_dict = False
|
|
start = idx
|
|
for i, (ent_id, start, end) in enumerate(matches):
|
|
on_match = self._callbacks.get(ent_id)
|
|
if on_match is not None:
|
|
on_match(self, doc, i, matches)
|
|
return matches
|
|
|
|
def pipe(self, stream, batch_size=1000, n_threads=-1, return_matches=False,
|
|
as_tuples=False):
|
|
"""Match a stream of documents, yielding them in turn.
|
|
|
|
docs (iterable): A stream of documents.
|
|
batch_size (int): Number of documents to accumulate into a working set.
|
|
return_matches (bool): Yield the match lists along with the docs, making
|
|
results (doc, matches) tuples.
|
|
as_tuples (bool): Interpret the input stream as (doc, context) tuples,
|
|
and yield (result, context) tuples out.
|
|
If both return_matches and as_tuples are True, the output will
|
|
be a sequence of ((doc, matches), context) tuples.
|
|
YIELDS (Doc): Documents, in order.
|
|
|
|
DOCS: https://spacy.io/api/phrasematcher#pipe
|
|
"""
|
|
if n_threads != -1:
|
|
deprecation_warning(Warnings.W016)
|
|
if as_tuples:
|
|
for doc, context in stream:
|
|
matches = self(doc)
|
|
if return_matches:
|
|
yield ((doc, matches), context)
|
|
else:
|
|
yield (doc, context)
|
|
else:
|
|
for doc in stream:
|
|
matches = self(doc)
|
|
if return_matches:
|
|
yield (doc, matches)
|
|
else:
|
|
yield doc
|
|
|
|
def get_lex_value(self, Doc doc, int i):
|
|
if self.attr == ORTH:
|
|
# Return the regular orth value of the lexeme
|
|
return doc.c[i].lex.orth
|
|
# Get the attribute value instead, e.g. token.pos
|
|
attr_value = get_token_attr(&doc.c[i], self.attr)
|
|
if attr_value in (0, 1):
|
|
# Value is boolean, convert to string
|
|
string_attr_value = str(attr_value)
|
|
else:
|
|
string_attr_value = self.vocab.strings[attr_value]
|
|
string_attr_name = self.vocab.strings[self.attr]
|
|
# Concatenate the attr name and value to not pollute lexeme space
|
|
# e.g. 'POS-VERB' instead of just 'VERB', which could otherwise
|
|
# create false positive matches
|
|
matcher_attr_string = "matcher:{}-{}".format(string_attr_name, string_attr_value)
|
|
# Add new string to vocab
|
|
_ = self.vocab[matcher_attr_string]
|
|
return self.vocab.strings[matcher_attr_string]
|
|
|
|
def _convert_to_array(self, Doc doc):
|
|
return np.array([self.get_lex_value(doc, i) for i in range(len(doc))], dtype=np.uint64)
|