diff --git a/spacy/tests/doc/test_underscore.py b/spacy/tests/doc/test_underscore.py index a62010820..96ddb730a 100644 --- a/spacy/tests/doc/test_underscore.py +++ b/spacy/tests/doc/test_underscore.py @@ -348,3 +348,118 @@ def test_underscore_for_unique_span(en_tokenizer): # Assert extensions with original key assert doc.user_data[("._.", "doc_extension", None, None)] == "doc extension" assert doc.user_data[("._.", "token_extension", 0, None)] == "token extension" + +def test_underscore_for_unique_span_from_docs(en_tokenizer): + """Test that spans in the user_data keep the same data structure""" + Span.set_extension(name="span_extension", default=None) + Token.set_extension(name="token_extension", default=None) + + # Initialize doc + text_1 = "Hello, world!" + doc_1 = en_tokenizer(text_1) + span_1a = Span(doc_1, 0, 2, "SPAN_1a") + span_1b = Span(doc_1, 0, 2, "SPAN_1b") + + text_2 = "This is a test." + doc_2 = en_tokenizer(text_2) + span_2a = Span(doc_2, 0, 3, "SPAN_2a") + + # Set custom extensions + doc_1[0]._.token_extension = "token_1" + doc_2[1]._.token_extension = "token_2" + span_1a._.span_extension = "span_1a extension" + span_1b._.span_extension = "span_1b extension" + span_2a._.span_extension = "span_2a extension" + + doc = Doc.from_docs([doc_1,doc_2]) + # Assert extensions + assert ( + doc_1.user_data[ + ( + "._.", + "span_extension", + span_1a.start_char, + span_1a.end_char, + span_1a.label, + span_1a.kb_id, + span_1a.id, + ) + ] + == "span_1a extension" + ) + + assert ( + doc_1.user_data[ + ( + "._.", + "span_extension", + span_1b.start_char, + span_1b.end_char, + span_1b.label, + span_1b.kb_id, + span_1b.id, + ) + ] + == "span_1b extension" + ) + + assert ( + doc_2.user_data[ + ( + "._.", + "span_extension", + span_2a.start_char, + span_2a.end_char, + span_2a.label, + span_2a.kb_id, + span_2a.id, + ) + ] + == "span_2a extension" + ) + + # Check merged doc + assert ( + doc.user_data[ + ( + "._.", + "span_extension", + span_1a.start_char, + span_1a.end_char, + span_1a.label, + span_1a.kb_id, + span_1a.id, + ) + ] + == "span_1a extension" + ) + + assert ( + doc.user_data[ + ( + "._.", + "span_extension", + span_1b.start_char, + span_1b.end_char, + span_1b.label, + span_1b.kb_id, + span_1b.id, + ) + ] + == "span_1b extension" + ) + + assert ( + doc.user_data[ + ( + "._.", + "span_extension", + span_2a.start_char + len(doc_1.text) + 1, + span_2a.end_char + len(doc_1.text) + 1, + span_2a.label, + span_2a.kb_id, + span_2a.id, + ) + ] + == "span_2a extension" + ) \ No newline at end of file diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 1259ad2a6..9a83a17d4 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -1178,13 +1178,22 @@ cdef class Doc: if "user_data" not in exclude: for key, value in doc.user_data.items(): - if isinstance(key, tuple) and len(key) == 4 and key[0] == "._.": - data_type, name, start, end = key + if isinstance(key, tuple) and len(key) >= 4 and key[0] == "._.": + data_type = key[0] + name = key[1] + start = key[2] + end = key[3] if start is not None or end is not None: start += char_offset if end is not None: end += char_offset - concat_user_data[(data_type, name, start, end)] = copy.copy(value) + _label = key[4] + _kb_id = key[5] + _span_id = key[6] + concat_user_data[(data_type, name, start, end, _label, _kb_id, _span_id)] = copy.copy(value) + else: + concat_user_data[(data_type, name, start, end)] = copy.copy(value) + else: warnings.warn(Warnings.W101.format(name=name)) else: