mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 04:10:20 +03:00
Refactoring
This commit is contained in:
parent
183d535ef4
commit
49f3fd39b9
|
@ -207,7 +207,7 @@ def test_visualization_dependency_tree_highly_nonprojective(en_vocab):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_visualization_get_entity_native_attribute_int(en_vocab):
|
def test_visualization_render_native_attribute_int(en_vocab):
|
||||||
doc = Doc(
|
doc = Doc(
|
||||||
en_vocab,
|
en_vocab,
|
||||||
words=[
|
words=[
|
||||||
|
@ -224,10 +224,10 @@ def test_visualization_get_entity_native_attribute_int(en_vocab):
|
||||||
heads=[1, None, 3, 1, 1, 7, 7, 3, 1],
|
heads=[1, None, 3, 1, 1, 7, 7, 3, 1],
|
||||||
deps=["dep"] * 9,
|
deps=["dep"] * 9,
|
||||||
)
|
)
|
||||||
assert Visualizer().get_entity(doc[2], "head.i") == "3"
|
assert AttributeFormat("head.i").render(doc[2]) == "3"
|
||||||
|
|
||||||
|
|
||||||
def test_visualization_get_entity_native_attribute_str(en_vocab):
|
def test_visualization_render_native_attribute_str(en_vocab):
|
||||||
doc = Doc(
|
doc = Doc(
|
||||||
en_vocab,
|
en_vocab,
|
||||||
words=[
|
words=[
|
||||||
|
@ -245,10 +245,10 @@ def test_visualization_get_entity_native_attribute_str(en_vocab):
|
||||||
deps=["dep"] * 9,
|
deps=["dep"] * 9,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert Visualizer().get_entity(doc[2], "dep_") == "dep"
|
assert AttributeFormat("dep_").render(doc[2]) == "dep"
|
||||||
|
|
||||||
|
|
||||||
def test_visualization_get_entity_colors(en_vocab):
|
def test_visualization_render_colors(en_vocab):
|
||||||
doc = Doc(
|
doc = Doc(
|
||||||
en_vocab,
|
en_vocab,
|
||||||
words=[
|
words=[
|
||||||
|
@ -267,19 +267,18 @@ def test_visualization_get_entity_colors(en_vocab):
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
Visualizer().get_entity(
|
AttributeFormat(
|
||||||
doc[2],
|
|
||||||
"dep_",
|
"dep_",
|
||||||
value_dependent_fg_colors={"dep": 2},
|
value_dependent_fg_colors={"dep": 2},
|
||||||
value_dependent_bg_colors={"dep": 11},
|
value_dependent_bg_colors={"dep": 11},
|
||||||
)
|
).render(doc[2])
|
||||||
== "\x1b[38;5;2;48;5;11mdep\x1b[0m"
|
== "\x1b[38;5;2;48;5;11mdep\x1b[0m"
|
||||||
if supports_ansi
|
if supports_ansi
|
||||||
else "dep"
|
else "dep"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_visualization_get_entity_colors_only_fg(en_vocab):
|
def test_visualization_render_colors_only_fg(en_vocab):
|
||||||
doc = Doc(
|
doc = Doc(
|
||||||
en_vocab,
|
en_vocab,
|
||||||
words=[
|
words=[
|
||||||
|
@ -298,14 +297,17 @@ def test_visualization_get_entity_colors_only_fg(en_vocab):
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
Visualizer().get_entity(doc[2], "dep_", value_dependent_fg_colors={"dep": 2})
|
AttributeFormat(
|
||||||
|
"dep_",
|
||||||
|
value_dependent_fg_colors={"dep": 2},
|
||||||
|
).render(doc[2])
|
||||||
== "\x1b[38;5;2mdep\x1b[0m"
|
== "\x1b[38;5;2mdep\x1b[0m"
|
||||||
if supports_ansi
|
if supports_ansi
|
||||||
else "dep"
|
else "dep"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_visualization_get_entity_colors_only_bg(en_vocab):
|
def test_visualization_render_entity_colors_only_bg(en_vocab):
|
||||||
doc = Doc(
|
doc = Doc(
|
||||||
en_vocab,
|
en_vocab,
|
||||||
words=[
|
words=[
|
||||||
|
@ -324,14 +326,17 @@ def test_visualization_get_entity_colors_only_bg(en_vocab):
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
Visualizer().get_entity(doc[2], "dep_", value_dependent_bg_colors={"dep": 11})
|
AttributeFormat(
|
||||||
|
"dep_",
|
||||||
|
value_dependent_bg_colors={"dep": 11},
|
||||||
|
).render(doc[2])
|
||||||
== "\x1b[48;5;11mdep\x1b[0m"
|
== "\x1b[48;5;11mdep\x1b[0m"
|
||||||
if supports_ansi
|
if supports_ansi
|
||||||
else "dep"
|
else "dep"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_visualization_get_entity_native_attribute_missing(en_vocab):
|
def test_visualization_render_native_attribute_missing(en_vocab):
|
||||||
doc = Doc(
|
doc = Doc(
|
||||||
en_vocab,
|
en_vocab,
|
||||||
words=[
|
words=[
|
||||||
|
@ -349,10 +354,10 @@ def test_visualization_get_entity_native_attribute_missing(en_vocab):
|
||||||
deps=["dep"] * 9,
|
deps=["dep"] * 9,
|
||||||
)
|
)
|
||||||
with pytest.raises(AttributeError):
|
with pytest.raises(AttributeError):
|
||||||
Visualizer().get_entity(doc[2], "depp")
|
AttributeFormat("depp").render(doc[2])
|
||||||
|
|
||||||
|
|
||||||
def test_visualization_get_entity_custom_attribute_str(en_vocab):
|
def test_visualization_render_custom_attribute_str(en_vocab):
|
||||||
doc = Doc(
|
doc = Doc(
|
||||||
en_vocab,
|
en_vocab,
|
||||||
words=[
|
words=[
|
||||||
|
@ -370,10 +375,10 @@ def test_visualization_get_entity_custom_attribute_str(en_vocab):
|
||||||
deps=["dep"] * 9,
|
deps=["dep"] * 9,
|
||||||
)
|
)
|
||||||
Token.set_extension("test", default="tested", force=True)
|
Token.set_extension("test", default="tested", force=True)
|
||||||
assert Visualizer().get_entity(doc[2], "_.test") == "tested"
|
assert AttributeFormat("_.test").render(doc[2]) == "tested"
|
||||||
|
|
||||||
|
|
||||||
def test_visualization_get_entity_nested_custom_attribute_str(en_vocab):
|
def test_visualization_render_nested_custom_attribute_str(en_vocab):
|
||||||
doc = Doc(
|
doc = Doc(
|
||||||
en_vocab,
|
en_vocab,
|
||||||
words=[
|
words=[
|
||||||
|
@ -396,10 +401,10 @@ def test_visualization_get_entity_nested_custom_attribute_str(en_vocab):
|
||||||
self.inner_test = "tested"
|
self.inner_test = "tested"
|
||||||
|
|
||||||
Token.set_extension("test", default=Test(), force=True)
|
Token.set_extension("test", default=Test(), force=True)
|
||||||
assert Visualizer().get_entity(doc[2], "_.test.inner_test") == "tested"
|
assert AttributeFormat("_.test.inner_test").render(doc[2]) == "tested"
|
||||||
|
|
||||||
|
|
||||||
def test_visualization_get_entity_custom_attribute_missing(en_vocab):
|
def test_visualization_render_custom_attribute_missing(en_vocab):
|
||||||
doc = Doc(
|
doc = Doc(
|
||||||
en_vocab,
|
en_vocab,
|
||||||
words=[
|
words=[
|
||||||
|
@ -417,10 +422,10 @@ def test_visualization_get_entity_custom_attribute_missing(en_vocab):
|
||||||
deps=["dep"] * 9,
|
deps=["dep"] * 9,
|
||||||
)
|
)
|
||||||
with pytest.raises(AttributeError):
|
with pytest.raises(AttributeError):
|
||||||
Visualizer().get_entity(doc[2], "_.depp")
|
AttributeFormat("._depp").render(doc[2])
|
||||||
|
|
||||||
|
|
||||||
def test_visualization_get_entity_permitted_values(en_vocab):
|
def test_visualization_render_permitted_values(en_vocab):
|
||||||
doc = Doc(
|
doc = Doc(
|
||||||
en_vocab,
|
en_vocab,
|
||||||
words=[
|
words=[
|
||||||
|
@ -437,8 +442,8 @@ def test_visualization_get_entity_permitted_values(en_vocab):
|
||||||
heads=[1, None, 3, 1, 1, 7, 7, 3, 1],
|
heads=[1, None, 3, 1, 1, 7, 7, 3, 1],
|
||||||
deps=["dep"] * 9,
|
deps=["dep"] * 9,
|
||||||
)
|
)
|
||||||
visualizer = Visualizer()
|
attribute_format = AttributeFormat("head.i", permitted_values=(3, 7))
|
||||||
assert [visualizer.get_entity(token, "head.i", permitted_values=(3, 7)) for token in doc] == [
|
assert [attribute_format.render(token) for token in doc] == [
|
||||||
"",
|
"",
|
||||||
"",
|
"",
|
||||||
"3",
|
"3",
|
||||||
|
@ -510,6 +515,7 @@ def test_visualization_minimal_render_table_permitted_values(
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_visualization_spacing(
|
def test_visualization_spacing(
|
||||||
fully_featured_doc_one_sentence,
|
fully_featured_doc_one_sentence,
|
||||||
):
|
):
|
||||||
|
|
|
@ -29,6 +29,38 @@ class AttributeFormat:
|
||||||
self.permitted_values = permitted_values
|
self.permitted_values = permitted_values
|
||||||
self.value_dependent_fg_colors = value_dependent_fg_colors
|
self.value_dependent_fg_colors = value_dependent_fg_colors
|
||||||
self.value_dependent_bg_colors = value_dependent_bg_colors
|
self.value_dependent_bg_colors = value_dependent_bg_colors
|
||||||
|
self.printer = wasabi.Printer(no_print=True)
|
||||||
|
|
||||||
|
def render(
|
||||||
|
self,
|
||||||
|
token: Token,
|
||||||
|
*,
|
||||||
|
ignore_colors: bool = False,
|
||||||
|
) -> str:
|
||||||
|
obj = token
|
||||||
|
parts = self.attribute.split(".")
|
||||||
|
for part in parts[:-1]:
|
||||||
|
obj = getattr(obj, part)
|
||||||
|
value = str(getattr(obj, parts[-1]))
|
||||||
|
if self.permitted_values is not None and value not in (
|
||||||
|
str(v) for v in self.permitted_values
|
||||||
|
):
|
||||||
|
return ""
|
||||||
|
if self.max_width is not None:
|
||||||
|
value = value[: self.max_width]
|
||||||
|
fg_color = (
|
||||||
|
self.value_dependent_fg_colors.get(value, None)
|
||||||
|
if not ignore_colors and self.value_dependent_fg_colors is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
bg_color = (
|
||||||
|
self.value_dependent_bg_colors.get(value, None)
|
||||||
|
if not ignore_colors and self.value_dependent_bg_colors is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if fg_color is not None or bg_color is not None:
|
||||||
|
value = self.printer.text(value, color=fg_color, bg_color=bg_color)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
SPACE = 0
|
SPACE = 0
|
||||||
|
@ -67,9 +99,6 @@ ROOT_LEFT_CHARS = {
|
||||||
|
|
||||||
|
|
||||||
class Visualizer:
|
class Visualizer:
|
||||||
def __init__(self):
|
|
||||||
self.printer = wasabi.Printer(no_print=True)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def render_dependency_tree(sent: Span, root_right: bool) -> list[str]:
|
def render_dependency_tree(sent: Span, root_right: bool) -> list[str]:
|
||||||
"""
|
"""
|
||||||
|
@ -296,39 +325,6 @@ class Visualizer:
|
||||||
for vertical_position in range(sent.end - sent.start)
|
for vertical_position in range(sent.end - sent.start)
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_entity(
|
|
||||||
self,
|
|
||||||
token: Token,
|
|
||||||
entity_name: str,
|
|
||||||
*,
|
|
||||||
permitted_values: list[str] = None,
|
|
||||||
value_dependent_fg_colors: dict[str : Union[str, int]] = None,
|
|
||||||
value_dependent_bg_colors: dict[str : Union[str, int]] = None,
|
|
||||||
truncate_at_width: int = None,
|
|
||||||
) -> str:
|
|
||||||
obj = token
|
|
||||||
parts = entity_name.split(".")
|
|
||||||
for part in parts[:-1]:
|
|
||||||
obj = getattr(obj, part)
|
|
||||||
value = str(getattr(obj, parts[-1]))
|
|
||||||
if permitted_values is not None and value not in (str(v) for v in permitted_values):
|
|
||||||
return ""
|
|
||||||
if truncate_at_width is not None:
|
|
||||||
value = value[:truncate_at_width]
|
|
||||||
fg_color = (
|
|
||||||
value_dependent_fg_colors.get(value, None)
|
|
||||||
if value_dependent_fg_colors is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
bg_color = (
|
|
||||||
value_dependent_bg_colors.get(value, None)
|
|
||||||
if value_dependent_bg_colors is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
if fg_color is not None or bg_color is not None:
|
|
||||||
value = self.printer.text(value, color=fg_color, bg_color=bg_color)
|
|
||||||
return value
|
|
||||||
|
|
||||||
def render_table(
|
def render_table(
|
||||||
self, doc: Doc, columns: list[AttributeFormat], spacing: int = 3
|
self, doc: Doc, columns: list[AttributeFormat], spacing: int = 3
|
||||||
) -> str:
|
) -> str:
|
||||||
|
@ -347,12 +343,7 @@ class Visualizer:
|
||||||
width = len(tree_right[0])
|
width = len(tree_right[0])
|
||||||
else:
|
else:
|
||||||
width = max(
|
width = max(
|
||||||
len(
|
len(column.render(token, ignore_colors=True)) for token in sent
|
||||||
self.get_entity(
|
|
||||||
token, column.attribute, permitted_values=column.permitted_values
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for token in sent
|
|
||||||
)
|
)
|
||||||
if column.max_width is not None:
|
if column.max_width is not None:
|
||||||
width = min(width, column.max_width)
|
width = min(width, column.max_width)
|
||||||
|
@ -364,14 +355,7 @@ class Visualizer:
|
||||||
if column.attribute == "tree_right"
|
if column.attribute == "tree_right"
|
||||||
else tree_left[token_index]
|
else tree_left[token_index]
|
||||||
if column.attribute == "tree_left"
|
if column.attribute == "tree_left"
|
||||||
else self.get_entity(
|
else column.render(token)
|
||||||
token,
|
|
||||||
column.attribute,
|
|
||||||
permitted_values=column.permitted_values,
|
|
||||||
value_dependent_fg_colors=column.value_dependent_fg_colors,
|
|
||||||
value_dependent_bg_colors=column.value_dependent_bg_colors,
|
|
||||||
truncate_at_width=widths[column_index],
|
|
||||||
)
|
|
||||||
for column_index, column in enumerate(columns)
|
for column_index, column in enumerate(columns)
|
||||||
]
|
]
|
||||||
for token_index, token in enumerate(sent)
|
for token_index, token in enumerate(sent)
|
||||||
|
@ -397,3 +381,13 @@ class Visualizer:
|
||||||
+ linesep
|
+ linesep
|
||||||
)
|
)
|
||||||
return return_string
|
return return_string
|
||||||
|
|
||||||
|
def render_text(self, doc: Doc, attributes: list[AttributeFormat]) -> str:
|
||||||
|
return_string = ""
|
||||||
|
for token in doc:
|
||||||
|
return_string += token.text_with_ws
|
||||||
|
for attribute in attributes:
|
||||||
|
if self.get_entity(
|
||||||
|
token,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
Loading…
Reference in New Issue
Block a user