TEST: Doc.ents as SpanGroup

Overview of required changes to support `SpanGroup` rather than
`Tuple[Span]`:

* implement slice for `SpanGroup`
* return `SpanGroup` for `SpanGroup + x` or `x + SpanGroup` rather than
refusing to concatenate (currently without good error handling)

Side effects:

* if appending to `Doc.ents`, only `Iterable[Span]` is supported rather
than other formats like raw entity tuples from matcher results, but you
can still assign mixed data in any of the currently supported formats to
`Doc.ents`
  * for the `Matcher` case, `as_spans` provides a good alternative
This commit is contained in:
Adriane Boyd 2023-03-07 16:26:02 +01:00
parent 41b3a0d932
commit 31cd5141bd
6 changed files with 29 additions and 11 deletions

View File

@ -815,7 +815,7 @@ def test_doc_set_ents(en_tokenizer):
doc.set_ents([], blocked=[doc[1:2], doc[3:5]], default="unmodified") doc.set_ents([], blocked=[doc[1:2], doc[3:5]], default="unmodified")
assert [t.ent_iob for t in doc] == [0, 3, 0, 3, 3] assert [t.ent_iob for t in doc] == [0, 3, 0, 3, 3]
assert [t.ent_type for t in doc] == [0, 0, 0, 0, 0] assert [t.ent_type for t in doc] == [0, 0, 0, 0, 0]
assert doc.ents == tuple() assert len(doc.ents) == 0
# invalid IOB repaired after blocked # invalid IOB repaired after blocked
doc.ents = [Span(doc, 3, 5, "ENT")] doc.ents = [Span(doc, 3, 5, "ENT")]
@ -892,7 +892,7 @@ def test_doc_init_iob():
words = ["a", "b", "c", "d", "e"] words = ["a", "b", "c", "d", "e"]
ents = ["O"] * len(words) ents = ["O"] * len(words)
doc = Doc(Vocab(), words=words, ents=ents) doc = Doc(Vocab(), words=words, ents=ents)
assert doc.ents == () assert len(doc.ents) == 0
ents = ["B-PERSON", "I-PERSON", "O", "I-PERSON", "I-PERSON"] ents = ["B-PERSON", "I-PERSON", "O", "I-PERSON", "I-PERSON"]
doc = Doc(Vocab(), words=words, ents=ents) doc = Doc(Vocab(), words=words, ents=ents)

View File

@ -87,7 +87,10 @@ def test_issue118_prefix_reorder(en_tokenizer, patterns):
matcher.add("BostonCeltics", patterns) matcher.add("BostonCeltics", patterns)
assert len(list(doc.ents)) == 0 assert len(list(doc.ents)) == 0
matches = [(ORG, start, end) for _, start, end in matcher(doc)] matches = [(ORG, start, end) for _, start, end in matcher(doc)]
doc.ents += tuple(matches)[1:] matches_spans = matcher(doc, as_spans=True)
for span in matches_spans:
span.label = ORG
doc.ents += matches_spans[1:]
assert matches == [(ORG, 9, 10), (ORG, 9, 11)] assert matches == [(ORG, 9, 10), (ORG, 9, 11)]
ents = doc.ents ents = doc.ents
assert len(ents) == 1 assert len(ents) == 1
@ -116,7 +119,8 @@ def test_issue242(en_tokenizer):
with pytest.raises(ValueError): with pytest.raises(ValueError):
# One token can only be part of one entity, so test that the matches # One token can only be part of one entity, so test that the matches
# can't be added as entities # can't be added as entities
doc.ents += tuple(matches) matches_spans = matcher(doc, as_spans=True)
doc.ents += tuple(matches_spans)
@pytest.mark.issue(587) @pytest.mark.issue(587)

View File

