diff --git a/spacy/tests/test_visualization.py b/spacy/tests/test_visualization.py index af68f7e03..f64b21d30 100644 --- a/spacy/tests/test_visualization.py +++ b/spacy/tests/test_visualization.py @@ -420,6 +420,37 @@ def test_visualization_get_entity_custom_attribute_missing(en_vocab): Visualizer().get_entity(doc[2], "_.depp") +def test_visualization_get_entity_permitted_values(en_vocab): + doc = Doc( + en_vocab, + words=[ + "I", + "saw", + "a", + "horse", + "yesterday", + "that", + "was", + "injured", + ".", + ], + heads=[1, None, 3, 1, 1, 7, 7, 3, 1], + deps=["dep"] * 9, + ) + visualizer = Visualizer() + assert [visualizer.get_entity(token, "head.i", permitted_values=(3, 7)) for token in doc] == [ + "", + "", + "3", + "", + "", + "7", + "7", + "3", + "", + ] + + def test_visualization_minimal_render_table_one_sentence( fully_featured_doc_one_sentence, ): @@ -450,6 +481,67 @@ def test_visualization_minimal_render_table_one_sentence( ) +def test_visualization_minimal_render_table_permitted_values( + fully_featured_doc_one_sentence, +): + formats = [ + AttributeFormat("tree_left"), + AttributeFormat("dep_"), + AttributeFormat("text"), + AttributeFormat("lemma_", permitted_values=("fly", "to")), + AttributeFormat("pos_"), + AttributeFormat("tag_"), + AttributeFormat("morph"), + AttributeFormat("ent_type_"), + ] + assert ( + Visualizer().render_table(fully_featured_doc_one_sentence, formats).strip() + == """ + ╔>╔═ poss Sarah PROPN NNP NounType=prop|Number=sing PERSON + ║ ╚> case 's PART POS Poss=yes +╔>╚═══ nsubj sister NOUN NN Number=sing +╠═════ ROOT flew fly VERB VBD Tense=past|VerbForm=fin +╠>╔═══ prep to to ADP IN +║ ║ ╔> compound Silicon PROPN NNP NounType=prop|Number=sing GPE +║ ╚>╚═ pobj Valley PROPN NNP NounType=prop|Number=sing GPE +╠══>╔═ prep via ADP IN +║ ╚> pobj London PROPN NNP NounType=prop|Number=sing GPE +╚════> punct . PUNCT . PunctType=peri + """.strip() + ) + +def test_visualization_spacing( + fully_featured_doc_one_sentence, +): + formats = [ + AttributeFormat("tree_left"), + AttributeFormat("dep_"), + AttributeFormat("text"), + AttributeFormat("lemma_"), + AttributeFormat("pos_"), + AttributeFormat("tag_"), + AttributeFormat("morph"), + AttributeFormat("ent_type_"), + ] + assert ( + Visualizer() + .render_table(fully_featured_doc_one_sentence, formats, spacing=1) + .strip() + == """ + ╔>╔═ poss Sarah sarah PROPN NNP NounType=prop|Number=sing PERSON + ║ ╚> case 's 's PART POS Poss=yes +╔>╚═══ nsubj sister sister NOUN NN Number=sing +╠═════ ROOT flew fly VERB VBD Tense=past|VerbForm=fin +╠>╔═══ prep to to ADP IN +║ ║ ╔> compound Silicon silicon PROPN NNP NounType=prop|Number=sing GPE +║ ╚>╚═ pobj Valley valley PROPN NNP NounType=prop|Number=sing GPE +╠══>╔═ prep via via ADP IN +║ ╚> pobj London london PROPN NNP NounType=prop|Number=sing GPE +╚════> punct . . PUNCT . PunctType=peri + """.strip() + ) + + def test_visualization_minimal_render_table_two_sentences( fully_featured_doc_two_sentences, ): diff --git a/spacy/visualization.py b/spacy/visualization.py index e03e77bbd..f9f067e26 100644 --- a/spacy/visualization.py +++ b/spacy/visualization.py @@ -16,6 +16,7 @@ class AttributeFormat: max_width: int = None, fg_color: Union[str, int] = None, bg_color: Union[str, int] = None, + permitted_values: tuple = None, value_dependent_fg_colors: dict[str, Union[str, int]] = None, value_dependent_bg_colors: dict[str, Union[str, int]] = None, ): @@ -25,6 +26,7 @@ class AttributeFormat: self.max_width = max_width self.fg_color = fg_color self.bg_color = bg_color + self.permitted_values = permitted_values self.value_dependent_fg_colors = value_dependent_fg_colors self.value_dependent_bg_colors = value_dependent_bg_colors @@ -65,7 +67,6 @@ ROOT_LEFT_CHARS = { class Visualizer: - def __init__(self): self.printer = wasabi.Printer(no_print=True) @@ -300,19 +301,30 @@ class Visualizer: token: Token, entity_name: str, *, + permitted_values: list[str] = None, value_dependent_fg_colors: dict[str : Union[str, int]] = None, value_dependent_bg_colors: dict[str : Union[str, int]] = None, - truncate_at_width: int = None + truncate_at_width: int = None, ) -> str: obj = token parts = entity_name.split(".") for part in parts[:-1]: obj = getattr(obj, part) value = str(getattr(obj, parts[-1])) + if permitted_values is not None and value not in (str(v) for v in permitted_values): + return "" if truncate_at_width is not None: value = value[:truncate_at_width] - fg_color = value_dependent_fg_colors.get(value, None) if value_dependent_fg_colors is not None else None - bg_color = value_dependent_bg_colors.get(value, None) if value_dependent_bg_colors is not None else None + fg_color = ( + value_dependent_fg_colors.get(value, None) + if value_dependent_fg_colors is not None + else None + ) + bg_color = ( + value_dependent_bg_colors.get(value, None) + if value_dependent_bg_colors is not None + else None + ) if fg_color is not None or bg_color is not None: value = self.printer.text(value, color=fg_color, bg_color=bg_color) return value @@ -329,12 +341,19 @@ class Visualizer: widths = [] for column in columns: # get the values without any color codes - if column.attribute == 'tree_left': + if column.attribute == "tree_left": width = len(tree_left[0]) - elif column.attribute == 'tree_right': + elif column.attribute == "tree_right": width = len(tree_right[0]) else: - width = max(len(self.get_entity(token, column.attribute)) for token in sent) + width = max( + len( + self.get_entity( + token, column.attribute, permitted_values=column.permitted_values + ) + ) + for token in sent + ) if column.max_width is not None: width = min(width, column.max_width) width = max(width, len(column.name)) @@ -348,9 +367,10 @@ class Visualizer: else self.get_entity( token, column.attribute, + permitted_values=column.permitted_values, value_dependent_fg_colors=column.value_dependent_fg_colors, value_dependent_bg_colors=column.value_dependent_bg_colors, - truncate_at_width=widths[column_index] + truncate_at_width=widths[column_index], ) for column_index, column in enumerate(columns) ] @@ -372,6 +392,7 @@ class Visualizer: widths=widths, fg_colors=fg_colors, bg_colors=bg_colors, + spacing=spacing, ) + linesep )