This commit is contained in:
thomashacker 2022-09-05 13:09:17 +02:00
parent 698b8b495f
commit 0adc7fde54
3 changed files with 228 additions and 6 deletions

View File

@ -171,3 +171,180 @@ def test_underscore_docstring(en_vocab):
doc = Doc(en_vocab, words=["hello", "world"])
assert test_method.__doc__ == "I am a docstring"
assert doc._.test_docstrings.__doc__.rsplit(". ")[-1] == "I am a docstring"
def test_underscore_for_unique_span(en_tokenizer):
"""Test that spans with the same boundaries but with different labels are uniquely identified (see #9706)."""
Doc.set_extension(name="doc_extension", default=None)
Span.set_extension(name="span_extension", default=None)
Token.set_extension(name="token_extension", default=None)
# Initialize doc
text = "Hello, world!"
doc = en_tokenizer(text)
span_1 = Span(doc, 0, 2, "SPAN_1")
span_2 = Span(doc, 0, 2, "SPAN_2")
# Set custom extensions
doc._.doc_extension = "doc extension"
doc[0]._.token_extension = "token extension"
span_1._.span_extension = "span_1 extension"
span_2._.span_extension = "span_2 extension"
# Assert extensions
assert (
doc.user_data[
(
"._.",
"span_extension",
span_1.start_char,
span_1.end_char,
span_1.label,
span_1.kb_id,
span_1.id,
)
]
== "span_1 extension"
)
assert (
doc.user_data[
(
"._.",
"span_extension",
span_2.start_char,
span_2.end_char,
span_2.label,
span_2.kb_id,
span_2.id,
)
]
== "span_2 extension"
)
# Change label of span and assert extensions
span_1.label_ = "NEW_LABEL"
assert (
doc.user_data[
(
"._.",
"span_extension",
span_1.start_char,
span_1.end_char,
span_1.label,
span_1.kb_id,
span_1.id,
)
]
== "span_1 extension"
)
assert (
doc.user_data[
(
"._.",
"span_extension",
span_2.start_char,
span_2.end_char,
span_2.label,
span_2.kb_id,
span_2.id,
)
]
== "span_2 extension"
)
# Change KB_ID and assert extensions
span_1.kb_id_ = "KB_ID"
assert (
doc.user_data[
(
"._.",
"span_extension",
span_1.start_char,
span_1.end_char,
span_1.label,
span_1.kb_id,
span_1.id,
)
]
== "span_1 extension"
)
assert (
doc.user_data[
(
"._.",
"span_extension",
span_2.start_char,
span_2.end_char,
span_2.label,
span_2.kb_id,
span_2.id,
)
]
== "span_2 extension"
)
# Change extensions and assert
span_2._.span_extension = "updated span_2 extension"
assert (
doc.user_data[
(
"._.",
"span_extension",
span_1.start_char,
span_1.end_char,
span_1.label,
span_1.kb_id,
span_1.id,
)
]
== "span_1 extension"
)
assert (
doc.user_data[
(
"._.",
"span_extension",
span_2.start_char,
span_2.end_char,
span_2.label,
span_2.kb_id,
span_2.id,
)
]
== "updated span_2 extension"
)
# Change span ID and assert extensions
span_2.id = 2
assert (
doc.user_data[
(
"._.",
"span_extension",
span_1.start_char,
span_1.end_char,
span_1.label,
span_1.kb_id,
span_1.id,
)
]
== "span_1 extension"
)
assert (
doc.user_data[
(
"._.",
"span_extension",
span_2.start_char,
span_2.end_char,
span_2.label,
span_2.kb_id,
span_2.id,
)
]
== "updated span_2 extension"
)
# Assert extensions with original key
assert doc.user_data[("._.", "doc_extension", None, None)] == "doc extension"
assert doc.user_data[("._.", "token_extension", 0, None)] == "token extension"

View File

