From ca75190a3d2972392d9333d731fdf6d1b4a48793 Mon Sep 17 00:00:00 2001 From: Edward <43848523+thomashacker@users.noreply.github.com> Date: Mon, 12 Dec 2022 08:55:53 +0100 Subject: [PATCH] Custom extensions for spans with equal boundaries (#11429) * Init * Fix return type for mypy * adjust types and improve setting new attributes * Add underscore changes to json conversion * Add test and underscore changes to from_docs * add underscore changes and test to span.to_doc * update return values Co-authored-by: Sofie Van Landeghem * Add types to function Co-authored-by: Sofie Van Landeghem * adjust formatting Co-authored-by: Sofie Van Landeghem * shorten return type Co-authored-by: Sofie Van Landeghem * add helper function to improve readability * Improve code and add comments * rerun azure tests * Fix tests for json conversion Co-authored-by: Sofie Van Landeghem --- spacy/tests/doc/test_underscore.py | 119 +++++++++++++++++++++++++++++ spacy/tokens/doc.pyx | 30 ++++++-- spacy/tokens/span.pyx | 38 +++++++-- spacy/tokens/underscore.py | 44 ++++++++++- 4 files changed, 214 insertions(+), 17 deletions(-) diff --git a/spacy/tests/doc/test_underscore.py b/spacy/tests/doc/test_underscore.py index b934221af..d23bb3162 100644 --- a/spacy/tests/doc/test_underscore.py +++ b/spacy/tests/doc/test_underscore.py @@ -3,6 +3,10 @@ from mock import Mock from spacy.tokens import Doc, Span, Token from spacy.tokens.underscore import Underscore +# Helper functions +def _get_tuple(s: Span): + return "._.", "span_extension", s.start_char, s.end_char, s.label, s.kb_id, s.id + @pytest.fixture(scope="function", autouse=True) def clean_underscore(): @@ -171,3 +175,118 @@ def test_underscore_docstring(en_vocab): doc = Doc(en_vocab, words=["hello", "world"]) assert test_method.__doc__ == "I am a docstring" assert doc._.test_docstrings.__doc__.rsplit(". ")[-1] == "I am a docstring" + + +def test_underscore_for_unique_span(en_tokenizer): + """Test that spans with the same boundaries but with different labels are uniquely identified (see #9706).""" + Doc.set_extension(name="doc_extension", default=None) + Span.set_extension(name="span_extension", default=None) + Token.set_extension(name="token_extension", default=None) + + # Initialize doc + text = "Hello, world!" + doc = en_tokenizer(text) + span_1 = Span(doc, 0, 2, "SPAN_1") + span_2 = Span(doc, 0, 2, "SPAN_2") + + # Set custom extensions + doc._.doc_extension = "doc extension" + doc[0]._.token_extension = "token extension" + span_1._.span_extension = "span_1 extension" + span_2._.span_extension = "span_2 extension" + + # Assert extensions + assert doc.user_data[_get_tuple(span_1)] == "span_1 extension" + assert doc.user_data[_get_tuple(span_2)] == "span_2 extension" + + # Change label of span and assert extensions + span_1.label_ = "NEW_LABEL" + assert doc.user_data[_get_tuple(span_1)] == "span_1 extension" + assert doc.user_data[_get_tuple(span_2)] == "span_2 extension" + + # Change KB_ID and assert extensions + span_1.kb_id_ = "KB_ID" + assert doc.user_data[_get_tuple(span_1)] == "span_1 extension" + assert doc.user_data[_get_tuple(span_2)] == "span_2 extension" + + # Change extensions and assert + span_2._.span_extension = "updated span_2 extension" + assert doc.user_data[_get_tuple(span_1)] == "span_1 extension" + assert doc.user_data[_get_tuple(span_2)] == "updated span_2 extension" + + # Change span ID and assert extensions + span_2.id = 2 + assert doc.user_data[_get_tuple(span_1)] == "span_1 extension" + assert doc.user_data[_get_tuple(span_2)] == "updated span_2 extension" + + # Assert extensions with original key + assert doc.user_data[("._.", "doc_extension", None, None)] == "doc extension" + assert doc.user_data[("._.", "token_extension", 0, None)] == "token extension" + + +def test_underscore_for_unique_span_from_docs(en_tokenizer): + """Test that spans in the user_data keep the same data structure when using Doc.from_docs""" + Span.set_extension(name="span_extension", default=None) + Token.set_extension(name="token_extension", default=None) + + # Initialize doc + text_1 = "Hello, world!" + doc_1 = en_tokenizer(text_1) + span_1a = Span(doc_1, 0, 2, "SPAN_1a") + span_1b = Span(doc_1, 0, 2, "SPAN_1b") + + text_2 = "This is a test." + doc_2 = en_tokenizer(text_2) + span_2a = Span(doc_2, 0, 3, "SPAN_2a") + + # Set custom extensions + doc_1[0]._.token_extension = "token_1" + doc_2[1]._.token_extension = "token_2" + span_1a._.span_extension = "span_1a extension" + span_1b._.span_extension = "span_1b extension" + span_2a._.span_extension = "span_2a extension" + + doc = Doc.from_docs([doc_1, doc_2]) + # Assert extensions + assert doc_1.user_data[_get_tuple(span_1a)] == "span_1a extension" + assert doc_1.user_data[_get_tuple(span_1b)] == "span_1b extension" + assert doc_2.user_data[_get_tuple(span_2a)] == "span_2a extension" + + # Check extensions on merged doc + assert doc.user_data[_get_tuple(span_1a)] == "span_1a extension" + assert doc.user_data[_get_tuple(span_1b)] == "span_1b extension" + assert ( + doc.user_data[ + ( + "._.", + "span_extension", + span_2a.start_char + len(doc_1.text) + 1, + span_2a.end_char + len(doc_1.text) + 1, + span_2a.label, + span_2a.kb_id, + span_2a.id, + ) + ] + == "span_2a extension" + ) + + +def test_underscore_for_unique_span_as_span(en_tokenizer): + """Test that spans in the user_data keep the same data structure when using Span.as_doc""" + Span.set_extension(name="span_extension", default=None) + + # Initialize doc + text = "Hello, world!" + doc = en_tokenizer(text) + span_1 = Span(doc, 0, 2, "SPAN_1") + span_2 = Span(doc, 0, 2, "SPAN_2") + + # Set custom extensions + span_1._.span_extension = "span_1 extension" + span_2._.span_extension = "span_2 extension" + + span_doc = span_1.as_doc(copy_user_data=True) + + # Assert extensions + assert span_doc.user_data[_get_tuple(span_1)] == "span_1 extension" + assert span_doc.user_data[_get_tuple(span_2)] == "span_2 extension" diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index bf3da0ce4..b7506f745 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -1177,13 +1177,22 @@ cdef class Doc: if "user_data" not in exclude: for key, value in doc.user_data.items(): - if isinstance(key, tuple) and len(key) == 4 and key[0] == "._.": - data_type, name, start, end = key + if isinstance(key, tuple) and len(key) >= 4 and key[0] == "._.": + data_type = key[0] + name = key[1] + start = key[2] + end = key[3] if start is not None or end is not None: start += char_offset if end is not None: end += char_offset - concat_user_data[(data_type, name, start, end)] = copy.copy(value) + _label = key[4] + _kb_id = key[5] + _span_id = key[6] + concat_user_data[(data_type, name, start, end, _label, _kb_id, _span_id)] = copy.copy(value) + else: + concat_user_data[(data_type, name, start, end)] = copy.copy(value) + else: warnings.warn(Warnings.W101.format(name=name)) else: @@ -1627,7 +1636,11 @@ cdef class Doc: Span.set_extension(span_attr) for span_data in doc_json["underscore_span"][span_attr]: value = span_data["value"] - self.char_span(span_data["start"], span_data["end"])._.set(span_attr, value) + span = self.char_span(span_data["start"], span_data["end"]) + span.label = span_data["label"] + span.kb_id = span_data["kb_id"] + span.id = span_data["id"] + span._.set(span_attr, value) return self def to_json(self, underscore=None): @@ -1705,13 +1718,16 @@ cdef class Doc: if attr not in data["underscore_token"]: data["underscore_token"][attr] = [] data["underscore_token"][attr].append({"start": start, "value": value}) - # Span attribute - elif start is not None and end is not None: + # Else span attribute + elif end is not None: + _label = data_key[4] + _kb_id = data_key[5] + _span_id = data_key[6] if "underscore_span" not in data: data["underscore_span"] = {} if attr not in data["underscore_span"]: data["underscore_span"][attr] = [] - data["underscore_span"][attr].append({"start": start, "end": end, "value": value}) + data["underscore_span"][attr].append({"start": start, "end": end, "value": value, "label": _label, "kb_id": _kb_id, "id":_span_id}) for attr in underscore: if attr not in user_keys: diff --git a/spacy/tokens/span.pyx b/spacy/tokens/span.pyx index 5530dd127..6f2d0379c 100644 --- a/spacy/tokens/span.pyx +++ b/spacy/tokens/span.pyx @@ -218,11 +218,10 @@ cdef class Span: cdef SpanC* span_c = self.span_c() """Custom extension attributes registered via `set_extension`.""" return Underscore(Underscore.span_extensions, self, - start=span_c.start_char, end=span_c.end_char) + start=span_c.start_char, end=span_c.end_char, label=self.label, kb_id=self.kb_id, span_id=self.id) def as_doc(self, *, bint copy_user_data=False, array_head=None, array=None): """Create a `Doc` object with a copy of the `Span`'s data. - copy_user_data (bool): Whether or not to copy the original doc's user data. array_head (tuple): `Doc` array attrs, can be passed in to speed up computation. array (ndarray): `Doc` as array, can be passed in to speed up computation. @@ -275,12 +274,22 @@ cdef class Span: char_offset = self.start_char for key, value in self.doc.user_data.items(): if isinstance(key, tuple) and len(key) == 4 and key[0] == "._.": - data_type, name, start, end = key + data_type = key[0] + name = key[1] + start = key[2] + end = key[3] if start is not None or end is not None: start -= char_offset + # Check if Span object if end is not None: end -= char_offset - user_data[(data_type, name, start, end)] = copy.copy(value) + _label = key[4] + _kb_id = key[5] + _span_id = key[6] + user_data[(data_type, name, start, end, _label, _kb_id, _span_id)] = copy.copy(value) + # Else Token object + else: + user_data[(data_type, name, start, end)] = copy.copy(value) else: user_data[key] = copy.copy(value) doc.user_data = user_data @@ -781,21 +790,36 @@ cdef class Span: return self.span_c().label def __set__(self, attr_t label): - self.span_c().label = label + if label != self.span_c().label : + old_label = self.span_c().label + self.span_c().label = label + new = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=self.label, kb_id=self.kb_id, span_id=self.id) + old = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=old_label, kb_id=self.kb_id, span_id=self.id) + Underscore._replace_keys(old, new) property kb_id: def __get__(self): return self.span_c().kb_id def __set__(self, attr_t kb_id): - self.span_c().kb_id = kb_id + if kb_id != self.span_c().kb_id : + old_kb_id = self.span_c().kb_id + self.span_c().kb_id = kb_id + new = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=self.label, kb_id=self.kb_id, span_id=self.id) + old = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=self.label, kb_id=old_kb_id, span_id=self.id) + Underscore._replace_keys(old, new) property id: def __get__(self): return self.span_c().id def __set__(self, attr_t id): - self.span_c().id = id + if id != self.span_c().id : + old_id = self.span_c().id + self.span_c().id = id + new = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=self.label, kb_id=self.kb_id, span_id=self.id) + old = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=self.label, kb_id=self.kb_id, span_id=old_id) + Underscore._replace_keys(old, new) property ent_id: """Alias for the span's ID.""" diff --git a/spacy/tokens/underscore.py b/spacy/tokens/underscore.py index e9a4e1862..f2f357441 100644 --- a/spacy/tokens/underscore.py +++ b/spacy/tokens/underscore.py @@ -2,10 +2,10 @@ from typing import Dict, Any, List, Optional, Tuple, Union, TYPE_CHECKING import functools import copy from ..errors import Errors +from .span import Span if TYPE_CHECKING: from .doc import Doc - from .span import Span from .token import Token @@ -25,6 +25,9 @@ class Underscore: obj: Union["Doc", "Span", "Token"], start: Optional[int] = None, end: Optional[int] = None, + label: int = 0, + kb_id: int = 0, + span_id: int = 0, ): object.__setattr__(self, "_extensions", extensions) object.__setattr__(self, "_obj", obj) @@ -36,6 +39,10 @@ class Underscore: object.__setattr__(self, "_doc", obj.doc) object.__setattr__(self, "_start", start) object.__setattr__(self, "_end", end) + if type(obj) == Span: + object.__setattr__(self, "_label", label) + object.__setattr__(self, "_kb_id", kb_id) + object.__setattr__(self, "_span_id", span_id) def __dir__(self) -> List[str]: # Hack to enable autocomplete on custom extensions @@ -88,8 +95,39 @@ class Underscore: def has(self, name: str) -> bool: return name in self._extensions - def _get_key(self, name: str) -> Tuple[str, str, Optional[int], Optional[int]]: - return ("._.", name, self._start, self._end) + def _get_key( + self, name: str + ) -> Union[ + Tuple[str, str, Optional[int], Optional[int]], + Tuple[str, str, Optional[int], Optional[int], int, int, int], + ]: + if hasattr(self, "_label"): + return ( + "._.", + name, + self._start, + self._end, + self._label, + self._kb_id, + self._span_id, + ) + else: + return "._.", name, self._start, self._end + + @staticmethod + def _replace_keys(old_underscore: "Underscore", new_underscore: "Underscore"): + """ + This function is called by Span when its kb_id or label are re-assigned. + It checks if any user_data is stored for this span and replaces the keys + """ + for name in old_underscore._extensions: + old_key = old_underscore._get_key(name) + old_doc = old_underscore._doc + new_key = new_underscore._get_key(name) + if old_key != new_key and old_key in old_doc.user_data: + old_underscore._doc.user_data[ + new_key + ] = old_underscore._doc.user_data.pop(old_key) @classmethod def get_state(cls) -> Tuple[Dict[Any, Any], Dict[Any, Any], Dict[Any, Any]]: