Custom extensions for spans with equal boundaries (#11429)

* Init

* Fix return type for mypy

* adjust types and improve setting new attributes

* Add underscore changes to json conversion

* Add test and underscore changes to from_docs

* add underscore changes and test to span.to_doc

* update return values

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Add types to function

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* adjust formatting

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* shorten return type

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* add helper function to improve readability

* Improve code and add comments

* rerun azure tests

* Fix tests for json conversion

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
Edward 2022-12-12 08:55:53 +01:00 committed by GitHub
parent f5aabaf7d6
commit ca75190a3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 214 additions and 17 deletions

View File

@ -3,6 +3,10 @@ from mock import Mock
from spacy.tokens import Doc, Span, Token
from spacy.tokens.underscore import Underscore
# Helper functions
def _get_tuple(s: Span):
return "._.", "span_extension", s.start_char, s.end_char, s.label, s.kb_id, s.id
@pytest.fixture(scope="function", autouse=True)
def clean_underscore():
@ -171,3 +175,118 @@ def test_underscore_docstring(en_vocab):
doc = Doc(en_vocab, words=["hello", "world"])
assert test_method.__doc__ == "I am a docstring"
assert doc._.test_docstrings.__doc__.rsplit(". ")[-1] == "I am a docstring"
def test_underscore_for_unique_span(en_tokenizer):
"""Test that spans with the same boundaries but with different labels are uniquely identified (see #9706)."""
Doc.set_extension(name="doc_extension", default=None)
Span.set_extension(name="span_extension", default=None)
Token.set_extension(name="token_extension", default=None)
# Initialize doc
text = "Hello, world!"
doc = en_tokenizer(text)
span_1 = Span(doc, 0, 2, "SPAN_1")
span_2 = Span(doc, 0, 2, "SPAN_2")
# Set custom extensions
doc._.doc_extension = "doc extension"
doc[0]._.token_extension = "token extension"
span_1._.span_extension = "span_1 extension"
span_2._.span_extension = "span_2 extension"
# Assert extensions
assert doc.user_data[_get_tuple(span_1)] == "span_1 extension"
assert doc.user_data[_get_tuple(span_2)] == "span_2 extension"
# Change label of span and assert extensions
span_1.label_ = "NEW_LABEL"
assert doc.user_data[_get_tuple(span_1)] == "span_1 extension"
assert doc.user_data[_get_tuple(span_2)] == "span_2 extension"
# Change KB_ID and assert extensions
span_1.kb_id_ = "KB_ID"
assert doc.user_data[_get_tuple(span_1)] == "span_1 extension"
assert doc.user_data[_get_tuple(span_2)] == "span_2 extension"
# Change extensions and assert
span_2._.span_extension = "updated span_2 extension"
assert doc.user_data[_get_tuple(span_1)] == "span_1 extension"
assert doc.user_data[_get_tuple(span_2)] == "updated span_2 extension"
# Change span ID and assert extensions
span_2.id = 2
assert doc.user_data[_get_tuple(span_1)] == "span_1 extension"
assert doc.user_data[_get_tuple(span_2)] == "updated span_2 extension"
# Assert extensions with original key
assert doc.user_data[("._.", "doc_extension", None, None)] == "doc extension"
assert doc.user_data[("._.", "token_extension", 0, None)] == "token extension"
def test_underscore_for_unique_span_from_docs(en_tokenizer):
"""Test that spans in the user_data keep the same data structure when using Doc.from_docs"""
Span.set_extension(name="span_extension", default=None)
Token.set_extension(name="token_extension", default=None)
# Initialize doc
text_1 = "Hello, world!"
doc_1 = en_tokenizer(text_1)
span_1a = Span(doc_1, 0, 2, "SPAN_1a")
span_1b = Span(doc_1, 0, 2, "SPAN_1b")
text_2 = "This is a test."
doc_2 = en_tokenizer(text_2)
span_2a = Span(doc_2, 0, 3, "SPAN_2a")
# Set custom extensions
doc_1[0]._.token_extension = "token_1"
doc_2[1]._.token_extension = "token_2"
span_1a._.span_extension = "span_1a extension"
span_1b._.span_extension = "span_1b extension"
span_2a._.span_extension = "span_2a extension"
doc = Doc.from_docs([doc_1, doc_2])
# Assert extensions
assert doc_1.user_data[_get_tuple(span_1a)] == "span_1a extension"
assert doc_1.user_data[_get_tuple(span_1b)] == "span_1b extension"
assert doc_2.user_data[_get_tuple(span_2a)] == "span_2a extension"
# Check extensions on merged doc
assert doc.user_data[_get_tuple(span_1a)] == "span_1a extension"
assert doc.user_data[_get_tuple(span_1b)] == "span_1b extension"
assert (
doc.user_data[
(
"._.",
"span_extension",
span_2a.start_char + len(doc_1.text) + 1,
span_2a.end_char + len(doc_1.text) + 1,
span_2a.label,
span_2a.kb_id,
span_2a.id,
)
]
== "span_2a extension"
)
def test_underscore_for_unique_span_as_span(en_tokenizer):
"""Test that spans in the user_data keep the same data structure when using Span.as_doc"""
Span.set_extension(name="span_extension", default=None)
# Initialize doc
text = "Hello, world!"
doc = en_tokenizer(text)
span_1 = Span(doc, 0, 2, "SPAN_1")
span_2 = Span(doc, 0, 2, "SPAN_2")
# Set custom extensions
span_1._.span_extension = "span_1 extension"
span_2._.span_extension = "span_2 extension"
span_doc = span_1.as_doc(copy_user_data=True)
# Assert extensions
assert span_doc.user_data[_get_tuple(span_1)] == "span_1 extension"
assert span_doc.user_data[_get_tuple(span_2)] == "span_2 extension"

View File

@ -1177,13 +1177,22 @@ cdef class Doc:
if "user_data" not in exclude:
for key, value in doc.user_data.items():
if isinstance(key, tuple) and len(key) == 4 and key[0] == "._.":
data_type, name, start, end = key
if isinstance(key, tuple) and len(key) >= 4 and key[0] == "._.":
data_type = key[0]
name = key[1]
start = key[2]
end = key[3]
if start is not None or end is not None:
start += char_offset
if end is not None:
end += char_offset
_label = key[4]
_kb_id = key[5]
_span_id = key[6]
concat_user_data[(data_type, name, start, end, _label, _kb_id, _span_id)] = copy.copy(value)
else:
concat_user_data[(data_type, name, start, end)] = copy.copy(value)
else:
warnings.warn(Warnings.W101.format(name=name))
else:
@ -1627,7 +1636,11 @@ cdef class Doc:
Span.set_extension(span_attr)
for span_data in doc_json["underscore_span"][span_attr]:
value = span_data["value"]
self.char_span(span_data["start"], span_data["end"])._.set(span_attr, value)
span = self.char_span(span_data["start"], span_data["end"])
span.label = span_data["label"]
span.kb_id = span_data["kb_id"]
span.id = span_data["id"]
span._.set(span_attr, value)
return self
def to_json(self, underscore=None):
@ -1705,13 +1718,16 @@ cdef class Doc:
if attr not in data["underscore_token"]:
data["underscore_token"][attr] = []
data["underscore_token"][attr].append({"start": start, "value": value})
# Span attribute
elif start is not None and end is not None:
# Else span attribute
elif end is not None:
_label = data_key[4]
_kb_id = data_key[5]
_span_id = data_key[6]
if "underscore_span" not in data:
data["underscore_span"] = {}
if attr not in data["underscore_span"]:
data["underscore_span"][attr] = []
data["underscore_span"][attr].append({"start": start, "end": end, "value": value})
data["underscore_span"][attr].append({"start": start, "end": end, "value": value, "label": _label, "kb_id": _kb_id, "id":_span_id})
for attr in underscore:
if attr not in user_keys:

View File

@ -218,11 +218,10 @@ cdef class Span:
cdef SpanC* span_c = self.span_c()
"""Custom extension attributes registered via `set_extension`."""
return Underscore(Underscore.span_extensions, self,
start=span_c.start_char, end=span_c.end_char)
start=span_c.start_char, end=span_c.end_char, label=self.label, kb_id=self.kb_id, span_id=self.id)
def as_doc(self, *, bint copy_user_data=False, array_head=None, array=None):
"""Create a `Doc` object with a copy of the `Span`'s data.
copy_user_data (bool): Whether or not to copy the original doc's user data.
array_head (tuple): `Doc` array attrs, can be passed in to speed up computation.
array (ndarray): `Doc` as array, can be passed in to speed up computation.
@ -275,11 +274,21 @@ cdef class Span:
char_offset = self.start_char
for key, value in self.doc.user_data.items():
if isinstance(key, tuple) and len(key) == 4 and key[0] == "._.":
data_type, name, start, end = key
data_type = key[0]
name = key[1]
start = key[2]
end = key[3]
if start is not None or end is not None:
start -= char_offset
# Check if Span object
if end is not None:
end -= char_offset
_label = key[4]
_kb_id = key[5]
_span_id = key[6]
user_data[(data_type, name, start, end, _label, _kb_id, _span_id)] = copy.copy(value)
# Else Token object
else:
user_data[(data_type, name, start, end)] = copy.copy(value)
else:
user_data[key] = copy.copy(value)
@ -781,21 +790,36 @@ cdef class Span:
return self.span_c().label
def __set__(self, attr_t label):
if label != self.span_c().label :
old_label = self.span_c().label
self.span_c().label = label
new = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=self.label, kb_id=self.kb_id, span_id=self.id)
old = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=old_label, kb_id=self.kb_id, span_id=self.id)
Underscore._replace_keys(old, new)
property kb_id:
def __get__(self):
return self.span_c().kb_id
def __set__(self, attr_t kb_id):
if kb_id != self.span_c().kb_id :
old_kb_id = self.span_c().kb_id
self.span_c().kb_id = kb_id
new = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=self.label, kb_id=self.kb_id, span_id=self.id)
old = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=self.label, kb_id=old_kb_id, span_id=self.id)
Underscore._replace_keys(old, new)
property id:
def __get__(self):
return self.span_c().id
def __set__(self, attr_t id):
if id != self.span_c().id :
old_id = self.span_c().id
self.span_c().id = id
new = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=self.label, kb_id=self.kb_id, span_id=self.id)
old = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=self.label, kb_id=self.kb_id, span_id=old_id)
Underscore._replace_keys(old, new)
property ent_id:
"""Alias for the span's ID."""

