Add test and underscore changes to from_docs

This commit is contained in:
thomashacker 2022-11-04 13:51:04 +01:00
parent 543c0d1410
commit 53dc321bb9
2 changed files with 127 additions and 3 deletions

View File

@ -348,3 +348,118 @@ def test_underscore_for_unique_span(en_tokenizer):
# Assert extensions with original key
assert doc.user_data[("._.", "doc_extension", None, None)] == "doc extension"
assert doc.user_data[("._.", "token_extension", 0, None)] == "token extension"
def test_underscore_for_unique_span_from_docs(en_tokenizer):
"""Test that spans in the user_data keep the same data structure"""
Span.set_extension(name="span_extension", default=None)
Token.set_extension(name="token_extension", default=None)
# Initialize doc
text_1 = "Hello, world!"
doc_1 = en_tokenizer(text_1)
span_1a = Span(doc_1, 0, 2, "SPAN_1a")
span_1b = Span(doc_1, 0, 2, "SPAN_1b")
text_2 = "This is a test."
doc_2 = en_tokenizer(text_2)
span_2a = Span(doc_2, 0, 3, "SPAN_2a")
# Set custom extensions
doc_1[0]._.token_extension = "token_1"
doc_2[1]._.token_extension = "token_2"
span_1a._.span_extension = "span_1a extension"
span_1b._.span_extension = "span_1b extension"
span_2a._.span_extension = "span_2a extension"
doc = Doc.from_docs([doc_1,doc_2])
# Assert extensions
assert (
doc_1.user_data[
(
"._.",
"span_extension",
span_1a.start_char,
span_1a.end_char,
span_1a.label,
span_1a.kb_id,
span_1a.id,
)
]
== "span_1a extension"
)
assert (
doc_1.user_data[
(
"._.",
"span_extension",
span_1b.start_char,
span_1b.end_char,
span_1b.label,
span_1b.kb_id,
span_1b.id,
)
]
== "span_1b extension"
)
assert (
doc_2.user_data[
(
"._.",
"span_extension",
span_2a.start_char,
span_2a.end_char,
span_2a.label,
span_2a.kb_id,
span_2a.id,
)
]
== "span_2a extension"
)
# Check merged doc
assert (
doc.user_data[
(
"._.",
"span_extension",
span_1a.start_char,
span_1a.end_char,
span_1a.label,
span_1a.kb_id,
span_1a.id,
)
]
== "span_1a extension"
)
assert (
doc.user_data[
(
"._.",
"span_extension",
span_1b.start_char,
span_1b.end_char,
span_1b.label,
span_1b.kb_id,
span_1b.id,
)
]
== "span_1b extension"
)
assert (
doc.user_data[
(
"._.",
"span_extension",
span_2a.start_char + len(doc_1.text) + 1,
span_2a.end_char + len(doc_1.text) + 1,
span_2a.label,
span_2a.kb_id,
span_2a.id,
)
]
== "span_2a extension"
)

View File

@ -1178,13 +1178,22 @@ cdef class Doc:
if "user_data" not in exclude:
for key, value in doc.user_data.items():
if isinstance(key, tuple) and len(key) == 4 and key[0] == "._.":
data_type, name, start, end = key
if isinstance(key, tuple) and len(key) >= 4 and key[0] == "._.":
data_type = key[0]
name = key[1]
start = key[2]
end = key[3]
if start is not None or end is not None:
start += char_offset
if end is not None:
end += char_offset
concat_user_data[(data_type, name, start, end)] = copy.copy(value)
_label = key[4]
_kb_id = key[5]
_span_id = key[6]
concat_user_data[(data_type, name, start, end, _label, _kb_id, _span_id)] = copy.copy(value)
else:
concat_user_data[(data_type, name, start, end)] = copy.copy(value)
else:
warnings.warn(Warnings.W101.format(name=name))
else: