diff --git a/spacy/tests/doc/test_underscore.py b/spacy/tests/doc/test_underscore.py index b934221af..a62010820 100644 --- a/spacy/tests/doc/test_underscore.py +++ b/spacy/tests/doc/test_underscore.py @@ -171,3 +171,180 @@ def test_underscore_docstring(en_vocab): doc = Doc(en_vocab, words=["hello", "world"]) assert test_method.__doc__ == "I am a docstring" assert doc._.test_docstrings.__doc__.rsplit(". ")[-1] == "I am a docstring" + + +def test_underscore_for_unique_span(en_tokenizer): + """Test that spans with the same boundaries but with different labels are uniquely identified (see #9706).""" + Doc.set_extension(name="doc_extension", default=None) + Span.set_extension(name="span_extension", default=None) + Token.set_extension(name="token_extension", default=None) + + # Initialize doc + text = "Hello, world!" + doc = en_tokenizer(text) + span_1 = Span(doc, 0, 2, "SPAN_1") + span_2 = Span(doc, 0, 2, "SPAN_2") + + # Set custom extensions + doc._.doc_extension = "doc extension" + doc[0]._.token_extension = "token extension" + span_1._.span_extension = "span_1 extension" + span_2._.span_extension = "span_2 extension" + + # Assert extensions + assert ( + doc.user_data[ + ( + "._.", + "span_extension", + span_1.start_char, + span_1.end_char, + span_1.label, + span_1.kb_id, + span_1.id, + ) + ] + == "span_1 extension" + ) + assert ( + doc.user_data[ + ( + "._.", + "span_extension", + span_2.start_char, + span_2.end_char, + span_2.label, + span_2.kb_id, + span_2.id, + ) + ] + == "span_2 extension" + ) + + # Change label of span and assert extensions + span_1.label_ = "NEW_LABEL" + assert ( + doc.user_data[ + ( + "._.", + "span_extension", + span_1.start_char, + span_1.end_char, + span_1.label, + span_1.kb_id, + span_1.id, + ) + ] + == "span_1 extension" + ) + assert ( + doc.user_data[ + ( + "._.", + "span_extension", + span_2.start_char, + span_2.end_char, + span_2.label, + span_2.kb_id, + span_2.id, + ) + ] + == "span_2 extension" + ) + + # Change KB_ID and assert extensions + span_1.kb_id_ = "KB_ID" + assert ( + doc.user_data[ + ( + "._.", + "span_extension", + span_1.start_char, + span_1.end_char, + span_1.label, + span_1.kb_id, + span_1.id, + ) + ] + == "span_1 extension" + ) + assert ( + doc.user_data[ + ( + "._.", + "span_extension", + span_2.start_char, + span_2.end_char, + span_2.label, + span_2.kb_id, + span_2.id, + ) + ] + == "span_2 extension" + ) + + # Change extensions and assert + span_2._.span_extension = "updated span_2 extension" + assert ( + doc.user_data[ + ( + "._.", + "span_extension", + span_1.start_char, + span_1.end_char, + span_1.label, + span_1.kb_id, + span_1.id, + ) + ] + == "span_1 extension" + ) + assert ( + doc.user_data[ + ( + "._.", + "span_extension", + span_2.start_char, + span_2.end_char, + span_2.label, + span_2.kb_id, + span_2.id, + ) + ] + == "updated span_2 extension" + ) + + # Change span ID and assert extensions + span_2.id = 2 + assert ( + doc.user_data[ + ( + "._.", + "span_extension", + span_1.start_char, + span_1.end_char, + span_1.label, + span_1.kb_id, + span_1.id, + ) + ] + == "span_1 extension" + ) + assert ( + doc.user_data[ + ( + "._.", + "span_extension", + span_2.start_char, + span_2.end_char, + span_2.label, + span_2.kb_id, + span_2.id, + ) + ] + == "updated span_2 extension" + ) + + # Assert extensions with original key + assert doc.user_data[("._.", "doc_extension", None, None)] == "doc extension" + assert doc.user_data[("._.", "token_extension", 0, None)] == "token extension" diff --git a/spacy/tokens/span.pyx b/spacy/tokens/span.pyx index 89d9727e9..d853f6fef 100644 --- a/spacy/tokens/span.pyx +++ b/spacy/tokens/span.pyx @@ -218,11 +218,10 @@ cdef class Span: cdef SpanC* span_c = self.span_c() """Custom extension attributes registered via `set_extension`.""" return Underscore(Underscore.span_extensions, self, - start=span_c.start_char, end=span_c.end_char) + start=span_c.start_char, end=span_c.end_char, label=self.label, kb_id=self.kb_id, span_id=self.id) def as_doc(self, *, bint copy_user_data=False, array_head=None, array=None): """Create a `Doc` object with a copy of the `Span`'s data. - copy_user_data (bool): Whether or not to copy the original doc's user data. array_head (tuple): `Doc` array attrs, can be passed in to speed up computation. array (ndarray): `Doc` as array, can be passed in to speed up computation. @@ -791,21 +790,42 @@ cdef class Span: return self.span_c().label def __set__(self, attr_t label): - self.span_c().label = label + if label != self.span_c().label : + old_label = self.span_c().label + self.span_c().label = label + new = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label = self.label, kb_id = self.kb_id, span_id=self.id) + old = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=old_label, kb_id=self.kb_id, span_id=self.id) + Underscore._replace_keys(old, new) + else: + self.span_c().label = label property kb_id: def __get__(self): return self.span_c().kb_id def __set__(self, attr_t kb_id): - self.span_c().kb_id = kb_id + if kb_id != self.span_c().kb_id : + old_kb_id = self.span_c().kb_id + self.span_c().kb_id = kb_id + new = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=self.label, kb_id=self.kb_id, span_id=self.id) + old = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=self.label, kb_id=old_kb_id, span_id=self.id) + Underscore._replace_keys(old, new) + else: + self.span_c().kb_id = kb_id property id: def __get__(self): return self.span_c().id def __set__(self, attr_t id): - self.span_c().id = id + if id != self.span_c().id : + old_id = self.span_c().id + self.span_c().id = id + new = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=self.label, kb_id=self.kb_id, span_id=self.id) + old = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=self.label, kb_id=self.kb_id, span_id=old_id) + Underscore._replace_keys(old, new) + else: + self.span_c().id = id property ent_id: """Alias for the span's ID.""" diff --git a/spacy/tokens/underscore.py b/spacy/tokens/underscore.py index e9a4e1862..a464e3622 100644 --- a/spacy/tokens/underscore.py +++ b/spacy/tokens/underscore.py @@ -25,6 +25,9 @@ class Underscore: obj: Union["Doc", "Span", "Token"], start: Optional[int] = None, end: Optional[int] = None, + label: Optional[Union[str, int]] = None, + kb_id: Optional[Union[str, int]] = None, + span_id: Optional[Union[str, int]] = None, ): object.__setattr__(self, "_extensions", extensions) object.__setattr__(self, "_obj", obj) @@ -36,6 +39,13 @@ class Underscore: object.__setattr__(self, "_doc", obj.doc) object.__setattr__(self, "_start", start) object.__setattr__(self, "_end", end) + if label is not None or kb_id is not None or span_id is not None: + # It is reasonably safe to assume that if label, kb_id and span_id are None, + # they were not passed to this function. Span.label and kb_id are + # converted to 0 in Span's c-tor and setters if they are None + object.__setattr__(self, "_label", label) + object.__setattr__(self, "_kb_id", kb_id) + object.__setattr__(self, "_span_id", span_id) def __dir__(self) -> List[str]: # Hack to enable autocomplete on custom extensions @@ -89,7 +99,22 @@ class Underscore: return name in self._extensions def _get_key(self, name: str) -> Tuple[str, str, Optional[int], Optional[int]]: - return ("._.", name, self._start, self._end) + if hasattr(self, "_label") : + return ("._.", name, self._start, self._end, self._label, self._kb_id, self._span_id) + else : + return ("._.", name, self._start, self._end) + + @staticmethod + def _replace_keys(old_underscore, new_underscore): + """ + This function is called by Span when its kb_id or label are re-assigned. + It checks if any user_data is stored for this span and replaces the keys + """ + for name in old_underscore._extensions : + old_key = old_underscore._get_key(name) + new_key = new_underscore._get_key(name) + if old_key in old_underscore._doc.user_data and old_key != new_key: + old_underscore._doc.user_data[new_key] = old_underscore._doc.user_data.pop(old_key) @classmethod def get_state(cls) -> Tuple[Dict[Any, Any], Dict[Any, Any], Dict[Any, Any]]: