diff --git a/spacy/matcher.pyx b/spacy/matcher.pyx index 59213bfc1..b9d7ea5f4 100644 --- a/spacy/matcher.pyx +++ b/spacy/matcher.pyx @@ -532,30 +532,16 @@ def _get_longest_matches(matches): def get_bilou(length): - if length == 1: + if length == 0: + raise ValueError("Length must be >= 1") + elif length == 1: return [U_ENT] elif length == 2: return [B2_ENT, L2_ENT] elif length == 3: return [B3_ENT, I3_ENT, L3_ENT] - elif length == 4: - return [B4_ENT, I4_ENT, I4_ENT, L4_ENT] - elif length == 5: - return [B5_ENT, I5_ENT, I5_ENT, I5_ENT, L5_ENT] - elif length == 6: - return [B6_ENT, I6_ENT, I6_ENT, I6_ENT, I6_ENT, L6_ENT] - elif length == 7: - return [B7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, L7_ENT] - elif length == 8: - return [B8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, L8_ENT] - elif length == 9: - return [B9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, - L9_ENT] - elif length == 10: - return [B10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, - I10_ENT, I10_ENT, L10_ENT] else: - raise ValueError("Max length currently 10 for phrase matching") + return [B4_ENT, I4_ENT] + [I4_ENT] * (length-3) + [L4_ENT] cdef class PhraseMatcher: @@ -564,21 +550,21 @@ cdef class PhraseMatcher: cdef Matcher matcher cdef PreshMap phrase_ids cdef int max_length - cdef attr_t* _phrase_key cdef public object _callbacks cdef public object _patterns def __init__(self, Vocab vocab, max_length=10): self.mem = Pool() - self._phrase_key = self.mem.alloc(max_length, sizeof(attr_t)) self.max_length = max_length self.vocab = vocab self.matcher = Matcher(self.vocab) self.phrase_ids = PreshMap() - abstract_patterns = [] - for length in range(1, max_length): - abstract_patterns.append([{tag: True} - for tag in get_bilou(length)]) + abstract_patterns = [ + [{U_ENT: True}], + [{B2_ENT: True}, {L2_ENT: True}], + [{B3_ENT: True}, {I3_ENT: True}, {L3_ENT: True}], + [{B4_ENT: True}, {I4_ENT: True}, {I4_ENT: True, "OP": "+"}, {L4_ENT: True}], + ] self.matcher.add('Candidate', None, *abstract_patterns) self._callbacks = {} @@ -612,29 +598,24 @@ cdef class PhraseMatcher: *docs (Doc): `Doc` objects representing match patterns. """ cdef Doc doc - for doc in docs: - if len(doc) >= self.max_length: - msg = ( - "Pattern length (%d) >= phrase_matcher.max_length (%d). " - "Length can be set on initialization, up to 10." - ) - raise ValueError(msg % (len(doc), self.max_length)) cdef hash_t ent_id = self.matcher._normalize_key(key) self._callbacks[ent_id] = on_match cdef int length cdef int i cdef hash_t phrase_hash + cdef Pool mem = Pool() for doc in docs: length = doc.length + if length == 0: + continue tags = get_bilou(length) - for i in range(self.max_length): - self._phrase_key[i] = 0 + phrase_key = mem.alloc(length, sizeof(attr_t)) for i, tag in enumerate(tags): lexeme = self.vocab[doc.c[i].lex.orth] lexeme.set_flag(tag, True) - self._phrase_key[i] = lexeme.orth - phrase_hash = hash64(self._phrase_key, - self.max_length * sizeof(attr_t), 0) + phrase_key[i] = lexeme.orth + phrase_hash = hash64(phrase_key, + length * sizeof(attr_t), 0) self.phrase_ids.set(phrase_hash, ent_id) def __call__(self, Doc doc): @@ -670,14 +651,13 @@ cdef class PhraseMatcher: yield doc def accept_match(self, Doc doc, int start, int end): - assert (end - start) < self.max_length cdef int i, j - for i in range(self.max_length): - self._phrase_key[i] = 0 + cdef Pool mem = Pool() + phrase_key = mem.alloc(end-start, sizeof(attr_t)) for i, j in enumerate(range(start, end)): - self._phrase_key[i] = doc.c[j].lex.orth - cdef hash_t key = hash64(self._phrase_key, - self.max_length * sizeof(attr_t), 0) + phrase_key[i] = doc.c[j].lex.orth + cdef hash_t key = hash64(phrase_key, + (end-start) * sizeof(attr_t), 0) ent_id = self.phrase_ids.get(key) if ent_id == 0: return None