diff --git a/spacy/matcher/matcher.pyx b/spacy/matcher/matcher.pyx index 16ab73735..fdce7e9fa 100644 --- a/spacy/matcher/matcher.pyx +++ b/spacy/matcher/matcher.pyx @@ -203,13 +203,16 @@ cdef class Matcher: else: yield doc - def __call__(self, object doclike): + def __call__(self, object doclike, as_spans=False): """Find all token sequences matching the supplied pattern. doclike (Doc or Span): The document to match over. - RETURNS (list): A list of `(key, start, end)` tuples, + as_spans (bool): Return Span objects with labels instead of (match_id, + start, end) tuples. + RETURNS (list): A list of `(match_id, start, end)` tuples, describing the matches. A match tuple describes a span - `doc[start:end]`. The `label_id` and `key` are both integers. + `doc[start:end]`. The `match_id` is an integer. If as_spans is set + to True, a list of Span objects is returned. """ if isinstance(doclike, Doc): doc = doclike @@ -262,7 +265,10 @@ cdef class Matcher: on_match = self._callbacks.get(key, None) if on_match is not None: on_match(self, doc, i, final_matches) - return final_matches + if as_spans: + return [Span(doc, start, end, label=key) for key, start, end in final_matches] + else: + return final_matches def _normalize_key(self, key): if isinstance(key, basestring): diff --git a/spacy/matcher/phrasematcher.pyx b/spacy/matcher/phrasematcher.pyx index 060c4d37f..6658c713e 100644 --- a/spacy/matcher/phrasematcher.pyx +++ b/spacy/matcher/phrasematcher.pyx @@ -7,6 +7,7 @@ import warnings from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA from ..structs cimport TokenC from ..tokens.token cimport Token +from ..tokens.span cimport Span from ..typedefs cimport attr_t from ..schemas import TokenPattern @@ -216,13 +217,16 @@ cdef class PhraseMatcher: result = internal_node map_set(self.mem, result, self.vocab.strings[key], NULL) - def __call__(self, doc): + def __call__(self, doc, as_spans=False): """Find all sequences matching the supplied patterns on the `Doc`. doc (Doc): The document to match over. - RETURNS (list): A list of `(key, start, end)` tuples, + as_spans (bool): Return Span objects with labels instead of (match_id, + start, end) tuples. + RETURNS (list): A list of `(match_id, start, end)` tuples, describing the matches. A match tuple describes a span - `doc[start:end]`. The `label_id` and `key` are both integers. + `doc[start:end]`. The `match_id` is an integer. If as_spans is set + to True, a list of Span objects is returned. DOCS: https://spacy.io/api/phrasematcher#call """ @@ -239,7 +243,10 @@ cdef class PhraseMatcher: on_match = self._callbacks.get(self.vocab.strings[ent_id]) if on_match is not None: on_match(self, doc, i, matches) - return matches + if as_spans: + return [Span(doc, start, end, label=key) for key, start, end in matches] + else: + return matches cdef void find_matches(self, Doc doc, vector[SpanC] *matches) nogil: cdef MapStruct* current_node = self.c_map diff --git a/spacy/tests/matcher/test_matcher_api.py b/spacy/tests/matcher/test_matcher_api.py index bcb224bd3..aeac509b5 100644 --- a/spacy/tests/matcher/test_matcher_api.py +++ b/spacy/tests/matcher/test_matcher_api.py @@ -2,7 +2,8 @@ import pytest import re from mock import Mock from spacy.matcher import Matcher, DependencyMatcher -from spacy.tokens import Doc, Token +from spacy.tokens import Doc, Token, Span + from ..doc.test_underscore import clean_underscore # noqa: F401 @@ -469,3 +470,17 @@ def test_matcher_span(matcher): assert len(matcher(doc)) == 2 assert len(matcher(span_js)) == 1 assert len(matcher(span_java)) == 1 + + +def test_matcher_as_spans(matcher): + """Test the new as_spans=True API.""" + text = "JavaScript is good but Java is better" + doc = Doc(matcher.vocab, words=text.split()) + matches = matcher(doc, as_spans=True) + assert len(matches) == 2 + assert isinstance(matches[0], Span) + assert matches[0].text == "JavaScript" + assert matches[0].label_ == "JS" + assert isinstance(matches[1], Span) + assert matches[1].text == "Java" + assert matches[1].label_ == "Java" diff --git a/spacy/tests/matcher/test_phrase_matcher.py b/spacy/tests/matcher/test_phrase_matcher.py index 2a3c7d693..edffaa900 100644 --- a/spacy/tests/matcher/test_phrase_matcher.py +++ b/spacy/tests/matcher/test_phrase_matcher.py @@ -2,7 +2,7 @@ import pytest import srsly from mock import Mock from spacy.matcher import PhraseMatcher -from spacy.tokens import Doc +from spacy.tokens import Doc, Span from ..util import get_doc @@ -287,3 +287,19 @@ def test_phrase_matcher_pickle(en_vocab): # clunky way to vaguely check that callback is unpickled (vocab, docs, callbacks, attr) = matcher_unpickled.__reduce__()[1] assert isinstance(callbacks.get("TEST2"), Mock) + + +def test_phrase_matcher_as_spans(en_vocab): + """Test the new as_spans=True API.""" + matcher = PhraseMatcher(en_vocab) + matcher.add("A", [Doc(en_vocab, words=["hello", "world"])]) + matcher.add("B", [Doc(en_vocab, words=["test"])]) + doc = Doc(en_vocab, words=["...", "hello", "world", "this", "is", "a", "test"]) + matches = matcher(doc, as_spans=True) + assert len(matches) == 2 + assert isinstance(matches[0], Span) + assert matches[0].text == "hello world" + assert matches[0].label_ == "A" + assert isinstance(matches[1], Span) + assert matches[1].text == "test" + assert matches[1].label_ == "B"