From 31cd5141bd59c0c5d15467b585c8464d1fff21a2 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Tue, 7 Mar 2023 16:26:02 +0100 Subject: [PATCH] TEST: Doc.ents as SpanGroup Overview of required changes to support `SpanGroup` rather than `Tuple[Span]`: * implement slice for `SpanGroup` * return `SpanGroup` for `SpanGroup + x` or `x + SpanGroup` rather than refusing to concatenate (currently without good error handling) Side effects: * if appending to `Doc.ents`, only `Iterable[Span]` is supported rather than other formats like raw entity tuples from matcher results, but you can still assign mixed data in any of the currently supported formats to `Doc.ents` * for the `Matcher` case, `as_spans` provides a good alternative --- spacy/tests/doc/test_doc_api.py | 4 ++-- spacy/tests/matcher/test_matcher_logic.py | 8 ++++++-- spacy/tokens/doc.pyi | 3 ++- spacy/tokens/doc.pyx | 3 ++- spacy/tokens/span_group.pyi | 5 ++++- spacy/tokens/span_group.pyx | 17 +++++++++++++---- 6 files changed, 29 insertions(+), 11 deletions(-) diff --git a/spacy/tests/doc/test_doc_api.py b/spacy/tests/doc/test_doc_api.py index 2009a29d6..438bd7b7a 100644 --- a/spacy/tests/doc/test_doc_api.py +++ b/spacy/tests/doc/test_doc_api.py @@ -815,7 +815,7 @@ def test_doc_set_ents(en_tokenizer): doc.set_ents([], blocked=[doc[1:2], doc[3:5]], default="unmodified") assert [t.ent_iob for t in doc] == [0, 3, 0, 3, 3] assert [t.ent_type for t in doc] == [0, 0, 0, 0, 0] - assert doc.ents == tuple() + assert len(doc.ents) == 0 # invalid IOB repaired after blocked doc.ents = [Span(doc, 3, 5, "ENT")] @@ -892,7 +892,7 @@ def test_doc_init_iob(): words = ["a", "b", "c", "d", "e"] ents = ["O"] * len(words) doc = Doc(Vocab(), words=words, ents=ents) - assert doc.ents == () + assert len(doc.ents) == 0 ents = ["B-PERSON", "I-PERSON", "O", "I-PERSON", "I-PERSON"] doc = Doc(Vocab(), words=words, ents=ents) diff --git a/spacy/tests/matcher/test_matcher_logic.py b/spacy/tests/matcher/test_matcher_logic.py index 3b65fee23..7d1fad02b 100644 --- a/spacy/tests/matcher/test_matcher_logic.py +++ b/spacy/tests/matcher/test_matcher_logic.py @@ -87,7 +87,10 @@ def test_issue118_prefix_reorder(en_tokenizer, patterns): matcher.add("BostonCeltics", patterns) assert len(list(doc.ents)) == 0 matches = [(ORG, start, end) for _, start, end in matcher(doc)] - doc.ents += tuple(matches)[1:] + matches_spans = matcher(doc, as_spans=True) + for span in matches_spans: + span.label = ORG + doc.ents += matches_spans[1:] assert matches == [(ORG, 9, 10), (ORG, 9, 11)] ents = doc.ents assert len(ents) == 1 @@ -116,7 +119,8 @@ def test_issue242(en_tokenizer): with pytest.raises(ValueError): # One token can only be part of one entity, so test that the matches # can't be added as entities - doc.ents += tuple(matches) + matches_spans = matcher(doc, as_spans=True) + doc.ents += tuple(matches_spans) @pytest.mark.issue(587) diff --git a/spacy/tokens/doc.pyi b/spacy/tokens/doc.pyi index 48bc21c27..fae8efcdc 100644 --- a/spacy/tokens/doc.pyi +++ b/spacy/tokens/doc.pyi @@ -4,6 +4,7 @@ from cymem.cymem import Pool from thinc.types import ArrayXd, Floats1d, Floats2d, Ints2d, Ragged from .span import Span from .token import Token +from .span_group import SpanGroup from .span_groups import SpanGroups from .retokenizer import Retokenizer from ..lexeme import Lexeme @@ -120,7 +121,7 @@ class Doc: def text(self) -> str: ... @property def text_with_ws(self) -> str: ... - ents: Tuple[Span] + ents: SpanGroup def set_ents( self, entities: List[Span], diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 0ea2c39ab..2ac2b1c50 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -19,6 +19,7 @@ import warnings from .span cimport Span from .token cimport MISSING_DEP +from .span_group import SpanGroup from .span_groups import SpanGroups from .token cimport Token from ..lexeme cimport Lexeme, EMPTY_LEXEME @@ -743,7 +744,7 @@ cdef class Doc: output.append(Span(self, start, self.length, label=label, kb_id=kb_id, span_id=ent_id)) # remove empty-label spans output = [o for o in output if o.label_ != ""] - return tuple(output) + return SpanGroup(self, spans=output) def __set__(self, ents): # TODO: diff --git a/spacy/tokens/span_group.pyi b/spacy/tokens/span_group.pyi index 0b4aa83aa..46b13d937 100644 --- a/spacy/tokens/span_group.pyi +++ b/spacy/tokens/span_group.pyi @@ -1,4 +1,4 @@ -from typing import Any, Dict, Iterable, Optional +from typing import Any, Dict, Iterable, Optional, Union, overload from .doc import Doc from .span import Span @@ -22,7 +22,10 @@ class SpanGroup: def __len__(self) -> int: ... def append(self, span: Span) -> None: ... def extend(self, spans: Iterable[Span]) -> None: ... + @overload def __getitem__(self, i: int) -> Span: ... + @overload + def __getitem__(self, i: slice) -> SpanGroup: ... def to_bytes(self) -> bytes: ... def from_bytes(self, bytes_data: bytes) -> SpanGroup: ... def copy(self, doc: Optional[Doc] = ...) -> SpanGroup: ... diff --git a/spacy/tokens/span_group.pyx b/spacy/tokens/span_group.pyx index 7325c1fa7..9bf73f122 100644 --- a/spacy/tokens/span_group.pyx +++ b/spacy/tokens/span_group.pyx @@ -4,10 +4,13 @@ import struct from copy import deepcopy import srsly -from spacy.errors import Errors -from .span cimport Span from libcpp.memory cimport make_shared +from .span cimport Span + +from ..errors import Errors +from .. import util + cdef class SpanGroup: """A group of spans that all belong to the same Doc object. The group @@ -93,7 +96,7 @@ cdef class SpanGroup: """ return self.c.size() - def __getitem__(self, int i) -> Span: + def __getitem__(self, object i) -> Span: """Get a span from the group. Note that a copy of the span is returned, so if any changes are made to this span, they are not reflected in the corresponding member of the span group. @@ -103,6 +106,10 @@ cdef class SpanGroup: DOCS: https://spacy.io/api/spangroup#getitem """ + if isinstance(i, slice): + start, stop = util.normalize_slice(len(self), i.start, i.stop, i.step) + spans = [self[i] for i in range(start, stop)] + return SpanGroup(self.doc, spans=spans) i = self._normalize_index(i) return Span.cinit(self.doc, self.c[i]) @@ -155,8 +162,10 @@ cdef class SpanGroup: """ # For Cython 0.x and __add__, you cannot rely on `self` as being `self` # or being the right type, so both types need to be checked explicitly. - if isinstance(self, SpanGroup) and isinstance(other, SpanGroup): + if isinstance(self, SpanGroup): return self._concat(other) + if isinstance(other, SpanGroup): + return other._concat(self) return NotImplemented def __iter__(self):