Correct Span.label, Span.kb_id types. Fix SpanGroup.__iter__().

This commit is contained in:
Raphael Mitsch 2022-12-15 14:06:36 +01:00
parent e5c7f3b077
commit d5b43a4850
4 changed files with 23 additions and 3 deletions

View File

@ -1,7 +1,9 @@
from typing import List
import pytest
from random import Random
from spacy.matcher import Matcher
from spacy.tokens import Span, SpanGroup
from spacy.tokens import Span, SpanGroup, Doc
@pytest.fixture
@ -240,3 +242,10 @@ def test_span_group_extend(doc):
def test_span_group_dealloc(span_group):
with pytest.raises(AttributeError):
print(span_group.doc)
def test_iter(doc: Doc):
span_group: SpanGroup = doc.spans["SPANS"]
spans: List[Span] = list(span_group)
for i, span in enumerate(span_group):
assert span == span_group[i] == spans[i]

View File

@ -115,8 +115,8 @@ class Span:
end: int
start_char: int
end_char: int
label: int
kb_id: int
label: Union[int, str]
kb_id: Union[int, str]
id: int
ent_id: int
ent_id_: str

View File

@ -18,6 +18,7 @@ class SpanGroup:
def doc(self) -> Doc: ...
@property
def has_overlap(self) -> bool: ...
def __iter__(self): ...
def __len__(self) -> int: ...
def append(self, span: Span) -> None: ...
def extend(self, spans: Iterable[Span]) -> None: ...

View File

@ -170,6 +170,16 @@ cdef class SpanGroup:
raise ValueError(Errors.E855.format(obj="span"))
self.push_back(span.c)
def __iter__(self):
"""
Iterate over the spans in this SpanGroup.
YIELDS (Span): An span in this SpanGroup.
DOCS: https://spacy.io/api/spangroup#iter
"""
for i in range(self.c.size()):
yield self[i]
def extend(self, spans_or_span_group: Union[SpanGroup, Iterable["Span"]]):
"""Add multiple spans or contents of another SpanGroup to the group.
All spans must refer to the same Doc object as the span group.