@ -4,6 +4,7 @@ from cymem.cymem import Pool
from thinc.types import ArrayXd, Floats1d, Floats2d, Ints2d, Ragged from thinc.types import ArrayXd, Floats1d, Floats2d, Ints2d, Ragged
from .span import Span from .span import Span
from .token import Token from .token import Token
from .span_group import SpanGroup
from .span_groups import SpanGroups from .span_groups import SpanGroups
from .retokenizer import Retokenizer from .retokenizer import Retokenizer
from ..lexeme import Lexeme from ..lexeme import Lexeme
@ -120,7 +121,7 @@ class Doc:
def text(self) -> str: ... def text(self) -> str: ...
@property @property
def text_with_ws(self) -> str: ... def text_with_ws(self) -> str: ...
ents: Tuple[Span] ents: SpanGroup
def set_ents( def set_ents(
self, self,
entities: List[Span], entities: List[Span],

View File

@ -19,6 +19,7 @@ import warnings
from .span cimport Span from .span cimport Span
from .token cimport MISSING_DEP from .token cimport MISSING_DEP
from .span_group import SpanGroup
from .span_groups import SpanGroups from .span_groups import SpanGroups
from .token cimport Token from .token cimport Token
from ..lexeme cimport Lexeme, EMPTY_LEXEME from ..lexeme cimport Lexeme, EMPTY_LEXEME
@ -743,7 +744,7 @@ cdef class Doc:
output.append(Span(self, start, self.length, label=label, kb_id=kb_id, span_id=ent_id)) output.append(Span(self, start, self.length, label=label, kb_id=kb_id, span_id=ent_id))
# remove empty-label spans # remove empty-label spans
output = [o for o in output if o.label_ != ""] output = [o for o in output if o.label_ != ""]
return tuple(output) return SpanGroup(self, spans=output)
def __set__(self, ents): def __set__(self, ents):
# TODO: # TODO:

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, Iterable, Optional from typing import Any, Dict, Iterable, Optional, Union, overload
from .doc import Doc from .doc import Doc
from .span import Span from .span import Span
@ -22,7 +22,10 @@ class SpanGroup:
def __len__(self) -> int: ... def __len__(self) -> int: ...
def append(self, span: Span) -> None: ... def append(self, span: Span) -> None: ...
def extend(self, spans: Iterable[Span]) -> None: ... def extend(self, spans: Iterable[Span]) -> None: ...
@overload
def __getitem__(self, i: int) -> Span: ... def __getitem__(self, i: int) -> Span: ...
@overload
def __getitem__(self, i: slice) -> SpanGroup: ...
def to_bytes(self) -> bytes: ... def to_bytes(self) -> bytes: ...
def from_bytes(self, bytes_data: bytes) -> SpanGroup: ... def from_bytes(self, bytes_data: bytes) -> SpanGroup: ...
def copy(self, doc: Optional[Doc] = ...) -> SpanGroup: ... def copy(self, doc: Optional[Doc] = ...) -> SpanGroup: ...

View File

@ -4,10 +4,13 @@ import struct
from copy import deepcopy from copy import deepcopy
import srsly import srsly
from spacy.errors import Errors
from .span cimport Span
from libcpp.memory cimport make_shared from libcpp.memory cimport make_shared
from .span cimport Span
from ..errors import Errors
from .. import util
cdef class SpanGroup: cdef class SpanGroup:
"""A group of spans that all belong to the same Doc object. The group """A group of spans that all belong to the same Doc object. The group
@ -93,7 +96,7 @@ cdef class SpanGroup:
""" """
return self.c.size() return self.c.size()
def __getitem__(self, int i) -> Span: def __getitem__(self, object i) -> Span:
"""Get a span from the group. Note that a copy of the span is returned, """Get a span from the group. Note that a copy of the span is returned,
so if any changes are made to this span, they are not reflected in the so if any changes are made to this span, they are not reflected in the
corresponding member of the span group. corresponding member of the span group.
@ -103,6 +106,10 @@ cdef class SpanGroup:
DOCS: https://spacy.io/api/spangroup#getitem DOCS: https://spacy.io/api/spangroup#getitem
""" """
if isinstance(i, slice):
start, stop = util.normalize_slice(len(self), i.start, i.stop, i.step)
spans = [self[i] for i in range(start, stop)]
return SpanGroup(self.doc, spans=spans)
i = self._normalize_index(i) i = self._normalize_index(i)
return Span.cinit(self.doc, self.c[i]) return Span.cinit(self.doc, self.c[i])
@ -155,8 +162,10 @@ cdef class SpanGroup:
""" """
# For Cython 0.x and __add__, you cannot rely on `self` as being `self` # For Cython 0.x and __add__, you cannot rely on `self` as being `self`
# or being the right type, so both types need to be checked explicitly. # or being the right type, so both types need to be checked explicitly.
if isinstance(self, SpanGroup) and isinstance(other, SpanGroup): if isinstance(self, SpanGroup):
return self._concat(other) return self._concat(other)
if isinstance(other, SpanGroup):
return other._concat(self)
return NotImplemented return NotImplemented
def __iter__(self): def __iter__(self):