mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 12:20:20 +03:00
Add permitted values
This commit is contained in:
parent
9f7f234b0f
commit
183d535ef4
|
@ -420,6 +420,37 @@ def test_visualization_get_entity_custom_attribute_missing(en_vocab):
|
||||||
Visualizer().get_entity(doc[2], "_.depp")
|
Visualizer().get_entity(doc[2], "_.depp")
|
||||||
|
|
||||||
|
|
||||||
|
def test_visualization_get_entity_permitted_values(en_vocab):
|
||||||
|
doc = Doc(
|
||||||
|
en_vocab,
|
||||||
|
words=[
|
||||||
|
"I",
|
||||||
|
"saw",
|
||||||
|
"a",
|
||||||
|
"horse",
|
||||||
|
"yesterday",
|
||||||
|
"that",
|
||||||
|
"was",
|
||||||
|
"injured",
|
||||||
|
".",
|
||||||
|
],
|
||||||
|
heads=[1, None, 3, 1, 1, 7, 7, 3, 1],
|
||||||
|
deps=["dep"] * 9,
|
||||||
|
)
|
||||||
|
visualizer = Visualizer()
|
||||||
|
assert [visualizer.get_entity(token, "head.i", permitted_values=(3, 7)) for token in doc] == [
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"3",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"7",
|
||||||
|
"7",
|
||||||
|
"3",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_visualization_minimal_render_table_one_sentence(
|
def test_visualization_minimal_render_table_one_sentence(
|
||||||
fully_featured_doc_one_sentence,
|
fully_featured_doc_one_sentence,
|
||||||
):
|
):
|
||||||
|
@ -450,6 +481,67 @@ def test_visualization_minimal_render_table_one_sentence(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_visualization_minimal_render_table_permitted_values(
|
||||||
|
fully_featured_doc_one_sentence,
|
||||||
|
):
|
||||||
|
formats = [
|
||||||
|
AttributeFormat("tree_left"),
|
||||||
|
AttributeFormat("dep_"),
|
||||||
|
AttributeFormat("text"),
|
||||||
|
AttributeFormat("lemma_", permitted_values=("fly", "to")),
|
||||||
|
AttributeFormat("pos_"),
|
||||||
|
AttributeFormat("tag_"),
|
||||||
|
AttributeFormat("morph"),
|
||||||
|
AttributeFormat("ent_type_"),
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
Visualizer().render_table(fully_featured_doc_one_sentence, formats).strip()
|
||||||
|
== """
|
||||||
|
╔>╔═ poss Sarah PROPN NNP NounType=prop|Number=sing PERSON
|
||||||
|
║ ╚> case 's PART POS Poss=yes
|
||||||
|
╔>╚═══ nsubj sister NOUN NN Number=sing
|
||||||
|
╠═════ ROOT flew fly VERB VBD Tense=past|VerbForm=fin
|
||||||
|
╠>╔═══ prep to to ADP IN
|
||||||
|
║ ║ ╔> compound Silicon PROPN NNP NounType=prop|Number=sing GPE
|
||||||
|
║ ╚>╚═ pobj Valley PROPN NNP NounType=prop|Number=sing GPE
|
||||||
|
╠══>╔═ prep via ADP IN
|
||||||
|
║ ╚> pobj London PROPN NNP NounType=prop|Number=sing GPE
|
||||||
|
╚════> punct . PUNCT . PunctType=peri
|
||||||
|
""".strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_visualization_spacing(
|
||||||
|
fully_featured_doc_one_sentence,
|
||||||
|
):
|
||||||
|
formats = [
|
||||||
|
AttributeFormat("tree_left"),
|
||||||
|
AttributeFormat("dep_"),
|
||||||
|
AttributeFormat("text"),
|
||||||
|
AttributeFormat("lemma_"),
|
||||||
|
AttributeFormat("pos_"),
|
||||||
|
AttributeFormat("tag_"),
|
||||||
|
AttributeFormat("morph"),
|
||||||
|
AttributeFormat("ent_type_"),
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
Visualizer()
|
||||||
|
.render_table(fully_featured_doc_one_sentence, formats, spacing=1)
|
||||||
|
.strip()
|
||||||
|
== """
|
||||||
|
╔>╔═ poss Sarah sarah PROPN NNP NounType=prop|Number=sing PERSON
|
||||||
|
║ ╚> case 's 's PART POS Poss=yes
|
||||||
|
╔>╚═══ nsubj sister sister NOUN NN Number=sing
|
||||||
|
╠═════ ROOT flew fly VERB VBD Tense=past|VerbForm=fin
|
||||||
|
╠>╔═══ prep to to ADP IN
|
||||||
|
║ ║ ╔> compound Silicon silicon PROPN NNP NounType=prop|Number=sing GPE
|
||||||
|
║ ╚>╚═ pobj Valley valley PROPN NNP NounType=prop|Number=sing GPE
|
||||||
|
╠══>╔═ prep via via ADP IN
|
||||||
|
║ ╚> pobj London london PROPN NNP NounType=prop|Number=sing GPE
|
||||||
|
╚════> punct . . PUNCT . PunctType=peri
|
||||||
|
""".strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_visualization_minimal_render_table_two_sentences(
|
def test_visualization_minimal_render_table_two_sentences(
|
||||||
fully_featured_doc_two_sentences,
|
fully_featured_doc_two_sentences,
|
||||||
):
|
):
|
||||||
|
|
|
@ -16,6 +16,7 @@ class AttributeFormat:
|
||||||
max_width: int = None,
|
max_width: int = None,
|
||||||
fg_color: Union[str, int] = None,
|
fg_color: Union[str, int] = None,
|
||||||
bg_color: Union[str, int] = None,
|
bg_color: Union[str, int] = None,
|
||||||
|
permitted_values: tuple = None,
|
||||||
value_dependent_fg_colors: dict[str, Union[str, int]] = None,
|
value_dependent_fg_colors: dict[str, Union[str, int]] = None,
|
||||||
value_dependent_bg_colors: dict[str, Union[str, int]] = None,
|
value_dependent_bg_colors: dict[str, Union[str, int]] = None,
|
||||||
):
|
):
|
||||||
|
@ -25,6 +26,7 @@ class AttributeFormat:
|
||||||
self.max_width = max_width
|
self.max_width = max_width
|
||||||
self.fg_color = fg_color
|
self.fg_color = fg_color
|
||||||
self.bg_color = bg_color
|
self.bg_color = bg_color
|
||||||
|
self.permitted_values = permitted_values
|
||||||
self.value_dependent_fg_colors = value_dependent_fg_colors
|
self.value_dependent_fg_colors = value_dependent_fg_colors
|
||||||
self.value_dependent_bg_colors = value_dependent_bg_colors
|
self.value_dependent_bg_colors = value_dependent_bg_colors
|
||||||
|
|
||||||
|
@ -65,7 +67,6 @@ ROOT_LEFT_CHARS = {
|
||||||
|
|
||||||
|
|
||||||
class Visualizer:
|
class Visualizer:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.printer = wasabi.Printer(no_print=True)
|
self.printer = wasabi.Printer(no_print=True)
|
||||||
|
|
||||||
|
@ -300,19 +301,30 @@ class Visualizer:
|
||||||
token: Token,
|
token: Token,
|
||||||
entity_name: str,
|
entity_name: str,
|
||||||
*,
|
*,
|
||||||
|
permitted_values: list[str] = None,
|
||||||
value_dependent_fg_colors: dict[str : Union[str, int]] = None,
|
value_dependent_fg_colors: dict[str : Union[str, int]] = None,
|
||||||
value_dependent_bg_colors: dict[str : Union[str, int]] = None,
|
value_dependent_bg_colors: dict[str : Union[str, int]] = None,
|
||||||
truncate_at_width: int = None
|
truncate_at_width: int = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
obj = token
|
obj = token
|
||||||
parts = entity_name.split(".")
|
parts = entity_name.split(".")
|
||||||
for part in parts[:-1]:
|
for part in parts[:-1]:
|
||||||
obj = getattr(obj, part)
|
obj = getattr(obj, part)
|
||||||
value = str(getattr(obj, parts[-1]))
|
value = str(getattr(obj, parts[-1]))
|
||||||
|
if permitted_values is not None and value not in (str(v) for v in permitted_values):
|
||||||
|
return ""
|
||||||
if truncate_at_width is not None:
|
if truncate_at_width is not None:
|
||||||
value = value[:truncate_at_width]
|
value = value[:truncate_at_width]
|
||||||
fg_color = value_dependent_fg_colors.get(value, None) if value_dependent_fg_colors is not None else None
|
fg_color = (
|
||||||
bg_color = value_dependent_bg_colors.get(value, None) if value_dependent_bg_colors is not None else None
|
value_dependent_fg_colors.get(value, None)
|
||||||
|
if value_dependent_fg_colors is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
bg_color = (
|
||||||
|
value_dependent_bg_colors.get(value, None)
|
||||||
|
if value_dependent_bg_colors is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
if fg_color is not None or bg_color is not None:
|
if fg_color is not None or bg_color is not None:
|
||||||
value = self.printer.text(value, color=fg_color, bg_color=bg_color)
|
value = self.printer.text(value, color=fg_color, bg_color=bg_color)
|
||||||
return value
|
return value
|
||||||
|
@ -329,12 +341,19 @@ class Visualizer:
|
||||||
widths = []
|
widths = []
|
||||||
for column in columns:
|
for column in columns:
|
||||||
# get the values without any color codes
|
# get the values without any color codes
|
||||||
if column.attribute == 'tree_left':
|
if column.attribute == "tree_left":
|
||||||
width = len(tree_left[0])
|
width = len(tree_left[0])
|
||||||
elif column.attribute == 'tree_right':
|
elif column.attribute == "tree_right":
|
||||||
width = len(tree_right[0])
|
width = len(tree_right[0])
|
||||||
else:
|
else:
|
||||||
width = max(len(self.get_entity(token, column.attribute)) for token in sent)
|
width = max(
|
||||||
|
len(
|
||||||
|
self.get_entity(
|
||||||
|
token, column.attribute, permitted_values=column.permitted_values
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for token in sent
|
||||||
|
)
|
||||||
if column.max_width is not None:
|
if column.max_width is not None:
|
||||||
width = min(width, column.max_width)
|
width = min(width, column.max_width)
|
||||||
width = max(width, len(column.name))
|
width = max(width, len(column.name))
|
||||||
|
@ -348,9 +367,10 @@ class Visualizer:
|
||||||
else self.get_entity(
|
else self.get_entity(
|
||||||
token,
|
token,
|
||||||
column.attribute,
|
column.attribute,
|
||||||
|
permitted_values=column.permitted_values,
|
||||||
value_dependent_fg_colors=column.value_dependent_fg_colors,
|
value_dependent_fg_colors=column.value_dependent_fg_colors,
|
||||||
value_dependent_bg_colors=column.value_dependent_bg_colors,
|
value_dependent_bg_colors=column.value_dependent_bg_colors,
|
||||||
truncate_at_width=widths[column_index]
|
truncate_at_width=widths[column_index],
|
||||||
)
|
)
|
||||||
for column_index, column in enumerate(columns)
|
for column_index, column in enumerate(columns)
|
||||||
]
|
]
|
||||||
|
@ -372,6 +392,7 @@ class Visualizer:
|
||||||
widths=widths,
|
widths=widths,
|
||||||
fg_colors=fg_colors,
|
fg_colors=fg_colors,
|
||||||
bg_colors=bg_colors,
|
bg_colors=bg_colors,
|
||||||
|
spacing=spacing,
|
||||||
)
|
)
|
||||||
+ linesep
|
+ linesep
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user