From 3557c613b1236b3292b10659ccdc2a9a500ae123 Mon Sep 17 00:00:00 2001 From: richardpaulhudson Date: Wed, 2 Mar 2022 13:09:41 +0100 Subject: [PATCH] Improvements based on PR feedback --- spacy/tests/conftest.py | 5 + spacy/tests/test_visualization.py | 551 +++++++----------------------- spacy/visualization.py | 531 +++++++++++++--------------- 3 files changed, 369 insertions(+), 718 deletions(-) diff --git a/spacy/tests/conftest.py b/spacy/tests/conftest.py index dda7ccbcb..bca3cab1b 100644 --- a/spacy/tests/conftest.py +++ b/spacy/tests/conftest.py @@ -292,6 +292,11 @@ def pl_tokenizer(): return get_lang_class("pl")().tokenizer +@pytest.fixture(scope="session") +def pl_vocab(): + return get_lang_class("pl")().vocab + + @pytest.fixture(scope="session") def pt_tokenizer(): return get_lang_class("pt")().tokenizer diff --git a/spacy/tests/test_visualization.py b/spacy/tests/test_visualization.py index 99b227086..727fa9e21 100644 --- a/spacy/tests/test_visualization.py +++ b/spacy/tests/test_visualization.py @@ -7,7 +7,27 @@ from spacy.tokens import Span, Doc, Token SUPPORTS_ANSI = supports_ansi() -def test_visualization_dependency_tree_basic(en_vocab): +@pytest.fixture +def horse_doc(en_vocab): + return Doc( + en_vocab, + words=[ + "I", + "saw", + "a", + "horse", + "yesterday", + "that", + "was", + "injured", + ".", + ], + heads=[1, None, 3, 1, 1, 7, 7, 3, 1], + deps=["dep"] * 9, + ) + + +def test_viz_dep_tree_basic(en_vocab): """Test basic dependency tree display.""" doc = Doc( en_vocab, @@ -25,7 +45,7 @@ def test_visualization_dependency_tree_basic(en_vocab): heads=[2, 2, 3, None, 6, 6, 3, 3, 3], deps=["dep"] * 9, ) - dep_tree = Visualizer.render_dependency_tree(doc[0 : len(doc)], True) + dep_tree = Visualizer.render_dep_tree(doc[0 : len(doc)], True) assert dep_tree == [ "<╗ ", "<╣ ", @@ -37,7 +57,7 @@ def test_visualization_dependency_tree_basic(en_vocab): "<══╣", "<══╝", ] - dep_tree = Visualizer.render_dependency_tree(doc[0 : len(doc)], False) + dep_tree = Visualizer.render_dep_tree(doc[0 : len(doc)], False) assert dep_tree == [ " ╔>", " ╠>", @@ -51,7 +71,7 @@ def test_visualization_dependency_tree_basic(en_vocab): ] -def test_visualization_dependency_tree_non_initial_sentence(en_vocab): +def test_viz_dep_tree_non_initial_sent(en_vocab): """Test basic dependency tree display.""" doc = Doc( en_vocab, @@ -72,7 +92,7 @@ def test_visualization_dependency_tree_non_initial_sentence(en_vocab): heads=[0, None, 0, 5, 5, 6, None, 9, 9, 6, 6, 6], deps=["dep"] * 12, ) - dep_tree = Visualizer.render_dependency_tree(doc[3 : len(doc)], True) + dep_tree = Visualizer.render_dep_tree(doc[3 : len(doc)], True) assert dep_tree == [ "<╗ ", "<╣ ", @@ -84,7 +104,7 @@ def test_visualization_dependency_tree_non_initial_sentence(en_vocab): "<══╣", "<══╝", ] - dep_tree = Visualizer.render_dependency_tree(doc[3 : len(doc)], False) + dep_tree = Visualizer.render_dep_tree(doc[3 : len(doc)], False) assert dep_tree == [ " ╔>", " ╠>", @@ -98,25 +118,9 @@ def test_visualization_dependency_tree_non_initial_sentence(en_vocab): ] -def test_visualization_dependency_tree_non_projective(en_vocab): +def test_viz_dep_tree_non_projective(horse_doc): """Test dependency tree display with a non-projective dependency.""" - doc = Doc( - en_vocab, - words=[ - "I", - "saw", - "a", - "horse", - "yesterday", - "that", - "was", - "injured", - ".", - ], - heads=[1, None, 3, 1, 1, 7, 7, 3, 1], - deps=["dep"] * 9, - ) - dep_tree = Visualizer.render_dependency_tree(doc[0 : len(doc)], True) + dep_tree = Visualizer.render_dep_tree(horse_doc[0 : len(horse_doc)], True) assert dep_tree == [ "<╗ ", "═╩═══╗", @@ -128,7 +132,7 @@ def test_visualization_dependency_tree_non_projective(en_vocab): "═╝<╝ ║", "<════╝", ] - dep_tree = Visualizer.render_dependency_tree(doc[0 : len(doc)], False) + dep_tree = Visualizer.render_dep_tree(horse_doc[0 : len(horse_doc)], False) assert dep_tree == [ " ╔>", "╔═══╩═", @@ -142,33 +146,10 @@ def test_visualization_dependency_tree_non_projective(en_vocab): ] -def test_visualization_dependency_tree_input_not_span(en_vocab): - """Test dependency tree display behaviour when the input is not a Span.""" - doc = Doc( - en_vocab, - words=[ - "I", - "saw", - "a", - "horse", - "yesterday", - "that", - "was", - "injured", - ".", - ], - heads=[1, None, 3, 1, 1, 7, 7, 3, 1], - deps=["dep"] * 9, - ) - - with pytest.raises(AssertionError): - Visualizer.render_dependency_tree(doc[1:3], True) - - -def test_visualization_dependency_tree_highly_nonprojective(en_vocab): +def test_viz_dep_tree_highly_nonprojective(pl_vocab): """Test a highly non-projective tree (colloquial Polish).""" doc = Doc( - en_vocab, + pl_vocab, words=[ "Owczarki", "przecież", @@ -182,7 +163,7 @@ def test_visualization_dependency_tree_highly_nonprojective(en_vocab): heads=[5, 5, 0, 5, 5, None, 4, 5], deps=["dep"] * 8, ) - dep_tree = Visualizer.render_dependency_tree(doc[0 : len(doc)], True) + dep_tree = Visualizer.render_dep_tree(doc[0 : len(doc)], True) assert dep_tree == [ "═╗<╗", " ║<╣", @@ -193,7 +174,7 @@ def test_visualization_dependency_tree_highly_nonprojective(en_vocab): "<╝ ║", "<══╝", ] - dep_tree = Visualizer.render_dependency_tree(doc[0 : len(doc)], False) + dep_tree = Visualizer.render_dep_tree(doc[0 : len(doc)], False) assert dep_tree == [ "╔>╔═", "╠>║ ", @@ -206,334 +187,99 @@ def test_visualization_dependency_tree_highly_nonprojective(en_vocab): ] -def test_visualization_render_native_attribute_int(en_vocab): - doc = Doc( - en_vocab, - words=[ - "I", - "saw", - "a", - "horse", - "yesterday", - "that", - "was", - "injured", - ".", - ], - heads=[1, None, 3, 1, 1, 7, 7, 3, 1], - deps=["dep"] * 9, - ) - assert AttributeFormat("head.i").render(doc[2]) == "3" +def test_viz_dep_tree_input_not_span(horse_doc): + """Test dependency tree display behaviour when the input is not a Span.""" + with pytest.raises(ValueError): + Visualizer.render_dep_tree(horse_doc[1:3], True) -def test_visualization_render_native_attribute_int_with_right_padding(en_vocab): - doc = Doc( - en_vocab, - words=[ - "I", - "saw", - "a", - "horse", - "yesterday", - "that", - "was", - "injured", - ".", - ], - heads=[1, None, 3, 1, 1, 7, 7, 3, 1], - deps=["dep"] * 9, - ) - assert AttributeFormat("head.i").render(doc[2], right_pad_to_length=3) == "3 " +def test_viz_render_native_attributes(horse_doc): + assert AttributeFormat("head.i").render(horse_doc[2]) == "3" + assert AttributeFormat("head.i").render(horse_doc[2], right_pad_to_len=3) == "3 " + assert AttributeFormat("dep_").render(horse_doc[2]) == "dep" + with pytest.raises(AttributeError): + AttributeFormat("depp").render(horse_doc[2]) -def test_visualization_render_native_attribute_str(en_vocab): - doc = Doc( - en_vocab, - words=[ - "I", - "saw", - "a", - "horse", - "yesterday", - "that", - "was", - "injured", - ".", - ], - heads=[1, None, 3, 1, 1, 7, 7, 3, 1], - deps=["dep"] * 9, - ) - - assert AttributeFormat("dep_").render(doc[2]) == "dep" - - -def test_visualization_render_colors(en_vocab): - doc = Doc( - en_vocab, - words=[ - "I", - "saw", - "a", - "horse", - "yesterday", - "that", - "was", - "injured", - ".", - ], - heads=[1, None, 3, 1, 1, 7, 7, 3, 1], - deps=["dep"] * 9, - ) - +def test_viz_render_colors(horse_doc): assert ( AttributeFormat( "dep_", - value_dependent_fg_colors={"dep": 2}, - value_dependent_bg_colors={"dep": 11}, - ).render(doc[2]) + value_dep_fg_colors={"dep": 2}, + value_dep_bg_colors={"dep": 11}, + ).render(horse_doc[2]) == "\x1b[38;5;2;48;5;11mdep\x1b[0m" if SUPPORTS_ANSI else "dep" ) - -def test_visualization_render_whole_row_colors(en_vocab): - doc = Doc( - en_vocab, - words=[ - "I", - "saw", - "a", - "horse", - "yesterday", - "that", - "was", - "injured", - ".", - ], - heads=[1, None, 3, 1, 1, 7, 7, 3, 1], - deps=["dep"] * 9, - ) - + # whole row assert ( AttributeFormat( "dep_", - ).render(doc[2], whole_row_fg_color=8, whole_row_bg_color=9) + ).render(horse_doc[2], whole_row_fg_color=8, whole_row_bg_color=9) == "\x1b[38;5;8;48;5;9mdep\x1b[0m" if SUPPORTS_ANSI else "dep" ) - -def test_visualization_render_whole_row_colors_with_value_dependent_colors(en_vocab): - doc = Doc( - en_vocab, - words=[ - "I", - "saw", - "a", - "horse", - "yesterday", - "that", - "was", - "injured", - ".", - ], - heads=[1, None, 3, 1, 1, 7, 7, 3, 1], - deps=["dep"] * 9, - ) - + # whole row with value dependent colors assert ( AttributeFormat( "dep_", - value_dependent_fg_colors={"dep": 2}, - value_dependent_bg_colors={"dep": 11}, - ).render(doc[2], whole_row_fg_color=8, whole_row_bg_color=9) + value_dep_fg_colors={"dep": 2}, + value_dep_bg_colors={"dep": 11}, + ).render(horse_doc[2], whole_row_fg_color=8, whole_row_bg_color=9) == "\x1b[38;5;8;48;5;9mdep\x1b[0m" if SUPPORTS_ANSI else "dep" ) - -def test_visualization_render_colors_only_fg(en_vocab): - doc = Doc( - en_vocab, - words=[ - "I", - "saw", - "a", - "horse", - "yesterday", - "that", - "was", - "injured", - ".", - ], - heads=[1, None, 3, 1, 1, 7, 7, 3, 1], - deps=["dep"] * 9, - ) - + # foreground only assert ( AttributeFormat( "dep_", - value_dependent_fg_colors={"dep": 2}, - ).render(doc[2]) + value_dep_fg_colors={"dep": 2}, + ).render(horse_doc[2]) == "\x1b[38;5;2mdep\x1b[0m" if SUPPORTS_ANSI else "dep" ) - -def test_visualization_render_entity_colors_only_bg(en_vocab): - doc = Doc( - en_vocab, - words=[ - "I", - "saw", - "a", - "horse", - "yesterday", - "that", - "was", - "injured", - ".", - ], - heads=[1, None, 3, 1, 1, 7, 7, 3, 1], - deps=["dep"] * 9, - ) - + # background only assert ( AttributeFormat( "dep_", - value_dependent_bg_colors={"dep": 11}, - ).render(doc[2]) + value_dep_bg_colors={"dep": 11}, + ).render(horse_doc[2]) == "\x1b[48;5;11mdep\x1b[0m" if SUPPORTS_ANSI else "dep" ) -def test_visualization_render_native_attribute_missing(en_vocab): - doc = Doc( - en_vocab, - words=[ - "I", - "saw", - "a", - "horse", - "yesterday", - "that", - "was", - "injured", - ".", - ], - heads=[1, None, 3, 1, 1, 7, 7, 3, 1], - deps=["dep"] * 9, - ) - with pytest.raises(AttributeError): - AttributeFormat("depp").render(doc[2]) - - -def test_visualization_render_custom_attribute_str(en_vocab): - doc = Doc( - en_vocab, - words=[ - "I", - "saw", - "a", - "horse", - "yesterday", - "that", - "was", - "injured", - ".", - ], - heads=[1, None, 3, 1, 1, 7, 7, 3, 1], - deps=["dep"] * 9, - ) - Token.set_extension("test", default="tested", force=True) - assert AttributeFormat("_.test").render(doc[2]) == "tested" - - -def test_visualization_render_nested_custom_attribute_str(en_vocab): - doc = Doc( - en_vocab, - words=[ - "I", - "saw", - "a", - "horse", - "yesterday", - "that", - "was", - "injured", - ".", - ], - heads=[1, None, 3, 1, 1, 7, 7, 3, 1], - deps=["dep"] * 9, - ) +def test_viz_render_custom_attributes(horse_doc): + Token.set_extension("test", default="tested1", force=True) + assert AttributeFormat("_.test").render(horse_doc[2]) == "tested1" class Test: def __init__(self): - self.inner_test = "tested" + self.inner_test = "tested2" Token.set_extension("test", default=Test(), force=True) - assert AttributeFormat("_.test.inner_test").render(doc[2]) == "tested" + assert AttributeFormat("_.test.inner_test").render(horse_doc[2]) == "tested2" - -def test_visualization_render_custom_attribute_missing(en_vocab): - doc = Doc( - en_vocab, - words=[ - "I", - "saw", - "a", - "horse", - "yesterday", - "that", - "was", - "injured", - ".", - ], - heads=[1, None, 3, 1, 1, 7, 7, 3, 1], - deps=["dep"] * 9, - ) with pytest.raises(AttributeError): - AttributeFormat("._depp").render(doc[2]) + AttributeFormat("._depp").render(horse_doc[2]) -def test_visualization_render_permitted_values(en_vocab): - doc = Doc( - en_vocab, - words=[ - "I", - "saw", - "a", - "horse", - "yesterday", - "that", - "was", - "injured", - ".", - ], - heads=[1, None, 3, 1, 1, 7, 7, 3, 1], - deps=["dep"] * 9, - ) - attribute_format = AttributeFormat("head.i", permitted_values=(3, 7)) - assert [attribute_format.render(token) for token in doc] == [ - "", - "", - "3", - "", - "", - "7", - "7", - "3", - "", - ] +def test_viz_render_permitted_values(horse_doc): + attribute_format = AttributeFormat("head.i", permitted_vals=(3, 7)) + vals = ["", "", "3", "", "", "7", "7", "3", ""] + assert [attribute_format.render(token) for token in horse_doc] == vals -def test_visualization_minimal_render_table_one_sentence( +def test_viz_minimal_render_table_one_sentence( fully_featured_doc_one_sentence, ): formats = [ @@ -565,9 +311,10 @@ def test_visualization_minimal_render_table_one_sentence( ) -def test_visualization_minimal_render_table_empty_text_no_headers( +def test_viz_minimal_render_table_empty_text( en_vocab, ): + # no headers formats = [ AttributeFormat("tree_left"), AttributeFormat("dep_"), @@ -580,10 +327,7 @@ def test_visualization_minimal_render_table_empty_text_no_headers( ] assert Visualizer().render_table(Doc(en_vocab), formats, spacing=3).strip() == "" - -def test_visualization_minimal_render_table_empty_text_headers( - en_vocab, -): + # headers formats = [ AttributeFormat("tree_left", name="tree"), AttributeFormat("dep_"), @@ -597,14 +341,14 @@ def test_visualization_minimal_render_table_empty_text_headers( assert Visualizer().render_table(Doc(en_vocab), formats, spacing=3).strip() == "" -def test_visualization_minimal_render_table_permitted_values( +def test_viz_minimal_render_table_permitted_values( fully_featured_doc_one_sentence, ): formats = [ AttributeFormat("tree_left"), AttributeFormat("dep_"), AttributeFormat("text"), - AttributeFormat("lemma_", permitted_values=("fly", "to")), + AttributeFormat("lemma_", permitted_vals=("fly", "to")), AttributeFormat("pos_"), AttributeFormat("tag_"), AttributeFormat("morph"), @@ -629,7 +373,7 @@ def test_visualization_minimal_render_table_permitted_values( ) -def test_visualization_spacing( +def test_viz_minimal_render_table_spacing( fully_featured_doc_one_sentence, ): formats = [ @@ -661,7 +405,7 @@ def test_visualization_spacing( ) -def test_visualization_minimal_render_table_two_sentences( +def test_viz_minimal_render_table_two_sentences( fully_featured_doc_two_sentences, ): formats = [ @@ -700,7 +444,7 @@ def test_visualization_minimal_render_table_two_sentences( ) -def test_visualization_rich_render_table_one_sentence( +def test_viz_rich_render_table_one_sentence( fully_featured_doc_one_sentence, ): formats = [ @@ -716,8 +460,8 @@ def test_visualization_rich_render_table_one_sentence( "ent_type_", name="ent", fg_color=196, - value_dependent_fg_colors={"PERSON": 50}, - value_dependent_bg_colors={"PERSON": 12}, + value_dep_fg_colors={"PERSON": 50}, + value_dep_bg_colors={"PERSON": 12}, ), ] assert ( @@ -727,10 +471,7 @@ def test_visualization_rich_render_table_one_sentence( else "\n\x1b[38;5;2m tree\x1b[0m \x1b[38;5;2mdep \x1b[0m index text lemma pos tag morph ent \n\x1b[38;5;2m------\x1b[0m \x1b[38;5;2m--------\x1b[0m ----- ------- ------- ----- --- --------------- ------\n\x1b[38;5;2m ╔>╔═\x1b[0m \x1b[38;5;2mposs \x1b[0m 0 Sarah sarah PROPN NNP NounType=prop|N PERSON\n\x1b[38;5;2m ║ ╚>\x1b[0m \x1b[38;5;2mcase \x1b[0m 1 's 's PART POS Poss=yes \n\x1b[38;5;2m╔>╚═══\x1b[0m \x1b[38;5;2mnsubj \x1b[0m 2 sister sister NOUN NN Number=sing \n\x1b[38;5;2m╠═════\x1b[0m \x1b[38;5;2mROOT \x1b[0m 3 flew fly VERB VBD Tense=past|Verb \n\x1b[38;5;2m╠>╔═══\x1b[0m \x1b[38;5;2mprep \x1b[0m 4 to to ADP IN \n\x1b[38;5;2m║ ║ ╔>\x1b[0m \x1b[38;5;2mcompound\x1b[0m 5 Silicon silicon PROPN NNP NounType=prop|N GPE \n\x1b[38;5;2m║ ╚>╚═\x1b[0m \x1b[38;5;2mpobj \x1b[0m 6 Valley valley PROPN NNP NounType=prop|N GPE \n\x1b[38;5;2m╠══>╔═\x1b[0m \x1b[38;5;2mprep \x1b[0m 7 via via ADP IN \n\x1b[38;5;2m║ ╚>\x1b[0m \x1b[38;5;2mpobj \x1b[0m 8 London london PROPN NNP NounType=prop|N GPE \n\x1b[38;5;2m╚════>\x1b[0m \x1b[38;5;2mpunct \x1b[0m 9 . . PUNCT . PunctType=peri \n\n" ) - -def test_visualization_rich_render_table_one_sentence_trigger_value_shorter_than_maximum( - fully_featured_doc_one_sentence, -): + # trigger value for value_dep shorter than maximum length in column formats = [ AttributeFormat("tree_left", name="tree", aligns="r", fg_color=2), AttributeFormat("dep_", name="dep", fg_color=2), @@ -739,8 +480,8 @@ def test_visualization_rich_render_table_one_sentence_trigger_value_shorter_than "text", name="text", fg_color=196, - value_dependent_fg_colors={"'s": 50}, - value_dependent_bg_colors={"'s": 12}, + value_dep_fg_colors={"'s": 50}, + value_dep_bg_colors={"'s": 12}, ), AttributeFormat("lemma_", name="lemma"), AttributeFormat("pos_", name="pos", fg_color=100), @@ -759,7 +500,7 @@ def test_visualization_rich_render_table_one_sentence_trigger_value_shorter_than ) -def test_visualization_rich_render_table_two_sentences( +def test_viz_rich_render_table_two_sentences( fully_featured_doc_two_sentences, ): formats = [ @@ -775,8 +516,8 @@ def test_visualization_rich_render_table_two_sentences( "ent_type_", name="ent", fg_color=196, - value_dependent_fg_colors={"PERSON": 50}, - value_dependent_bg_colors={"PERSON": 12}, + value_dep_fg_colors={"PERSON": 50}, + value_dep_bg_colors={"PERSON": 12}, ), ] assert ( @@ -787,25 +528,25 @@ def test_visualization_rich_render_table_two_sentences( ) -def test_visualization_text_with_text_format( +def test_viz_text_with_text_format( fully_featured_doc_two_sentences, ): formats = [ AttributeFormat( "ent_type_", fg_color=50, - value_dependent_fg_colors={"PERSON": 50}, - value_dependent_bg_colors={"PERSON": 12}, + value_dep_fg_colors={"PERSON": 50}, + value_dep_bg_colors={"PERSON": 12}, ), AttributeFormat( "text", fg_color=50, bg_color=53, - value_dependent_fg_colors={"PERSON": 50}, - value_dependent_bg_colors={"PERSON": 12}, + value_dep_fg_colors={"PERSON": 50}, + value_dep_bg_colors={"PERSON": 12}, ), AttributeFormat( - "lemma_", fg_color=50, bg_color=53, permitted_values=("fly", "valley") + "lemma_", fg_color=50, bg_color=53, permitted_vals=("fly", "valley") ), ] assert ( @@ -816,16 +557,16 @@ def test_visualization_text_with_text_format( ) -def test_visualization_render_text_without_text_format( +def test_viz_render_text_without_text_format( fully_featured_doc_two_sentences, ): formats = [ AttributeFormat( "ent_type_", - value_dependent_fg_colors={"PERSON": 50}, - value_dependent_bg_colors={"PERSON": 12}, + value_dep_fg_colors={"PERSON": 50}, + value_dep_bg_colors={"PERSON": 12}, ), - AttributeFormat("lemma_", permitted_values=("fly", "valley")), + AttributeFormat("lemma_", permitted_vals=("fly", "valley")), ] assert ( Visualizer().render_text(fully_featured_doc_two_sentences, formats) @@ -835,9 +576,10 @@ def test_visualization_render_text_without_text_format( ) -def test_visualization_minimal_render_instances_two_sentences_type_non_grouping( +def test_viz_render_instances_two_sentences( fully_featured_doc_two_sentences, ): + # search on entity type display_columns = [ AttributeFormat("dep_"), AttributeFormat("text"), @@ -852,8 +594,8 @@ def test_visualization_minimal_render_instances_two_sentences_type_non_grouping( assert ( Visualizer().render_instances( fully_featured_doc_two_sentences, - search_attributes=search_attributes, - display_columns=display_columns, + search_attrs=search_attributes, + display_cols=display_columns, group=False, spacing=3, surrounding_tokens_height=0, @@ -863,10 +605,7 @@ def test_visualization_minimal_render_instances_two_sentences_type_non_grouping( == "\nposs Sarah sarah PROPN NNP NounType=prop|Number=sing PERSON\n\ncompound Silicon silicon PROPN NNP NounType=prop|Number=sing GPE \npobj Valley valley PROPN NNP NounType=prop|Number=sing GPE \n\npobj London london PROPN NNP NounType=prop|Number=sing GPE \n" ) - -def test_visualization_minimal_render_instances_two_sentences_value_non_grouping( - fully_featured_doc_two_sentences, -): + # search on entity type with permitted values display_columns = [ AttributeFormat("dep_"), AttributeFormat("text"), @@ -877,13 +616,13 @@ def test_visualization_minimal_render_instances_two_sentences_value_non_grouping AttributeFormat("ent_type_"), ] - search_attributes = [AttributeFormat("ent_type_", permitted_values=["PERSON"])] + search_attributes = [AttributeFormat("ent_type_", permitted_vals=["PERSON"])] assert ( Visualizer().render_instances( fully_featured_doc_two_sentences, - search_attributes=search_attributes, - display_columns=display_columns, + search_attrs=search_attributes, + display_cols=display_columns, group=False, spacing=3, surrounding_tokens_height=0, @@ -893,10 +632,7 @@ def test_visualization_minimal_render_instances_two_sentences_value_non_grouping == "\nposs Sarah sarah PROPN NNP NounType=prop|Number=sing PERSON\n" ) - -def test_visualization_minimal_render_instances_two_sentences_value_surrounding_sentences_non_grouping( - fully_featured_doc_two_sentences, -): + # include surrounding tokens display_columns = [ AttributeFormat("dep_"), AttributeFormat("text"), @@ -907,13 +643,13 @@ def test_visualization_minimal_render_instances_two_sentences_value_surrounding_ AttributeFormat("ent_type_"), ] - search_attributes = [AttributeFormat("ent_type_", permitted_values=["PERSON"])] + search_attributes = [AttributeFormat("ent_type_", permitted_vals=["PERSON"])] assert ( Visualizer().render_instances( fully_featured_doc_two_sentences, - search_attributes=search_attributes, - display_columns=display_columns, + search_attrs=search_attributes, + display_cols=display_columns, group=False, spacing=3, surrounding_tokens_height=2, @@ -926,9 +662,7 @@ def test_visualization_minimal_render_instances_two_sentences_value_surrounding_ ) -def test_visualization_render_instances_two_sentences_missing_value_non_grouping( - fully_featured_doc_two_sentences, -): + # missing permitted value display_columns = [ AttributeFormat("dep_", name="dep"), AttributeFormat("text", name="text"), @@ -939,13 +673,13 @@ def test_visualization_render_instances_two_sentences_missing_value_non_grouping AttributeFormat("ent_type_"), ] - search_attributes = [AttributeFormat("ent_type_", permitted_values=["PERSONN"])] + search_attributes = [AttributeFormat("ent_type_", permitted_vals=["PERSONN"])] assert ( Visualizer().render_instances( fully_featured_doc_two_sentences, - search_attributes=search_attributes, - display_columns=display_columns, + search_attrs=search_attributes, + display_cols=display_columns, group=False, spacing=3, surrounding_tokens_height=0, @@ -955,10 +689,7 @@ def test_visualization_render_instances_two_sentences_missing_value_non_grouping == "\ndep text \n--- ---- \n" ) - -def test_visualization_render_instances_two_sentences_missing_value_surrounding_sentences_non_grouping( - fully_featured_doc_two_sentences, -): + # missing permitted value, include surrounding tokens display_columns = [ AttributeFormat("dep_", name="dep"), AttributeFormat("text", name="text"), @@ -969,13 +700,13 @@ def test_visualization_render_instances_two_sentences_missing_value_surrounding_ AttributeFormat("ent_type_"), ] - search_attributes = [AttributeFormat("ent_type_", permitted_values=["PERSONN"])] + search_attributes = [AttributeFormat("ent_type_", permitted_vals=["PERSONN"])] assert ( Visualizer().render_instances( fully_featured_doc_two_sentences, - search_attributes=search_attributes, - display_columns=display_columns, + search_attrs=search_attributes, + display_cols=display_columns, group=False, spacing=3, surrounding_tokens_height=0, @@ -985,10 +716,7 @@ def test_visualization_render_instances_two_sentences_missing_value_surrounding_ == "\ndep text \n--- ---- \n" ) - -def test_visualization_render_instances_two_sentences_type_grouping( - fully_featured_doc_two_sentences, -): + # with grouping display_columns = [ AttributeFormat("dep_"), AttributeFormat("text"), @@ -1004,8 +732,8 @@ def test_visualization_render_instances_two_sentences_type_grouping( assert ( Visualizer().render_instances( fully_featured_doc_two_sentences, - search_attributes=search_attributes, - display_columns=display_columns, + search_attrs=search_attributes, + display_cols=display_columns, group=True, spacing=3, surrounding_tokens_height=0, @@ -1015,10 +743,7 @@ def test_visualization_render_instances_two_sentences_type_grouping( == "\npobj London london PROPN NNP NounType=prop|Number=sing GPE \n\ncompound Silicon silicon PROPN NNP NounType=prop|Number=sing GPE \npobj Valley valley PROPN NNP NounType=prop|Number=sing GPE \n\nposs Sarah sarah PROPN NNP NounType=prop|Number=sing PERSON\n" ) - -def test_visualization_render_instances_two_sentences_type_grouping_colors( - fully_featured_doc_two_sentences, -): + # with grouping and colors display_columns = [ AttributeFormat("dep_", fg_color=20), AttributeFormat("text", bg_color=30), @@ -1034,8 +759,8 @@ def test_visualization_render_instances_two_sentences_type_grouping_colors( assert ( Visualizer().render_instances( fully_featured_doc_two_sentences, - search_attributes=search_attributes, - display_columns=display_columns, + search_attrs=search_attributes, + display_cols=display_columns, group=True, spacing=3, surrounding_tokens_height=0, @@ -1046,35 +771,3 @@ def test_visualization_render_instances_two_sentences_type_grouping_colors( if SUPPORTS_ANSI else "npobj London london PROPN NNP NounType=prop|Number=sing GPE \n\ncompound Silicon silicon PROPN NNP NounType=prop|Number=sing GPE \npobj Valley valley PROPN NNP NounType=prop|Number=sing GPE \n\nposs Sarah sarah PROPN NNP NounType=prop|Number=sing PERSON\n" ) - - -def test_visualization_render_instances_two_sentences_type_grouping_colors_with_surrounding_sentences( - fully_featured_doc_two_sentences, -): - display_columns = [ - AttributeFormat("dep_", fg_color=20), - AttributeFormat("text", bg_color=30), - AttributeFormat("lemma_"), - AttributeFormat("pos_"), - AttributeFormat("tag_"), - AttributeFormat("morph"), - AttributeFormat("ent_type_"), - ] - - search_attributes = [AttributeFormat("ent_type_"), AttributeFormat("lemma_")] - - assert ( - Visualizer().render_instances( - fully_featured_doc_two_sentences, - search_attributes=search_attributes, - display_columns=display_columns, - group=True, - spacing=3, - surrounding_tokens_height=3, - surrounding_tokens_fg_color=11, - surrounding_tokens_bg_color=None, - ) - == "\n\x1b[38;5;20m\x1b[38;5;11mcompound\x1b[0m\x1b[0m \x1b[48;5;30m\x1b[38;5;11mSilicon\x1b[0m\x1b[0m \x1b[38;5;11msilicon\x1b[0m \x1b[38;5;11mPROPN\x1b[0m \x1b[38;5;11mNNP\x1b[0m \x1b[38;5;11mNounType=prop|Number=sing\x1b[0m \x1b[38;5;11mGPE\x1b[0m \n\x1b[38;5;20m\x1b[38;5;11mpobj\x1b[0m \x1b[0m \x1b[48;5;30m\x1b[38;5;11mValley\x1b[0m \x1b[0m \x1b[38;5;11mvalley\x1b[0m \x1b[38;5;11mPROPN\x1b[0m \x1b[38;5;11mNNP\x1b[0m \x1b[38;5;11mNounType=prop|Number=sing\x1b[0m \x1b[38;5;11mGPE\x1b[0m \n\x1b[38;5;20m\x1b[38;5;11mprep\x1b[0m \x1b[0m \x1b[48;5;30m\x1b[38;5;11mvia\x1b[0m \x1b[0m \x1b[38;5;11mvia\x1b[0m \x1b[38;5;11mADP\x1b[0m \x1b[38;5;11mIN\x1b[0m \n\x1b[38;5;20mpobj \x1b[0m \x1b[48;5;30mLondon \x1b[0m london PROPN NNP NounType=prop|Number=sing GPE \n\x1b[38;5;20m\x1b[38;5;11mpunct\x1b[0m \x1b[0m \x1b[48;5;30m\x1b[38;5;11m.\x1b[0m \x1b[0m \x1b[38;5;11m.\x1b[0m \x1b[38;5;11mPUNCT\x1b[0m \x1b[38;5;11m.\x1b[0m \x1b[38;5;11mPunctType=peri\x1b[0m \n\x1b[38;5;20m\x1b[38;5;11mnsubj\x1b[0m \x1b[0m \x1b[48;5;30m\x1b[38;5;11mShe\x1b[0m \x1b[0m \x1b[38;5;11mshe\x1b[0m \x1b[38;5;11mPRON\x1b[0m \x1b[38;5;11mPRP\x1b[0m \x1b[38;5;11mCase=Nom|Gender=Fem|Number=Sing|Person=3|PronType=Prs\x1b[0m \n\x1b[38;5;20m\x1b[38;5;11mROOT\x1b[0m \x1b[0m \x1b[48;5;30m\x1b[38;5;11mloved\x1b[0m \x1b[0m \x1b[38;5;11mlove\x1b[0m \x1b[38;5;11mVERB\x1b[0m \x1b[38;5;11mVBD\x1b[0m \x1b[38;5;11mTense=Past|VerbForm=Fin\x1b[0m \n\n\x1b[38;5;20m\x1b[38;5;11mnsubj\x1b[0m \x1b[0m \x1b[48;5;30m\x1b[38;5;11msister\x1b[0m \x1b[0m \x1b[38;5;11msister\x1b[0m \x1b[38;5;11mNOUN\x1b[0m \x1b[38;5;11mNN\x1b[0m \x1b[38;5;11mNumber=sing\x1b[0m \n\x1b[38;5;20m\x1b[38;5;11mROOT\x1b[0m \x1b[0m \x1b[48;5;30m\x1b[38;5;11mflew\x1b[0m \x1b[0m \x1b[38;5;11mfly\x1b[0m \x1b[38;5;11mVERB\x1b[0m \x1b[38;5;11mVBD\x1b[0m \x1b[38;5;11mTense=past|VerbForm=fin\x1b[0m \n\x1b[38;5;20m\x1b[38;5;11mprep\x1b[0m \x1b[0m \x1b[48;5;30m\x1b[38;5;11mto\x1b[0m \x1b[0m \x1b[38;5;11mto\x1b[0m \x1b[38;5;11mADP\x1b[0m \x1b[38;5;11mIN\x1b[0m \n\x1b[38;5;20mcompound\x1b[0m \x1b[48;5;30mSilicon\x1b[0m silicon PROPN NNP NounType=prop|Number=sing GPE \n\x1b[38;5;20mpobj \x1b[0m \x1b[48;5;30mValley \x1b[0m valley PROPN NNP NounType=prop|Number=sing GPE \n\x1b[38;5;20m\x1b[38;5;11mprep\x1b[0m \x1b[0m \x1b[48;5;30m\x1b[38;5;11mvia\x1b[0m \x1b[0m \x1b[38;5;11mvia\x1b[0m \x1b[38;5;11mADP\x1b[0m \x1b[38;5;11mIN\x1b[0m \n\x1b[38;5;20m\x1b[38;5;11mpobj\x1b[0m \x1b[0m \x1b[48;5;30m\x1b[38;5;11mLondon\x1b[0m \x1b[0m \x1b[38;5;11mlondon\x1b[0m \x1b[38;5;11mPROPN\x1b[0m \x1b[38;5;11mNNP\x1b[0m \x1b[38;5;11mNounType=prop|Number=sing\x1b[0m \x1b[38;5;11mGPE\x1b[0m \n\x1b[38;5;20m\x1b[38;5;11mpunct\x1b[0m \x1b[0m \x1b[48;5;30m\x1b[38;5;11m.\x1b[0m \x1b[0m \x1b[38;5;11m.\x1b[0m \x1b[38;5;11mPUNCT\x1b[0m \x1b[38;5;11m.\x1b[0m \x1b[38;5;11mPunctType=peri\x1b[0m \n\n\x1b[38;5;20mposs \x1b[0m \x1b[48;5;30mSarah \x1b[0m sarah PROPN NNP NounType=prop|Number=sing PERSON\n\x1b[38;5;20m\x1b[38;5;11mcase\x1b[0m \x1b[0m \x1b[48;5;30m\x1b[38;5;11m's\x1b[0m \x1b[0m \x1b[38;5;11m's\x1b[0m \x1b[38;5;11mPART\x1b[0m \x1b[38;5;11mPOS\x1b[0m \x1b[38;5;11mPoss=yes\x1b[0m \n\x1b[38;5;20m\x1b[38;5;11mnsubj\x1b[0m \x1b[0m \x1b[48;5;30m\x1b[38;5;11msister\x1b[0m \x1b[0m \x1b[38;5;11msister\x1b[0m \x1b[38;5;11mNOUN\x1b[0m \x1b[38;5;11mNN\x1b[0m \x1b[38;5;11mNumber=sing\x1b[0m \n\x1b[38;5;20m\x1b[38;5;11mROOT\x1b[0m \x1b[0m \x1b[48;5;30m\x1b[38;5;11mflew\x1b[0m \x1b[0m \x1b[38;5;11mfly\x1b[0m \x1b[38;5;11mVERB\x1b[0m \x1b[38;5;11mVBD\x1b[0m \x1b[38;5;11mTense=past|VerbForm=fin\x1b[0m \n" - if SUPPORTS_ANSI - else "\ncompound Silicon silicon PROPN NNP NounType=prop|Number=sing GPE \npobj Valley valley PROPN NNP NounType=prop|Number=sing GPE \nprep via via ADP IN \npobj London london PROPN NNP NounType=prop|Number=sing GPE \npunct . . PUNCT . PunctType=peri \nnsubj She she PRON PRP Case=Nom|Gender=Fem|Number=Sing|Person=3|PronType=Prs \nROOT loved love VERB VBD Tense=Past|VerbForm=Fin \n\nnsubj sister sister NOUN NN Number=sing \nROOT flew fly VERB VBD Tense=past|VerbForm=fin \nprep to to ADP IN \ncompound Silicon silicon PROPN NNP NounType=prop|Number=sing GPE \npobj Valley valley PROPN NNP NounType=prop|Number=sing GPE \nprep via via ADP IN \npobj London london PROPN NNP NounType=prop|Number=sing GPE \npunct . . PUNCT . PunctType=peri \n\nposs Sarah sarah PROPN NNP NounType=prop|Number=sing PERSON\ncase 's 's PART POS Poss=yes \nnsubj sister sister NOUN NN Number=sing \nROOT flew fly VERB VBD Tense=past|VerbForm=fin \n" - ) diff --git a/spacy/visualization.py b/spacy/visualization.py index d21084e2b..ac0de959a 100644 --- a/spacy/visualization.py +++ b/spacy/visualization.py @@ -52,25 +52,25 @@ class AttributeFormat: name: str = "", aligns: str = "l", max_width: Optional[int] = None, - fg_color: Union[str, int, None] = None, - bg_color: Union[str, int, None] = None, - permitted_values: Optional[tuple] = None, - value_dependent_fg_colors: Optional[Dict[str, Union[str, int]]] = None, - value_dependent_bg_colors: Optional[Dict[str, Union[str, int]]] = None, + fg_color: Optional[Union[str, int]] = None, + bg_color: Optional[Union[str, int]] = None, + permitted_vals: Optional[tuple] = None, + value_dep_fg_colors: Optional[Dict[str, Union[str, int]]] = None, + value_dep_bg_colors: Optional[Dict[str, Union[str, int]]] = None, ): """ - attribute: the token attribute, e.g. lemma_, ._.holmes.lemma - name: the name to display e.g. in column headers - aligns: where appropriate the column alignment 'l' (left, - default), 'r' (right) or 'c' (center). - max_width: a maximum width to which values of the attribute should be truncated. - fg_color: the foreground color that should be used to display instances of the attribute - bg_color: the background color that should be used to display instances of the attribute - permitted_values: a tuple of values of the attribute that should be displayed. If - permitted_values is not None and a value of the attribute is not - in permitted_values, the empty string is rendered instead of the value. - value_dependent_fg_colors: a dictionary from values to foreground colors that should be used to display those values. - value_dependent_bg_colors: a dictionary from values to background colors that should be used to display those values. + attribute: the token attribute, e.g. lemma_, ._.holmes.lemma + name: the name to display e.g. in column headers + aligns: where appropriate the column alignment 'l' (left, + default), 'r' (right) or 'c' (center). + max_width: a maximum width to which values of the attribute should be truncated. + fg_color: the foreground color that should be used to display instances of the attribute + bg_color: the background color that should be used to display instances of the attribute + permitted_vals: a tuple of values of the attribute that should be displayed. If + permitted_values is not None and a value of the attribute is not + in permitted_values, the empty string is rendered instead of the value. + value_dep_fg_colors: a dictionary from values to foreground colors that should be used to display those values. + value_dep_bg_colors: a dictionary from values to background colors that should be used to display those values. """ self.attribute = attribute self.name = name @@ -78,57 +78,58 @@ class AttributeFormat: self.max_width = max_width self.fg_color = fg_color self.bg_color = bg_color - self.permitted_values = permitted_values - self.value_dependent_fg_colors = value_dependent_fg_colors - self.value_dependent_bg_colors = value_dependent_bg_colors + self.permitted_vals = permitted_vals + self.value_dep_fg_colors = value_dep_fg_colors + self.value_dep_bg_colors = value_dep_bg_colors self.printer = wasabi.Printer(no_print=True) def render( self, token: Token, *, - right_pad_to_length: Optional[int] = None, + right_pad_to_len: Optional[int] = None, ignore_colors: bool = False, - render_all_colors_within_values: bool = False, + render_all_colors_in_vals: bool = False, whole_row_fg_color: Union[int, str, None] = None, whole_row_bg_color: Union[int, str, None] = None, ) -> str: """ - ignore_colors: no colors should be rendered, typically because the values are required to calculate widths - render_all_colors_within_values: when rendering a table, self.fg_color and self.bg_color are rendered in Wasabi. - This argument is set to True when rendering a text to signal that colors should be rendered here. - whole_row_fg_color: a foreground color used for the whole row. This takes precedence over value_dependent_fg_colors. - whole_row_bg_color: a background color used for the whole row. This takes precedence over value_dependent_bg_colors. + right_pad_to_len: the width to which values should be right-padded, or 'None' for no right-padding. + ignore_colors: no colors should be rendered, typically because the values are required to calculate widths + render_all_colors_in_vals: when rendering a table, self.fg_color and self.bg_color are rendered in Wasabi. + This argument is set to True when rendering a text to signal that colors should be rendered here. + whole_row_fg_color: a foreground color used for the whole row. This takes precedence over value_dependent_fg_colors. + whole_row_bg_color: a background color used for the whole row. This takes precedence over value_dependent_bg_colors. """ obj = token parts = self.attribute.split(".") for part in parts[:-1]: obj = getattr(obj, part) value = str(getattr(obj, parts[-1])) - if self.permitted_values is not None and value not in ( - str(v) for v in self.permitted_values + if self.permitted_vals is not None and value not in ( + str(v) for v in self.permitted_vals ): return "" if self.max_width is not None: value = value[: self.max_width] fg_color = None bg_color = None - if right_pad_to_length is not None: - right_padding = " " * (right_pad_to_length - len(value)) + if right_pad_to_len is not None: + right_padding = " " * (right_pad_to_len - len(value)) else: right_padding = "" if SUPPORTS_ANSI and not ignore_colors and len(value) > 0: if whole_row_fg_color is not None: fg_color = whole_row_fg_color - elif self.value_dependent_fg_colors is not None: - fg_color = self.value_dependent_fg_colors.get(value, None) - if fg_color is None and render_all_colors_within_values: + elif self.value_dep_fg_colors is not None: + fg_color = self.value_dep_fg_colors.get(value, None) + if fg_color is None and render_all_colors_in_vals: fg_color = self.fg_color - if self.value_dependent_bg_colors is not None: - bg_color = self.value_dependent_bg_colors.get(value, None) + if self.value_dep_bg_colors is not None: + bg_color = self.value_dep_bg_colors.get(value, None) if whole_row_bg_color is not None: bg_color = whole_row_bg_color - elif bg_color is None and render_all_colors_within_values: + elif bg_color is None and render_all_colors_in_vals: bg_color = self.bg_color if fg_color is not None or bg_color is not None: value = self.printer.text(value, color=fg_color, bg_color=bg_color) @@ -137,7 +138,7 @@ class AttributeFormat: class Visualizer: @staticmethod - def render_dependency_tree(sent: Span, root_right: bool) -> List[str]: + def render_dep_tree(sent: Span, root_right: bool) -> List[str]: """ Returns an ASCII rendering of the document with a dependency tree for each sentence. The dependency tree output for a given token has the same index within the output list of @@ -150,276 +151,234 @@ class Visualizer: """ # Check sent is really a sentence - assert sent.start == sent[0].sent.start - assert sent.end == sent[0].sent.end - heads: List[Optional[int]] = [ - None - if token.dep_.lower() == "root" or token.head.i == token.i - else token.head.i - sent.start - for token in sent - ] + if sent.start != sent[0].sent.start or sent.end != sent[0].sent.end: + raise ValueError(f"Span is not a sentence: '{sent}'") + heads: List[Optional[int]] = [] + for token in sent: + if token.dep_.lower() == "root" or token.head.i == token.i: + heads.append(None) + else: + heads.append(token.head.i - sent.start) # Check there are no head references outside the sentence - assert ( - len( - [ - head - for head in heads - if head is not None and (head < 0 or head > sent.end - sent.start) - ] - ) - == 0 - ) + heads_outside_sent = [ + 1 for h in heads if h is not None and (h < 0 or h > sent.end - sent.start) + ] + if len(heads_outside_sent) > 0: + raise ValueError(f"Head reference outside sentence in sentence '{sent}'") children_lists: List[List[int]] = [[] for _ in range(sent.end - sent.start)] for child, head in enumerate(heads): if head is not None: children_lists[head].append(child) - all_indices_ordered_by_column: List[int] = [] + all_ind_ord_by_col: List[int] = [] # start with the root column - indices_in_current_column = [i for i, h in enumerate(heads) if h is None] - while len(indices_in_current_column) > 0: - assert ( - len( - [ - i - for i in indices_in_current_column - if i in all_indices_ordered_by_column - ] - ) - == 0 - ) - all_indices_ordered_by_column = ( - indices_in_current_column + all_indices_ordered_by_column - ) - indices_in_next_column = [] + inds_in_this_col = [i for i, h in enumerate(heads) if h is None] + while len(inds_in_this_col) > 0: + all_ind_ord_by_col = inds_in_this_col + all_ind_ord_by_col + inds_in_next_col = [] # The calculation order of the horizontal lengths of the children # on either given side of a head must ensure that children # closer to the head are processed first. - for index_in_current_column in indices_in_current_column: - following_children_indices = [ - i - for i in children_lists[index_in_current_column] - if i > index_in_current_column + for ind_in_this_col in inds_in_this_col: + following_child_inds = [ + i for i in children_lists[ind_in_this_col] if i > ind_in_this_col ] - indices_in_next_column.extend(following_children_indices) - preceding_children_indices = [ - i - for i in children_lists[index_in_current_column] - if i < index_in_current_column + inds_in_next_col.extend(following_child_inds) + preceding_child_inds = [ + i for i in children_lists[ind_in_this_col] if i < ind_in_this_col ] - preceding_children_indices.reverse() - indices_in_next_column.extend(preceding_children_indices) - indices_in_current_column = indices_in_next_column - horizontal_line_lengths = [ - -1 if heads[i] is None else 1 - # length == 1: governed by direct neighbour and has no children itself - if len(children_lists[i]) == 0 and abs(cast(int, heads[i]) - i) == 1 else 0 - for i in range(sent.end - sent.start) - ] - while 0 in horizontal_line_lengths: - for working_token_index in ( - i - for i in all_indices_ordered_by_column - if horizontal_line_lengths[i] == 0 + preceding_child_inds.reverse() + inds_in_next_col.extend(preceding_child_inds) + inds_in_this_col = inds_in_next_col + horiz_line_lens: List[int] = [] + for i in range(sent.end - sent.start): + if heads[i] is None: + horiz_line_lens.append(-1) + elif len(children_lists[i]) == 0 and abs(cast(int, heads[i]) - i) == 1: + # governed by direct neighbour and has no children itself + horiz_line_lens.append(1) + else: + horiz_line_lens.append(0) + while 0 in horiz_line_lens: + for working_token_ind in ( + i for i in all_ind_ord_by_col if horiz_line_lens[i] == 0 ): # render relation between this token and its head - first_index_in_relation = min( - working_token_index, - cast(int, heads[working_token_index]), + first_ind_in_rel = min( + working_token_ind, + cast(int, heads[working_token_ind]), ) - second_index_in_relation = max( - working_token_index, - cast(int, heads[working_token_index]), + second_ind_in_rel = max( + working_token_ind, + cast(int, heads[working_token_ind]), ) # If this token has children, they will already have been rendered. # The line needs to be one character longer than the longest of the # children's lines. - if len(children_lists[working_token_index]) > 0: - horizontal_line_lengths[working_token_index] = ( + if len(children_lists[working_token_ind]) > 0: + horiz_line_lens[working_token_ind] = ( max( [ - horizontal_line_lengths[i] - for i in children_lists[working_token_index] + horiz_line_lens[i] + for i in children_lists[working_token_ind] ] ) + 1 ) else: - horizontal_line_lengths[working_token_index] = 1 - for inbetween_index in ( + horiz_line_lens[working_token_ind] = 1 + for inbetween_ind in ( i - for i in range( - first_index_in_relation + 1, second_index_in_relation - ) - if horizontal_line_lengths[i] != 0 + for i in range(first_ind_in_rel + 1, second_ind_in_rel) + if horiz_line_lens[i] != 0 ): - horizontal_line_lengths[working_token_index] = max( - horizontal_line_lengths[working_token_index], - horizontal_line_lengths[inbetween_index] - if inbetween_index - in children_lists[cast(int, heads[working_token_index])] - and inbetween_index not in children_lists[working_token_index] - else horizontal_line_lengths[inbetween_index] + 1, - ) - max_horizontal_line_length = max(horizontal_line_lengths) + alt_ind: int + if ( + inbetween_ind + in children_lists[cast(int, heads[working_token_ind])] + and inbetween_ind not in children_lists[working_token_ind] + ): + alt_ind = horiz_line_lens[inbetween_ind] + else: + alt_ind = horiz_line_lens[inbetween_ind] + 1 + if alt_ind > horiz_line_lens[working_token_ind]: + horiz_line_lens[working_token_ind] = alt_ind + max_horiz_line_len = max(horiz_line_lens) char_matrix = [ - [SPACE] * max_horizontal_line_length * 2 - for _ in range(sent.start, sent.end) + [SPACE] * max_horiz_line_len * 2 for _ in range(sent.start, sent.end) ] - for working_token_index in range(sent.end - sent.start): - head_token_index = heads[working_token_index] - if head_token_index is None: + for working_token_ind in range(sent.end - sent.start): + head_token_ind = heads[working_token_ind] + if head_token_ind is None: continue - first_index_in_relation = min(working_token_index, head_token_index) - second_index_in_relation = max(working_token_index, head_token_index) - char_horizontal_line_length = ( - 2 * horizontal_line_lengths[working_token_index] - ) + first_ind_in_rel = min(working_token_ind, head_token_ind) + second_ind_in_rel = max(working_token_ind, head_token_ind) + char_horiz_line_len = 2 * horiz_line_lens[working_token_ind] # Draw the corners of the relation - char_matrix[first_index_in_relation][char_horizontal_line_length - 1] |= ( + char_matrix[first_ind_in_rel][char_horiz_line_len - 1] |= ( HALF_HORIZONTAL_LINE + LOWER_HALF_VERTICAL_LINE ) - char_matrix[second_index_in_relation][char_horizontal_line_length - 1] |= ( + char_matrix[second_ind_in_rel][char_horiz_line_len - 1] |= ( HALF_HORIZONTAL_LINE + UPPER_HALF_VERTICAL_LINE ) # Draw the horizontal line for the governing token - for working_horizontal_position in range(char_horizontal_line_length - 1): - if ( - char_matrix[head_token_index][working_horizontal_position] - != FULL_VERTICAL_LINE - ): - char_matrix[head_token_index][ - working_horizontal_position + for working_horiz_pos in range(char_horiz_line_len - 1): + if char_matrix[head_token_ind][working_horiz_pos] != FULL_VERTICAL_LINE: + char_matrix[head_token_ind][ + working_horiz_pos ] |= FULL_HORIZONTAL_LINE # Draw the vertical line for the relation - for working_vertical_position in range( - first_index_in_relation + 1, second_index_in_relation - ): + for working_vert_pos in range(first_ind_in_rel + 1, second_ind_in_rel): if ( - char_matrix[working_vertical_position][ - char_horizontal_line_length - 1 - ] + char_matrix[working_vert_pos][char_horiz_line_len - 1] != FULL_HORIZONTAL_LINE ): - char_matrix[working_vertical_position][ - char_horizontal_line_length - 1 + char_matrix[working_vert_pos][ + char_horiz_line_len - 1 ] |= FULL_VERTICAL_LINE - for working_token_index in ( + for working_token_ind in ( i for i in range(sent.end - sent.start) if heads[i] is not None ): - for working_horizontal_position in range( - 2 * horizontal_line_lengths[working_token_index] - 2, -1, -1 + for working_horiz_pos in range( + 2 * horiz_line_lens[working_token_ind] - 2, -1, -1 ): if ( ( - char_matrix[working_token_index][working_horizontal_position] + char_matrix[working_token_ind][working_horiz_pos] == FULL_VERTICAL_LINE ) - and working_horizontal_position > 1 - and char_matrix[working_token_index][ - working_horizontal_position - 2 - ] - == SPACE + and working_horiz_pos > 1 + and char_matrix[working_token_ind][working_horiz_pos - 2] == SPACE ): # Cross over the existing vertical line, which is owing to a non-projective tree continue - if ( - char_matrix[working_token_index][working_horizontal_position] - != SPACE - ): + if char_matrix[working_token_ind][working_horiz_pos] != SPACE: # Draw the arrowhead to the right of what is already there - char_matrix[working_token_index][ - working_horizontal_position + 1 - ] = ARROWHEAD + char_matrix[working_token_ind][working_horiz_pos + 1] = ARROWHEAD break - if working_horizontal_position == 0: + if working_horiz_pos == 0: # Draw the arrowhead at the boundary of the diagram - char_matrix[working_token_index][ - working_horizontal_position - ] = ARROWHEAD + char_matrix[working_token_ind][working_horiz_pos] = ARROWHEAD else: # Fill in the horizontal line for the governed token - char_matrix[working_token_index][ - working_horizontal_position + char_matrix[working_token_ind][ + working_horiz_pos ] |= FULL_HORIZONTAL_LINE if root_right: return [ "".join( - ROOT_RIGHT_CHARS[ - char_matrix[vertical_position][horizontal_position] - ] - for horizontal_position in range((max_horizontal_line_length * 2)) + ROOT_RIGHT_CHARS[char_matrix[vert_pos][horiz_pos]] + for horiz_pos in range((max_horiz_line_len * 2)) ) - for vertical_position in range(sent.end - sent.start) + for vert_pos in range(sent.end - sent.start) ] else: return [ "".join( - ROOT_LEFT_CHARS[char_matrix[vertical_position][horizontal_position]] - for horizontal_position in range((max_horizontal_line_length * 2)) + ROOT_LEFT_CHARS[char_matrix[vert_pos][horiz_pos]] + for horiz_pos in range((max_horiz_line_len * 2)) )[::-1] - for vertical_position in range(sent.end - sent.start) + for vert_pos in range(sent.end - sent.start) ] - def render_table( - self, doc: Doc, columns: List[AttributeFormat], spacing: int - ) -> str: + def render_table(self, doc: Doc, cols: List[AttributeFormat], spacing: int) -> str: """Renders a document as a table. TODO: specify a specific portion of the document to display. - columns: the attribute formats of the columns to display. - tree_right and tree_left are magic values for the - attributes that render dependency trees where the - roots are on the left or right respectively. - spacing: the number of spaces between each column in the table. + cols: the attribute formats of the columns to display. + tree_right and tree_left are magic values for the + attributes that render dependency trees where the + roots are on the left or right respectively. + spacing: the number of spaces between each column in the table. """ - return_string = "" + return_str = "" for sent in doc.sents: - if "tree_right" in (c.attribute for c in columns): - tree_right = self.render_dependency_tree(sent, True) - if "tree_left" in (c.attribute for c in columns): - tree_left = self.render_dependency_tree(sent, False) + if "tree_right" in (c.attribute for c in cols): + tree_right = self.render_dep_tree(sent, True) + if "tree_left" in (c.attribute for c in cols): + tree_left = self.render_dep_tree(sent, False) widths = [] - for column in columns: + for col in cols: # get the values without any color codes - if column.attribute == "tree_left": + if col.attribute == "tree_left": width = len(tree_left[0]) # type: ignore - elif column.attribute == "tree_right": + elif col.attribute == "tree_right": width = len(tree_right[0]) # type: ignore else: if len(sent) > 0: width = max( - len(column.render(token, ignore_colors=True)) - for token in sent + len(col.render(token, ignore_colors=True)) for token in sent ) else: width = 0 - if column.max_width is not None: - width = min(width, column.max_width) - width = max(width, len(column.name)) + if col.max_width is not None: + width = min(width, col.max_width) + width = max(width, len(col.name)) widths.append(width) - data = [ - [ - tree_right[token_index] # type: ignore - if column.attribute == "tree_right" - else tree_left[token_index] # type: ignore - if column.attribute == "tree_left" - else column.render(token, right_pad_to_length=widths[column_index]) - for column_index, column in enumerate(columns) - ] - for token_index, token in enumerate(sent) - ] + data: List[List[str]] = [] + for token_index, token in enumerate(sent): + inner_data: List[str] = [] + for col_index, col in enumerate(cols): + if col.attribute == "tree_right": + inner_data.append(tree_right[token_index]) + elif col.attribute == "tree_left": + inner_data.append(tree_left[token_index]) + else: + inner_data.append( + col.render(token, right_pad_to_len=widths[col_index]) + ) + data.append(inner_data) header: Optional[List[str]] - if len([1 for c in columns if len(c.name) > 0]) > 0: - header = [c.name for c in columns] + if len([1 for c in cols if len(c.name) > 0]) > 0: + header = [c.name for c in cols] else: header = None - aligns = [c.aligns for c in columns] - fg_colors = [c.fg_color for c in columns] - bg_colors = [c.bg_color for c in columns] - return_string += ( + aligns = [c.aligns for c in cols] + fg_colors = [c.fg_color for c in cols] + bg_colors = [c.bg_color for c in cols] + return_str += ( wasabi.table( data, header=header, @@ -432,41 +391,38 @@ class Visualizer: ) + "\n" ) - return return_string + return return_str - def render_text(self, doc: Doc, attributes: List[AttributeFormat]) -> str: + def render_text(self, doc: Doc, attrs: List[AttributeFormat]) -> str: """Renders a text interspersed with attribute labels. TODO: specify a specific portion of the document to display. """ - return_string = "" - text_attributes = [a for a in attributes if a.attribute == "text"] - text_attribute = ( - text_attributes[0] if len(text_attributes) > 0 else AttributeFormat("text") - ) + return_str = "" + text_attrs = [a for a in attrs if a.attribute == "text"] + text_attr = text_attrs[0] if len(text_attrs) > 0 else AttributeFormat("text") for token in doc: - this_token_strings = [""] - for attribute in (a for a in attributes if a.attribute != "text"): - attribute_text = attribute.render( - token, render_all_colors_within_values=True + this_token_strs = [""] + for attr in (a for a in attrs if a.attribute != "text"): + attr_text = attr.render(token, render_all_colors_in_vals=True) + if attr_text is not None and len(attr_text) > 0: + this_token_strs.append(" " + attr_text) + if len(this_token_strs) == 1: + this_token_strs[0] = token.text + else: + this_token_strs[0] = text_attr.render( + token, render_all_colors_in_vals=True ) - if attribute_text is not None and len(attribute_text) > 0: - this_token_strings.append(" " + attribute_text) - this_token_strings[0] = ( - token.text - if len(this_token_strings) == 1 - else text_attribute.render(token, render_all_colors_within_values=True) - ) - this_token_strings.append(token.whitespace_) - return_string += "".join(this_token_strings) - return return_string + this_token_strs.append(token.whitespace_) + return_str += "".join(this_token_strs) + return return_str def render_instances( self, doc: Doc, *, - search_attributes: List[AttributeFormat], - display_columns: List[AttributeFormat], + search_attrs: List[AttributeFormat], + display_cols: List[AttributeFormat], group: bool, spacing: int, surrounding_tokens_height: int, @@ -476,8 +432,8 @@ class Visualizer: """Shows all tokens in a document with specific attribute(s), e.g. entity labels, or attribute value(s), e.g. 'GPE'. TODO: specify a specific portion of the document to display. - search_attributes: the attribute(s) or attribute value(s) that cause a row to be displayed for a token. - display_columns: the attributes that should be displayed in each row. + search_attrs: the attribute(s) or attribute value(s) that cause a row to be displayed for a token. + display_cols: the attributes that should be displayed in each row. group: True if the rows should be ordered by the search attribute values, False if they should retain their in-document order. spacing: the number of spaces between each column. @@ -491,105 +447,102 @@ class Visualizer: """ def filter(token: Token) -> bool: - for attribute in search_attributes: - value = attribute.render(token, ignore_colors=True) + for attr in search_attrs: + value = attr.render(token, ignore_colors=True) if len(value) == 0: return False return True matched_tokens = [token for token in doc if filter(token)] - tokens_to_display_indices = [ - index - for token in matched_tokens - for index in range( + tokens_to_display_inds: List[int] = [] + for token in matched_tokens: + for ind in range( token.i - surrounding_tokens_height, token.i + surrounding_tokens_height + 1, - ) - if index >= 0 and index < len(doc) - ] + ): + if ind >= 0 and ind < len(doc): + tokens_to_display_inds.append(ind) widths = [] - for column in display_columns: - if len(tokens_to_display_indices) > 0: + for col in display_cols: + if len(tokens_to_display_inds) > 0: width = max( - len(column.render(doc[i], ignore_colors=True)) - for i in tokens_to_display_indices + len(col.render(doc[i], ignore_colors=True)) + for i in tokens_to_display_inds ) else: width = 0 - if column.max_width is not None: - width = min(width, column.max_width) - width = max(width, len(column.name)) + if col.max_width is not None: + width = min(width, col.max_width) + width = max(width, len(col.name)) widths.append(width) if group: matched_tokens.sort( key=( lambda token: [ - attribute.render(token, ignore_colors=True) - for attribute in search_attributes + attr.render(token, ignore_colors=True) for attr in search_attrs ] ) ) rows = [] - token_index_to_display = -1 - for matched_token_index, matched_token in enumerate(matched_tokens): + token_ind_to_display = -1 + for matched_token_ind, matched_token in enumerate(matched_tokens): if surrounding_tokens_height > 0: - surrounding_start_index = max( + surrounding_start_ind = max( 0, matched_token.i - surrounding_tokens_height ) - if token_index_to_display + 1 == matched_token.i: - surrounding_start_index = token_index_to_display + 1 - surrounding_end_index = min( + if token_ind_to_display + 1 == matched_token.i: + surrounding_start_ind = token_ind_to_display + 1 + surrounding_end_ind = min( len(doc), matched_token.i + surrounding_tokens_height + 1 ) if ( - matched_token_index + 1 < len(matched_tokens) - and matched_token.i + 1 == matched_tokens[matched_token_index + 1].i + matched_token_ind + 1 < len(matched_tokens) + and matched_token.i + 1 == matched_tokens[matched_token_ind + 1].i ): - surrounding_end_index = matched_token.i + 1 + surrounding_end_ind = matched_token.i + 1 else: - surrounding_start_index = matched_token.i - surrounding_end_index = surrounding_start_index + 1 - for token_index_to_display in range( - surrounding_start_index, surrounding_end_index + surrounding_start_ind = matched_token.i + surrounding_end_ind = surrounding_start_ind + 1 + for token_ind_to_display in range( + surrounding_start_ind, surrounding_end_ind ): - if token_index_to_display == matched_token.i: + if token_ind_to_display == matched_token.i: rows.append( [ - column.render( + col.render( matched_token, - right_pad_to_length=widths[column_index], + right_pad_to_len=widths[col_ind], ) - for column_index, column in enumerate(display_columns) + for col_ind, col in enumerate(display_cols) ] ) else: rows.append( [ - column.render( - doc[token_index_to_display], + col.render( + doc[token_ind_to_display], whole_row_fg_color=surrounding_tokens_fg_color, whole_row_bg_color=surrounding_tokens_bg_color, - right_pad_to_length=widths[column_index], + right_pad_to_len=widths[col_ind], ) - for column_index, column in enumerate(display_columns) + for col_ind, col in enumerate(display_cols) ] ) if ( - matched_token_index + 1 < len(matched_tokens) - and token_index_to_display + 1 - != matched_tokens[matched_token_index + 1].i + matched_token_ind + 1 < len(matched_tokens) + and token_ind_to_display + 1 != matched_tokens[matched_token_ind + 1].i ): rows.append([]) header: Optional[List[str]] - if len([1 for c in display_columns if len(c.name) > 0]) > 0: - header = [c.name for c in display_columns] + if len([1 for c in display_cols if len(c.name) > 0]) > 0: + header = [c.name for c in display_cols] else: header = None - aligns = [c.aligns for c in display_columns] - fg_colors = [c.fg_color for c in display_columns] - bg_colors = [c.bg_color for c in display_columns] + aligns = [c.aligns for c in display_cols] + fg_colors = [c.fg_color for c in display_cols] + bg_colors = [c.bg_color for c in display_cols] return wasabi.table( rows, header=header,