Correct Span.label, Span.kb_id types. Fix SpanGroup.__iter__().

This commit is contained in:
Raphael Mitsch 2022-12-15 14:06:36 +01:00
parent e5c7f3b077
commit d5b43a4850
4 changed files with 23 additions and 3 deletions

View File

@ -1,7 +1,9 @@
from typing import List
import pytest import pytest
from random import Random from random import Random
from spacy.matcher import Matcher from spacy.matcher import Matcher
from spacy.tokens import Span, SpanGroup from spacy.tokens import Span, SpanGroup, Doc
@pytest.fixture @pytest.fixture
@ -240,3 +242,10 @@ def test_span_group_extend(doc):
def test_span_group_dealloc(span_group): def test_span_group_dealloc(span_group):
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
print(span_group.doc) print(span_group.doc)
def test_iter(doc: Doc):
span_group: SpanGroup = doc.spans["SPANS"]
spans: List[Span] = list(span_group)
for i, span in enumerate(span_group):
assert span == span_group[i] == spans[i]

View File

@ -115,8 +115,8 @@ class Span:
end: int end: int
start_char: int start_char: int
end_char: int end_char: int
label: int label: Union[int, str]
kb_id: int kb_id: Union[int, str]
id: int id: int
ent_id: int ent_id: int
ent_id_: str ent_id_: str

View File

@ -18,6 +18,7 @@ class SpanGroup:
def doc(self) -> Doc: ... def doc(self) -> Doc: ...
@property @property
def has_overlap(self) -> bool: ... def has_overlap(self) -> bool: ...
def __iter__(self): ...
def __len__(self) -> int: ... def __len__(self) -> int: ...
def append(self, span: Span) -> None: ... def append(self, span: Span) -> None: ...
def extend(self, spans: Iterable[Span]) -> None: ... def extend(self, spans: Iterable[Span]) -> None: ...

View File

@ -170,6 +170,16 @@ cdef class SpanGroup:
raise ValueError(Errors.E855.format(obj="span")) raise ValueError(Errors.E855.format(obj="span"))
self.push_back(span.c) self.push_back(span.c)
def __iter__(self):
"""
Iterate over the spans in this SpanGroup.
YIELDS (Span): An span in this SpanGroup.
DOCS: https://spacy.io/api/spangroup#iter
"""
for i in range(self.c.size()):
yield self[i]
def extend(self, spans_or_span_group: Union[SpanGroup, Iterable["Span"]]): def extend(self, spans_or_span_group: Union[SpanGroup, Iterable["Span"]]):
"""Add multiple spans or contents of another SpanGroup to the group. """Add multiple spans or contents of another SpanGroup to the group.
All spans must refer to the same Doc object as the span group. All spans must refer to the same Doc object as the span group.