View File

@ -2,10 +2,10 @@ from typing import Dict, Any, List, Optional, Tuple, Union, TYPE_CHECKING
import functools
import copy
from ..errors import Errors
from .span import Span
if TYPE_CHECKING:
from .doc import Doc
from .span import Span
from .token import Token
@ -25,6 +25,9 @@ class Underscore:
obj: Union["Doc", "Span", "Token"],
start: Optional[int] = None,
end: Optional[int] = None,
label: int = 0,
kb_id: int = 0,
span_id: int = 0,
):
object.__setattr__(self, "_extensions", extensions)
object.__setattr__(self, "_obj", obj)
@ -36,6 +39,10 @@ class Underscore:
object.__setattr__(self, "_doc", obj.doc)
object.__setattr__(self, "_start", start)
object.__setattr__(self, "_end", end)
if type(obj) == Span:
object.__setattr__(self, "_label", label)
object.__setattr__(self, "_kb_id", kb_id)
object.__setattr__(self, "_span_id", span_id)
def __dir__(self) -> List[str]:
# Hack to enable autocomplete on custom extensions
@ -88,8 +95,39 @@ class Underscore:
def has(self, name: str) -> bool:
return name in self._extensions
def _get_key(self, name: str) -> Tuple[str, str, Optional[int], Optional[int]]:
return ("._.", name, self._start, self._end)
def _get_key(
self, name: str
) -> Union[
Tuple[str, str, Optional[int], Optional[int]],
Tuple[str, str, Optional[int], Optional[int], int, int, int],
]:
if hasattr(self, "_label"):
return (
"._.",
name,
self._start,
self._end,
self._label,
self._kb_id,
self._span_id,
)
else:
return "._.", name, self._start, self._end
@staticmethod
def _replace_keys(old_underscore: "Underscore", new_underscore: "Underscore"):
"""
This function is called by Span when its kb_id or label are re-assigned.
It checks if any user_data is stored for this span and replaces the keys
"""
for name in old_underscore._extensions:
old_key = old_underscore._get_key(name)
old_doc = old_underscore._doc
new_key = new_underscore._get_key(name)
if old_key != new_key and old_key in old_doc.user_data:
old_underscore._doc.user_data[
new_key
] = old_underscore._doc.user_data.pop(old_key)
@classmethod
def get_state(cls) -> Tuple[Dict[Any, Any], Dict[Any, Any], Dict[Any, Any]]: