diff --git a/spacy/errors.py b/spacy/errors.py index 24a9f0339..84eca8016 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -524,6 +524,9 @@ class Errors(metaclass=ErrorsWithCodes): E202 = ("Unsupported {name} mode '{mode}'. Supported modes: {modes}.") # New errors added in v3.x + E855 = ("Invalid {obj}: {obj} is not from the same doc.") + E856 = ("Error accessing span at position {i}: out of bounds in span group " + "of length {length}.") E857 = ("Entry '{name}' not found in edit tree lemmatizer labels.") E858 = ("The {mode} vector table does not support this operation. " "{alternative}") diff --git a/spacy/tests/doc/test_span_group.py b/spacy/tests/doc/test_span_group.py new file mode 100644 index 000000000..8c70a83e1 --- /dev/null +++ b/spacy/tests/doc/test_span_group.py @@ -0,0 +1,242 @@ +import pytest +from random import Random +from spacy.matcher import Matcher +from spacy.tokens import Span, SpanGroup + + +@pytest.fixture +def doc(en_tokenizer): + doc = en_tokenizer("0 1 2 3 4 5 6") + matcher = Matcher(en_tokenizer.vocab, validate=True) + + # fmt: off + matcher.add("4", [[{}, {}, {}, {}]]) + matcher.add("2", [[{}, {}, ]]) + matcher.add("1", [[{}, ]]) + # fmt: on + matches = matcher(doc) + spans = [] + for match in matches: + spans.append( + Span(doc, match[1], match[2], en_tokenizer.vocab.strings[match[0]]) + ) + Random(42).shuffle(spans) + doc.spans["SPANS"] = SpanGroup( + doc, name="SPANS", attrs={"key": "value"}, spans=spans + ) + return doc + + +@pytest.fixture +def other_doc(en_tokenizer): + doc = en_tokenizer("0 1 2 3 4 5 6") + matcher = Matcher(en_tokenizer.vocab, validate=True) + + # fmt: off + matcher.add("4", [[{}, {}, {}, {}]]) + matcher.add("2", [[{}, {}, ]]) + matcher.add("1", [[{}, ]]) + # fmt: on + + matches = matcher(doc) + spans = [] + for match in matches: + spans.append( + Span(doc, match[1], match[2], en_tokenizer.vocab.strings[match[0]]) + ) + Random(42).shuffle(spans) + doc.spans["SPANS"] = SpanGroup( + doc, name="SPANS", attrs={"key": "value"}, spans=spans + ) + return doc + + +@pytest.fixture +def span_group(en_tokenizer): + doc = en_tokenizer("0 1 2 3 4 5 6") + matcher = Matcher(en_tokenizer.vocab, validate=True) + + # fmt: off + matcher.add("4", [[{}, {}, {}, {}]]) + matcher.add("2", [[{}, {}, ]]) + matcher.add("1", [[{}, ]]) + # fmt: on + + matches = matcher(doc) + spans = [] + for match in matches: + spans.append( + Span(doc, match[1], match[2], en_tokenizer.vocab.strings[match[0]]) + ) + Random(42).shuffle(spans) + doc.spans["SPANS"] = SpanGroup( + doc, name="SPANS", attrs={"key": "value"}, spans=spans + ) + + +def test_span_group_copy(doc): + span_group = doc.spans["SPANS"] + clone = span_group.copy() + assert clone != span_group + assert clone.name == span_group.name + assert clone.attrs == span_group.attrs + assert len(clone) == len(span_group) + assert list(span_group) == list(clone) + clone.name = "new_name" + clone.attrs["key"] = "new_value" + clone.append(Span(doc, 0, 6, "LABEL")) + assert clone.name != span_group.name + assert clone.attrs != span_group.attrs + assert span_group.attrs["key"] == "value" + assert list(span_group) != list(clone) + + +def test_span_group_set_item(doc, other_doc): + span_group = doc.spans["SPANS"] + + index = 5 + span = span_group[index] + span.label_ = "NEW LABEL" + span.kb_id = doc.vocab.strings["KB_ID"] + + assert span_group[index].label != span.label + assert span_group[index].kb_id != span.kb_id + + span_group[index] = span + assert span_group[index].start == span.start + assert span_group[index].end == span.end + assert span_group[index].label == span.label + assert span_group[index].kb_id == span.kb_id + assert span_group[index] == span + + with pytest.raises(IndexError): + span_group[-100] = span + with pytest.raises(IndexError): + span_group[100] = span + + span = Span(other_doc, 0, 2) + with pytest.raises(ValueError): + span_group[index] = span + + +def test_span_group_has_overlap(doc): + span_group = doc.spans["SPANS"] + assert span_group.has_overlap + + +def test_span_group_concat(doc, other_doc): + span_group_1 = doc.spans["SPANS"] + spans = [doc[0:5], doc[0:6]] + span_group_2 = SpanGroup( + doc, + name="MORE_SPANS", + attrs={"key": "new_value", "new_key": "new_value"}, + spans=spans, + ) + span_group_3 = span_group_1._concat(span_group_2) + assert span_group_3.name == span_group_1.name + assert span_group_3.attrs == {"key": "value", "new_key": "new_value"} + span_list_expected = list(span_group_1) + list(span_group_2) + assert list(span_group_3) == list(span_list_expected) + + # Inplace + span_list_expected = list(span_group_1) + list(span_group_2) + span_group_3 = span_group_1._concat(span_group_2, inplace=True) + assert span_group_3 == span_group_1 + assert span_group_3.name == span_group_1.name + assert span_group_3.attrs == {"key": "value", "new_key": "new_value"} + assert list(span_group_3) == list(span_list_expected) + + span_group_2 = other_doc.spans["SPANS"] + with pytest.raises(ValueError): + span_group_1._concat(span_group_2) + + +def test_span_doc_delitem(doc): + span_group = doc.spans["SPANS"] + length = len(span_group) + index = 5 + span = span_group[index] + next_span = span_group[index + 1] + del span_group[index] + assert len(span_group) == length - 1 + assert span_group[index] != span + assert span_group[index] == next_span + + with pytest.raises(IndexError): + del span_group[-100] + with pytest.raises(IndexError): + del span_group[100] + + +def test_span_group_add(doc): + span_group_1 = doc.spans["SPANS"] + spans = [doc[0:5], doc[0:6]] + span_group_2 = SpanGroup( + doc, + name="MORE_SPANS", + attrs={"key": "new_value", "new_key": "new_value"}, + spans=spans, + ) + + span_group_3_expected = span_group_1._concat(span_group_2) + + span_group_3 = span_group_1 + span_group_2 + assert len(span_group_3) == len(span_group_3_expected) + assert span_group_3.attrs == {"key": "value", "new_key": "new_value"} + assert list(span_group_3) == list(span_group_3_expected) + + +def test_span_group_iadd(doc): + span_group_1 = doc.spans["SPANS"].copy() + spans = [doc[0:5], doc[0:6]] + span_group_2 = SpanGroup( + doc, + name="MORE_SPANS", + attrs={"key": "new_value", "new_key": "new_value"}, + spans=spans, + ) + + span_group_1_expected = span_group_1._concat(span_group_2) + + span_group_1 += span_group_2 + assert len(span_group_1) == len(span_group_1_expected) + assert span_group_1.attrs == {"key": "value", "new_key": "new_value"} + assert list(span_group_1) == list(span_group_1_expected) + + span_group_1 = doc.spans["SPANS"].copy() + span_group_1 += spans + assert len(span_group_1) == len(span_group_1_expected) + assert span_group_1.attrs == { + "key": "value", + } + assert list(span_group_1) == list(span_group_1_expected) + + +def test_span_group_extend(doc): + span_group_1 = doc.spans["SPANS"].copy() + spans = [doc[0:5], doc[0:6]] + span_group_2 = SpanGroup( + doc, + name="MORE_SPANS", + attrs={"key": "new_value", "new_key": "new_value"}, + spans=spans, + ) + + span_group_1_expected = span_group_1._concat(span_group_2) + + span_group_1.extend(span_group_2) + assert len(span_group_1) == len(span_group_1_expected) + assert span_group_1.attrs == {"key": "value", "new_key": "new_value"} + assert list(span_group_1) == list(span_group_1_expected) + + span_group_1 = doc.spans["SPANS"] + span_group_1.extend(spans) + assert len(span_group_1) == len(span_group_1_expected) + assert span_group_1.attrs == {"key": "value"} + assert list(span_group_1) == list(span_group_1_expected) + + +def test_span_group_dealloc(span_group): + with pytest.raises(AttributeError): + print(span_group.doc) diff --git a/spacy/tokens/span_group.pyx b/spacy/tokens/span_group.pyx index 6cfa75237..1c09f4ea2 100644 --- a/spacy/tokens/span_group.pyx +++ b/spacy/tokens/span_group.pyx @@ -1,10 +1,11 @@ +from typing import Iterable, Tuple, Union, Optional, TYPE_CHECKING import weakref import struct +from copy import deepcopy import srsly from spacy.errors import Errors from .span cimport Span -from libc.stdint cimport uint64_t, uint32_t, int32_t cdef class SpanGroup: @@ -48,6 +49,8 @@ cdef class SpanGroup: self.name = name self.attrs = dict(attrs) if attrs is not None else {} cdef Span span + if len(spans) : + self.c.reserve(len(spans)) for span in spans: self.push_back(span.c) @@ -89,6 +92,72 @@ cdef class SpanGroup: """ return self.c.size() + def __getitem__(self, int i) -> Span: + """Get a span from the group. Note that a copy of the span is returned, + so if any changes are made to this span, they are not reflected in the + corresponding member of the span group. + + i (int): The item index. + RETURNS (Span): The span at the given index. + + DOCS: https://spacy.io/api/spangroup#getitem + """ + i = self._normalize_index(i) + return Span.cinit(self.doc, self.c[i]) + + def __delitem__(self, int i): + """Delete a span from the span group at index i. + + i (int): The item index. + + DOCS: https://spacy.io/api/spangroup#delitem + """ + i = self._normalize_index(i) + self.c.erase(self.c.begin() + i - 1) + + def __setitem__(self, int i, Span span): + """Set a span in the span group. + + i (int): The item index. + span (Span): The span. + + DOCS: https://spacy.io/api/spangroup#setitem + """ + if span.doc is not self.doc: + raise ValueError(Errors.E855.format(obj="span")) + + i = self._normalize_index(i) + self.c[i] = span.c + + def __iadd__(self, other: Union[SpanGroup, Iterable["Span"]]) -> SpanGroup: + """Operator +=. Append a span group or spans to this group and return + the current span group. + + other (Union[SpanGroup, Iterable["Span"]]): The SpanGroup or spans to + add. + + RETURNS (SpanGroup): The current span group. + + DOCS: https://spacy.io/api/spangroup#iadd + """ + return self._concat(other, inplace=True) + + def __add__(self, other: SpanGroup) -> SpanGroup: + """Operator +. Concatenate a span group with this group and return a + new span group. + + other (SpanGroup): The SpanGroup to add. + + RETURNS (SpanGroup): The concatenated SpanGroup. + + DOCS: https://spacy.io/api/spangroup#add + """ + # For Cython 0.x and __add__, you cannot rely on `self` as being `self` + # or being the right type, so both types need to be checked explicitly. + if isinstance(self, SpanGroup) and isinstance(other, SpanGroup): + return self._concat(other) + return NotImplemented + def append(self, Span span): """Add a span to the group. The span must refer to the same Doc object as the span group. @@ -98,35 +167,18 @@ cdef class SpanGroup: DOCS: https://spacy.io/api/spangroup#append """ if span.doc is not self.doc: - raise ValueError("Cannot add span to group: refers to different Doc.") + raise ValueError(Errors.E855.format(obj="span")) self.push_back(span.c) - def extend(self, spans): - """Add multiple spans to the group. All spans must refer to the same - Doc object as the span group. + def extend(self, spans_or_span_group: Union[SpanGroup, Iterable["Span"]]): + """Add multiple spans or contents of another SpanGroup to the group. + All spans must refer to the same Doc object as the span group. - spans (Iterable[Span]): The spans to add. + spans (Union[SpanGroup, Iterable["Span"]]): The spans to add. DOCS: https://spacy.io/api/spangroup#extend """ - cdef Span span - for span in spans: - self.append(span) - - def __getitem__(self, int i): - """Get a span from the group. - - i (int): The item index. - RETURNS (Span): The span at the given index. - - DOCS: https://spacy.io/api/spangroup#getitem - """ - cdef int size = self.c.size() - if i < -size or i >= size: - raise IndexError(f"list index {i} out of range") - if i < 0: - i += size - return Span.cinit(self.doc, self.c[i]) + self._concat(spans_or_span_group, inplace=True) def to_bytes(self): """Serialize the SpanGroup's contents to a byte string. @@ -136,6 +188,7 @@ cdef class SpanGroup: DOCS: https://spacy.io/api/spangroup#to_bytes """ output = {"name": self.name, "attrs": self.attrs, "spans": []} + cdef int i for i in range(self.c.size()): span = self.c[i] # The struct.pack here is probably overkill, but it might help if @@ -187,3 +240,74 @@ cdef class SpanGroup: cdef void push_back(self, SpanC span) nogil: self.c.push_back(span) + + def copy(self) -> SpanGroup: + """Clones the span group. + + RETURNS (SpanGroup): A copy of the span group. + + DOCS: https://spacy.io/api/spangroup#copy + """ + return SpanGroup( + self.doc, + name=self.name, + attrs=deepcopy(self.attrs), + spans=list(self), + ) + + def _concat( + self, + other: Union[SpanGroup, Iterable["Span"]], + *, + inplace: bool = False, + ) -> SpanGroup: + """Concatenates the current span group with the provided span group or + spans, either in place or creating a copy. Preserves the name of self, + updates attrs only with values that are not in self. + + other (Union[SpanGroup, Iterable[Span]]): The spans to append. + inplace (bool): Indicates whether the operation should be performed in + place on the current span group. + + RETURNS (SpanGroup): Either a new SpanGroup or the current SpanGroup + depending on the value of inplace. + """ + cdef SpanGroup span_group = self if inplace else self.copy() + cdef SpanGroup other_group + cdef Span span + + if isinstance(other, SpanGroup): + other_group = other + if other_group.doc is not self.doc: + raise ValueError(Errors.E855.format(obj="span group")) + + other_attrs = deepcopy(other_group.attrs) + span_group.attrs.update({ + key: value for key, value in other_attrs.items() \ + if key not in span_group.attrs + }) + if len(other_group): + span_group.c.reserve(span_group.c.size() + other_group.c.size()) + span_group.c.insert(span_group.c.end(), other_group.c.begin(), other_group.c.end()) + else: + if len(other): + span_group.c.reserve(self.c.size() + len(other)) + for span in other: + if span.doc is not self.doc: + raise ValueError(Errors.E855.format(obj="span")) + span_group.c.push_back(span.c) + + return span_group + + def _normalize_index(self, int i) -> int: + """Checks list index boundaries and adjusts the index if negative. + + i (int): The index. + RETURNS (int): The adjusted index. + """ + cdef int length = self.c.size() + if i < -length or i >= length: + raise IndexError(Errors.E856.format(i=i, length=length)) + if i < 0: + i += length + return i diff --git a/website/docs/api/spangroup.md b/website/docs/api/spangroup.md index 654067eb1..337e61749 100644 --- a/website/docs/api/spangroup.md +++ b/website/docs/api/spangroup.md @@ -104,7 +104,10 @@ Get the number of spans in the group. ## SpanGroup.\_\_getitem\_\_ {#getitem tag="method"} -Get a span from the group. +Get a span from the group. Note that a copy of the span is returned, so if any +changes are made to this span, they are not reflected in the corresponding +member of the span group. The item or group will need to be reassigned for +changes to be reflected in the span group. > #### Example > @@ -113,6 +116,8 @@ Get a span from the group. > doc.spans["errors"] = [doc[0:1], doc[2:4]] > span = doc.spans["errors"][1] > assert span.text == "goi ng" +> span.label_ = 'LABEL' +> assert doc.spans["errors"][1] != 'LABEL' # The span within the group was not updated > ``` | Name | Description | @@ -120,6 +125,83 @@ Get a span from the group. | `i` | The item index. ~~int~~ | | **RETURNS** | The span at the given index. ~~Span~~ | +## SpanGroup.\_\_setitem\_\_ {#setitem tag="method", new="3.3"} + +Set a span in the span group. + +> #### Example +> +> ```python +> doc = nlp("Their goi ng home") +> doc.spans["errors"] = [doc[0:1], doc[2:4]] +> span = doc[0:2] +> doc.spans["errors"][0] = span +> assert doc.spans["errors"][0].text == "Their goi" +> ``` + +| Name | Description | +| ------ | ----------------------- | +| `i` | The item index. ~~int~~ | +| `span` | The new value. ~~Span~~ | + +## SpanGroup.\_\_delitem\_\_ {#delitem tag="method", new="3.3"} + +Delete a span from the span group. + +> #### Example +> +> ```python +> doc = nlp("Their goi ng home") +> doc.spans["errors"] = [doc[0:1], doc[2:4]] +> del doc.spans[0] +> assert len(doc.spans["errors"]) == 1 +> ``` + +| Name | Description | +| ---- | ----------------------- | +| `i` | The item index. ~~int~~ | + +## SpanGroup.\_\_add\_\_ {#add tag="method", new="3.3"} + +Concatenate the current span group with another span group and return the result +in a new span group. Any `attrs` from the first span group will have precedence +over `attrs` in the second. + +> #### Example +> +> ```python +> doc = nlp("Their goi ng home") +> doc.spans["errors"] = [doc[0:1], doc[2:4]] +> doc.spans["other"] = [doc[0:2], doc[1:3]] +> span_group = doc.spans["errors"] + doc.spans["other"] +> assert len(span_group) == 4 +> ``` + +| Name | Description | +| ----------- | ---------------------------------------------------------------------------- | +| `other` | The span group or spans to concatenate. ~~Union[SpanGroup, Iterable[Span]]~~ | +| **RETURNS** | The new span group. ~~SpanGroup~~ | + +## SpanGroup.\_\_iadd\_\_ {#iadd tag="method", new="3.3"} + +Append an iterable of spans or the content of a span group to the current span +group. Any `attrs` in the other span group will be added for keys that are not +already present in the current span group. + +> #### Example +> +> ```python +> doc = nlp("Their goi ng home") +> doc.spans["errors"] = [doc[0:1], doc[2:4]] +> doc.spans["errors"] += [doc[3:4], doc[2:3]] +> assert len(doc.spans["errors"]) == 4 +> ``` + +| Name | Description | +| ----------- | ----------------------------------------------------------------------- | +| `other` | The span group or spans to append. ~~Union[SpanGroup, Iterable[Span]]~~ | +| **RETURNS** | The span group. ~~SpanGroup~~ | + ## SpanGroup.append {#append tag="method"} Add a [`Span`](/api/span) object to the group. The span must refer to the same @@ -140,8 +222,9 @@ Add a [`Span`](/api/span) object to the group. The span must refer to the same ## SpanGroup.extend {#extend tag="method"} -Add multiple [`Span`](/api/span) objects to the group. All spans must refer to -the same [`Doc`](/api/doc) object as the span group. +Add multiple [`Span`](/api/span) objects or contents of another `SpanGroup` to +the group. All spans must refer to the same [`Doc`](/api/doc) object as the span +group. > #### Example > @@ -150,11 +233,31 @@ the same [`Doc`](/api/doc) object as the span group. > doc.spans["errors"] = [] > doc.spans["errors"].extend([doc[2:4], doc[0:1]]) > assert len(doc.spans["errors"]) == 2 +> span_group = SpanGroup([doc[1:4], doc[0:3]) +> doc.spans["errors"].extend(span_group) > ``` -| Name | Description | -| ------- | ------------------------------------ | -| `spans` | The spans to add. ~~Iterable[Span]~~ | +| Name | Description | +| ------- | -------------------------------------------------------- | +| `spans` | The spans to add. ~~Union[SpanGroup, Iterable["Span"]]~~ | + +## SpanGroup.copy {#copy tag="method", new="3.3"} + +Return a copy of the span group. + +> #### Example +> +> ```python +> from spacy.tokens import SpanGroup +> +> doc = nlp("Their goi ng home") +> doc.spans["errors"] = [doc[2:4], doc[0:3]] +> new_group = doc.spans["errors"].copy() +> ``` + +| Name | Description | +| ----------- | ----------------------------------------------- | +| **RETURNS** | A copy of the `SpanGroup` object. ~~SpanGroup~~ | ## SpanGroup.to_bytes {#to_bytes tag="method"}