Refactoring

This commit is contained in:
Richard Hudson 2021-12-08 16:42:39 +01:00
parent 183d535ef4
commit 49f3fd39b9
2 changed files with 73 additions and 73 deletions

View File

@ -207,7 +207,7 @@ def test_visualization_dependency_tree_highly_nonprojective(en_vocab):
] ]
def test_visualization_get_entity_native_attribute_int(en_vocab): def test_visualization_render_native_attribute_int(en_vocab):
doc = Doc( doc = Doc(
en_vocab, en_vocab,
words=[ words=[
@ -224,10 +224,10 @@ def test_visualization_get_entity_native_attribute_int(en_vocab):
heads=[1, None, 3, 1, 1, 7, 7, 3, 1], heads=[1, None, 3, 1, 1, 7, 7, 3, 1],
deps=["dep"] * 9, deps=["dep"] * 9,
) )
assert Visualizer().get_entity(doc[2], "head.i") == "3" assert AttributeFormat("head.i").render(doc[2]) == "3"
def test_visualization_get_entity_native_attribute_str(en_vocab): def test_visualization_render_native_attribute_str(en_vocab):
doc = Doc( doc = Doc(
en_vocab, en_vocab,
words=[ words=[
@ -245,10 +245,10 @@ def test_visualization_get_entity_native_attribute_str(en_vocab):
deps=["dep"] * 9, deps=["dep"] * 9,
) )
assert Visualizer().get_entity(doc[2], "dep_") == "dep" assert AttributeFormat("dep_").render(doc[2]) == "dep"
def test_visualization_get_entity_colors(en_vocab): def test_visualization_render_colors(en_vocab):
doc = Doc( doc = Doc(
en_vocab, en_vocab,
words=[ words=[
@ -267,19 +267,18 @@ def test_visualization_get_entity_colors(en_vocab):
) )
assert ( assert (
Visualizer().get_entity( AttributeFormat(
doc[2],
"dep_", "dep_",
value_dependent_fg_colors={"dep": 2}, value_dependent_fg_colors={"dep": 2},
value_dependent_bg_colors={"dep": 11}, value_dependent_bg_colors={"dep": 11},
) ).render(doc[2])
== "\x1b[38;5;2;48;5;11mdep\x1b[0m" == "\x1b[38;5;2;48;5;11mdep\x1b[0m"
if supports_ansi if supports_ansi
else "dep" else "dep"
) )
def test_visualization_get_entity_colors_only_fg(en_vocab): def test_visualization_render_colors_only_fg(en_vocab):
doc = Doc( doc = Doc(
en_vocab, en_vocab,
words=[ words=[
@ -298,14 +297,17 @@ def test_visualization_get_entity_colors_only_fg(en_vocab):
) )
assert ( assert (
Visualizer().get_entity(doc[2], "dep_", value_dependent_fg_colors={"dep": 2}) AttributeFormat(
"dep_",
value_dependent_fg_colors={"dep": 2},
).render(doc[2])
== "\x1b[38;5;2mdep\x1b[0m" == "\x1b[38;5;2mdep\x1b[0m"
if supports_ansi if supports_ansi
else "dep" else "dep"
) )
def test_visualization_get_entity_colors_only_bg(en_vocab): def test_visualization_render_entity_colors_only_bg(en_vocab):
doc = Doc( doc = Doc(
en_vocab, en_vocab,
words=[ words=[
@ -324,14 +326,17 @@ def test_visualization_get_entity_colors_only_bg(en_vocab):
) )
assert ( assert (
Visualizer().get_entity(doc[2], "dep_", value_dependent_bg_colors={"dep": 11}) AttributeFormat(
"dep_",
value_dependent_bg_colors={"dep": 11},
).render(doc[2])
== "\x1b[48;5;11mdep\x1b[0m" == "\x1b[48;5;11mdep\x1b[0m"
if supports_ansi if supports_ansi
else "dep" else "dep"
) )
def test_visualization_get_entity_native_attribute_missing(en_vocab): def test_visualization_render_native_attribute_missing(en_vocab):
doc = Doc( doc = Doc(
en_vocab, en_vocab,
words=[ words=[
@ -349,10 +354,10 @@ def test_visualization_get_entity_native_attribute_missing(en_vocab):
deps=["dep"] * 9, deps=["dep"] * 9,
) )
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
Visualizer().get_entity(doc[2], "depp") AttributeFormat("depp").render(doc[2])
def test_visualization_get_entity_custom_attribute_str(en_vocab): def test_visualization_render_custom_attribute_str(en_vocab):
doc = Doc( doc = Doc(
en_vocab, en_vocab,
words=[ words=[
@ -370,10 +375,10 @@ def test_visualization_get_entity_custom_attribute_str(en_vocab):
deps=["dep"] * 9, deps=["dep"] * 9,
) )
Token.set_extension("test", default="tested", force=True) Token.set_extension("test", default="tested", force=True)
assert Visualizer().get_entity(doc[2], "_.test") == "tested" assert AttributeFormat("_.test").render(doc[2]) == "tested"
def test_visualization_get_entity_nested_custom_attribute_str(en_vocab): def test_visualization_render_nested_custom_attribute_str(en_vocab):
doc = Doc( doc = Doc(
en_vocab, en_vocab,
words=[ words=[
@ -396,10 +401,10 @@ def test_visualization_get_entity_nested_custom_attribute_str(en_vocab):
self.inner_test = "tested" self.inner_test = "tested"
Token.set_extension("test", default=Test(), force=True) Token.set_extension("test", default=Test(), force=True)
assert Visualizer().get_entity(doc[2], "_.test.inner_test") == "tested" assert AttributeFormat("_.test.inner_test").render(doc[2]) == "tested"
def test_visualization_get_entity_custom_attribute_missing(en_vocab): def test_visualization_render_custom_attribute_missing(en_vocab):
doc = Doc( doc = Doc(
en_vocab, en_vocab,
words=[ words=[
@ -417,10 +422,10 @@ def test_visualization_get_entity_custom_attribute_missing(en_vocab):
deps=["dep"] * 9, deps=["dep"] * 9,
) )
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
Visualizer().get_entity(doc[2], "_.depp") AttributeFormat("._depp").render(doc[2])
def test_visualization_get_entity_permitted_values(en_vocab): def test_visualization_render_permitted_values(en_vocab):
doc = Doc( doc = Doc(
en_vocab, en_vocab,
words=[ words=[
@ -437,8 +442,8 @@ def test_visualization_get_entity_permitted_values(en_vocab):
heads=[1, None, 3, 1, 1, 7, 7, 3, 1], heads=[1, None, 3, 1, 1, 7, 7, 3, 1],
deps=["dep"] * 9, deps=["dep"] * 9,
) )
visualizer = Visualizer() attribute_format = AttributeFormat("head.i", permitted_values=(3, 7))
assert [visualizer.get_entity(token, "head.i", permitted_values=(3, 7)) for token in doc] == [ assert [attribute_format.render(token) for token in doc] == [
"", "",
"", "",
"3", "3",
@ -510,6 +515,7 @@ def test_visualization_minimal_render_table_permitted_values(
""".strip() """.strip()
) )
def test_visualization_spacing( def test_visualization_spacing(
fully_featured_doc_one_sentence, fully_featured_doc_one_sentence,
): ):

View File

@ -29,6 +29,38 @@ class AttributeFormat:
self.permitted_values = permitted_values self.permitted_values = permitted_values
self.value_dependent_fg_colors = value_dependent_fg_colors self.value_dependent_fg_colors = value_dependent_fg_colors
self.value_dependent_bg_colors = value_dependent_bg_colors self.value_dependent_bg_colors = value_dependent_bg_colors
self.printer = wasabi.Printer(no_print=True)
def render(
self,
token: Token,
*,
ignore_colors: bool = False,
) -> str:
obj = token
parts = self.attribute.split(".")
for part in parts[:-1]:
obj = getattr(obj, part)
value = str(getattr(obj, parts[-1]))
if self.permitted_values is not None and value not in (
str(v) for v in self.permitted_values
):
return ""
if self.max_width is not None:
value = value[: self.max_width]
fg_color = (
self.value_dependent_fg_colors.get(value, None)
if not ignore_colors and self.value_dependent_fg_colors is not None
else None
)
bg_color = (
self.value_dependent_bg_colors.get(value, None)
if not ignore_colors and self.value_dependent_bg_colors is not None
else None
)
if fg_color is not None or bg_color is not None:
value = self.printer.text(value, color=fg_color, bg_color=bg_color)
return value
SPACE = 0 SPACE = 0
@ -67,9 +99,6 @@ ROOT_LEFT_CHARS = {
class Visualizer: class Visualizer:
def __init__(self):
self.printer = wasabi.Printer(no_print=True)
@staticmethod @staticmethod
def render_dependency_tree(sent: Span, root_right: bool) -> list[str]: def render_dependency_tree(sent: Span, root_right: bool) -> list[str]:
""" """
@ -296,39 +325,6 @@ class Visualizer:
for vertical_position in range(sent.end - sent.start) for vertical_position in range(sent.end - sent.start)
] ]
def get_entity(
self,
token: Token,
entity_name: str,
*,
permitted_values: list[str] = None,
value_dependent_fg_colors: dict[str : Union[str, int]] = None,
value_dependent_bg_colors: dict[str : Union[str, int]] = None,
truncate_at_width: int = None,
) -> str:
obj = token
parts = entity_name.split(".")
for part in parts[:-1]:
obj = getattr(obj, part)
value = str(getattr(obj, parts[-1]))
if permitted_values is not None and value not in (str(v) for v in permitted_values):
return ""
if truncate_at_width is not None:
value = value[:truncate_at_width]
fg_color = (
value_dependent_fg_colors.get(value, None)
if value_dependent_fg_colors is not None
else None
)
bg_color = (
value_dependent_bg_colors.get(value, None)
if value_dependent_bg_colors is not None
else None
)
if fg_color is not None or bg_color is not None:
value = self.printer.text(value, color=fg_color, bg_color=bg_color)
return value
def render_table( def render_table(
self, doc: Doc, columns: list[AttributeFormat], spacing: int = 3 self, doc: Doc, columns: list[AttributeFormat], spacing: int = 3
) -> str: ) -> str:
@ -347,12 +343,7 @@ class Visualizer:
width = len(tree_right[0]) width = len(tree_right[0])
else: else:
width = max( width = max(
len( len(column.render(token, ignore_colors=True)) for token in sent
self.get_entity(
token, column.attribute, permitted_values=column.permitted_values
)
)
for token in sent
) )
if column.max_width is not None: if column.max_width is not None:
width = min(width, column.max_width) width = min(width, column.max_width)
@ -364,14 +355,7 @@ class Visualizer:
if column.attribute == "tree_right" if column.attribute == "tree_right"
else tree_left[token_index] else tree_left[token_index]
if column.attribute == "tree_left" if column.attribute == "tree_left"
else self.get_entity( else column.render(token)
token,
column.attribute,
permitted_values=column.permitted_values,
value_dependent_fg_colors=column.value_dependent_fg_colors,
value_dependent_bg_colors=column.value_dependent_bg_colors,
truncate_at_width=widths[column_index],
)
for column_index, column in enumerate(columns) for column_index, column in enumerate(columns)
] ]
for token_index, token in enumerate(sent) for token_index, token in enumerate(sent)
@ -397,3 +381,13 @@ class Visualizer:
+ linesep + linesep
) )
return return_string return return_string
def render_text(self, doc: Doc, attributes: list[AttributeFormat]) -> str:
return_string = ""
for token in doc:
return_string += token.text_with_ws
for attribute in attributes:
if self.get_entity(
token,
):
pass