diff --git a/spacy/matcher.pyx b/spacy/matcher.pyx index d321218b8..ba3559966 100644 --- a/spacy/matcher.pyx +++ b/spacy/matcher.pyx @@ -421,52 +421,67 @@ cdef class PhraseMatcher: cdef int max_length cdef attr_t* _phrase_key - def __init__(self, Vocab vocab, phrases, max_length=10): + cdef public object _callbacks + + def __init__(self, Vocab vocab, max_length=10): self.mem = Pool() self._phrase_key = self.mem.alloc(max_length, sizeof(attr_t)) self.max_length = max_length self.vocab = vocab self.matcher = Matcher(self.vocab) self.phrase_ids = PreshMap() - for phrase in phrases: - if len(phrase) < max_length: - self.add(phrase) - abstract_patterns = [] for length in range(1, max_length): abstract_patterns.append([{tag: True} for tag in get_bilou(length)]) self.matcher.add('Candidate', None, *abstract_patterns) + self._callbacks = {} - def add(self, Doc tokens): - cdef int length = tokens.length - assert length < self.max_length - tags = get_bilou(length) - assert len(tags) == length, length + def add(self, key, on_match, *docs): + cdef Doc doc + for doc in docs: + if len(doc) >= self.max_length: + msg = ( + "Pattern length (%d) >= phrase_matcher.max_length (%d). " + "Length can be set on initialization, up to 10." + ) + raise ValueError(msg % (len(doc), self.max_length)) + cdef hash_t ent_id = self.matcher._normalize_key(key) + self._callbacks[ent_id] = on_match + cdef int length cdef int i - for i in range(self.max_length): - self._phrase_key[i] = 0 - for i, tag in enumerate(tags): - lexeme = self.vocab[tokens.c[i].lex.orth] - lexeme.set_flag(tag, True) - self._phrase_key[i] = lexeme.orth - cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0) - self.phrase_ids[key] = True + cdef hash_t phrase_hash + for doc in docs: + length = doc.length + tags = get_bilou(length) + for i in range(self.max_length): + self._phrase_key[i] = 0 + for i, tag in enumerate(tags): + lexeme = self.vocab[doc.c[i].lex.orth] + lexeme.set_flag(tag, True) + self._phrase_key[i] = lexeme.orth + phrase_hash = hash64(self._phrase_key, + self.max_length * sizeof(attr_t), 0) + self.phrase_ids[phrase_hash] = ent_id def __call__(self, Doc doc): - matches = self.matcher(doc) - accepted = [] - for ent_id, start, end in matches: - if self.accept_match(doc, ent_id, start, end): - accepted.append((ent_id, start, end)) - return accepted + matches = [] + for _, start, end in self.matcher(doc): + ent_id = self.accept_match(doc, start, end) + if ent_id is not None: + matches.append((ent_id, start, end)) + for i, (ent_id, start, end) in enumerate(matches): + on_match = self._callbacks.get(ent_id) + if on_match is not None: + on_match(self, doc, i, matches) + return matches def pipe(self, stream, batch_size=1000, n_threads=2): for doc in stream: self(doc) yield doc - def accept_match(self, Doc doc, attr_t ent_id, int start, int end): + def accept_match(self, Doc doc, int start, int end): assert (end - start) < self.max_length cdef int i, j for i in range(self.max_length): @@ -474,7 +489,8 @@ cdef class PhraseMatcher: for i, j in enumerate(range(start, end)): self._phrase_key[i] = doc.c[j].lex.orth cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0) - if self.phrase_ids.get(key): - return True + ent_id = self.phrase_ids.get(key) + if ent_id == 0: + return None else: - return False + return ent_id diff --git a/spacy/tests/test_matcher.py b/spacy/tests/test_matcher.py index 651707019..1b9f92519 100644 --- a/spacy/tests/test_matcher.py +++ b/spacy/tests/test_matcher.py @@ -101,7 +101,8 @@ def test_matcher_match_multi(matcher): def test_matcher_phrase_matcher(en_vocab): words = ["Google", "Now"] doc = get_doc(en_vocab, words) - matcher = PhraseMatcher(en_vocab, [doc]) + matcher = PhraseMatcher(en_vocab) + matcher.add('COMPANY', None, doc) words = ["I", "like", "Google", "Now", "best"] doc = get_doc(en_vocab, words) assert len(matcher(doc)) == 1