From 49f3fd39b9395d98978b8e81398cf99b2f13e675 Mon Sep 17 00:00:00 2001 From: Richard Hudson Date: Wed, 8 Dec 2021 16:42:39 +0100 Subject: [PATCH] Refactoring --- spacy/tests/test_visualization.py | 52 +++++++++-------- spacy/visualization.py | 94 +++++++++++++++---------------- 2 files changed, 73 insertions(+), 73 deletions(-) diff --git a/spacy/tests/test_visualization.py b/spacy/tests/test_visualization.py index f64b21d30..c5366c7c6 100644 --- a/spacy/tests/test_visualization.py +++ b/spacy/tests/test_visualization.py @@ -207,7 +207,7 @@ def test_visualization_dependency_tree_highly_nonprojective(en_vocab): ] -def test_visualization_get_entity_native_attribute_int(en_vocab): +def test_visualization_render_native_attribute_int(en_vocab): doc = Doc( en_vocab, words=[ @@ -224,10 +224,10 @@ def test_visualization_get_entity_native_attribute_int(en_vocab): heads=[1, None, 3, 1, 1, 7, 7, 3, 1], deps=["dep"] * 9, ) - assert Visualizer().get_entity(doc[2], "head.i") == "3" + assert AttributeFormat("head.i").render(doc[2]) == "3" -def test_visualization_get_entity_native_attribute_str(en_vocab): +def test_visualization_render_native_attribute_str(en_vocab): doc = Doc( en_vocab, words=[ @@ -245,10 +245,10 @@ def test_visualization_get_entity_native_attribute_str(en_vocab): deps=["dep"] * 9, ) - assert Visualizer().get_entity(doc[2], "dep_") == "dep" + assert AttributeFormat("dep_").render(doc[2]) == "dep" -def test_visualization_get_entity_colors(en_vocab): +def test_visualization_render_colors(en_vocab): doc = Doc( en_vocab, words=[ @@ -267,19 +267,18 @@ def test_visualization_get_entity_colors(en_vocab): ) assert ( - Visualizer().get_entity( - doc[2], + AttributeFormat( "dep_", value_dependent_fg_colors={"dep": 2}, value_dependent_bg_colors={"dep": 11}, - ) + ).render(doc[2]) == "\x1b[38;5;2;48;5;11mdep\x1b[0m" if supports_ansi else "dep" ) -def test_visualization_get_entity_colors_only_fg(en_vocab): +def test_visualization_render_colors_only_fg(en_vocab): doc = Doc( en_vocab, words=[ @@ -298,14 +297,17 @@ def test_visualization_get_entity_colors_only_fg(en_vocab): ) assert ( - Visualizer().get_entity(doc[2], "dep_", value_dependent_fg_colors={"dep": 2}) + AttributeFormat( + "dep_", + value_dependent_fg_colors={"dep": 2}, + ).render(doc[2]) == "\x1b[38;5;2mdep\x1b[0m" if supports_ansi else "dep" ) -def test_visualization_get_entity_colors_only_bg(en_vocab): +def test_visualization_render_entity_colors_only_bg(en_vocab): doc = Doc( en_vocab, words=[ @@ -324,14 +326,17 @@ def test_visualization_get_entity_colors_only_bg(en_vocab): ) assert ( - Visualizer().get_entity(doc[2], "dep_", value_dependent_bg_colors={"dep": 11}) + AttributeFormat( + "dep_", + value_dependent_bg_colors={"dep": 11}, + ).render(doc[2]) == "\x1b[48;5;11mdep\x1b[0m" if supports_ansi else "dep" ) -def test_visualization_get_entity_native_attribute_missing(en_vocab): +def test_visualization_render_native_attribute_missing(en_vocab): doc = Doc( en_vocab, words=[ @@ -349,10 +354,10 @@ def test_visualization_get_entity_native_attribute_missing(en_vocab): deps=["dep"] * 9, ) with pytest.raises(AttributeError): - Visualizer().get_entity(doc[2], "depp") + AttributeFormat("depp").render(doc[2]) -def test_visualization_get_entity_custom_attribute_str(en_vocab): +def test_visualization_render_custom_attribute_str(en_vocab): doc = Doc( en_vocab, words=[ @@ -370,10 +375,10 @@ def test_visualization_get_entity_custom_attribute_str(en_vocab): deps=["dep"] * 9, ) Token.set_extension("test", default="tested", force=True) - assert Visualizer().get_entity(doc[2], "_.test") == "tested" + assert AttributeFormat("_.test").render(doc[2]) == "tested" -def test_visualization_get_entity_nested_custom_attribute_str(en_vocab): +def test_visualization_render_nested_custom_attribute_str(en_vocab): doc = Doc( en_vocab, words=[ @@ -396,10 +401,10 @@ def test_visualization_get_entity_nested_custom_attribute_str(en_vocab): self.inner_test = "tested" Token.set_extension("test", default=Test(), force=True) - assert Visualizer().get_entity(doc[2], "_.test.inner_test") == "tested" + assert AttributeFormat("_.test.inner_test").render(doc[2]) == "tested" -def test_visualization_get_entity_custom_attribute_missing(en_vocab): +def test_visualization_render_custom_attribute_missing(en_vocab): doc = Doc( en_vocab, words=[ @@ -417,10 +422,10 @@ def test_visualization_get_entity_custom_attribute_missing(en_vocab): deps=["dep"] * 9, ) with pytest.raises(AttributeError): - Visualizer().get_entity(doc[2], "_.depp") + AttributeFormat("._depp").render(doc[2]) -def test_visualization_get_entity_permitted_values(en_vocab): +def test_visualization_render_permitted_values(en_vocab): doc = Doc( en_vocab, words=[ @@ -437,8 +442,8 @@ def test_visualization_get_entity_permitted_values(en_vocab): heads=[1, None, 3, 1, 1, 7, 7, 3, 1], deps=["dep"] * 9, ) - visualizer = Visualizer() - assert [visualizer.get_entity(token, "head.i", permitted_values=(3, 7)) for token in doc] == [ + attribute_format = AttributeFormat("head.i", permitted_values=(3, 7)) + assert [attribute_format.render(token) for token in doc] == [ "", "", "3", @@ -510,6 +515,7 @@ def test_visualization_minimal_render_table_permitted_values( """.strip() ) + def test_visualization_spacing( fully_featured_doc_one_sentence, ): diff --git a/spacy/visualization.py b/spacy/visualization.py index f9f067e26..e0796c8b3 100644 --- a/spacy/visualization.py +++ b/spacy/visualization.py @@ -29,6 +29,38 @@ class AttributeFormat: self.permitted_values = permitted_values self.value_dependent_fg_colors = value_dependent_fg_colors self.value_dependent_bg_colors = value_dependent_bg_colors + self.printer = wasabi.Printer(no_print=True) + + def render( + self, + token: Token, + *, + ignore_colors: bool = False, + ) -> str: + obj = token + parts = self.attribute.split(".") + for part in parts[:-1]: + obj = getattr(obj, part) + value = str(getattr(obj, parts[-1])) + if self.permitted_values is not None and value not in ( + str(v) for v in self.permitted_values + ): + return "" + if self.max_width is not None: + value = value[: self.max_width] + fg_color = ( + self.value_dependent_fg_colors.get(value, None) + if not ignore_colors and self.value_dependent_fg_colors is not None + else None + ) + bg_color = ( + self.value_dependent_bg_colors.get(value, None) + if not ignore_colors and self.value_dependent_bg_colors is not None + else None + ) + if fg_color is not None or bg_color is not None: + value = self.printer.text(value, color=fg_color, bg_color=bg_color) + return value SPACE = 0 @@ -67,9 +99,6 @@ ROOT_LEFT_CHARS = { class Visualizer: - def __init__(self): - self.printer = wasabi.Printer(no_print=True) - @staticmethod def render_dependency_tree(sent: Span, root_right: bool) -> list[str]: """ @@ -296,39 +325,6 @@ class Visualizer: for vertical_position in range(sent.end - sent.start) ] - def get_entity( - self, - token: Token, - entity_name: str, - *, - permitted_values: list[str] = None, - value_dependent_fg_colors: dict[str : Union[str, int]] = None, - value_dependent_bg_colors: dict[str : Union[str, int]] = None, - truncate_at_width: int = None, - ) -> str: - obj = token - parts = entity_name.split(".") - for part in parts[:-1]: - obj = getattr(obj, part) - value = str(getattr(obj, parts[-1])) - if permitted_values is not None and value not in (str(v) for v in permitted_values): - return "" - if truncate_at_width is not None: - value = value[:truncate_at_width] - fg_color = ( - value_dependent_fg_colors.get(value, None) - if value_dependent_fg_colors is not None - else None - ) - bg_color = ( - value_dependent_bg_colors.get(value, None) - if value_dependent_bg_colors is not None - else None - ) - if fg_color is not None or bg_color is not None: - value = self.printer.text(value, color=fg_color, bg_color=bg_color) - return value - def render_table( self, doc: Doc, columns: list[AttributeFormat], spacing: int = 3 ) -> str: @@ -347,12 +343,7 @@ class Visualizer: width = len(tree_right[0]) else: width = max( - len( - self.get_entity( - token, column.attribute, permitted_values=column.permitted_values - ) - ) - for token in sent + len(column.render(token, ignore_colors=True)) for token in sent ) if column.max_width is not None: width = min(width, column.max_width) @@ -364,14 +355,7 @@ class Visualizer: if column.attribute == "tree_right" else tree_left[token_index] if column.attribute == "tree_left" - else self.get_entity( - token, - column.attribute, - permitted_values=column.permitted_values, - value_dependent_fg_colors=column.value_dependent_fg_colors, - value_dependent_bg_colors=column.value_dependent_bg_colors, - truncate_at_width=widths[column_index], - ) + else column.render(token) for column_index, column in enumerate(columns) ] for token_index, token in enumerate(sent) @@ -397,3 +381,13 @@ class Visualizer: + linesep ) return return_string + + def render_text(self, doc: Doc, attributes: list[AttributeFormat]) -> str: + return_string = "" + for token in doc: + return_string += token.text_with_ws + for attribute in attributes: + if self.get_entity( + token, + ): + pass