mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-02 20:54:10 +03:00
Custom extensions for spans with equal boundaries (#11429)
* Init * Fix return type for mypy * adjust types and improve setting new attributes * Add underscore changes to json conversion * Add test and underscore changes to from_docs * add underscore changes and test to span.to_doc * update return values Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Add types to function Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * adjust formatting Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * shorten return type Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * add helper function to improve readability * Improve code and add comments * rerun azure tests * Fix tests for json conversion Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
parent
f5aabaf7d6
commit
ca75190a3d
|
@ -3,6 +3,10 @@ from mock import Mock
|
|||
from spacy.tokens import Doc, Span, Token
|
||||
from spacy.tokens.underscore import Underscore
|
||||
|
||||
# Helper functions
|
||||
def _get_tuple(s: Span):
|
||||
return "._.", "span_extension", s.start_char, s.end_char, s.label, s.kb_id, s.id
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def clean_underscore():
|
||||
|
@ -171,3 +175,118 @@ def test_underscore_docstring(en_vocab):
|
|||
doc = Doc(en_vocab, words=["hello", "world"])
|
||||
assert test_method.__doc__ == "I am a docstring"
|
||||
assert doc._.test_docstrings.__doc__.rsplit(". ")[-1] == "I am a docstring"
|
||||
|
||||
|
||||
def test_underscore_for_unique_span(en_tokenizer):
|
||||
"""Test that spans with the same boundaries but with different labels are uniquely identified (see #9706)."""
|
||||
Doc.set_extension(name="doc_extension", default=None)
|
||||
Span.set_extension(name="span_extension", default=None)
|
||||
Token.set_extension(name="token_extension", default=None)
|
||||
|
||||
# Initialize doc
|
||||
text = "Hello, world!"
|
||||
doc = en_tokenizer(text)
|
||||
span_1 = Span(doc, 0, 2, "SPAN_1")
|
||||
span_2 = Span(doc, 0, 2, "SPAN_2")
|
||||
|
||||
# Set custom extensions
|
||||
doc._.doc_extension = "doc extension"
|
||||
doc[0]._.token_extension = "token extension"
|
||||
span_1._.span_extension = "span_1 extension"
|
||||
span_2._.span_extension = "span_2 extension"
|
||||
|
||||
# Assert extensions
|
||||
assert doc.user_data[_get_tuple(span_1)] == "span_1 extension"
|
||||
assert doc.user_data[_get_tuple(span_2)] == "span_2 extension"
|
||||
|
||||
# Change label of span and assert extensions
|
||||
span_1.label_ = "NEW_LABEL"
|
||||
assert doc.user_data[_get_tuple(span_1)] == "span_1 extension"
|
||||
assert doc.user_data[_get_tuple(span_2)] == "span_2 extension"
|
||||
|
||||
# Change KB_ID and assert extensions
|
||||
span_1.kb_id_ = "KB_ID"
|
||||
assert doc.user_data[_get_tuple(span_1)] == "span_1 extension"
|
||||
assert doc.user_data[_get_tuple(span_2)] == "span_2 extension"
|
||||
|
||||
# Change extensions and assert
|
||||
span_2._.span_extension = "updated span_2 extension"
|
||||
assert doc.user_data[_get_tuple(span_1)] == "span_1 extension"
|
||||
assert doc.user_data[_get_tuple(span_2)] == "updated span_2 extension"
|
||||
|
||||
# Change span ID and assert extensions
|
||||
span_2.id = 2
|
||||
assert doc.user_data[_get_tuple(span_1)] == "span_1 extension"
|
||||
assert doc.user_data[_get_tuple(span_2)] == "updated span_2 extension"
|
||||
|
||||
# Assert extensions with original key
|
||||
assert doc.user_data[("._.", "doc_extension", None, None)] == "doc extension"
|
||||
assert doc.user_data[("._.", "token_extension", 0, None)] == "token extension"
|
||||
|
||||
|
||||
def test_underscore_for_unique_span_from_docs(en_tokenizer):
|
||||
"""Test that spans in the user_data keep the same data structure when using Doc.from_docs"""
|
||||
Span.set_extension(name="span_extension", default=None)
|
||||
Token.set_extension(name="token_extension", default=None)
|
||||
|
||||
# Initialize doc
|
||||
text_1 = "Hello, world!"
|
||||
doc_1 = en_tokenizer(text_1)
|
||||
span_1a = Span(doc_1, 0, 2, "SPAN_1a")
|
||||
span_1b = Span(doc_1, 0, 2, "SPAN_1b")
|
||||
|
||||
text_2 = "This is a test."
|
||||
doc_2 = en_tokenizer(text_2)
|
||||
span_2a = Span(doc_2, 0, 3, "SPAN_2a")
|
||||
|
||||
# Set custom extensions
|
||||
doc_1[0]._.token_extension = "token_1"
|
||||
doc_2[1]._.token_extension = "token_2"
|
||||
span_1a._.span_extension = "span_1a extension"
|
||||
span_1b._.span_extension = "span_1b extension"
|
||||
span_2a._.span_extension = "span_2a extension"
|
||||
|
||||
doc = Doc.from_docs([doc_1, doc_2])
|
||||
# Assert extensions
|
||||
assert doc_1.user_data[_get_tuple(span_1a)] == "span_1a extension"
|
||||
assert doc_1.user_data[_get_tuple(span_1b)] == "span_1b extension"
|
||||
assert doc_2.user_data[_get_tuple(span_2a)] == "span_2a extension"
|
||||
|
||||
# Check extensions on merged doc
|
||||
assert doc.user_data[_get_tuple(span_1a)] == "span_1a extension"
|
||||
assert doc.user_data[_get_tuple(span_1b)] == "span_1b extension"
|
||||
assert (
|
||||
doc.user_data[
|
||||
(
|
||||
"._.",
|
||||
"span_extension",
|
||||
span_2a.start_char + len(doc_1.text) + 1,
|
||||
span_2a.end_char + len(doc_1.text) + 1,
|
||||
span_2a.label,
|
||||
span_2a.kb_id,
|
||||
span_2a.id,
|
||||
)
|
||||
]
|
||||
== "span_2a extension"
|
||||
)
|
||||
|
||||
|
||||
def test_underscore_for_unique_span_as_span(en_tokenizer):
|
||||
"""Test that spans in the user_data keep the same data structure when using Span.as_doc"""
|
||||
Span.set_extension(name="span_extension", default=None)
|
||||
|
||||
# Initialize doc
|
||||
text = "Hello, world!"
|
||||
doc = en_tokenizer(text)
|
||||
span_1 = Span(doc, 0, 2, "SPAN_1")
|
||||
span_2 = Span(doc, 0, 2, "SPAN_2")
|
||||
|
||||
# Set custom extensions
|
||||
span_1._.span_extension = "span_1 extension"
|
||||
span_2._.span_extension = "span_2 extension"
|
||||
|
||||
span_doc = span_1.as_doc(copy_user_data=True)
|
||||
|
||||
# Assert extensions
|
||||
assert span_doc.user_data[_get_tuple(span_1)] == "span_1 extension"
|
||||
assert span_doc.user_data[_get_tuple(span_2)] == "span_2 extension"
|
||||
|
|
|
@ -1177,13 +1177,22 @@ cdef class Doc:
|
|||
|
||||
if "user_data" not in exclude:
|
||||
for key, value in doc.user_data.items():
|
||||
if isinstance(key, tuple) and len(key) == 4 and key[0] == "._.":
|
||||
data_type, name, start, end = key
|
||||
if isinstance(key, tuple) and len(key) >= 4 and key[0] == "._.":
|
||||
data_type = key[0]
|
||||
name = key[1]
|
||||
start = key[2]
|
||||
end = key[3]
|
||||
if start is not None or end is not None:
|
||||
start += char_offset
|
||||
if end is not None:
|
||||
end += char_offset
|
||||
_label = key[4]
|
||||
_kb_id = key[5]
|
||||
_span_id = key[6]
|
||||
concat_user_data[(data_type, name, start, end, _label, _kb_id, _span_id)] = copy.copy(value)
|
||||
else:
|
||||
concat_user_data[(data_type, name, start, end)] = copy.copy(value)
|
||||
|
||||
else:
|
||||
warnings.warn(Warnings.W101.format(name=name))
|
||||
else:
|
||||
|
@ -1627,7 +1636,11 @@ cdef class Doc:
|
|||
Span.set_extension(span_attr)
|
||||
for span_data in doc_json["underscore_span"][span_attr]:
|
||||
value = span_data["value"]
|
||||
self.char_span(span_data["start"], span_data["end"])._.set(span_attr, value)
|
||||
span = self.char_span(span_data["start"], span_data["end"])
|
||||
span.label = span_data["label"]
|
||||
span.kb_id = span_data["kb_id"]
|
||||
span.id = span_data["id"]
|
||||
span._.set(span_attr, value)
|
||||
return self
|
||||
|
||||
def to_json(self, underscore=None):
|
||||
|
@ -1705,13 +1718,16 @@ cdef class Doc:
|
|||
if attr not in data["underscore_token"]:
|
||||
data["underscore_token"][attr] = []
|
||||
data["underscore_token"][attr].append({"start": start, "value": value})
|
||||
# Span attribute
|
||||
elif start is not None and end is not None:
|
||||
# Else span attribute
|
||||
elif end is not None:
|
||||
_label = data_key[4]
|
||||
_kb_id = data_key[5]
|
||||
_span_id = data_key[6]
|
||||
if "underscore_span" not in data:
|
||||
data["underscore_span"] = {}
|
||||
if attr not in data["underscore_span"]:
|
||||
data["underscore_span"][attr] = []
|
||||
data["underscore_span"][attr].append({"start": start, "end": end, "value": value})
|
||||
data["underscore_span"][attr].append({"start": start, "end": end, "value": value, "label": _label, "kb_id": _kb_id, "id":_span_id})
|
||||
|
||||
for attr in underscore:
|
||||
if attr not in user_keys:
|
||||
|
|
|
@ -218,11 +218,10 @@ cdef class Span:
|
|||
cdef SpanC* span_c = self.span_c()
|
||||
"""Custom extension attributes registered via `set_extension`."""
|
||||
return Underscore(Underscore.span_extensions, self,
|
||||
start=span_c.start_char, end=span_c.end_char)
|
||||
start=span_c.start_char, end=span_c.end_char, label=self.label, kb_id=self.kb_id, span_id=self.id)
|
||||
|
||||
def as_doc(self, *, bint copy_user_data=False, array_head=None, array=None):
|
||||
"""Create a `Doc` object with a copy of the `Span`'s data.
|
||||
|
||||
copy_user_data (bool): Whether or not to copy the original doc's user data.
|
||||
array_head (tuple): `Doc` array attrs, can be passed in to speed up computation.
|
||||
array (ndarray): `Doc` as array, can be passed in to speed up computation.
|
||||
|
@ -275,11 +274,21 @@ cdef class Span:
|
|||
char_offset = self.start_char
|
||||
for key, value in self.doc.user_data.items():
|
||||
if isinstance(key, tuple) and len(key) == 4 and key[0] == "._.":
|
||||
data_type, name, start, end = key
|
||||
data_type = key[0]
|
||||
name = key[1]
|
||||
start = key[2]
|
||||
end = key[3]
|
||||
if start is not None or end is not None:
|
||||
start -= char_offset
|
||||
# Check if Span object
|
||||
if end is not None:
|
||||
end -= char_offset
|
||||
_label = key[4]
|
||||
_kb_id = key[5]
|
||||
_span_id = key[6]
|
||||
user_data[(data_type, name, start, end, _label, _kb_id, _span_id)] = copy.copy(value)
|
||||
# Else Token object
|
||||
else:
|
||||
user_data[(data_type, name, start, end)] = copy.copy(value)
|
||||
else:
|
||||
user_data[key] = copy.copy(value)
|
||||
|
@ -781,21 +790,36 @@ cdef class Span:
|
|||
return self.span_c().label
|
||||
|
||||
def __set__(self, attr_t label):
|
||||
if label != self.span_c().label :
|
||||
old_label = self.span_c().label
|
||||
self.span_c().label = label
|
||||
new = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=self.label, kb_id=self.kb_id, span_id=self.id)
|
||||
old = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=old_label, kb_id=self.kb_id, span_id=self.id)
|
||||
Underscore._replace_keys(old, new)
|
||||
|
||||
property kb_id:
|
||||
def __get__(self):
|
||||
return self.span_c().kb_id
|
||||
|
||||
def __set__(self, attr_t kb_id):
|
||||
if kb_id != self.span_c().kb_id :
|
||||
old_kb_id = self.span_c().kb_id
|
||||
self.span_c().kb_id = kb_id
|
||||
new = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=self.label, kb_id=self.kb_id, span_id=self.id)
|
||||
old = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=self.label, kb_id=old_kb_id, span_id=self.id)
|
||||
Underscore._replace_keys(old, new)
|
||||
|
||||
property id:
|
||||
def __get__(self):
|
||||
return self.span_c().id
|
||||
|
||||
def __set__(self, attr_t id):
|
||||
if id != self.span_c().id :
|
||||
old_id = self.span_c().id
|
||||
self.span_c().id = id
|
||||
new = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=self.label, kb_id=self.kb_id, span_id=self.id)
|
||||
old = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=self.label, kb_id=self.kb_id, span_id=old_id)
|
||||
Underscore._replace_keys(old, new)
|
||||
|
||||
property ent_id:
|
||||
"""Alias for the span's ID."""
|
||||
|
|
|
@ -2,10 +2,10 @@ from typing import Dict, Any, List, Optional, Tuple, Union, TYPE_CHECKING
|
|||
import functools
|
||||
import copy
|
||||
from ..errors import Errors
|
||||
from .span import Span
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .doc import Doc
|
||||
from .span import Span
|
||||
from .token import Token
|
||||
|
||||
|
||||
|
@ -25,6 +25,9 @@ class Underscore:
|
|||
obj: Union["Doc", "Span", "Token"],
|
||||
start: Optional[int] = None,
|
||||
end: Optional[int] = None,
|
||||
label: int = 0,
|
||||
kb_id: int = 0,
|
||||
span_id: int = 0,
|
||||
):
|
||||
object.__setattr__(self, "_extensions", extensions)
|
||||
object.__setattr__(self, "_obj", obj)
|
||||
|
@ -36,6 +39,10 @@ class Underscore:
|
|||
object.__setattr__(self, "_doc", obj.doc)
|
||||
object.__setattr__(self, "_start", start)
|
||||
object.__setattr__(self, "_end", end)
|
||||
if type(obj) == Span:
|
||||
object.__setattr__(self, "_label", label)
|
||||
object.__setattr__(self, "_kb_id", kb_id)
|
||||
object.__setattr__(self, "_span_id", span_id)
|
||||
|
||||
def __dir__(self) -> List[str]:
|
||||
# Hack to enable autocomplete on custom extensions
|
||||
|
@ -88,8 +95,39 @@ class Underscore:
|
|||
def has(self, name: str) -> bool:
|
||||
return name in self._extensions
|
||||
|
||||
def _get_key(self, name: str) -> Tuple[str, str, Optional[int], Optional[int]]:
|
||||
return ("._.", name, self._start, self._end)
|
||||
def _get_key(
|
||||
self, name: str
|
||||
) -> Union[
|
||||
Tuple[str, str, Optional[int], Optional[int]],
|
||||
Tuple[str, str, Optional[int], Optional[int], int, int, int],
|
||||
]:
|
||||
if hasattr(self, "_label"):
|
||||
return (
|
||||
"._.",
|
||||
name,
|
||||
self._start,
|
||||
self._end,
|
||||
self._label,
|
||||
self._kb_id,
|
||||
self._span_id,
|
||||
)
|
||||
else:
|
||||
return "._.", name, self._start, self._end
|
||||
|
||||
@staticmethod
|
||||
def _replace_keys(old_underscore: "Underscore", new_underscore: "Underscore"):
|
||||
"""
|
||||
This function is called by Span when its kb_id or label are re-assigned.
|
||||
It checks if any user_data is stored for this span and replaces the keys
|
||||
"""
|
||||
for name in old_underscore._extensions:
|
||||
old_key = old_underscore._get_key(name)
|
||||
old_doc = old_underscore._doc
|
||||
new_key = new_underscore._get_key(name)
|
||||
if old_key != new_key and old_key in old_doc.user_data:
|
||||
old_underscore._doc.user_data[
|
||||
new_key
|
||||
] = old_underscore._doc.user_data.pop(old_key)
|
||||
|
||||
@classmethod
|
||||
def get_state(cls) -> Tuple[Dict[Any, Any], Dict[Any, Any], Dict[Any, Any]]:
|
||||
|
|
Loading…
Reference in New Issue
Block a user