From 801d55a6d950f708a1911e84abff024c772ad466 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Fri, 9 Oct 2015 02:00:45 +1100 Subject: [PATCH] * Fix phrase matcher --- spacy/matcher.pyx | 176 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 144 insertions(+), 32 deletions(-) diff --git a/spacy/matcher.pyx b/spacy/matcher.pyx index 88a4f9ba2..afafd3ddb 100644 --- a/spacy/matcher.pyx +++ b/spacy/matcher.pyx @@ -1,11 +1,18 @@ +# cython: profile=True +from __future__ import unicode_literals + from os import path from .typedefs cimport attr_t +from .typedefs cimport hash_t from .attrs cimport attr_id_t -from .structs cimport TokenC +from .structs cimport TokenC, LexemeC +from .lexeme cimport Lexeme from cymem.cymem cimport Pool +from preshed.maps cimport PreshMap from libcpp.vector cimport vector +from murmurhash.mrmr cimport hash64 from .attrs cimport LENGTH, ENT_TYPE, ORTH, NORM, LEMMA, LOWER, SHAPE from .attrs cimport FLAG13, FLAG14, FLAG15, FLAG16, FLAG17, FLAG18, FLAG19, FLAG20, FLAG21, FLAG22, FLAG23, FLAG24, FLAG25 @@ -15,6 +22,38 @@ from .vocab cimport Vocab from libcpp.vector cimport vector +from .attrs import FLAG61 as U_ENT + +from .attrs import FLAG60 as B2_ENT +from .attrs import FLAG59 as B3_ENT +from .attrs import FLAG58 as B4_ENT +from .attrs import FLAG57 as B5_ENT +from .attrs import FLAG56 as B6_ENT +from .attrs import FLAG55 as B7_ENT +from .attrs import FLAG54 as B8_ENT +from .attrs import FLAG53 as B9_ENT +from .attrs import FLAG52 as B10_ENT + +from .attrs import FLAG51 as I3_ENT +from .attrs import FLAG50 as I4_ENT +from .attrs import FLAG49 as I5_ENT +from .attrs import FLAG48 as I6_ENT +from .attrs import FLAG47 as I7_ENT +from .attrs import FLAG46 as I8_ENT +from .attrs import FLAG45 as I9_ENT +from .attrs import FLAG44 as I10_ENT + +from .attrs import FLAG43 as L2_ENT +from .attrs import FLAG42 as L3_ENT +from .attrs import FLAG41 as L4_ENT +from .attrs import FLAG40 as L5_ENT +from .attrs import FLAG39 as L6_ENT +from .attrs import FLAG38 as L7_ENT +from .attrs import FLAG37 as L8_ENT +from .attrs import FLAG36 as L9_ENT +from .attrs import FLAG35 as L10_ENT + + try: import ujson as json except ImportError: @@ -41,7 +80,7 @@ cdef Pattern* init_pattern(Pool mem, object token_specs, attr_t entity_type) exc pattern[i].spec[j].attr = attr pattern[i].spec[j].value = value i = len(token_specs) - pattern[i].spec = mem.alloc(1, sizeof(AttrValue)) + pattern[i].spec = mem.alloc(2, sizeof(AttrValue)) pattern[i].spec[0].attr = ENT_TYPE pattern[i].spec[0].value = entity_type pattern[i].spec[1].attr = LENGTH @@ -81,7 +120,33 @@ def _convert_strings(token_specs, string_store): value = int(value) converted[-1].append((attr, value)) return converted - + + +def get_bilou(length): + if length == 1: + return [U_ENT] + elif length == 2: + return [B2_ENT, L2_ENT] + elif length == 3: + return [B3_ENT, I3_ENT, L3_ENT] + elif length == 4: + return [B4_ENT, I4_ENT, I4_ENT, L4_ENT] + elif length == 5: + return [B5_ENT, I5_ENT, I5_ENT, I5_ENT, L5_ENT] + elif length == 6: + return [B6_ENT, I6_ENT, I6_ENT, I6_ENT, I6_ENT, L6_ENT] + elif length == 7: + return [B7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, L7_ENT] + elif length == 8: + return [B8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, L8_ENT] + elif length == 9: + return [B9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, L9_ENT] + elif length == 10: + return [B10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, + I10_ENT, I10_ENT, L10_ENT] + else: + raise ValueError("Max length currently 10 for phrase matching") + def map_attr_name(attr): attr = attr.upper() @@ -95,32 +160,6 @@ def map_attr_name(attr): return SHAPE elif attr == 'NORM': return NORM - elif attr == 'FLAG13': - return FLAG13 - elif attr == 'FLAG14': - return FLAG14 - elif attr == 'FLAG15': - return FLAG15 - elif attr == 'FLAG16': - return FLAG16 - elif attr == 'FLAG17': - return FLAG17 - elif attr == 'FLAG18': - return FLAG18 - elif attr == 'FLAG19': - return FLAG19 - elif attr == 'FLAG20': - return FLAG20 - elif attr == 'FLAG21': - return FLAG21 - elif attr == 'FLAG22': - return FLAG22 - elif attr == 'FLAG23': - return FLAG23 - elif attr == 'FLAG24': - return FLAG24 - elif attr == 'FLAG25': - return FLAG25 else: raise Exception("TODO: Finish supporting attr mapping %s" % attr) @@ -163,7 +202,7 @@ cdef class Matcher: spec = _convert_strings(spec, self.vocab.strings) self.patterns.push_back(init_pattern(self.mem, spec, etype)) - def __call__(self, Doc doc): + def __call__(self, Doc doc, acceptor=None): cdef vector[Pattern*] partials cdef int n_partials = 0 cdef int q = 0 @@ -174,21 +213,94 @@ cdef class Matcher: for token_i in range(doc.length): token = &doc.data[token_i] q = 0 + # Go over the open matches, extending or finalizing if able. Otherwise, + # we over-write them (q doesn't advance) for i in range(partials.size()): state = partials.at(i) if match(state, token): if is_final(state): - matches.append(get_entity(state, token, token_i)) + label, start, end = get_entity(state, token, token_i) + if acceptor is None or acceptor(doc, label, start, end): + matches.append((label, start, end)) else: partials[q] = state + 1 q += 1 partials.resize(q) + # Check whether we open any new patterns on this token for i in range(self.n_patterns): state = self.patterns[i] if match(state, token): if is_final(state): - matches.append(get_entity(state, token, token_i)) + label, start, end = get_entity(state, token, token_i) + if acceptor is None or acceptor(doc, label, start, end): + matches.append((label, start, end)) else: partials.push_back(state + 1) doc.ents = [(e.label, e.start, e.end) for e in doc.ents] + matches return matches + + +cdef class PhraseMatcher: + cdef Pool mem + cdef Vocab vocab + cdef Matcher matcher + cdef PreshMap phrase_ids + + cdef int max_length + cdef attr_t* _phrase_key + + def __init__(self, Vocab vocab, phrases, max_length=10): + self.mem = Pool() + self._phrase_key = self.mem.alloc(max_length, sizeof(attr_t)) + self.max_length = max_length + self.vocab = vocab + self.matcher = Matcher(self.vocab, {}) + self.phrase_ids = PreshMap() + for phrase in phrases: + if len(phrase) < max_length: + self.add(phrase) + + abstract_patterns = [] + for length in range(1, max_length): + abstract_patterns.append([{tag: True} for tag in get_bilou(length)]) + self.matcher.add('Candidate', 'MWE', {}, abstract_patterns) + + def add(self, Doc tokens): + cdef int length = tokens.length + assert length < self.max_length + tags = get_bilou(length) + assert len(tags) == length, length + + cdef int i + for i in range(self.max_length): + self._phrase_key[i] = 0 + for i, tag in enumerate(tags): + lexeme = self.vocab[tokens.data[i].lex.orth] + lexeme.set_flag(tag, True) + self._phrase_key[i] = lexeme.orth + cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0) + self.phrase_ids[key] = True + + def __call__(self, Doc doc): + matches = [] + for label, start, end in self.matcher(doc, acceptor=self.accept_match): + cand = doc[start : end] + start = cand[0].idx + end = cand[-1].idx + len(cand[-1]) + matches.append((start, end, cand.root.tag_, cand.text, 'MWE')) + for match in matches: + doc.merge(*match) + return matches + + def accept_match(self, Doc doc, int label, int start, int end): + assert (end - start) < self.max_length + cdef int i, j + for i in range(self.max_length): + self._phrase_key[i] = 0 + for i, j in enumerate(range(start, end)): + self._phrase_key[i] = doc.data[j].lex.orth + cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0) + if self.phrase_ids.get(key): + return True + else: + return False