Support more internal methods for SpanGroup (#10476)

* Added new convenience cython functions to SpanGroup to avoid unnecessary allocation/deallocation of objects

* Replaced sorting in has_overlap with C++ for efficiency. Also, added a test for has_overlap

* Added a method to efficiently merge SpanGroups

* Added __delitem__, __add__ and __iadd__. Also, allowed to pass span lists to merge function. Replaced extend() body with call to merge

* Renamed merge to concat and added missing things to documentation

* Added operator+ and operator += in the documentation

* Added a test for Doc deallocation

* Update spacy/tokens/span_group.pyx

* Updated SpanGroup tests to use new span list comparison function rather than assert_span_list_equal, eliminating the need to have a separate assert_not_equal fnction

* Fixed typos in SpanGroup documentation

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Minor changes requested by Sofie: rearranged import statements. Added new=3.2.1 tag to SpanGroup.__setitem__ documentation

* SpanGroup: moved repetitive list index check/adjustment in a separate function

* Turn off formatting that hurts readability spacy/tests/doc/test_span_group.py

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Remove formatting that hurts readability spacy/tests/doc/test_span_group.py

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Turn off formatting that hurts readability in spacy/tests/doc/test_span_group.py

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Support more internal methods for SpanGroup

Add support for:

* `__setitem__`
* `__delitem__`
* `__iadd__`: for `SpanGroup` or `Iterable[Span]`
* `__add__`: for `SpanGroup` only

Adapted from #9698 with the scope limited to the magic methods.

* Use v3.3 as new version in docs

* Add new tag to SpanGroup.copy in API docs

* Remove duplicate import

* Apply suggestions from code review

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Remaining suggestions and formatting

Co-authored-by: nrodnova <nrodnova@hotmail.com>
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
Co-authored-by: Natalia Rodnova <4512370+nrodnova@users.noreply.github.com>
This commit is contained in:
Adriane Boyd 2022-04-01 09:56:26 +02:00 committed by GitHub
parent c90dd6f265
commit ca54de27bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 502 additions and 30 deletions

View File

@ -524,6 +524,9 @@ class Errors(metaclass=ErrorsWithCodes):
E202 = ("Unsupported {name} mode '{mode}'. Supported modes: {modes}.") E202 = ("Unsupported {name} mode '{mode}'. Supported modes: {modes}.")
# New errors added in v3.x # New errors added in v3.x
E855 = ("Invalid {obj}: {obj} is not from the same doc.")
E856 = ("Error accessing span at position {i}: out of bounds in span group "
"of length {length}.")
E857 = ("Entry '{name}' not found in edit tree lemmatizer labels.") E857 = ("Entry '{name}' not found in edit tree lemmatizer labels.")
E858 = ("The {mode} vector table does not support this operation. " E858 = ("The {mode} vector table does not support this operation. "
"{alternative}") "{alternative}")

View File

@ -0,0 +1,242 @@
import pytest
from random import Random
from spacy.matcher import Matcher
from spacy.tokens import Span, SpanGroup
@pytest.fixture
def doc(en_tokenizer):
doc = en_tokenizer("0 1 2 3 4 5 6")
matcher = Matcher(en_tokenizer.vocab, validate=True)
# fmt: off
matcher.add("4", [[{}, {}, {}, {}]])
matcher.add("2", [[{}, {}, ]])
matcher.add("1", [[{}, ]])
# fmt: on
matches = matcher(doc)
spans = []
for match in matches:
spans.append(
Span(doc, match[1], match[2], en_tokenizer.vocab.strings[match[0]])
)
Random(42).shuffle(spans)
doc.spans["SPANS"] = SpanGroup(
doc, name="SPANS", attrs={"key": "value"}, spans=spans
)
return doc
@pytest.fixture
def other_doc(en_tokenizer):
doc = en_tokenizer("0 1 2 3 4 5 6")
matcher = Matcher(en_tokenizer.vocab, validate=True)
# fmt: off
matcher.add("4", [[{}, {}, {}, {}]])
matcher.add("2", [[{}, {}, ]])
matcher.add("1", [[{}, ]])
# fmt: on
matches = matcher(doc)
spans = []
for match in matches:
spans.append(
Span(doc, match[1], match[2], en_tokenizer.vocab.strings[match[0]])
)
Random(42).shuffle(spans)
doc.spans["SPANS"] = SpanGroup(
doc, name="SPANS", attrs={"key": "value"}, spans=spans
)
return doc
@pytest.fixture
def span_group(en_tokenizer):
doc = en_tokenizer("0 1 2 3 4 5 6")
matcher = Matcher(en_tokenizer.vocab, validate=True)
# fmt: off
matcher.add("4", [[{}, {}, {}, {}]])
matcher.add("2", [[{}, {}, ]])
matcher.add("1", [[{}, ]])
# fmt: on
matches = matcher(doc)
spans = []
for match in matches:
spans.append(
Span(doc, match[1], match[2], en_tokenizer.vocab.strings[match[0]])
)
Random(42).shuffle(spans)
doc.spans["SPANS"] = SpanGroup(
doc, name="SPANS", attrs={"key": "value"}, spans=spans
)
def test_span_group_copy(doc):
span_group = doc.spans["SPANS"]
clone = span_group.copy()
assert clone != span_group
assert clone.name == span_group.name
assert clone.attrs == span_group.attrs
assert len(clone) == len(span_group)
assert list(span_group) == list(clone)
clone.name = "new_name"
clone.attrs["key"] = "new_value"
clone.append(Span(doc, 0, 6, "LABEL"))
assert clone.name != span_group.name
assert clone.attrs != span_group.attrs
assert span_group.attrs["key"] == "value"
assert list(span_group) != list(clone)
def test_span_group_set_item(doc, other_doc):
span_group = doc.spans["SPANS"]
index = 5
span = span_group[index]
span.label_ = "NEW LABEL"
span.kb_id = doc.vocab.strings["KB_ID"]
assert span_group[index].label != span.label
assert span_group[index].kb_id != span.kb_id
span_group[index] = span
assert span_group[index].start == span.start
assert span_group[index].end == span.end
assert span_group[index].label == span.label
assert span_group[index].kb_id == span.kb_id
assert span_group[index] == span
with pytest.raises(IndexError):
span_group[-100] = span
with pytest.raises(IndexError):
span_group[100] = span
span = Span(other_doc, 0, 2)
with pytest.raises(ValueError):
span_group[index] = span
def test_span_group_has_overlap(doc):
span_group = doc.spans["SPANS"]
assert span_group.has_overlap
def test_span_group_concat(doc, other_doc):
span_group_1 = doc.spans["SPANS"]
spans = [doc[0:5], doc[0:6]]
span_group_2 = SpanGroup(
doc,
name="MORE_SPANS",
attrs={"key": "new_value", "new_key": "new_value"},
spans=spans,
)
span_group_3 = span_group_1._concat(span_group_2)
assert span_group_3.name == span_group_1.name
assert span_group_3.attrs == {"key": "value", "new_key": "new_value"}
span_list_expected = list(span_group_1) + list(span_group_2)
assert list(span_group_3) == list(span_list_expected)
# Inplace
span_list_expected = list(span_group_1) + list(span_group_2)
span_group_3 = span_group_1._concat(span_group_2, inplace=True)
assert span_group_3 == span_group_1
assert span_group_3.name == span_group_1.name
assert span_group_3.attrs == {"key": "value", "new_key": "new_value"}
assert list(span_group_3) == list(span_list_expected)
span_group_2 = other_doc.spans["SPANS"]
with pytest.raises(ValueError):
span_group_1._concat(span_group_2)
def test_span_doc_delitem(doc):
span_group = doc.spans["SPANS"]
length = len(span_group)
index = 5
span = span_group[index]
next_span = span_group[index + 1]
del span_group[index]
assert len(span_group) == length - 1
assert span_group[index] != span
assert span_group[index] == next_span
with pytest.raises(IndexError):
del span_group[-100]
with pytest.raises(IndexError):
del span_group[100]
def test_span_group_add(doc):
span_group_1 = doc.spans["SPANS"]
spans = [doc[0:5], doc[0:6]]
span_group_2 = SpanGroup(
doc,
name="MORE_SPANS",
attrs={"key": "new_value", "new_key": "new_value"},
spans=spans,
)
span_group_3_expected = span_group_1._concat(span_group_2)
span_group_3 = span_group_1 + span_group_2
assert len(span_group_3) == len(span_group_3_expected)
assert span_group_3.attrs == {"key": "value", "new_key": "new_value"}
assert list(span_group_3) == list(span_group_3_expected)
def test_span_group_iadd(doc):
span_group_1 = doc.spans["SPANS"].copy()
spans = [doc[0:5], doc[0:6]]
span_group_2 = SpanGroup(
doc,
name="MORE_SPANS",
attrs={"key": "new_value", "new_key": "new_value"},
spans=spans,
)
span_group_1_expected = span_group_1._concat(span_group_2)
span_group_1 += span_group_2
assert len(span_group_1) == len(span_group_1_expected)
assert span_group_1.attrs == {"key": "value", "new_key": "new_value"}
assert list(span_group_1) == list(span_group_1_expected)
span_group_1 = doc.spans["SPANS"].copy()
span_group_1 += spans
assert len(span_group_1) == len(span_group_1_expected)
assert span_group_1.attrs == {
"key": "value",
}
assert list(span_group_1) == list(span_group_1_expected)
def test_span_group_extend(doc):
span_group_1 = doc.spans["SPANS"].copy()
spans = [doc[0:5], doc[0:6]]
span_group_2 = SpanGroup(
doc,
name="MORE_SPANS",
attrs={"key": "new_value", "new_key": "new_value"},
spans=spans,
)
span_group_1_expected = span_group_1._concat(span_group_2)
span_group_1.extend(span_group_2)
assert len(span_group_1) == len(span_group_1_expected)
assert span_group_1.attrs == {"key": "value", "new_key": "new_value"}
assert list(span_group_1) == list(span_group_1_expected)
span_group_1 = doc.spans["SPANS"]
span_group_1.extend(spans)
assert len(span_group_1) == len(span_group_1_expected)
assert span_group_1.attrs == {"key": "value"}
assert list(span_group_1) == list(span_group_1_expected)
def test_span_group_dealloc(span_group):
with pytest.raises(AttributeError):
print(span_group.doc)

View File

@ -1,10 +1,11 @@
from typing import Iterable, Tuple, Union, Optional, TYPE_CHECKING
import weakref import weakref
import struct import struct
from copy import deepcopy
import srsly import srsly
from spacy.errors import Errors from spacy.errors import Errors
from .span cimport Span from .span cimport Span
from libc.stdint cimport uint64_t, uint32_t, int32_t
cdef class SpanGroup: cdef class SpanGroup:
@ -48,6 +49,8 @@ cdef class SpanGroup:
self.name = name self.name = name
self.attrs = dict(attrs) if attrs is not None else {} self.attrs = dict(attrs) if attrs is not None else {}
cdef Span span cdef Span span
if len(spans) :
self.c.reserve(len(spans))
for span in spans: for span in spans:
self.push_back(span.c) self.push_back(span.c)
@ -89,6 +92,72 @@ cdef class SpanGroup:
""" """
return self.c.size() return self.c.size()
def __getitem__(self, int i) -> Span:
"""Get a span from the group. Note that a copy of the span is returned,
so if any changes are made to this span, they are not reflected in the
corresponding member of the span group.
i (int): The item index.
RETURNS (Span): The span at the given index.
DOCS: https://spacy.io/api/spangroup#getitem
"""
i = self._normalize_index(i)
return Span.cinit(self.doc, self.c[i])
def __delitem__(self, int i):
"""Delete a span from the span group at index i.
i (int): The item index.
DOCS: https://spacy.io/api/spangroup#delitem
"""
i = self._normalize_index(i)
self.c.erase(self.c.begin() + i - 1)
def __setitem__(self, int i, Span span):
"""Set a span in the span group.
i (int): The item index.
span (Span): The span.
DOCS: https://spacy.io/api/spangroup#setitem
"""
if span.doc is not self.doc:
raise ValueError(Errors.E855.format(obj="span"))
i = self._normalize_index(i)
self.c[i] = span.c
def __iadd__(self, other: Union[SpanGroup, Iterable["Span"]]) -> SpanGroup:
"""Operator +=. Append a span group or spans to this group and return
the current span group.
other (Union[SpanGroup, Iterable["Span"]]): The SpanGroup or spans to
add.
RETURNS (SpanGroup): The current span group.
DOCS: https://spacy.io/api/spangroup#iadd
"""
return self._concat(other, inplace=True)
def __add__(self, other: SpanGroup) -> SpanGroup:
"""Operator +. Concatenate a span group with this group and return a
new span group.
other (SpanGroup): The SpanGroup to add.
RETURNS (SpanGroup): The concatenated SpanGroup.
DOCS: https://spacy.io/api/spangroup#add
"""
# For Cython 0.x and __add__, you cannot rely on `self` as being `self`
# or being the right type, so both types need to be checked explicitly.
if isinstance(self, SpanGroup) and isinstance(other, SpanGroup):
return self._concat(other)
return NotImplemented
def append(self, Span span): def append(self, Span span):
"""Add a span to the group. The span must refer to the same Doc """Add a span to the group. The span must refer to the same Doc
object as the span group. object as the span group.
@ -98,35 +167,18 @@ cdef class SpanGroup:
DOCS: https://spacy.io/api/spangroup#append DOCS: https://spacy.io/api/spangroup#append
""" """
if span.doc is not self.doc: if span.doc is not self.doc:
raise ValueError("Cannot add span to group: refers to different Doc.") raise ValueError(Errors.E855.format(obj="span"))
self.push_back(span.c) self.push_back(span.c)
def extend(self, spans): def extend(self, spans_or_span_group: Union[SpanGroup, Iterable["Span"]]):
"""Add multiple spans to the group. All spans must refer to the same """Add multiple spans or contents of another SpanGroup to the group.
Doc object as the span group. All spans must refer to the same Doc object as the span group.
spans (Iterable[Span]): The spans to add. spans (Union[SpanGroup, Iterable["Span"]]): The spans to add.
DOCS: https://spacy.io/api/spangroup#extend DOCS: https://spacy.io/api/spangroup#extend
""" """
cdef Span span self._concat(spans_or_span_group, inplace=True)
for span in spans:
self.append(span)
def __getitem__(self, int i):
"""Get a span from the group.
i (int): The item index.
RETURNS (Span): The span at the given index.
DOCS: https://spacy.io/api/spangroup#getitem
"""
cdef int size = self.c.size()
if i < -size or i >= size:
raise IndexError(f"list index {i} out of range")
if i < 0:
i += size
return Span.cinit(self.doc, self.c[i])
def to_bytes(self): def to_bytes(self):
"""Serialize the SpanGroup's contents to a byte string. """Serialize the SpanGroup's contents to a byte string.
@ -136,6 +188,7 @@ cdef class SpanGroup:
DOCS: https://spacy.io/api/spangroup#to_bytes DOCS: https://spacy.io/api/spangroup#to_bytes
""" """
output = {"name": self.name, "attrs": self.attrs, "spans": []} output = {"name": self.name, "attrs": self.attrs, "spans": []}
cdef int i
for i in range(self.c.size()): for i in range(self.c.size()):
span = self.c[i] span = self.c[i]
# The struct.pack here is probably overkill, but it might help if # The struct.pack here is probably overkill, but it might help if
@ -187,3 +240,74 @@ cdef class SpanGroup:
cdef void push_back(self, SpanC span) nogil: cdef void push_back(self, SpanC span) nogil:
self.c.push_back(span) self.c.push_back(span)
def copy(self) -> SpanGroup:
"""Clones the span group.
RETURNS (SpanGroup): A copy of the span group.
DOCS: https://spacy.io/api/spangroup#copy
"""
return SpanGroup(
self.doc,
name=self.name,
attrs=deepcopy(self.attrs),
spans=list(self),
)
def _concat(
self,
other: Union[SpanGroup, Iterable["Span"]],
*,
inplace: bool = False,
) -> SpanGroup:
"""Concatenates the current span group with the provided span group or
spans, either in place or creating a copy. Preserves the name of self,
updates attrs only with values that are not in self.
other (Union[SpanGroup, Iterable[Span]]): The spans to append.
inplace (bool): Indicates whether the operation should be performed in
place on the current span group.
RETURNS (SpanGroup): Either a new SpanGroup or the current SpanGroup
depending on the value of inplace.
"""
cdef SpanGroup span_group = self if inplace else self.copy()
cdef SpanGroup other_group
cdef Span span
if isinstance(other, SpanGroup):
other_group = other
if other_group.doc is not self.doc:
raise ValueError(Errors.E855.format(obj="span group"))
other_attrs = deepcopy(other_group.attrs)
span_group.attrs.update({
key: value for key, value in other_attrs.items() \
if key not in span_group.attrs
})
if len(other_group):
span_group.c.reserve(span_group.c.size() + other_group.c.size())
span_group.c.insert(span_group.c.end(), other_group.c.begin(), other_group.c.end())
else:
if len(other):
span_group.c.reserve(self.c.size() + len(other))
for span in other:
if span.doc is not self.doc:
raise ValueError(Errors.E855.format(obj="span"))
span_group.c.push_back(span.c)
return span_group
def _normalize_index(self, int i) -> int:
"""Checks list index boundaries and adjusts the index if negative.
i (int): The index.
RETURNS (int): The adjusted index.
"""
cdef int length = self.c.size()
if i < -length or i >= length:
raise IndexError(Errors.E856.format(i=i, length=length))
if i < 0:
i += length
return i

View File

@ -104,7 +104,10 @@ Get the number of spans in the group.
## SpanGroup.\_\_getitem\_\_ {#getitem tag="method"} ## SpanGroup.\_\_getitem\_\_ {#getitem tag="method"}
Get a span from the group. Get a span from the group. Note that a copy of the span is returned, so if any
changes are made to this span, they are not reflected in the corresponding
member of the span group. The item or group will need to be reassigned for
changes to be reflected in the span group.
> #### Example > #### Example
> >
@ -113,6 +116,8 @@ Get a span from the group.
> doc.spans["errors"] = [doc[0:1], doc[2:4]] > doc.spans["errors"] = [doc[0:1], doc[2:4]]
> span = doc.spans["errors"][1] > span = doc.spans["errors"][1]
> assert span.text == "goi ng" > assert span.text == "goi ng"
> span.label_ = 'LABEL'
> assert doc.spans["errors"][1] != 'LABEL' # The span within the group was not updated
> ``` > ```
| Name | Description | | Name | Description |
@ -120,6 +125,83 @@ Get a span from the group.
| `i` | The item index. ~~int~~ | | `i` | The item index. ~~int~~ |
| **RETURNS** | The span at the given index. ~~Span~~ | | **RETURNS** | The span at the given index. ~~Span~~ |
## SpanGroup.\_\_setitem\_\_ {#setitem tag="method", new="3.3"}
Set a span in the span group.
> #### Example
>
> ```python
> doc = nlp("Their goi ng home")
> doc.spans["errors"] = [doc[0:1], doc[2:4]]
> span = doc[0:2]
> doc.spans["errors"][0] = span
> assert doc.spans["errors"][0].text == "Their goi"
> ```
| Name | Description |
| ------ | ----------------------- |
| `i` | The item index. ~~int~~ |
| `span` | The new value. ~~Span~~ |
## SpanGroup.\_\_delitem\_\_ {#delitem tag="method", new="3.3"}
Delete a span from the span group.
> #### Example
>
> ```python
> doc = nlp("Their goi ng home")
> doc.spans["errors"] = [doc[0:1], doc[2:4]]
> del doc.spans[0]
> assert len(doc.spans["errors"]) == 1
> ```
| Name | Description |
| ---- | ----------------------- |
| `i` | The item index. ~~int~~ |
## SpanGroup.\_\_add\_\_ {#add tag="method", new="3.3"}
Concatenate the current span group with another span group and return the result
in a new span group. Any `attrs` from the first span group will have precedence
over `attrs` in the second.
> #### Example
>
> ```python
> doc = nlp("Their goi ng home")
> doc.spans["errors"] = [doc[0:1], doc[2:4]]
> doc.spans["other"] = [doc[0:2], doc[1:3]]
> span_group = doc.spans["errors"] + doc.spans["other"]
> assert len(span_group) == 4
> ```
| Name | Description |
| ----------- | ---------------------------------------------------------------------------- |
| `other` | The span group or spans to concatenate. ~~Union[SpanGroup, Iterable[Span]]~~ |
| **RETURNS** | The new span group. ~~SpanGroup~~ |
## SpanGroup.\_\_iadd\_\_ {#iadd tag="method", new="3.3"}
Append an iterable of spans or the content of a span group to the current span
group. Any `attrs` in the other span group will be added for keys that are not
already present in the current span group.
> #### Example
>
> ```python
> doc = nlp("Their goi ng home")
> doc.spans["errors"] = [doc[0:1], doc[2:4]]
> doc.spans["errors"] += [doc[3:4], doc[2:3]]
> assert len(doc.spans["errors"]) == 4
> ```
| Name | Description |
| ----------- | ----------------------------------------------------------------------- |
| `other` | The span group or spans to append. ~~Union[SpanGroup, Iterable[Span]]~~ |
| **RETURNS** | The span group. ~~SpanGroup~~ |
## SpanGroup.append {#append tag="method"} ## SpanGroup.append {#append tag="method"}
Add a [`Span`](/api/span) object to the group. The span must refer to the same Add a [`Span`](/api/span) object to the group. The span must refer to the same
@ -140,8 +222,9 @@ Add a [`Span`](/api/span) object to the group. The span must refer to the same
## SpanGroup.extend {#extend tag="method"} ## SpanGroup.extend {#extend tag="method"}
Add multiple [`Span`](/api/span) objects to the group. All spans must refer to Add multiple [`Span`](/api/span) objects or contents of another `SpanGroup` to
the same [`Doc`](/api/doc) object as the span group. the group. All spans must refer to the same [`Doc`](/api/doc) object as the span
group.
> #### Example > #### Example
> >
@ -150,11 +233,31 @@ the same [`Doc`](/api/doc) object as the span group.
> doc.spans["errors"] = [] > doc.spans["errors"] = []
> doc.spans["errors"].extend([doc[2:4], doc[0:1]]) > doc.spans["errors"].extend([doc[2:4], doc[0:1]])
> assert len(doc.spans["errors"]) == 2 > assert len(doc.spans["errors"]) == 2
> span_group = SpanGroup([doc[1:4], doc[0:3])
> doc.spans["errors"].extend(span_group)
> ``` > ```
| Name | Description | | Name | Description |
| ------- | ------------------------------------ | | ------- | -------------------------------------------------------- |
| `spans` | The spans to add. ~~Iterable[Span]~~ | | `spans` | The spans to add. ~~Union[SpanGroup, Iterable["Span"]]~~ |
## SpanGroup.copy {#copy tag="method", new="3.3"}
Return a copy of the span group.
> #### Example
>
> ```python
> from spacy.tokens import SpanGroup
>
> doc = nlp("Their goi ng home")
> doc.spans["errors"] = [doc[2:4], doc[0:3]]
> new_group = doc.spans["errors"].copy()
> ```
| Name | Description |
| ----------- | ----------------------------------------------- |
| **RETURNS** | A copy of the `SpanGroup` object. ~~SpanGroup~~ |
## SpanGroup.to_bytes {#to_bytes tag="method"} ## SpanGroup.to_bytes {#to_bytes tag="method"}