Add as_spans to Matcher/PhraseMatcher

This commit is contained in:
Ines Montani 2020-08-31 14:53:22 +02:00
parent ec14744ee4
commit 6340d1c63d
4 changed files with 54 additions and 10 deletions

View File

@ -203,13 +203,16 @@ cdef class Matcher:
else:
yield doc
def __call__(self, object doclike):
def __call__(self, object doclike, as_spans=False):
"""Find all token sequences matching the supplied pattern.
doclike (Doc or Span): The document to match over.
RETURNS (list): A list of `(key, start, end)` tuples,
as_spans (bool): Return Span objects with labels instead of (match_id,
start, end) tuples.
RETURNS (list): A list of `(match_id, start, end)` tuples,
describing the matches. A match tuple describes a span
`doc[start:end]`. The `label_id` and `key` are both integers.
`doc[start:end]`. The `match_id` is an integer. If as_spans is set
to True, a list of Span objects is returned.
"""
if isinstance(doclike, Doc):
doc = doclike
@ -262,7 +265,10 @@ cdef class Matcher:
on_match = self._callbacks.get(key, None)
if on_match is not None:
on_match(self, doc, i, final_matches)
return final_matches
if as_spans:
return [Span(doc, start, end, label=key) for key, start, end in final_matches]
else:
return final_matches
def _normalize_key(self, key):
if isinstance(key, basestring):

View File

@ -7,6 +7,7 @@ import warnings
from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA
from ..structs cimport TokenC
from ..tokens.token cimport Token
from ..tokens.span cimport Span
from ..typedefs cimport attr_t
from ..schemas import TokenPattern
@ -216,13 +217,16 @@ cdef class PhraseMatcher:
result = internal_node
map_set(self.mem, <MapStruct*>result, self.vocab.strings[key], NULL)
def __call__(self, doc):
def __call__(self, doc, as_spans=False):
"""Find all sequences matching the supplied patterns on the `Doc`.
doc (Doc): The document to match over.
RETURNS (list): A list of `(key, start, end)` tuples,
as_spans (bool): Return Span objects with labels instead of (match_id,
start, end) tuples.
RETURNS (list): A list of `(match_id, start, end)` tuples,
describing the matches. A match tuple describes a span
`doc[start:end]`. The `label_id` and `key` are both integers.
`doc[start:end]`. The `match_id` is an integer. If as_spans is set
to True, a list of Span objects is returned.
DOCS: https://spacy.io/api/phrasematcher#call
"""
@ -239,7 +243,10 @@ cdef class PhraseMatcher:
on_match = self._callbacks.get(self.vocab.strings[ent_id])
if on_match is not None:
on_match(self, doc, i, matches)
return matches
if as_spans:
return [Span(doc, start, end, label=key) for key, start, end in matches]
else:
return matches
cdef void find_matches(self, Doc doc, vector[SpanC] *matches) nogil:
cdef MapStruct* current_node = self.c_map

View File

@ -2,7 +2,8 @@ import pytest
import re
from mock import Mock
from spacy.matcher import Matcher, DependencyMatcher
from spacy.tokens import Doc, Token
from spacy.tokens import Doc, Token, Span
from ..doc.test_underscore import clean_underscore # noqa: F401
@ -469,3 +470,17 @@ def test_matcher_span(matcher):
assert len(matcher(doc)) == 2
assert len(matcher(span_js)) == 1
assert len(matcher(span_java)) == 1
def test_matcher_as_spans(matcher):
"""Test the new as_spans=True API."""
text = "JavaScript is good but Java is better"
doc = Doc(matcher.vocab, words=text.split())
matches = matcher(doc, as_spans=True)
assert len(matches) == 2
assert isinstance(matches[0], Span)
assert matches[0].text == "JavaScript"
assert matches[0].label_ == "JS"
assert isinstance(matches[1], Span)
assert matches[1].text == "Java"
assert matches[1].label_ == "Java"

View File

@ -2,7 +2,7 @@ import pytest
import srsly
from mock import Mock
from spacy.matcher import PhraseMatcher
from spacy.tokens import Doc
from spacy.tokens import Doc, Span
from ..util import get_doc
@ -287,3 +287,19 @@ def test_phrase_matcher_pickle(en_vocab):
# clunky way to vaguely check that callback is unpickled
(vocab, docs, callbacks, attr) = matcher_unpickled.__reduce__()[1]
assert isinstance(callbacks.get("TEST2"), Mock)
def test_phrase_matcher_as_spans(en_vocab):
"""Test the new as_spans=True API."""
matcher = PhraseMatcher(en_vocab)
matcher.add("A", [Doc(en_vocab, words=["hello", "world"])])
matcher.add("B", [Doc(en_vocab, words=["test"])])
doc = Doc(en_vocab, words=["...", "hello", "world", "this", "is", "a", "test"])
matches = matcher(doc, as_spans=True)
assert len(matches) == 2
assert isinstance(matches[0], Span)
assert matches[0].text == "hello world"
assert matches[0].label_ == "A"
assert isinstance(matches[1], Span)
assert matches[1].text == "test"
assert matches[1].label_ == "B"