TEST: Doc.ents as SpanGroup

Overview of required changes to support `SpanGroup` rather than
`Tuple[Span]`:

* implement slice for `SpanGroup`
* return `SpanGroup` for `SpanGroup + x` or `x + SpanGroup` rather than
refusing to concatenate (currently without good error handling)

Side effects:

* if appending to `Doc.ents`, only `Iterable[Span]` is supported rather
than other formats like raw entity tuples from matcher results, but you
can still assign mixed data in any of the currently supported formats to
`Doc.ents`
  * for the `Matcher` case, `as_spans` provides a good alternative
This commit is contained in:
Adriane Boyd 2023-03-07 16:26:02 +01:00
parent 41b3a0d932
commit 31cd5141bd
6 changed files with 29 additions and 11 deletions

View File

@ -815,7 +815,7 @@ def test_doc_set_ents(en_tokenizer):
doc.set_ents([], blocked=[doc[1:2], doc[3:5]], default="unmodified")
assert [t.ent_iob for t in doc] == [0, 3, 0, 3, 3]
assert [t.ent_type for t in doc] == [0, 0, 0, 0, 0]
assert doc.ents == tuple()
assert len(doc.ents) == 0
# invalid IOB repaired after blocked
doc.ents = [Span(doc, 3, 5, "ENT")]
@ -892,7 +892,7 @@ def test_doc_init_iob():
words = ["a", "b", "c", "d", "e"]
ents = ["O"] * len(words)
doc = Doc(Vocab(), words=words, ents=ents)
assert doc.ents == ()
assert len(doc.ents) == 0
ents = ["B-PERSON", "I-PERSON", "O", "I-PERSON", "I-PERSON"]
doc = Doc(Vocab(), words=words, ents=ents)

View File

@ -87,7 +87,10 @@ def test_issue118_prefix_reorder(en_tokenizer, patterns):
matcher.add("BostonCeltics", patterns)
assert len(list(doc.ents)) == 0
matches = [(ORG, start, end) for _, start, end in matcher(doc)]
doc.ents += tuple(matches)[1:]
matches_spans = matcher(doc, as_spans=True)
for span in matches_spans:
span.label = ORG
doc.ents += matches_spans[1:]
assert matches == [(ORG, 9, 10), (ORG, 9, 11)]
ents = doc.ents
assert len(ents) == 1
@ -116,7 +119,8 @@ def test_issue242(en_tokenizer):
with pytest.raises(ValueError):
# One token can only be part of one entity, so test that the matches
# can't be added as entities
doc.ents += tuple(matches)
matches_spans = matcher(doc, as_spans=True)
doc.ents += tuple(matches_spans)
@pytest.mark.issue(587)

View File

@ -4,6 +4,7 @@ from cymem.cymem import Pool
from thinc.types import ArrayXd, Floats1d, Floats2d, Ints2d, Ragged
from .span import Span
from .token import Token
from .span_group import SpanGroup
from .span_groups import SpanGroups
from .retokenizer import Retokenizer
from ..lexeme import Lexeme
@ -120,7 +121,7 @@ class Doc:
def text(self) -> str: ...
@property
def text_with_ws(self) -> str: ...
ents: Tuple[Span]
ents: SpanGroup
def set_ents(
self,
entities: List[Span],

View File

@ -19,6 +19,7 @@ import warnings
from .span cimport Span
from .token cimport MISSING_DEP
from .span_group import SpanGroup
from .span_groups import SpanGroups
from .token cimport Token
from ..lexeme cimport Lexeme, EMPTY_LEXEME
@ -743,7 +744,7 @@ cdef class Doc:
output.append(Span(self, start, self.length, label=label, kb_id=kb_id, span_id=ent_id))
# remove empty-label spans
output = [o for o in output if o.label_ != ""]
return tuple(output)
return SpanGroup(self, spans=output)
def __set__(self, ents):
# TODO:

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, Iterable, Optional
from typing import Any, Dict, Iterable, Optional, Union, overload
from .doc import Doc
from .span import Span
@ -22,7 +22,10 @@ class SpanGroup:
def __len__(self) -> int: ...
def append(self, span: Span) -> None: ...
def extend(self, spans: Iterable[Span]) -> None: ...
@overload
def __getitem__(self, i: int) -> Span: ...
@overload
def __getitem__(self, i: slice) -> SpanGroup: ...
def to_bytes(self) -> bytes: ...
def from_bytes(self, bytes_data: bytes) -> SpanGroup: ...
def copy(self, doc: Optional[Doc] = ...) -> SpanGroup: ...

View File

@ -4,10 +4,13 @@ import struct
from copy import deepcopy
import srsly
from spacy.errors import Errors
from .span cimport Span
from libcpp.memory cimport make_shared
from .span cimport Span
from ..errors import Errors
from .. import util
cdef class SpanGroup:
"""A group of spans that all belong to the same Doc object. The group
@ -93,7 +96,7 @@ cdef class SpanGroup:
"""
return self.c.size()
def __getitem__(self, int i) -> Span:
def __getitem__(self, object i) -> Span:
"""Get a span from the group. Note that a copy of the span is returned,
so if any changes are made to this span, they are not reflected in the
corresponding member of the span group.
@ -103,6 +106,10 @@ cdef class SpanGroup:
DOCS: https://spacy.io/api/spangroup#getitem
"""
if isinstance(i, slice):
start, stop = util.normalize_slice(len(self), i.start, i.stop, i.step)
spans = [self[i] for i in range(start, stop)]
return SpanGroup(self.doc, spans=spans)
i = self._normalize_index(i)
return Span.cinit(self.doc, self.c[i])
@ -155,8 +162,10 @@ cdef class SpanGroup:
"""
# For Cython 0.x and __add__, you cannot rely on `self` as being `self`
# or being the right type, so both types need to be checked explicitly.
if isinstance(self, SpanGroup) and isinstance(other, SpanGroup):
if isinstance(self, SpanGroup):
return self._concat(other)
if isinstance(other, SpanGroup):
return other._concat(self)
return NotImplemented
def __iter__(self):