Add permitted values

This commit is contained in:
Richard Hudson 2021-12-08 14:58:02 +01:00
parent 9f7f234b0f
commit 183d535ef4
2 changed files with 121 additions and 8 deletions

View File

@ -420,6 +420,37 @@ def test_visualization_get_entity_custom_attribute_missing(en_vocab):
Visualizer().get_entity(doc[2], "_.depp") Visualizer().get_entity(doc[2], "_.depp")
def test_visualization_get_entity_permitted_values(en_vocab):
doc = Doc(
en_vocab,
words=[
"I",
"saw",
"a",
"horse",
"yesterday",
"that",
"was",
"injured",
".",
],
heads=[1, None, 3, 1, 1, 7, 7, 3, 1],
deps=["dep"] * 9,
)
visualizer = Visualizer()
assert [visualizer.get_entity(token, "head.i", permitted_values=(3, 7)) for token in doc] == [
"",
"",
"3",
"",
"",
"7",
"7",
"3",
"",
]
def test_visualization_minimal_render_table_one_sentence( def test_visualization_minimal_render_table_one_sentence(
fully_featured_doc_one_sentence, fully_featured_doc_one_sentence,
): ):
@ -450,6 +481,67 @@ def test_visualization_minimal_render_table_one_sentence(
) )
def test_visualization_minimal_render_table_permitted_values(
fully_featured_doc_one_sentence,
):
formats = [
AttributeFormat("tree_left"),
AttributeFormat("dep_"),
AttributeFormat("text"),
AttributeFormat("lemma_", permitted_values=("fly", "to")),
AttributeFormat("pos_"),
AttributeFormat("tag_"),
AttributeFormat("morph"),
AttributeFormat("ent_type_"),
]
assert (
Visualizer().render_table(fully_featured_doc_one_sentence, formats).strip()
== """
> poss Sarah PROPN NNP NounType=prop|Number=sing PERSON
> case 's PART POS Poss=yes
> nsubj sister NOUN NN Number=sing
ROOT flew fly VERB VBD Tense=past|VerbForm=fin
> prep to to ADP IN
> compound Silicon PROPN NNP NounType=prop|Number=sing GPE
> pobj Valley PROPN NNP NounType=prop|Number=sing GPE
> prep via ADP IN
> pobj London PROPN NNP NounType=prop|Number=sing GPE
> punct . PUNCT . PunctType=peri
""".strip()
)
def test_visualization_spacing(
fully_featured_doc_one_sentence,
):
formats = [
AttributeFormat("tree_left"),
AttributeFormat("dep_"),
AttributeFormat("text"),
AttributeFormat("lemma_"),
AttributeFormat("pos_"),
AttributeFormat("tag_"),
AttributeFormat("morph"),
AttributeFormat("ent_type_"),
]
assert (
Visualizer()
.render_table(fully_featured_doc_one_sentence, formats, spacing=1)
.strip()
== """
> poss Sarah sarah PROPN NNP NounType=prop|Number=sing PERSON
> case 's 's PART POS Poss=yes
> nsubj sister sister NOUN NN Number=sing
ROOT flew fly VERB VBD Tense=past|VerbForm=fin
> prep to to ADP IN
> compound Silicon silicon PROPN NNP NounType=prop|Number=sing GPE
> pobj Valley valley PROPN NNP NounType=prop|Number=sing GPE
> prep via via ADP IN
> pobj London london PROPN NNP NounType=prop|Number=sing GPE
> punct . . PUNCT . PunctType=peri
""".strip()
)
def test_visualization_minimal_render_table_two_sentences( def test_visualization_minimal_render_table_two_sentences(
fully_featured_doc_two_sentences, fully_featured_doc_two_sentences,
): ):

View File

