add helper function to improve readability

This commit is contained in:
thomashacker 2022-12-05 12:32:34 +01:00
parent 9a41081a26
commit b1ec51491c

View File

@ -3,6 +3,10 @@ from mock import Mock
from spacy.tokens import Doc, Span, Token
from spacy.tokens.underscore import Underscore
# Helper functions
def _get_tuple(s: Span):
return "._.", "span_extension", s.start_char, s.end_char, s.label, s.kb_id, s.id
@pytest.fixture(scope="function", autouse=True)
def clean_underscore():
@ -192,163 +196,34 @@ def test_underscore_for_unique_span(en_tokenizer):
span_2._.span_extension = "span_2 extension"
# Assert extensions
assert (
doc.user_data[
(
"._.",
"span_extension",
span_1.start_char,
span_1.end_char,
span_1.label,
span_1.kb_id,
span_1.id,
)
]
== "span_1 extension"
)
assert (
doc.user_data[
(
"._.",
"span_extension",
span_2.start_char,
span_2.end_char,
span_2.label,
span_2.kb_id,
span_2.id,
)
]
== "span_2 extension"
)
assert doc.user_data[_get_tuple(span_1)] == "span_1 extension"
assert doc.user_data[_get_tuple(span_2)] == "span_2 extension"
# Change label of span and assert extensions
span_1.label_ = "NEW_LABEL"
assert (
doc.user_data[
(
"._.",
"span_extension",
span_1.start_char,
span_1.end_char,
span_1.label,
span_1.kb_id,
span_1.id,
)
]
== "span_1 extension"
)
assert (
doc.user_data[
(
"._.",
"span_extension",
span_2.start_char,
span_2.end_char,
span_2.label,
span_2.kb_id,
span_2.id,
)
]
== "span_2 extension"
)
assert doc.user_data[_get_tuple(span_1)] == "span_1 extension"
assert doc.user_data[_get_tuple(span_2)] == "span_2 extension"
# Change KB_ID and assert extensions
span_1.kb_id_ = "KB_ID"
assert (
doc.user_data[
(
"._.",
"span_extension",
span_1.start_char,
span_1.end_char,
span_1.label,
span_1.kb_id,
span_1.id,
)
]
== "span_1 extension"
)
assert (
doc.user_data[
(
"._.",
"span_extension",
span_2.start_char,
span_2.end_char,
span_2.label,
span_2.kb_id,
span_2.id,
)
]
== "span_2 extension"
)
assert doc.user_data[_get_tuple(span_1)] == "span_1 extension"
assert doc.user_data[_get_tuple(span_2)] == "span_2 extension"
# Change extensions and assert
span_2._.span_extension = "updated span_2 extension"
assert (
doc.user_data[
(
"._.",
"span_extension",
span_1.start_char,
span_1.end_char,
span_1.label,
span_1.kb_id,
span_1.id,
)
]
== "span_1 extension"
)
assert (
doc.user_data[
(
"._.",
"span_extension",
span_2.start_char,
span_2.end_char,
span_2.label,
span_2.kb_id,
span_2.id,
)
]
== "updated span_2 extension"
)
assert doc.user_data[_get_tuple(span_1)] == "span_1 extension"
assert doc.user_data[_get_tuple(span_2)] == "updated span_2 extension"
# Change span ID and assert extensions
span_2.id = 2
assert (
doc.user_data[
(
"._.",
"span_extension",
span_1.start_char,
span_1.end_char,
span_1.label,
span_1.kb_id,
span_1.id,
)
]
== "span_1 extension"
)
assert (
doc.user_data[
(
"._.",
"span_extension",
span_2.start_char,
span_2.end_char,
span_2.label,
span_2.kb_id,
span_2.id,
)
]
== "updated span_2 extension"
)
assert doc.user_data[_get_tuple(span_1)] == "span_1 extension"
assert doc.user_data[_get_tuple(span_2)] == "updated span_2 extension"
# Assert extensions with original key
assert doc.user_data[("._.", "doc_extension", None, None)] == "doc extension"
assert doc.user_data[("._.", "token_extension", 0, None)] == "token extension"
def test_underscore_for_unique_span_from_docs(en_tokenizer):
"""Test that spans in the user_data keep the same data structure when using Doc.from_docs"""
Span.set_extension(name="span_extension", default=None)
@ -371,84 +246,15 @@ def test_underscore_for_unique_span_from_docs(en_tokenizer):
span_1b._.span_extension = "span_1b extension"
span_2a._.span_extension = "span_2a extension"
doc = Doc.from_docs([doc_1,doc_2])
doc = Doc.from_docs([doc_1, doc_2])
# Assert extensions
assert (
doc_1.user_data[
(
"._.",
"span_extension",
span_1a.start_char,
span_1a.end_char,
span_1a.label,
span_1a.kb_id,
span_1a.id,
)
]
== "span_1a extension"
)
assert (
doc_1.user_data[
(
"._.",
"span_extension",
span_1b.start_char,
span_1b.end_char,
span_1b.label,
span_1b.kb_id,
span_1b.id,
)
]
== "span_1b extension"
)
assert (
doc_2.user_data[
(
"._.",
"span_extension",
span_2a.start_char,
span_2a.end_char,
span_2a.label,
span_2a.kb_id,
span_2a.id,
)
]
== "span_2a extension"
)
assert doc_1.user_data[_get_tuple(span_1a)] == "span_1a extension"
assert doc_1.user_data[_get_tuple(span_1b)] == "span_1b extension"
assert doc_2.user_data[_get_tuple(span_2a)] == "span_2a extension"
# Check merged doc
assert (
doc.user_data[
(
"._.",
"span_extension",
span_1a.start_char,
span_1a.end_char,
span_1a.label,
span_1a.kb_id,
span_1a.id,
)
]
== "span_1a extension"
)
assert (
doc.user_data[
(
"._.",
"span_extension",
span_1b.start_char,
span_1b.end_char,
span_1b.label,
span_1b.kb_id,
span_1b.id,
)
]
== "span_1b extension"
)
assert doc.user_data[_get_tuple(span_1a)] == "span_1a extension"
assert doc.user_data[_get_tuple(span_1b)] == "span_1b extension"
assert (
doc.user_data[
(
@ -464,6 +270,7 @@ def test_underscore_for_unique_span_from_docs(en_tokenizer):
== "span_2a extension"
)
def test_underscore_for_unique_span_as_span(en_tokenizer):
"""Test that spans in the user_data keep the same data structure when using Span.as_doc"""
Span.set_extension(name="span_extension", default=None)
@ -481,33 +288,5 @@ def test_underscore_for_unique_span_as_span(en_tokenizer):
span_doc = span_1.as_doc(copy_user_data=True)
# Assert extensions
assert (
span_doc.user_data[
(
"._.",
"span_extension",
span_1.start_char,
span_1.end_char,
span_1.label,
span_1.kb_id,
span_1.id,
)
]
== "span_1 extension"
)
# Assert extensions
assert (
span_doc.user_data[
(
"._.",
"span_extension",
span_2.start_char,
span_2.end_char,
span_2.label,
span_2.kb_id,
span_2.id,
)
]
== "span_2 extension"
)
assert span_doc.user_data[_get_tuple(span_1)] == "span_1 extension"
assert span_doc.user_data[_get_tuple(span_2)] == "span_2 extension"