@ -218,11 +218,10 @@ cdef class Span:
cdef SpanC* span_c = self.span_c()
"""Custom extension attributes registered via `set_extension`."""
return Underscore(Underscore.span_extensions, self,
start=span_c.start_char, end=span_c.end_char)
start=span_c.start_char, end=span_c.end_char, label=self.label, kb_id=self.kb_id, span_id=self.id)
def as_doc(self, *, bint copy_user_data=False, array_head=None, array=None):
"""Create a `Doc` object with a copy of the `Span`'s data.
copy_user_data (bool): Whether or not to copy the original doc's user data.
array_head (tuple): `Doc` array attrs, can be passed in to speed up computation.
array (ndarray): `Doc` as array, can be passed in to speed up computation.
@ -791,21 +790,42 @@ cdef class Span:
return self.span_c().label
def __set__(self, attr_t label):
self.span_c().label = label
if label != self.span_c().label :
old_label = self.span_c().label
self.span_c().label = label
new = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label = self.label, kb_id = self.kb_id, span_id=self.id)
old = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=old_label, kb_id=self.kb_id, span_id=self.id)
Underscore._replace_keys(old, new)
else:
self.span_c().label = label
property kb_id:
def __get__(self):
return self.span_c().kb_id
def __set__(self, attr_t kb_id):
self.span_c().kb_id = kb_id
if kb_id != self.span_c().kb_id :
old_kb_id = self.span_c().kb_id
self.span_c().kb_id = kb_id
new = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=self.label, kb_id=self.kb_id, span_id=self.id)
old = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=self.label, kb_id=old_kb_id, span_id=self.id)
Underscore._replace_keys(old, new)
else:
self.span_c().kb_id = kb_id
property id:
def __get__(self):
return self.span_c().id
def __set__(self, attr_t id):
self.span_c().id = id
if id != self.span_c().id :
old_id = self.span_c().id
self.span_c().id = id
new = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=self.label, kb_id=self.kb_id, span_id=self.id)
old = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=self.label, kb_id=self.kb_id, span_id=old_id)
Underscore._replace_keys(old, new)
else:
self.span_c().id = id
property ent_id:
"""Alias for the span's ID."""

View File

@ -25,6 +25,9 @@ class Underscore:
obj: Union["Doc", "Span", "Token"],
start: Optional[int] = None,
end: Optional[int] = None,
label: Optional[Union[str, int]] = None,
kb_id: Optional[Union[str, int]] = None,
span_id: Optional[Union[str, int]] = None,
):
object.__setattr__(self, "_extensions", extensions)
object.__setattr__(self, "_obj", obj)
@ -36,6 +39,13 @@ class Underscore:
object.__setattr__(self, "_doc", obj.doc)
object.__setattr__(self, "_start", start)
object.__setattr__(self, "_end", end)
if label is not None or kb_id is not None or span_id is not None:
# It is reasonably safe to assume that if label, kb_id and span_id are None,
# they were not passed to this function. Span.label and kb_id are
# converted to 0 in Span's c-tor and setters if they are None
object.__setattr__(self, "_label", label)
object.__setattr__(self, "_kb_id", kb_id)
object.__setattr__(self, "_span_id", span_id)
def __dir__(self) -> List[str]:
# Hack to enable autocomplete on custom extensions
@ -89,7 +99,22 @@ class Underscore:
return name in self._extensions
def _get_key(self, name: str) -> Tuple[str, str, Optional[int], Optional[int]]:
return ("._.", name, self._start, self._end)
if hasattr(self, "_label") :
return ("._.", name, self._start, self._end, self._label, self._kb_id, self._span_id)
else :
return ("._.", name, self._start, self._end)
@staticmethod
def _replace_keys(old_underscore, new_underscore):
"""
This function is called by Span when its kb_id or label are re-assigned.
It checks if any user_data is stored for this span and replaces the keys
"""
for name in old_underscore._extensions :
old_key = old_underscore._get_key(name)
new_key = new_underscore._get_key(name)
if old_key in old_underscore._doc.user_data and old_key != new_key:
old_underscore._doc.user_data[new_key] = old_underscore._doc.user_data.pop(old_key)
@classmethod
def get_state(cls) -> Tuple[Dict[Any, Any], Dict[Any, Any], Dict[Any, Any]]: