From 9efda9e9abec9e0303787671adab007c48cc8629 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 13 Feb 2018 16:27:46 +0100 Subject: [PATCH] Add PhraseMatcher in matcher2.pyx --- spacy/matcher2.pyx | 195 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 194 insertions(+), 1 deletion(-) diff --git a/spacy/matcher2.pyx b/spacy/matcher2.pyx index 4545a2f31..d3de94911 100644 --- a/spacy/matcher2.pyx +++ b/spacy/matcher2.pyx @@ -12,6 +12,34 @@ from .tokens.doc cimport Doc from .tokens.doc cimport get_token_attr from .attrs cimport ID, attr_id_t, NULL_ATTR from .attrs import IDS +from .attrs import FLAG61 as U_ENT +from .attrs import FLAG60 as B2_ENT +from .attrs import FLAG59 as B3_ENT +from .attrs import FLAG58 as B4_ENT +from .attrs import FLAG57 as B5_ENT +from .attrs import FLAG56 as B6_ENT +from .attrs import FLAG55 as B7_ENT +from .attrs import FLAG54 as B8_ENT +from .attrs import FLAG53 as B9_ENT +from .attrs import FLAG52 as B10_ENT +from .attrs import FLAG51 as I3_ENT +from .attrs import FLAG50 as I4_ENT +from .attrs import FLAG49 as I5_ENT +from .attrs import FLAG48 as I6_ENT +from .attrs import FLAG47 as I7_ENT +from .attrs import FLAG46 as I8_ENT +from .attrs import FLAG45 as I9_ENT +from .attrs import FLAG44 as I10_ENT +from .attrs import FLAG43 as L2_ENT +from .attrs import FLAG42 as L3_ENT +from .attrs import FLAG41 as L4_ENT +from .attrs import FLAG40 as L5_ENT +from .attrs import FLAG39 as L6_ENT +from .attrs import FLAG38 as L7_ENT +from .attrs import FLAG37 as L8_ENT +from .attrs import FLAG36 as L9_ENT +from .attrs import FLAG35 as L10_ENT + cdef enum quantifier_t: @@ -435,6 +463,20 @@ cdef class Matcher: if key not in self._patterns: return default return (self._callbacks[key], self._patterns[key]) + + def pipe(self, docs, batch_size=1000, n_threads=2): + """Match a stream of documents, yielding them in turn. + + docs (iterable): A stream of documents. + batch_size (int): Number of documents to accumulate into a working set. + n_threads (int): The number of threads with which to work on the buffer + in parallel, if the implementation supports multi-threading. + YIELDS (Doc): Documents, in order. + """ + for doc in docs: + self(doc) + yield doc + def __call__(self, Doc doc): """Find all token sequences matching the supplied pattern. @@ -466,4 +508,155 @@ def unpickle_matcher(vocab, patterns, callbacks): return matcher - +def get_bilou(length): + if length == 1: + return [U_ENT] + elif length == 2: + return [B2_ENT, L2_ENT] + elif length == 3: + return [B3_ENT, I3_ENT, L3_ENT] + elif length == 4: + return [B4_ENT, I4_ENT, I4_ENT, L4_ENT] + elif length == 5: + return [B5_ENT, I5_ENT, I5_ENT, I5_ENT, L5_ENT] + elif length == 6: + return [B6_ENT, I6_ENT, I6_ENT, I6_ENT, I6_ENT, L6_ENT] + elif length == 7: + return [B7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, L7_ENT] + elif length == 8: + return [B8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, L8_ENT] + elif length == 9: + return [B9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, + L9_ENT] + elif length == 10: + return [B10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, + I10_ENT, I10_ENT, L10_ENT] + else: + raise ValueError("Max length currently 10 for phrase matching") + + +cdef class PhraseMatcher: + cdef Pool mem + cdef Vocab vocab + cdef Matcher matcher + cdef PreshMap phrase_ids + cdef int max_length + cdef attr_t* _phrase_key + cdef public object _callbacks + cdef public object _patterns + + def __init__(self, Vocab vocab, max_length=10): + self.mem = Pool() + self._phrase_key = self.mem.alloc(max_length, sizeof(attr_t)) + self.max_length = max_length + self.vocab = vocab + self.matcher = Matcher(self.vocab) + self.phrase_ids = PreshMap() + abstract_patterns = [] + for length in range(1, max_length): + abstract_patterns.append([{tag: True} + for tag in get_bilou(length)]) + self.matcher.add('Candidate', None, *abstract_patterns) + self._callbacks = {} + + def __len__(self): + """Get the number of rules added to the matcher. Note that this only + returns the number of rules (identical with the number of IDs), not the + number of individual patterns. + + RETURNS (int): The number of rules. + """ + return len(self.phrase_ids) + + def __contains__(self, key): + """Check whether the matcher contains rules for a match ID. + + key (unicode): The match ID. + RETURNS (bool): Whether the matcher contains rules for this match ID. + """ + cdef hash_t ent_id = self.matcher._normalize_key(key) + return ent_id in self._callbacks + + def __reduce__(self): + return (self.__class__, (self.vocab,), None, None) + + def add(self, key, on_match, *docs): + """Add a match-rule to the matcher. A match-rule consists of: an ID + key, an on_match callback, and one or more patterns. + + key (unicode): The match ID. + on_match (callable): Callback executed on match. + *docs (Doc): `Doc` objects representing match patterns. + """ + cdef Doc doc + for doc in docs: + if len(doc) >= self.max_length: + msg = ( + "Pattern length (%d) >= phrase_matcher.max_length (%d). " + "Length can be set on initialization, up to 10." + ) + raise ValueError(msg % (len(doc), self.max_length)) + cdef hash_t ent_id = self.matcher._normalize_key(key) + self._callbacks[ent_id] = on_match + cdef int length + cdef int i + cdef hash_t phrase_hash + for doc in docs: + length = doc.length + tags = get_bilou(length) + for i in range(self.max_length): + self._phrase_key[i] = 0 + for i, tag in enumerate(tags): + lexeme = self.vocab[doc.c[i].lex.orth] + lexeme.set_flag(tag, True) + self._phrase_key[i] = lexeme.orth + phrase_hash = hash64(self._phrase_key, + self.max_length * sizeof(attr_t), 0) + self.phrase_ids.set(phrase_hash, ent_id) + + def __call__(self, Doc doc): + """Find all sequences matching the supplied patterns on the `Doc`. + + doc (Doc): The document to match over. + RETURNS (list): A list of `(key, start, end)` tuples, + describing the matches. A match tuple describes a span + `doc[start:end]`. The `label_id` and `key` are both integers. + """ + matches = [] + for _, start, end in self.matcher(doc): + ent_id = self.accept_match(doc, start, end) + if ent_id is not None: + matches.append((ent_id, start, end)) + for i, (ent_id, start, end) in enumerate(matches): + on_match = self._callbacks.get(ent_id) + if on_match is not None: + on_match(self, doc, i, matches) + return matches + + def pipe(self, stream, batch_size=1000, n_threads=2): + """Match a stream of documents, yielding them in turn. + + docs (iterable): A stream of documents. + batch_size (int): Number of documents to accumulate into a working set. + n_threads (int): The number of threads with which to work on the buffer + in parallel, if the implementation supports multi-threading. + YIELDS (Doc): Documents, in order. + """ + for doc in stream: + self(doc) + yield doc + + def accept_match(self, Doc doc, int start, int end): + assert (end - start) < self.max_length + cdef int i, j + for i in range(self.max_length): + self._phrase_key[i] = 0 + for i, j in enumerate(range(start, end)): + self._phrase_key[i] = doc.c[j].lex.orth + cdef hash_t key = hash64(self._phrase_key, + self.max_length * sizeof(attr_t), 0) + ent_id = self.phrase_ids.get(key) + if ent_id == 0: + return None + else: + return ent_id