render_dependency_trees complete

This commit is contained in:
Richard Hudson 2021-11-30 15:25:25 +01:00
parent a660a7d347
commit b4265eccf9

View File

@ -1,51 +1,53 @@
from spacy.tests.lang.ko.test_tokenizer import FULL_TAG_TESTS
from spacy.tokens import Doc, Token from spacy.tokens import Doc, Token
from spacy.util import working_dir from spacy.util import working_dir
HORIZONTAL_LINE_PART_AWAY_FROM_ROOT = 1 SPACE = 0
HORIZONTAL_LINE_PART_TOWARDS_ROOT = 2 HALF_HORIZONTAL_LINE = 1 # the half is the half further away from the root
UPPER_VERTICAL_LINE_PART = 4 FULL_HORIZONTAL_LINE = 3
LOWER_VERTICAL_LINE_PART = 8 UPPER_HALF_VERTICAL_LINE = 4
LOWER_HALF_VERTICAL_LINE = 8
FULL_VERTICAL_LINE = 12
ARROWHEAD = 16 ARROWHEAD = 16
SPACE = 0
HORIZONTAL_LINE = 3
VERTICAL_LINE = 12
ROOT_RIGHT_CHARS = { ROOT_RIGHT_CHARS = {
0: " ", SPACE: " ",
3: "", FULL_HORIZONTAL_LINE: "",
5: "", UPPER_HALF_VERTICAL_LINE + HALF_HORIZONTAL_LINE: "",
7: "", UPPER_HALF_VERTICAL_LINE + FULL_HORIZONTAL_LINE: "",
9: "", LOWER_HALF_VERTICAL_LINE + HALF_HORIZONTAL_LINE: "",
11: "", LOWER_HALF_VERTICAL_LINE + FULL_HORIZONTAL_LINE: "",
12: "", FULL_VERTICAL_LINE: "",
13: "", FULL_VERTICAL_LINE + HALF_HORIZONTAL_LINE: "",
15: "", FULL_VERTICAL_LINE + FULL_HORIZONTAL_LINE: "",
16: "<", ARROWHEAD: "<",
} }
ROOT_LEFT_CHARS = { ROOT_LEFT_CHARS = {
0: " ", SPACE: " ",
3: "", FULL_HORIZONTAL_LINE: "",
5: "", UPPER_HALF_VERTICAL_LINE + HALF_HORIZONTAL_LINE: "",
7: "", UPPER_HALF_VERTICAL_LINE + FULL_HORIZONTAL_LINE: "",
9: "", LOWER_HALF_VERTICAL_LINE + HALF_HORIZONTAL_LINE: "",
11: "", LOWER_HALF_VERTICAL_LINE + FULL_HORIZONTAL_LINE: "",
12: "", FULL_VERTICAL_LINE: "",
13: "", FULL_VERTICAL_LINE + HALF_HORIZONTAL_LINE: "",
15: "", FULL_VERTICAL_LINE + FULL_HORIZONTAL_LINE: "",
16: ">", ARROWHEAD: ">",
} }
class Visualizer: class Visualizer:
@staticmethod @staticmethod
def render_dependency_trees(doc: Doc) -> list[str]: def render_dependency_trees(doc: Doc, root_right: bool) -> list[str]:
""" """
Returns an ASCII rendering of the document with a dependency tree for each sentence. The Returns an ASCII rendering of the document with a dependency tree for each sentence. The
dependency tree output for a given token has the same index within the output list of dependency tree output for a given token has the same index within the output list of
strings as that token within the input document. strings as that token within the input document.
root_right: True if the tree should be rendered with the root on the right-hand side,
False if the tree should be rendered with the root on the left-hand side.
Adapted from https://github.com/KoichiYasuoka/deplacy Adapted from https://github.com/KoichiYasuoka/deplacy
""" """
heads = [ heads = [
@ -59,6 +61,7 @@ class Visualizer:
if head is not None: if head is not None:
children_lists[head].append(child) children_lists[head].append(child)
all_indices_ordered_by_column = [] all_indices_ordered_by_column = []
# start with the root column
indices_in_current_column = [i for i, h in enumerate(heads) if h is None] indices_in_current_column = [i for i, h in enumerate(heads) if h is None]
while len(indices_in_current_column) > 0: while len(indices_in_current_column) > 0:
all_indices_ordered_by_column = ( all_indices_ordered_by_column = (
@ -72,11 +75,9 @@ class Visualizer:
] ]
# -1: root token with no arrow; 0: length not yet set # -1: root token with no arrow; 0: length not yet set
horizontal_line_lengths = [ horizontal_line_lengths = [
-1 -1 if heads[i] is None else 1
if heads[i] is None # length == 1: governed by direct neighbour and has no children itself
else 1 if len(children_lists[i]) == 0 and abs(heads[i] - i) == 1 else 0
if len(children_lists[i]) == 0 and abs(heads[i] - i) == 1
else 0
for i in range(len(doc)) for i in range(len(doc))
] ]
while 0 in horizontal_line_lengths: while 0 in horizontal_line_lengths:
@ -85,6 +86,7 @@ class Visualizer:
for i in all_indices_ordered_by_column for i in all_indices_ordered_by_column
if horizontal_line_lengths[i] == 0 if horizontal_line_lengths[i] == 0
): ):
# render relation between this token and its head
first_index_in_relation = min( first_index_in_relation = min(
working_token_index, working_token_index,
heads[working_token_index], heads[working_token_index],
@ -93,8 +95,11 @@ class Visualizer:
working_token_index, working_token_index,
heads[working_token_index], heads[working_token_index],
) )
# If this token has children, they will already have been rendered.
# The line needs to be one character longer than the longest of the
# children's lines.
if len(children_lists[working_token_index]) > 0: if len(children_lists[working_token_index]) > 0:
working_horizontal_line_length = ( horizontal_line_lengths[working_token_index] = (
max( max(
[ [
horizontal_line_lengths[i] horizontal_line_lengths[i]
@ -104,36 +109,21 @@ class Visualizer:
+ 1 + 1
) )
else: else:
working_horizontal_line_length = 1 horizontal_line_lengths[working_token_index] = 1
for inbetween_index in ( for inbetween_index in (
i i
for i in range( for i in range(
first_index_in_relation + 1, second_index_in_relation first_index_in_relation + 1, second_index_in_relation
) )
if i not in children_lists[working_token_index] if horizontal_line_lengths[i] != 0
and horizontal_line_lengths[i] != 0
): ):
working_horizontal_line_length = max( horizontal_line_lengths[working_token_index] = max(
working_horizontal_line_length, horizontal_line_lengths[working_token_index],
horizontal_line_lengths[inbetween_index] horizontal_line_lengths[inbetween_index]
if inbetween_index in children_lists[heads[working_token_index]] if inbetween_index in children_lists[heads[working_token_index]]
and inbetween_index not in children_lists[working_token_index]
else horizontal_line_lengths[inbetween_index] + 1, else horizontal_line_lengths[inbetween_index] + 1,
) )
for child_horizontal_line_length in (
horizontal_line_lengths[i]
for i in children_lists[working_token_index]
if (i < first_index_in_relation or i > second_index_in_relation)
and horizontal_line_lengths[i] != 0
):
working_horizontal_line_length = max(
working_horizontal_line_length,
child_horizontal_line_length + 1,
)
horizontal_line_lengths[
working_token_index
] = working_horizontal_line_length
print(horizontal_line_lengths)
max_horizontal_line_length = max(horizontal_line_lengths) max_horizontal_line_length = max(horizontal_line_lengths)
char_matrix = [ char_matrix = [
[SPACE] * max_horizontal_line_length * 2 for _ in range(len(doc)) [SPACE] * max_horizontal_line_length * 2 for _ in range(len(doc))
@ -150,21 +140,21 @@ class Visualizer:
# Draw the corners of the relation # Draw the corners of the relation
char_matrix[first_index_in_relation][char_horizontal_line_length - 1] |= ( char_matrix[first_index_in_relation][char_horizontal_line_length - 1] |= (
HORIZONTAL_LINE_PART_AWAY_FROM_ROOT + LOWER_VERTICAL_LINE_PART HALF_HORIZONTAL_LINE + LOWER_HALF_VERTICAL_LINE
) )
char_matrix[second_index_in_relation][char_horizontal_line_length - 1] |= ( char_matrix[second_index_in_relation][char_horizontal_line_length - 1] |= (
HORIZONTAL_LINE_PART_AWAY_FROM_ROOT + UPPER_VERTICAL_LINE_PART HALF_HORIZONTAL_LINE + UPPER_HALF_VERTICAL_LINE
) )
# Draw the horizontal line for the governing token # Draw the horizontal line for the governing token
for working_horizontal_position in range(char_horizontal_line_length - 1): for working_horizontal_position in range(char_horizontal_line_length - 1):
if ( if (
char_matrix[head_token_index][working_horizontal_position] char_matrix[head_token_index][working_horizontal_position]
!= VERTICAL_LINE != FULL_VERTICAL_LINE
): ):
char_matrix[head_token_index][ char_matrix[head_token_index][
working_horizontal_position working_horizontal_position
] |= HORIZONTAL_LINE ] |= FULL_HORIZONTAL_LINE
# Draw the vertical line for the relation # Draw the vertical line for the relation
for working_vertical_position in range( for working_vertical_position in range(
@ -174,20 +164,28 @@ class Visualizer:
char_matrix[working_vertical_position][ char_matrix[working_vertical_position][
char_horizontal_line_length - 1 char_horizontal_line_length - 1
] ]
!= HORIZONTAL_LINE != FULL_HORIZONTAL_LINE
): ):
char_matrix[working_vertical_position][ char_matrix[working_vertical_position][
char_horizontal_line_length - 1 char_horizontal_line_length - 1
] |= VERTICAL_LINE ] |= FULL_VERTICAL_LINE
for working_token_index in (i for i in range(len(doc)) if heads[i] is not None): for working_token_index in (i for i in range(len(doc)) if heads[i] is not None):
for working_horizontal_position in range( for working_horizontal_position in range(
2 * horizontal_line_lengths[working_token_index] - 2, -1, -1 2 * horizontal_line_lengths[working_token_index] - 2, -1, -1
): ):
if ( if (
char_matrix[working_token_index][working_horizontal_position] (
== VERTICAL_LINE) and working_horizontal_position > 1 and char_matrix[working_token_index][working_horizontal_position - 2] == SPACE: char_matrix[working_token_index][working_horizontal_position]
# Cross over the existing vertical line, which is owing to a non-projective tree == FULL_VERTICAL_LINE
continue )
and working_horizontal_position > 1
and char_matrix[working_token_index][
working_horizontal_position - 2
]
== SPACE
):
# Cross over the existing vertical line, which is owing to a non-projective tree
continue
if ( if (
char_matrix[working_token_index][working_horizontal_position] char_matrix[working_token_index][working_horizontal_position]
!= SPACE != SPACE
@ -200,26 +198,28 @@ class Visualizer:
if working_horizontal_position == 0: if working_horizontal_position == 0:
# Draw the arrowhead at the boundary of the diagram # Draw the arrowhead at the boundary of the diagram
char_matrix[working_token_index][ char_matrix[working_token_index][
working_horizontal_position] = ARROWHEAD working_horizontal_position
] = ARROWHEAD
else: else:
# Fill in the horizontal line for the governed token # Fill in the horizontal line for the governed token
char_matrix[working_token_index][ char_matrix[working_token_index][
working_horizontal_position working_horizontal_position
] |= HORIZONTAL_LINE ] |= FULL_HORIZONTAL_LINE
return [ if root_right:
"".join( return [
ROOT_RIGHT_CHARS[char_matrix[vertical_position][horizontal_position]] "".join(
for horizontal_position in range((max_horizontal_line_length * 2)) ROOT_RIGHT_CHARS[
) char_matrix[vertical_position][horizontal_position]
for vertical_position in range(len(doc)) ]
] for horizontal_position in range((max_horizontal_line_length * 2))
)
for vertical_position in range(len(doc))
import spacy ]
else:
nlp = spacy.load("en_core_web_sm") return [
doc = nlp("I saw a horse yesterday that was injured") "".join(
ROOT_LEFT_CHARS[char_matrix[vertical_position][horizontal_position]]
lines = Visualizer().render_dependency_trees(doc) for horizontal_position in range((max_horizontal_line_length * 2))
for line in lines: )[::-1]
print(line) for vertical_position in range(len(doc))
]