diff --git a/spacy/tests/doc/test_span_group.py b/spacy/tests/doc/test_span_group.py index 8c70a83e1..963126b46 100644 --- a/spacy/tests/doc/test_span_group.py +++ b/spacy/tests/doc/test_span_group.py @@ -1,7 +1,9 @@ +from typing import List + import pytest from random import Random from spacy.matcher import Matcher -from spacy.tokens import Span, SpanGroup +from spacy.tokens import Span, SpanGroup, Doc @pytest.fixture @@ -240,3 +242,10 @@ def test_span_group_extend(doc): def test_span_group_dealloc(span_group): with pytest.raises(AttributeError): print(span_group.doc) + + +def test_iter(doc: Doc): + span_group: SpanGroup = doc.spans["SPANS"] + spans: List[Span] = list(span_group) + for i, span in enumerate(span_group): + assert span == span_group[i] == spans[i] diff --git a/spacy/tokens/span.pyi b/spacy/tokens/span.pyi index 0a6f306a6..ecf8e38b2 100644 --- a/spacy/tokens/span.pyi +++ b/spacy/tokens/span.pyi @@ -115,8 +115,8 @@ class Span: end: int start_char: int end_char: int - label: int - kb_id: int + label: Union[int, str] + kb_id: Union[int, str] id: int ent_id: int ent_id_: str diff --git a/spacy/tokens/span_group.pyi b/spacy/tokens/span_group.pyi index 21cd124ab..0b4aa83aa 100644 --- a/spacy/tokens/span_group.pyi +++ b/spacy/tokens/span_group.pyi @@ -18,6 +18,7 @@ class SpanGroup: def doc(self) -> Doc: ... @property def has_overlap(self) -> bool: ... + def __iter__(self): ... def __len__(self) -> int: ... def append(self, span: Span) -> None: ... def extend(self, spans: Iterable[Span]) -> None: ... diff --git a/spacy/tokens/span_group.pyx b/spacy/tokens/span_group.pyx index 1aa3c0bc8..6b1b3366c 100644 --- a/spacy/tokens/span_group.pyx +++ b/spacy/tokens/span_group.pyx @@ -170,6 +170,16 @@ cdef class SpanGroup: raise ValueError(Errors.E855.format(obj="span")) self.push_back(span.c) + def __iter__(self): + """ + Iterate over the spans in this SpanGroup. + YIELDS (Span): An span in this SpanGroup. + + DOCS: https://spacy.io/api/spangroup#iter + """ + for i in range(self.c.size()): + yield self[i] + def extend(self, spans_or_span_group: Union[SpanGroup, Iterable["Span"]]): """Add multiple spans or contents of another SpanGroup to the group. All spans must refer to the same Doc object as the span group.