diff --git a/spacy/matcher/phrasematcher.pyx b/spacy/matcher/phrasematcher.pyx index 375b25b46..048b5bd68 100644 --- a/spacy/matcher/phrasematcher.pyx +++ b/spacy/matcher/phrasematcher.pyx @@ -184,56 +184,34 @@ cdef class PhraseMatcher: return matches current_dict = self.keyword_trie_dict start = 0 - reset_current_dict = False idx = 0 doc_array_len = len(doc_array) while idx < doc_array_len: + start = idx token = doc_array[idx] - # if end is present in current_dict - if self._terminal in current_dict or token in current_dict: - if self._terminal in current_dict: - ent_ids = current_dict[self._terminal] - for ent_id in ent_ids: - matches.append((self.vocab.strings[ent_id], start, idx)) - - # look for longer sequences from this position - if token in current_dict: - current_dict_continued = current_dict[token] - - idy = idx + 1 - while idy < doc_array_len: - inner_token = doc_array[idy] - if self._terminal in current_dict_continued: - ent_ids = current_dict_continued[self._terminal] - for ent_id in ent_ids: - matches.append((self.vocab.strings[ent_id], start, idy)) - if inner_token in current_dict_continued: - current_dict_continued = current_dict_continued[inner_token] - else: - break + # look for sequences from this position + if token in current_dict: + current_dict_continued = current_dict[token] + idy = idx + 1 + while idy < doc_array_len: + if self._terminal in current_dict_continued: + ent_ids = current_dict_continued[self._terminal] + for ent_id in ent_ids: + matches.append((self.vocab.strings[ent_id], start, idy)) + inner_token = doc_array[idy] + if inner_token in current_dict_continued: + current_dict_continued = current_dict_continued[inner_token] idy += 1 else: - # end of doc_array reached - if self._terminal in current_dict_continued: - ent_ids = current_dict_continued[self._terminal] - for ent_id in ent_ids: - matches.append((self.vocab.strings[ent_id], start, idy)) - current_dict = self.keyword_trie_dict - reset_current_dict = True - else: - # we reset current_dict - current_dict = self.keyword_trie_dict - reset_current_dict = True - # if we are end of doc_array and have a sequence discovered - if idx + 1 >= doc_array_len: - if self._terminal in current_dict: - ent_ids = current_dict[self._terminal] - for ent_id in ent_ids: - matches.append((self.vocab.strings[ent_id], start, doc_array_len)) + break + else: + # end of doc_array reached + if self._terminal in current_dict_continued: + ent_ids = current_dict_continued[self._terminal] + for ent_id in ent_ids: + matches.append((self.vocab.strings[ent_id], start, idy)) + current_dict = self.keyword_trie_dict idx += 1 - if reset_current_dict: - reset_current_dict = False - start = idx for i, (ent_id, start, end) in enumerate(matches): on_match = self._callbacks.get(ent_id) if on_match is not None: diff --git a/spacy/tests/matcher/test_phrase_matcher.py b/spacy/tests/matcher/test_phrase_matcher.py index a9f5ac990..3eb454c39 100644 --- a/spacy/tests/matcher/test_phrase_matcher.py +++ b/spacy/tests/matcher/test_phrase_matcher.py @@ -8,10 +8,31 @@ from ..util import get_doc def test_matcher_phrase_matcher(en_vocab): - doc = Doc(en_vocab, words=["Google", "Now"]) - matcher = PhraseMatcher(en_vocab) - matcher.add("COMPANY", None, doc) doc = Doc(en_vocab, words=["I", "like", "Google", "Now", "best"]) + # intermediate phrase + pattern = Doc(en_vocab, words=["Google", "Now"]) + matcher = PhraseMatcher(en_vocab) + matcher.add("COMPANY", None, pattern) + assert len(matcher(doc)) == 1 + # initial token + pattern = Doc(en_vocab, words=["I"]) + matcher = PhraseMatcher(en_vocab) + matcher.add("I", None, pattern) + assert len(matcher(doc)) == 1 + # initial phrase + pattern = Doc(en_vocab, words=["I", "like"]) + matcher = PhraseMatcher(en_vocab) + matcher.add("ILIKE", None, pattern) + assert len(matcher(doc)) == 1 + # final token + pattern = Doc(en_vocab, words=["best"]) + matcher = PhraseMatcher(en_vocab) + matcher.add("BEST", None, pattern) + assert len(matcher(doc)) == 1 + # final phrase + pattern = Doc(en_vocab, words=["Now", "best"]) + matcher = PhraseMatcher(en_vocab) + matcher.add("NOWBEST", None, pattern) assert len(matcher(doc)) == 1