From 2d0e916220a9438b8e7b132e2ff18d6b2809d6af Mon Sep 17 00:00:00 2001 From: Richard Hudson Date: Sun, 28 Nov 2021 15:39:07 +0100 Subject: [PATCH] Work in progress --- spacy/visualization.py | 120 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 115 insertions(+), 5 deletions(-) diff --git a/spacy/visualization.py b/spacy/visualization.py index b95e6fb6a..04b8c7516 100644 --- a/spacy/visualization.py +++ b/spacy/visualization.py @@ -1,5 +1,41 @@ from spacy.tokens import Doc, Token +from spacy.util import working_dir +HORIZONTAL_LINE_PART_AWAY_FROM_ROOT = 1 +HORIZONTAL_LINE_PART_TOWARDS_ROOT = 2 +UPPER_VERTICAL_LINE_PART = 4 +LOWER_VERTICAL_LINE_PART = 8 +ARROWHEAD = 16 + +SPACE = 0 +HORIZONTAL_LINE = 3 +VERTICAL_LINE = 12 + +ROOT_RIGHT_CHARS = { + 0: " ", + 3: "═", + 5: "╝", + 7: "╩", + 9: "╗", + 11: "╦", + 12: "║", + 13: "╣", + 15: "╬", + 16: "<", +} + +ROOT_LEFT_CHARS = { + 0: " ", + 3: "═", + 5: "╚", + 7: "╩", + 9: "╔", + 11: "╦", + 12: "║", + 13: "╠", + 15: "╬", + 16: ">", +} class Visualizer: @staticmethod @@ -12,17 +48,17 @@ class Visualizer: Adapted from https://github.com/KoichiYasuoka/deplacy """ heads = [ - -1 + None if token.dep_.lower() == "root" or token.head.i == token.i else token.head.i for token in doc ] children_lists = [[] for _ in range(len(doc))] for child, head in enumerate(heads): - if head != -1: + if head is not None: children_lists[head].append(child) all_indices_ordered_by_column = [] - indices_in_current_column = [i for i, h in enumerate(heads) if h == -1] + indices_in_current_column = [i for i, h in enumerate(heads) if h is None] while len(indices_in_current_column) > 0: all_indices_ordered_by_column = ( indices_in_current_column + all_indices_ordered_by_column @@ -36,7 +72,7 @@ class Visualizer: # -1: root token with no arrow; 0: length not yet set horizontal_line_lengths = [ -1 - if heads[i] == -1 + if heads[i] is None else 1 if len(children_lists[i]) == 0 and abs(heads[i] - i) == 1 else 0 @@ -97,9 +133,83 @@ class Visualizer: ] = working_horizontal_line_length print(horizontal_line_lengths) + max_horizontal_line_length = max(horizontal_line_lengths) + char_matrix = [ + [SPACE] * max_horizontal_line_length * 2 for _ in range(len(doc)) + ] + for working_token_index in range(len(doc)): + head_token_index = heads[working_token_index] + if head_token_index is None: + continue + first_index_in_relation = min(working_token_index, head_token_index) + second_index_in_relation = max(working_token_index, head_token_index) + char_horizontal_line_length = ( + 2 * horizontal_line_lengths[working_token_index] + ) + + char_matrix[first_index_in_relation][char_horizontal_line_length - 1] |= ( + HORIZONTAL_LINE_PART_AWAY_FROM_ROOT + LOWER_VERTICAL_LINE_PART + ) + char_matrix[second_index_in_relation][char_horizontal_line_length - 1] |= ( + HORIZONTAL_LINE_PART_TOWARDS_ROOT + UPPER_VERTICAL_LINE_PART + ) + for working_horizontal_position in range(char_horizontal_line_length - 1): + if ( + char_matrix[working_token_index][working_horizontal_position] + != VERTICAL_LINE + ): + char_matrix[working_token_index][ + working_horizontal_position + ] |= HORIZONTAL_LINE + for working_vertical_position in range( + first_index_in_relation + 1, second_index_in_relation + ): + if ( + char_matrix[working_vertical_position][ + char_horizontal_line_length - 1 + ] + != HORIZONTAL_LINE + ): + char_matrix[working_vertical_position][ + char_horizontal_line_length - 1 + ] |= VERTICAL_LINE + for working_token_index in (i for i in range(len(doc)) if heads[i] is not None): + for working_horizontal_position in range( + 2 * horizontal_line_lengths[working_token_index], -1, -1 + ): + print(working_token_index, working_horizontal_position) + if ( + working_horizontal_position > 0 + and char_matrix[working_token_index][working_horizontal_position] + == VERTICAL_LINE + and char_matrix[ + working_token_index][working_horizontal_position - 1 + ] + == 0 + ): + # jump over crossing line in non-projective parse case + continue + if ( + char_matrix[working_token_index][working_horizontal_position] + != SPACE + ): + char_matrix[working_token_index][ + working_horizontal_position + 1 + ] = ARROWHEAD + break + if working_horizontal_position == 0: + char_matrix[working_token_index][ + working_horizontal_position + ] = ARROWHEAD + else: + char_matrix[working_token_index][ + working_horizontal_position + ] |= HORIZONTAL_LINE + return [''.join(ROOT_RIGHT_CHARS[char_matrix[vertical_position][horizontal_position]] for horizontal_position in range((max_horizontal_line_length * 2 - 1))) for vertical_position in range(len(doc))] import spacy nlp = spacy.load("en_core_web_sm") doc = nlp("I saw a horse yesterday that was injured") -Visualizer().render_dependency_trees(doc) + +print (Visualizer().render_dependency_trees(doc))