mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-25 00:34:20 +03:00
Add PhraseMatcher in matcher2.pyx
This commit is contained in:
parent
0004331895
commit
9efda9e9ab
|
@ -12,6 +12,34 @@ from .tokens.doc cimport Doc
|
|||
from .tokens.doc cimport get_token_attr
|
||||
from .attrs cimport ID, attr_id_t, NULL_ATTR
|
||||
from .attrs import IDS
|
||||
from .attrs import FLAG61 as U_ENT
|
||||
from .attrs import FLAG60 as B2_ENT
|
||||
from .attrs import FLAG59 as B3_ENT
|
||||
from .attrs import FLAG58 as B4_ENT
|
||||
from .attrs import FLAG57 as B5_ENT
|
||||
from .attrs import FLAG56 as B6_ENT
|
||||
from .attrs import FLAG55 as B7_ENT
|
||||
from .attrs import FLAG54 as B8_ENT
|
||||
from .attrs import FLAG53 as B9_ENT
|
||||
from .attrs import FLAG52 as B10_ENT
|
||||
from .attrs import FLAG51 as I3_ENT
|
||||
from .attrs import FLAG50 as I4_ENT
|
||||
from .attrs import FLAG49 as I5_ENT
|
||||
from .attrs import FLAG48 as I6_ENT
|
||||
from .attrs import FLAG47 as I7_ENT
|
||||
from .attrs import FLAG46 as I8_ENT
|
||||
from .attrs import FLAG45 as I9_ENT
|
||||
from .attrs import FLAG44 as I10_ENT
|
||||
from .attrs import FLAG43 as L2_ENT
|
||||
from .attrs import FLAG42 as L3_ENT
|
||||
from .attrs import FLAG41 as L4_ENT
|
||||
from .attrs import FLAG40 as L5_ENT
|
||||
from .attrs import FLAG39 as L6_ENT
|
||||
from .attrs import FLAG38 as L7_ENT
|
||||
from .attrs import FLAG37 as L8_ENT
|
||||
from .attrs import FLAG36 as L9_ENT
|
||||
from .attrs import FLAG35 as L10_ENT
|
||||
|
||||
|
||||
|
||||
cdef enum quantifier_t:
|
||||
|
@ -435,6 +463,20 @@ cdef class Matcher:
|
|||
if key not in self._patterns:
|
||||
return default
|
||||
return (self._callbacks[key], self._patterns[key])
|
||||
|
||||
def pipe(self, docs, batch_size=1000, n_threads=2):
|
||||
"""Match a stream of documents, yielding them in turn.
|
||||
|
||||
docs (iterable): A stream of documents.
|
||||
batch_size (int): Number of documents to accumulate into a working set.
|
||||
n_threads (int): The number of threads with which to work on the buffer
|
||||
in parallel, if the implementation supports multi-threading.
|
||||
YIELDS (Doc): Documents, in order.
|
||||
"""
|
||||
for doc in docs:
|
||||
self(doc)
|
||||
yield doc
|
||||
|
||||
|
||||
def __call__(self, Doc doc):
|
||||
"""Find all token sequences matching the supplied pattern.
|
||||
|
@ -466,4 +508,155 @@ def unpickle_matcher(vocab, patterns, callbacks):
|
|||
return matcher
|
||||
|
||||
|
||||
|
||||
def get_bilou(length):
|
||||
if length == 1:
|
||||
return [U_ENT]
|
||||
elif length == 2:
|
||||
return [B2_ENT, L2_ENT]
|
||||
elif length == 3:
|
||||
return [B3_ENT, I3_ENT, L3_ENT]
|
||||
elif length == 4:
|
||||
return [B4_ENT, I4_ENT, I4_ENT, L4_ENT]
|
||||
elif length == 5:
|
||||
return [B5_ENT, I5_ENT, I5_ENT, I5_ENT, L5_ENT]
|
||||
elif length == 6:
|
||||
return [B6_ENT, I6_ENT, I6_ENT, I6_ENT, I6_ENT, L6_ENT]
|
||||
elif length == 7:
|
||||
return [B7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, L7_ENT]
|
||||
elif length == 8:
|
||||
return [B8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, L8_ENT]
|
||||
elif length == 9:
|
||||
return [B9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT,
|
||||
L9_ENT]
|
||||
elif length == 10:
|
||||
return [B10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT,
|
||||
I10_ENT, I10_ENT, L10_ENT]
|
||||
else:
|
||||
raise ValueError("Max length currently 10 for phrase matching")
|
||||
|
||||
|
||||
cdef class PhraseMatcher:
|
||||
cdef Pool mem
|
||||
cdef Vocab vocab
|
||||
cdef Matcher matcher
|
||||
cdef PreshMap phrase_ids
|
||||
cdef int max_length
|
||||
cdef attr_t* _phrase_key
|
||||
cdef public object _callbacks
|
||||
cdef public object _patterns
|
||||
|
||||
def __init__(self, Vocab vocab, max_length=10):
|
||||
self.mem = Pool()
|
||||
self._phrase_key = <attr_t*>self.mem.alloc(max_length, sizeof(attr_t))
|
||||
self.max_length = max_length
|
||||
self.vocab = vocab
|
||||
self.matcher = Matcher(self.vocab)
|
||||
self.phrase_ids = PreshMap()
|
||||
abstract_patterns = []
|
||||
for length in range(1, max_length):
|
||||
abstract_patterns.append([{tag: True}
|
||||
for tag in get_bilou(length)])
|
||||
self.matcher.add('Candidate', None, *abstract_patterns)
|
||||
self._callbacks = {}
|
||||
|
||||
def __len__(self):
|
||||
"""Get the number of rules added to the matcher. Note that this only
|
||||
returns the number of rules (identical with the number of IDs), not the
|
||||
number of individual patterns.
|
||||
|
||||
RETURNS (int): The number of rules.
|
||||
"""
|
||||
return len(self.phrase_ids)
|
||||
|
||||
def __contains__(self, key):
|
||||
"""Check whether the matcher contains rules for a match ID.
|
||||
|
||||
key (unicode): The match ID.
|
||||
RETURNS (bool): Whether the matcher contains rules for this match ID.
|
||||
"""
|
||||
cdef hash_t ent_id = self.matcher._normalize_key(key)
|
||||
return ent_id in self._callbacks
|
||||
|
||||
def __reduce__(self):
|
||||
return (self.__class__, (self.vocab,), None, None)
|
||||
|
||||
def add(self, key, on_match, *docs):
|
||||
"""Add a match-rule to the matcher. A match-rule consists of: an ID
|
||||
key, an on_match callback, and one or more patterns.
|
||||
|
||||
key (unicode): The match ID.
|
||||
on_match (callable): Callback executed on match.
|
||||
*docs (Doc): `Doc` objects representing match patterns.
|
||||
"""
|
||||
cdef Doc doc
|
||||
for doc in docs:
|
||||
if len(doc) >= self.max_length:
|
||||
msg = (
|
||||
"Pattern length (%d) >= phrase_matcher.max_length (%d). "
|
||||
"Length can be set on initialization, up to 10."
|
||||
)
|
||||
raise ValueError(msg % (len(doc), self.max_length))
|
||||
cdef hash_t ent_id = self.matcher._normalize_key(key)
|
||||
self._callbacks[ent_id] = on_match
|
||||
cdef int length
|
||||
cdef int i
|
||||
cdef hash_t phrase_hash
|
||||
for doc in docs:
|
||||
length = doc.length
|
||||
tags = get_bilou(length)
|
||||
for i in range(self.max_length):
|
||||
self._phrase_key[i] = 0
|
||||
for i, tag in enumerate(tags):
|
||||
lexeme = self.vocab[doc.c[i].lex.orth]
|
||||
lexeme.set_flag(tag, True)
|
||||
self._phrase_key[i] = lexeme.orth
|
||||
phrase_hash = hash64(self._phrase_key,
|
||||
self.max_length * sizeof(attr_t), 0)
|
||||
self.phrase_ids.set(phrase_hash, <void*>ent_id)
|
||||
|
||||
def __call__(self, Doc doc):
|
||||
"""Find all sequences matching the supplied patterns on the `Doc`.
|
||||
|
||||
doc (Doc): The document to match over.
|
||||
RETURNS (list): A list of `(key, start, end)` tuples,
|
||||
describing the matches. A match tuple describes a span
|
||||
`doc[start:end]`. The `label_id` and `key` are both integers.
|
||||
"""
|
||||
matches = []
|
||||
for _, start, end in self.matcher(doc):
|
||||
ent_id = self.accept_match(doc, start, end)
|
||||
if ent_id is not None:
|
||||
matches.append((ent_id, start, end))
|
||||
for i, (ent_id, start, end) in enumerate(matches):
|
||||
on_match = self._callbacks.get(ent_id)
|
||||
if on_match is not None:
|
||||
on_match(self, doc, i, matches)
|
||||
return matches
|
||||
|
||||
def pipe(self, stream, batch_size=1000, n_threads=2):
|
||||
"""Match a stream of documents, yielding them in turn.
|
||||
|
||||
docs (iterable): A stream of documents.
|
||||
batch_size (int): Number of documents to accumulate into a working set.
|
||||
n_threads (int): The number of threads with which to work on the buffer
|
||||
in parallel, if the implementation supports multi-threading.
|
||||
YIELDS (Doc): Documents, in order.
|
||||
"""
|
||||
for doc in stream:
|
||||
self(doc)
|
||||
yield doc
|
||||
|
||||
def accept_match(self, Doc doc, int start, int end):
|
||||
assert (end - start) < self.max_length
|
||||
cdef int i, j
|
||||
for i in range(self.max_length):
|
||||
self._phrase_key[i] = 0
|
||||
for i, j in enumerate(range(start, end)):
|
||||
self._phrase_key[i] = doc.c[j].lex.orth
|
||||
cdef hash_t key = hash64(self._phrase_key,
|
||||
self.max_length * sizeof(attr_t), 0)
|
||||
ent_id = <hash_t>self.phrase_ids.get(key)
|
||||
if ent_id == 0:
|
||||
return None
|
||||
else:
|
||||
return ent_id
|
||||
|
|
Loading…
Reference in New Issue
Block a user