mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-11 04:08:09 +03:00
Add as_spans to Matcher/PhraseMatcher
This commit is contained in:
parent
ec14744ee4
commit
6340d1c63d
|
@ -203,13 +203,16 @@ cdef class Matcher:
|
|||
else:
|
||||
yield doc
|
||||
|
||||
def __call__(self, object doclike):
|
||||
def __call__(self, object doclike, as_spans=False):
|
||||
"""Find all token sequences matching the supplied pattern.
|
||||
|
||||
doclike (Doc or Span): The document to match over.
|
||||
RETURNS (list): A list of `(key, start, end)` tuples,
|
||||
as_spans (bool): Return Span objects with labels instead of (match_id,
|
||||
start, end) tuples.
|
||||
RETURNS (list): A list of `(match_id, start, end)` tuples,
|
||||
describing the matches. A match tuple describes a span
|
||||
`doc[start:end]`. The `label_id` and `key` are both integers.
|
||||
`doc[start:end]`. The `match_id` is an integer. If as_spans is set
|
||||
to True, a list of Span objects is returned.
|
||||
"""
|
||||
if isinstance(doclike, Doc):
|
||||
doc = doclike
|
||||
|
@ -262,7 +265,10 @@ cdef class Matcher:
|
|||
on_match = self._callbacks.get(key, None)
|
||||
if on_match is not None:
|
||||
on_match(self, doc, i, final_matches)
|
||||
return final_matches
|
||||
if as_spans:
|
||||
return [Span(doc, start, end, label=key) for key, start, end in final_matches]
|
||||
else:
|
||||
return final_matches
|
||||
|
||||
def _normalize_key(self, key):
|
||||
if isinstance(key, basestring):
|
||||
|
|
|
@ -7,6 +7,7 @@ import warnings
|
|||
from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA
|
||||
from ..structs cimport TokenC
|
||||
from ..tokens.token cimport Token
|
||||
from ..tokens.span cimport Span
|
||||
from ..typedefs cimport attr_t
|
||||
|
||||
from ..schemas import TokenPattern
|
||||
|
@ -216,13 +217,16 @@ cdef class PhraseMatcher:
|
|||
result = internal_node
|
||||
map_set(self.mem, <MapStruct*>result, self.vocab.strings[key], NULL)
|
||||
|
||||
def __call__(self, doc):
|
||||
def __call__(self, doc, as_spans=False):
|
||||
"""Find all sequences matching the supplied patterns on the `Doc`.
|
||||
|
||||
doc (Doc): The document to match over.
|
||||
RETURNS (list): A list of `(key, start, end)` tuples,
|
||||
as_spans (bool): Return Span objects with labels instead of (match_id,
|
||||
start, end) tuples.
|
||||
RETURNS (list): A list of `(match_id, start, end)` tuples,
|
||||
describing the matches. A match tuple describes a span
|
||||
`doc[start:end]`. The `label_id` and `key` are both integers.
|
||||
`doc[start:end]`. The `match_id` is an integer. If as_spans is set
|
||||
to True, a list of Span objects is returned.
|
||||
|
||||
DOCS: https://spacy.io/api/phrasematcher#call
|
||||
"""
|
||||
|
@ -239,7 +243,10 @@ cdef class PhraseMatcher:
|
|||
on_match = self._callbacks.get(self.vocab.strings[ent_id])
|
||||
if on_match is not None:
|
||||
on_match(self, doc, i, matches)
|
||||
return matches
|
||||
if as_spans:
|
||||
return [Span(doc, start, end, label=key) for key, start, end in matches]
|
||||
else:
|
||||
return matches
|
||||
|
||||
cdef void find_matches(self, Doc doc, vector[SpanC] *matches) nogil:
|
||||
cdef MapStruct* current_node = self.c_map
|
||||
|
|
|
@ -2,7 +2,8 @@ import pytest
|
|||
import re
|
||||
from mock import Mock
|
||||
from spacy.matcher import Matcher, DependencyMatcher
|
||||
from spacy.tokens import Doc, Token
|
||||
from spacy.tokens import Doc, Token, Span
|
||||
|
||||
from ..doc.test_underscore import clean_underscore # noqa: F401
|
||||
|
||||
|
||||
|
@ -469,3 +470,17 @@ def test_matcher_span(matcher):
|
|||
assert len(matcher(doc)) == 2
|
||||
assert len(matcher(span_js)) == 1
|
||||
assert len(matcher(span_java)) == 1
|
||||
|
||||
|
||||
def test_matcher_as_spans(matcher):
|
||||
"""Test the new as_spans=True API."""
|
||||
text = "JavaScript is good but Java is better"
|
||||
doc = Doc(matcher.vocab, words=text.split())
|
||||
matches = matcher(doc, as_spans=True)
|
||||
assert len(matches) == 2
|
||||
assert isinstance(matches[0], Span)
|
||||
assert matches[0].text == "JavaScript"
|
||||
assert matches[0].label_ == "JS"
|
||||
assert isinstance(matches[1], Span)
|
||||
assert matches[1].text == "Java"
|
||||
assert matches[1].label_ == "Java"
|
||||
|
|
|
@ -2,7 +2,7 @@ import pytest
|
|||
import srsly
|
||||
from mock import Mock
|
||||
from spacy.matcher import PhraseMatcher
|
||||
from spacy.tokens import Doc
|
||||
from spacy.tokens import Doc, Span
|
||||
from ..util import get_doc
|
||||
|
||||
|
||||
|
@ -287,3 +287,19 @@ def test_phrase_matcher_pickle(en_vocab):
|
|||
# clunky way to vaguely check that callback is unpickled
|
||||
(vocab, docs, callbacks, attr) = matcher_unpickled.__reduce__()[1]
|
||||
assert isinstance(callbacks.get("TEST2"), Mock)
|
||||
|
||||
|
||||
def test_phrase_matcher_as_spans(en_vocab):
|
||||
"""Test the new as_spans=True API."""
|
||||
matcher = PhraseMatcher(en_vocab)
|
||||
matcher.add("A", [Doc(en_vocab, words=["hello", "world"])])
|
||||
matcher.add("B", [Doc(en_vocab, words=["test"])])
|
||||
doc = Doc(en_vocab, words=["...", "hello", "world", "this", "is", "a", "test"])
|
||||
matches = matcher(doc, as_spans=True)
|
||||
assert len(matches) == 2
|
||||
assert isinstance(matches[0], Span)
|
||||
assert matches[0].text == "hello world"
|
||||
assert matches[0].label_ == "A"
|
||||
assert isinstance(matches[1], Span)
|
||||
assert matches[1].text == "test"
|
||||
assert matches[1].label_ == "B"
|
||||
|
|
Loading…
Reference in New Issue
Block a user