@ -16,6 +16,7 @@ class AttributeFormat:
max_width: int = None, max_width: int = None,
fg_color: Union[str, int] = None, fg_color: Union[str, int] = None,
bg_color: Union[str, int] = None, bg_color: Union[str, int] = None,
permitted_values: tuple = None,
value_dependent_fg_colors: dict[str, Union[str, int]] = None, value_dependent_fg_colors: dict[str, Union[str, int]] = None,
value_dependent_bg_colors: dict[str, Union[str, int]] = None, value_dependent_bg_colors: dict[str, Union[str, int]] = None,
): ):
@ -25,6 +26,7 @@ class AttributeFormat:
self.max_width = max_width self.max_width = max_width
self.fg_color = fg_color self.fg_color = fg_color
self.bg_color = bg_color self.bg_color = bg_color
self.permitted_values = permitted_values
self.value_dependent_fg_colors = value_dependent_fg_colors self.value_dependent_fg_colors = value_dependent_fg_colors
self.value_dependent_bg_colors = value_dependent_bg_colors self.value_dependent_bg_colors = value_dependent_bg_colors
@ -65,7 +67,6 @@ ROOT_LEFT_CHARS = {
class Visualizer: class Visualizer:
def __init__(self): def __init__(self):
self.printer = wasabi.Printer(no_print=True) self.printer = wasabi.Printer(no_print=True)
@ -300,19 +301,30 @@ class Visualizer:
token: Token, token: Token,
entity_name: str, entity_name: str,
*, *,
permitted_values: list[str] = None,
value_dependent_fg_colors: dict[str : Union[str, int]] = None, value_dependent_fg_colors: dict[str : Union[str, int]] = None,
value_dependent_bg_colors: dict[str : Union[str, int]] = None, value_dependent_bg_colors: dict[str : Union[str, int]] = None,
truncate_at_width: int = None truncate_at_width: int = None,
) -> str: ) -> str:
obj = token obj = token
parts = entity_name.split(".") parts = entity_name.split(".")
for part in parts[:-1]: for part in parts[:-1]:
obj = getattr(obj, part) obj = getattr(obj, part)
value = str(getattr(obj, parts[-1])) value = str(getattr(obj, parts[-1]))
if permitted_values is not None and value not in (str(v) for v in permitted_values):
return ""
if truncate_at_width is not None: if truncate_at_width is not None:
value = value[:truncate_at_width] value = value[:truncate_at_width]
fg_color = value_dependent_fg_colors.get(value, None) if value_dependent_fg_colors is not None else None fg_color = (
bg_color = value_dependent_bg_colors.get(value, None) if value_dependent_bg_colors is not None else None value_dependent_fg_colors.get(value, None)
if value_dependent_fg_colors is not None
else None
)
bg_color = (
value_dependent_bg_colors.get(value, None)
if value_dependent_bg_colors is not None
else None
)
if fg_color is not None or bg_color is not None: if fg_color is not None or bg_color is not None:
value = self.printer.text(value, color=fg_color, bg_color=bg_color) value = self.printer.text(value, color=fg_color, bg_color=bg_color)
return value return value
@ -329,12 +341,19 @@ class Visualizer:
widths = [] widths = []
for column in columns: for column in columns:
# get the values without any color codes # get the values without any color codes
if column.attribute == 'tree_left': if column.attribute == "tree_left":
width = len(tree_left[0]) width = len(tree_left[0])
elif column.attribute == 'tree_right': elif column.attribute == "tree_right":
width = len(tree_right[0]) width = len(tree_right[0])
else: else:
width = max(len(self.get_entity(token, column.attribute)) for token in sent) width = max(
len(
self.get_entity(
token, column.attribute, permitted_values=column.permitted_values
)
)
for token in sent
)
if column.max_width is not None: if column.max_width is not None:
width = min(width, column.max_width) width = min(width, column.max_width)
width = max(width, len(column.name)) width = max(width, len(column.name))
@ -348,9 +367,10 @@ class Visualizer:
else self.get_entity( else self.get_entity(
token, token,
column.attribute, column.attribute,
permitted_values=column.permitted_values,
value_dependent_fg_colors=column.value_dependent_fg_colors, value_dependent_fg_colors=column.value_dependent_fg_colors,
value_dependent_bg_colors=column.value_dependent_bg_colors, value_dependent_bg_colors=column.value_dependent_bg_colors,
truncate_at_width=widths[column_index] truncate_at_width=widths[column_index],
) )
for column_index, column in enumerate(columns) for column_index, column in enumerate(columns)
] ]
@ -372,6 +392,7 @@ class Visualizer:
widths=widths, widths=widths,
fg_colors=fg_colors, fg_colors=fg_colors,
bg_colors=bg_colors, bg_colors=bg_colors,
spacing=spacing,
) )
+ linesep + linesep
) )