mirror of
https://github.com/explosion/spaCy.git
synced 2024-09-21 11:29:13 +03:00
483dddc9bc
* Add custom MatchPatternError * Improve validators and add validation option to Matcher * Adjust formatting * Never validate in Matcher within PhraseMatcher If we do decide to make validate default to True, the PhraseMatcher's Matcher shouldn't ever validate. Here, we create the patterns automatically anyways (and it's currently unclear whether the validation has performance impacts at a very large scale).
216 lines
8.2 KiB
Cython
216 lines
8.2 KiB
Cython
# cython: infer_types=True
|
|
# cython: profile=True
|
|
from __future__ import unicode_literals
|
|
|
|
from cymem.cymem cimport Pool
|
|
from murmurhash.mrmr cimport hash64
|
|
from preshed.maps cimport PreshMap
|
|
|
|
from .matcher cimport Matcher
|
|
from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA, attr_id_t
|
|
from ..vocab cimport Vocab
|
|
from ..tokens.doc cimport Doc, get_token_attr
|
|
from ..typedefs cimport attr_t, hash_t
|
|
|
|
from ..errors import Warnings, deprecation_warning, user_warning
|
|
from ..attrs import FLAG61 as U_ENT
|
|
from ..attrs import FLAG60 as B2_ENT
|
|
from ..attrs import FLAG59 as B3_ENT
|
|
from ..attrs import FLAG58 as B4_ENT
|
|
from ..attrs import FLAG43 as L2_ENT
|
|
from ..attrs import FLAG42 as L3_ENT
|
|
from ..attrs import FLAG41 as L4_ENT
|
|
from ..attrs import FLAG42 as I3_ENT
|
|
from ..attrs import FLAG41 as I4_ENT
|
|
|
|
|
|
cdef class PhraseMatcher:
|
|
cdef Pool mem
|
|
cdef Vocab vocab
|
|
cdef Matcher matcher
|
|
cdef PreshMap phrase_ids
|
|
cdef int max_length
|
|
cdef attr_id_t attr
|
|
cdef public object _callbacks
|
|
cdef public object _patterns
|
|
cdef public object _validate
|
|
|
|
def __init__(self, Vocab vocab, max_length=0, attr='ORTH', validate=False):
|
|
if max_length != 0:
|
|
deprecation_warning(Warnings.W010)
|
|
self.mem = Pool()
|
|
self.max_length = max_length
|
|
self.vocab = vocab
|
|
self.matcher = Matcher(self.vocab, validate=False)
|
|
if isinstance(attr, long):
|
|
self.attr = attr
|
|
else:
|
|
self.attr = self.vocab.strings[attr]
|
|
self.phrase_ids = PreshMap()
|
|
abstract_patterns = [
|
|
[{U_ENT: True}],
|
|
[{B2_ENT: True}, {L2_ENT: True}],
|
|
[{B3_ENT: True}, {I3_ENT: True}, {L3_ENT: True}],
|
|
[{B4_ENT: True}, {I4_ENT: True}, {I4_ENT: True, "OP": "+"}, {L4_ENT: True}],
|
|
]
|
|
self.matcher.add('Candidate', None, *abstract_patterns)
|
|
self._callbacks = {}
|
|
self._validate = validate
|
|
|
|
def __len__(self):
|
|
"""Get the number of rules added to the matcher. Note that this only
|
|
returns the number of rules (identical with the number of IDs), not the
|
|
number of individual patterns.
|
|
|
|
RETURNS (int): The number of rules.
|
|
"""
|
|
return len(self.phrase_ids)
|
|
|
|
def __contains__(self, key):
|
|
"""Check whether the matcher contains rules for a match ID.
|
|
|
|
key (unicode): The match ID.
|
|
RETURNS (bool): Whether the matcher contains rules for this match ID.
|
|
"""
|
|
cdef hash_t ent_id = self.matcher._normalize_key(key)
|
|
return ent_id in self._callbacks
|
|
|
|
def __reduce__(self):
|
|
return (self.__class__, (self.vocab,), None, None)
|
|
|
|
def add(self, key, on_match, *docs):
|
|
"""Add a match-rule to the phrase-matcher. A match-rule consists of: an ID
|
|
key, an on_match callback, and one or more patterns.
|
|
|
|
key (unicode): The match ID.
|
|
on_match (callable): Callback executed on match.
|
|
*docs (Doc): `Doc` objects representing match patterns.
|
|
"""
|
|
cdef Doc doc
|
|
cdef hash_t ent_id = self.matcher._normalize_key(key)
|
|
self._callbacks[ent_id] = on_match
|
|
cdef int length
|
|
cdef int i
|
|
cdef hash_t phrase_hash
|
|
cdef Pool mem = Pool()
|
|
for doc in docs:
|
|
length = doc.length
|
|
if length == 0:
|
|
continue
|
|
if self._validate and (doc.is_tagged or doc.is_parsed) \
|
|
and self.attr not in (DEP, POS, TAG, LEMMA):
|
|
string_attr = self.vocab.strings[self.attr]
|
|
user_warning(Warnings.W012.format(key=key, attr=string_attr))
|
|
tags = get_bilou(length)
|
|
phrase_key = <attr_t*>mem.alloc(length, sizeof(attr_t))
|
|
for i, tag in enumerate(tags):
|
|
attr_value = self.get_lex_value(doc, i)
|
|
lexeme = self.vocab[attr_value]
|
|
lexeme.set_flag(tag, True)
|
|
phrase_key[i] = lexeme.orth
|
|
phrase_hash = hash64(phrase_key,
|
|
length * sizeof(attr_t), 0)
|
|
self.phrase_ids.set(phrase_hash, <void*>ent_id)
|
|
|
|
def __call__(self, Doc doc):
|
|
"""Find all sequences matching the supplied patterns on the `Doc`.
|
|
|
|
doc (Doc): The document to match over.
|
|
RETURNS (list): A list of `(key, start, end)` tuples,
|
|
describing the matches. A match tuple describes a span
|
|
`doc[start:end]`. The `label_id` and `key` are both integers.
|
|
"""
|
|
matches = []
|
|
if self.attr == ORTH:
|
|
match_doc = doc
|
|
else:
|
|
# If we're not matching on the ORTH, match_doc will be a Doc whose
|
|
# token.orth values are the attribute values we're matching on,
|
|
# e.g. Doc(nlp.vocab, words=[token.pos_ for token in doc])
|
|
words = [self.get_lex_value(doc, i) for i in range(len(doc))]
|
|
match_doc = Doc(self.vocab, words=words)
|
|
for _, start, end in self.matcher(match_doc):
|
|
ent_id = self.accept_match(match_doc, start, end)
|
|
if ent_id is not None:
|
|
matches.append((ent_id, start, end))
|
|
for i, (ent_id, start, end) in enumerate(matches):
|
|
on_match = self._callbacks.get(ent_id)
|
|
if on_match is not None:
|
|
on_match(self, doc, i, matches)
|
|
return matches
|
|
|
|
def pipe(self, stream, batch_size=1000, n_threads=1, return_matches=False,
|
|
as_tuples=False):
|
|
"""Match a stream of documents, yielding them in turn.
|
|
|
|
docs (iterable): A stream of documents.
|
|
batch_size (int): Number of documents to accumulate into a working set.
|
|
n_threads (int): The number of threads with which to work on the buffer
|
|
in parallel, if the implementation supports multi-threading.
|
|
return_matches (bool): Yield the match lists along with the docs, making
|
|
results (doc, matches) tuples.
|
|
as_tuples (bool): Interpret the input stream as (doc, context) tuples,
|
|
and yield (result, context) tuples out.
|
|
If both return_matches and as_tuples are True, the output will
|
|
be a sequence of ((doc, matches), context) tuples.
|
|
YIELDS (Doc): Documents, in order.
|
|
"""
|
|
if as_tuples:
|
|
for doc, context in stream:
|
|
matches = self(doc)
|
|
if return_matches:
|
|
yield ((doc, matches), context)
|
|
else:
|
|
yield (doc, context)
|
|
else:
|
|
for doc in stream:
|
|
matches = self(doc)
|
|
if return_matches:
|
|
yield (doc, matches)
|
|
else:
|
|
yield doc
|
|
|
|
def accept_match(self, Doc doc, int start, int end):
|
|
cdef int i, j
|
|
cdef Pool mem = Pool()
|
|
phrase_key = <attr_t*>mem.alloc(end-start, sizeof(attr_t))
|
|
for i, j in enumerate(range(start, end)):
|
|
phrase_key[i] = doc.c[j].lex.orth
|
|
cdef hash_t key = hash64(phrase_key,
|
|
(end-start) * sizeof(attr_t), 0)
|
|
ent_id = <hash_t>self.phrase_ids.get(key)
|
|
if ent_id == 0:
|
|
return None
|
|
else:
|
|
return ent_id
|
|
|
|
def get_lex_value(self, Doc doc, int i):
|
|
if self.attr == ORTH:
|
|
# Return the regular orth value of the lexeme
|
|
return doc.c[i].lex.orth
|
|
# Get the attribute value instead, e.g. token.pos
|
|
attr_value = get_token_attr(&doc.c[i], self.attr)
|
|
if attr_value in (0, 1):
|
|
# Value is boolean, convert to string
|
|
string_attr_value = str(attr_value)
|
|
else:
|
|
string_attr_value = self.vocab.strings[attr_value]
|
|
string_attr_name = self.vocab.strings[self.attr]
|
|
# Concatenate the attr name and value to not pollute lexeme space
|
|
# e.g. 'POS-VERB' instead of just 'VERB', which could otherwise
|
|
# create false positive matches
|
|
return 'matcher:{}-{}'.format(string_attr_name, string_attr_value)
|
|
|
|
|
|
def get_bilou(length):
|
|
if length == 0:
|
|
raise ValueError("Length must be >= 1")
|
|
elif length == 1:
|
|
return [U_ENT]
|
|
elif length == 2:
|
|
return [B2_ENT, L2_ENT]
|
|
elif length == 3:
|
|
return [B3_ENT, I3_ENT, L3_ENT]
|
|
else:
|
|
return [B4_ENT, I4_ENT] + [I4_ENT] * (length-3) + [L4_ENT]
|