mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 12:20:20 +03:00
render_dependency_trees complete
This commit is contained in:
parent
a660a7d347
commit
b4265eccf9
|
@ -1,51 +1,53 @@
|
||||||
|
from spacy.tests.lang.ko.test_tokenizer import FULL_TAG_TESTS
|
||||||
from spacy.tokens import Doc, Token
|
from spacy.tokens import Doc, Token
|
||||||
from spacy.util import working_dir
|
from spacy.util import working_dir
|
||||||
|
|
||||||
HORIZONTAL_LINE_PART_AWAY_FROM_ROOT = 1
|
SPACE = 0
|
||||||
HORIZONTAL_LINE_PART_TOWARDS_ROOT = 2
|
HALF_HORIZONTAL_LINE = 1 # the half is the half further away from the root
|
||||||
UPPER_VERTICAL_LINE_PART = 4
|
FULL_HORIZONTAL_LINE = 3
|
||||||
LOWER_VERTICAL_LINE_PART = 8
|
UPPER_HALF_VERTICAL_LINE = 4
|
||||||
|
LOWER_HALF_VERTICAL_LINE = 8
|
||||||
|
FULL_VERTICAL_LINE = 12
|
||||||
ARROWHEAD = 16
|
ARROWHEAD = 16
|
||||||
|
|
||||||
SPACE = 0
|
|
||||||
HORIZONTAL_LINE = 3
|
|
||||||
VERTICAL_LINE = 12
|
|
||||||
|
|
||||||
ROOT_RIGHT_CHARS = {
|
ROOT_RIGHT_CHARS = {
|
||||||
0: " ",
|
SPACE: " ",
|
||||||
3: "═",
|
FULL_HORIZONTAL_LINE: "═",
|
||||||
5: "╝",
|
UPPER_HALF_VERTICAL_LINE + HALF_HORIZONTAL_LINE: "╝",
|
||||||
7: "╩",
|
UPPER_HALF_VERTICAL_LINE + FULL_HORIZONTAL_LINE: "╩",
|
||||||
9: "╗",
|
LOWER_HALF_VERTICAL_LINE + HALF_HORIZONTAL_LINE: "╗",
|
||||||
11: "╦",
|
LOWER_HALF_VERTICAL_LINE + FULL_HORIZONTAL_LINE: "╦",
|
||||||
12: "║",
|
FULL_VERTICAL_LINE: "║",
|
||||||
13: "╣",
|
FULL_VERTICAL_LINE + HALF_HORIZONTAL_LINE: "╣",
|
||||||
15: "╬",
|
FULL_VERTICAL_LINE + FULL_HORIZONTAL_LINE: "╬",
|
||||||
16: "<",
|
ARROWHEAD: "<",
|
||||||
}
|
}
|
||||||
|
|
||||||
ROOT_LEFT_CHARS = {
|
ROOT_LEFT_CHARS = {
|
||||||
0: " ",
|
SPACE: " ",
|
||||||
3: "═",
|
FULL_HORIZONTAL_LINE: "═",
|
||||||
5: "╚",
|
UPPER_HALF_VERTICAL_LINE + HALF_HORIZONTAL_LINE: "╚",
|
||||||
7: "╩",
|
UPPER_HALF_VERTICAL_LINE + FULL_HORIZONTAL_LINE: "╩",
|
||||||
9: "╔",
|
LOWER_HALF_VERTICAL_LINE + HALF_HORIZONTAL_LINE: "╔",
|
||||||
11: "╦",
|
LOWER_HALF_VERTICAL_LINE + FULL_HORIZONTAL_LINE: "╦",
|
||||||
12: "║",
|
FULL_VERTICAL_LINE: "║",
|
||||||
13: "╠",
|
FULL_VERTICAL_LINE + HALF_HORIZONTAL_LINE: "╠",
|
||||||
15: "╬",
|
FULL_VERTICAL_LINE + FULL_HORIZONTAL_LINE: "╬",
|
||||||
16: ">",
|
ARROWHEAD: ">",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class Visualizer:
|
class Visualizer:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def render_dependency_trees(doc: Doc) -> list[str]:
|
def render_dependency_trees(doc: Doc, root_right: bool) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Returns an ASCII rendering of the document with a dependency tree for each sentence. The
|
Returns an ASCII rendering of the document with a dependency tree for each sentence. The
|
||||||
dependency tree output for a given token has the same index within the output list of
|
dependency tree output for a given token has the same index within the output list of
|
||||||
strings as that token within the input document.
|
strings as that token within the input document.
|
||||||
|
|
||||||
|
root_right: True if the tree should be rendered with the root on the right-hand side,
|
||||||
|
False if the tree should be rendered with the root on the left-hand side.
|
||||||
|
|
||||||
Adapted from https://github.com/KoichiYasuoka/deplacy
|
Adapted from https://github.com/KoichiYasuoka/deplacy
|
||||||
"""
|
"""
|
||||||
heads = [
|
heads = [
|
||||||
|
@ -59,6 +61,7 @@ class Visualizer:
|
||||||
if head is not None:
|
if head is not None:
|
||||||
children_lists[head].append(child)
|
children_lists[head].append(child)
|
||||||
all_indices_ordered_by_column = []
|
all_indices_ordered_by_column = []
|
||||||
|
# start with the root column
|
||||||
indices_in_current_column = [i for i, h in enumerate(heads) if h is None]
|
indices_in_current_column = [i for i, h in enumerate(heads) if h is None]
|
||||||
while len(indices_in_current_column) > 0:
|
while len(indices_in_current_column) > 0:
|
||||||
all_indices_ordered_by_column = (
|
all_indices_ordered_by_column = (
|
||||||
|
@ -72,11 +75,9 @@ class Visualizer:
|
||||||
]
|
]
|
||||||
# -1: root token with no arrow; 0: length not yet set
|
# -1: root token with no arrow; 0: length not yet set
|
||||||
horizontal_line_lengths = [
|
horizontal_line_lengths = [
|
||||||
-1
|
-1 if heads[i] is None else 1
|
||||||
if heads[i] is None
|
# length == 1: governed by direct neighbour and has no children itself
|
||||||
else 1
|
if len(children_lists[i]) == 0 and abs(heads[i] - i) == 1 else 0
|
||||||
if len(children_lists[i]) == 0 and abs(heads[i] - i) == 1
|
|
||||||
else 0
|
|
||||||
for i in range(len(doc))
|
for i in range(len(doc))
|
||||||
]
|
]
|
||||||
while 0 in horizontal_line_lengths:
|
while 0 in horizontal_line_lengths:
|
||||||
|
@ -85,6 +86,7 @@ class Visualizer:
|
||||||
for i in all_indices_ordered_by_column
|
for i in all_indices_ordered_by_column
|
||||||
if horizontal_line_lengths[i] == 0
|
if horizontal_line_lengths[i] == 0
|
||||||
):
|
):
|
||||||
|
# render relation between this token and its head
|
||||||
first_index_in_relation = min(
|
first_index_in_relation = min(
|
||||||
working_token_index,
|
working_token_index,
|
||||||
heads[working_token_index],
|
heads[working_token_index],
|
||||||
|
@ -93,8 +95,11 @@ class Visualizer:
|
||||||
working_token_index,
|
working_token_index,
|
||||||
heads[working_token_index],
|
heads[working_token_index],
|
||||||
)
|
)
|
||||||
|
# If this token has children, they will already have been rendered.
|
||||||
|
# The line needs to be one character longer than the longest of the
|
||||||
|
# children's lines.
|
||||||
if len(children_lists[working_token_index]) > 0:
|
if len(children_lists[working_token_index]) > 0:
|
||||||
working_horizontal_line_length = (
|
horizontal_line_lengths[working_token_index] = (
|
||||||
max(
|
max(
|
||||||
[
|
[
|
||||||
horizontal_line_lengths[i]
|
horizontal_line_lengths[i]
|
||||||
|
@ -104,36 +109,21 @@ class Visualizer:
|
||||||
+ 1
|
+ 1
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
working_horizontal_line_length = 1
|
horizontal_line_lengths[working_token_index] = 1
|
||||||
for inbetween_index in (
|
for inbetween_index in (
|
||||||
i
|
i
|
||||||
for i in range(
|
for i in range(
|
||||||
first_index_in_relation + 1, second_index_in_relation
|
first_index_in_relation + 1, second_index_in_relation
|
||||||
)
|
)
|
||||||
if i not in children_lists[working_token_index]
|
if horizontal_line_lengths[i] != 0
|
||||||
and horizontal_line_lengths[i] != 0
|
|
||||||
):
|
):
|
||||||
working_horizontal_line_length = max(
|
horizontal_line_lengths[working_token_index] = max(
|
||||||
working_horizontal_line_length,
|
horizontal_line_lengths[working_token_index],
|
||||||
horizontal_line_lengths[inbetween_index]
|
horizontal_line_lengths[inbetween_index]
|
||||||
if inbetween_index in children_lists[heads[working_token_index]]
|
if inbetween_index in children_lists[heads[working_token_index]]
|
||||||
|
and inbetween_index not in children_lists[working_token_index]
|
||||||
else horizontal_line_lengths[inbetween_index] + 1,
|
else horizontal_line_lengths[inbetween_index] + 1,
|
||||||
)
|
)
|
||||||
for child_horizontal_line_length in (
|
|
||||||
horizontal_line_lengths[i]
|
|
||||||
for i in children_lists[working_token_index]
|
|
||||||
if (i < first_index_in_relation or i > second_index_in_relation)
|
|
||||||
and horizontal_line_lengths[i] != 0
|
|
||||||
):
|
|
||||||
working_horizontal_line_length = max(
|
|
||||||
working_horizontal_line_length,
|
|
||||||
child_horizontal_line_length + 1,
|
|
||||||
)
|
|
||||||
horizontal_line_lengths[
|
|
||||||
working_token_index
|
|
||||||
] = working_horizontal_line_length
|
|
||||||
print(horizontal_line_lengths)
|
|
||||||
|
|
||||||
max_horizontal_line_length = max(horizontal_line_lengths)
|
max_horizontal_line_length = max(horizontal_line_lengths)
|
||||||
char_matrix = [
|
char_matrix = [
|
||||||
[SPACE] * max_horizontal_line_length * 2 for _ in range(len(doc))
|
[SPACE] * max_horizontal_line_length * 2 for _ in range(len(doc))
|
||||||
|
@ -150,21 +140,21 @@ class Visualizer:
|
||||||
|
|
||||||
# Draw the corners of the relation
|
# Draw the corners of the relation
|
||||||
char_matrix[first_index_in_relation][char_horizontal_line_length - 1] |= (
|
char_matrix[first_index_in_relation][char_horizontal_line_length - 1] |= (
|
||||||
HORIZONTAL_LINE_PART_AWAY_FROM_ROOT + LOWER_VERTICAL_LINE_PART
|
HALF_HORIZONTAL_LINE + LOWER_HALF_VERTICAL_LINE
|
||||||
)
|
)
|
||||||
char_matrix[second_index_in_relation][char_horizontal_line_length - 1] |= (
|
char_matrix[second_index_in_relation][char_horizontal_line_length - 1] |= (
|
||||||
HORIZONTAL_LINE_PART_AWAY_FROM_ROOT + UPPER_VERTICAL_LINE_PART
|
HALF_HORIZONTAL_LINE + UPPER_HALF_VERTICAL_LINE
|
||||||
)
|
)
|
||||||
|
|
||||||
# Draw the horizontal line for the governing token
|
# Draw the horizontal line for the governing token
|
||||||
for working_horizontal_position in range(char_horizontal_line_length - 1):
|
for working_horizontal_position in range(char_horizontal_line_length - 1):
|
||||||
if (
|
if (
|
||||||
char_matrix[head_token_index][working_horizontal_position]
|
char_matrix[head_token_index][working_horizontal_position]
|
||||||
!= VERTICAL_LINE
|
!= FULL_VERTICAL_LINE
|
||||||
):
|
):
|
||||||
char_matrix[head_token_index][
|
char_matrix[head_token_index][
|
||||||
working_horizontal_position
|
working_horizontal_position
|
||||||
] |= HORIZONTAL_LINE
|
] |= FULL_HORIZONTAL_LINE
|
||||||
|
|
||||||
# Draw the vertical line for the relation
|
# Draw the vertical line for the relation
|
||||||
for working_vertical_position in range(
|
for working_vertical_position in range(
|
||||||
|
@ -174,20 +164,28 @@ class Visualizer:
|
||||||
char_matrix[working_vertical_position][
|
char_matrix[working_vertical_position][
|
||||||
char_horizontal_line_length - 1
|
char_horizontal_line_length - 1
|
||||||
]
|
]
|
||||||
!= HORIZONTAL_LINE
|
!= FULL_HORIZONTAL_LINE
|
||||||
):
|
):
|
||||||
char_matrix[working_vertical_position][
|
char_matrix[working_vertical_position][
|
||||||
char_horizontal_line_length - 1
|
char_horizontal_line_length - 1
|
||||||
] |= VERTICAL_LINE
|
] |= FULL_VERTICAL_LINE
|
||||||
for working_token_index in (i for i in range(len(doc)) if heads[i] is not None):
|
for working_token_index in (i for i in range(len(doc)) if heads[i] is not None):
|
||||||
for working_horizontal_position in range(
|
for working_horizontal_position in range(
|
||||||
2 * horizontal_line_lengths[working_token_index] - 2, -1, -1
|
2 * horizontal_line_lengths[working_token_index] - 2, -1, -1
|
||||||
):
|
):
|
||||||
if (
|
if (
|
||||||
char_matrix[working_token_index][working_horizontal_position]
|
(
|
||||||
== VERTICAL_LINE) and working_horizontal_position > 1 and char_matrix[working_token_index][working_horizontal_position - 2] == SPACE:
|
char_matrix[working_token_index][working_horizontal_position]
|
||||||
# Cross over the existing vertical line, which is owing to a non-projective tree
|
== FULL_VERTICAL_LINE
|
||||||
continue
|
)
|
||||||
|
and working_horizontal_position > 1
|
||||||
|
and char_matrix[working_token_index][
|
||||||
|
working_horizontal_position - 2
|
||||||
|
]
|
||||||
|
== SPACE
|
||||||
|
):
|
||||||
|
# Cross over the existing vertical line, which is owing to a non-projective tree
|
||||||
|
continue
|
||||||
if (
|
if (
|
||||||
char_matrix[working_token_index][working_horizontal_position]
|
char_matrix[working_token_index][working_horizontal_position]
|
||||||
!= SPACE
|
!= SPACE
|
||||||
|
@ -200,26 +198,28 @@ class Visualizer:
|
||||||
if working_horizontal_position == 0:
|
if working_horizontal_position == 0:
|
||||||
# Draw the arrowhead at the boundary of the diagram
|
# Draw the arrowhead at the boundary of the diagram
|
||||||
char_matrix[working_token_index][
|
char_matrix[working_token_index][
|
||||||
working_horizontal_position] = ARROWHEAD
|
working_horizontal_position
|
||||||
|
] = ARROWHEAD
|
||||||
else:
|
else:
|
||||||
# Fill in the horizontal line for the governed token
|
# Fill in the horizontal line for the governed token
|
||||||
char_matrix[working_token_index][
|
char_matrix[working_token_index][
|
||||||
working_horizontal_position
|
working_horizontal_position
|
||||||
] |= HORIZONTAL_LINE
|
] |= FULL_HORIZONTAL_LINE
|
||||||
return [
|
if root_right:
|
||||||
"".join(
|
return [
|
||||||
ROOT_RIGHT_CHARS[char_matrix[vertical_position][horizontal_position]]
|
"".join(
|
||||||
for horizontal_position in range((max_horizontal_line_length * 2))
|
ROOT_RIGHT_CHARS[
|
||||||
)
|
char_matrix[vertical_position][horizontal_position]
|
||||||
for vertical_position in range(len(doc))
|
]
|
||||||
]
|
for horizontal_position in range((max_horizontal_line_length * 2))
|
||||||
|
)
|
||||||
|
for vertical_position in range(len(doc))
|
||||||
import spacy
|
]
|
||||||
|
else:
|
||||||
nlp = spacy.load("en_core_web_sm")
|
return [
|
||||||
doc = nlp("I saw a horse yesterday that was injured")
|
"".join(
|
||||||
|
ROOT_LEFT_CHARS[char_matrix[vertical_position][horizontal_position]]
|
||||||
lines = Visualizer().render_dependency_trees(doc)
|
for horizontal_position in range((max_horizontal_line_length * 2))
|
||||||
for line in lines:
|
)[::-1]
|
||||||
print(line)
|
for vertical_position in range(len(doc))
|
||||||
|
]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user