diff --git a/spacy/visualization.py b/spacy/visualization.py index 04b8c7516..546b13627 100644 --- a/spacy/visualization.py +++ b/spacy/visualization.py @@ -37,6 +37,7 @@ ROOT_LEFT_CHARS = { 16: ">", } + class Visualizer: @staticmethod def render_dependency_trees(doc: Doc) -> list[str]: @@ -147,20 +148,25 @@ class Visualizer: 2 * horizontal_line_lengths[working_token_index] ) + # Draw the corners of the relation char_matrix[first_index_in_relation][char_horizontal_line_length - 1] |= ( HORIZONTAL_LINE_PART_AWAY_FROM_ROOT + LOWER_VERTICAL_LINE_PART ) char_matrix[second_index_in_relation][char_horizontal_line_length - 1] |= ( - HORIZONTAL_LINE_PART_TOWARDS_ROOT + UPPER_VERTICAL_LINE_PART + HORIZONTAL_LINE_PART_AWAY_FROM_ROOT + UPPER_VERTICAL_LINE_PART ) + + # Draw the horizontal line for the governing token for working_horizontal_position in range(char_horizontal_line_length - 1): if ( - char_matrix[working_token_index][working_horizontal_position] + char_matrix[head_token_index][working_horizontal_position] != VERTICAL_LINE ): - char_matrix[working_token_index][ + char_matrix[head_token_index][ working_horizontal_position ] |= HORIZONTAL_LINE + + # Draw the vertical line for the relation for working_vertical_position in range( first_index_in_relation + 1, second_index_in_relation ): @@ -175,41 +181,45 @@ class Visualizer: ] |= VERTICAL_LINE for working_token_index in (i for i in range(len(doc)) if heads[i] is not None): for working_horizontal_position in range( - 2 * horizontal_line_lengths[working_token_index], -1, -1 + 2 * horizontal_line_lengths[working_token_index] - 2, -1, -1 ): - print(working_token_index, working_horizontal_position) if ( - working_horizontal_position > 0 - and char_matrix[working_token_index][working_horizontal_position] - == VERTICAL_LINE - and char_matrix[ - working_token_index][working_horizontal_position - 1 - ] - == 0 - ): - # jump over crossing line in non-projective parse case - continue + char_matrix[working_token_index][working_horizontal_position] + == VERTICAL_LINE) and working_horizontal_position > 1 and char_matrix[working_token_index][working_horizontal_position - 2] == SPACE: + # Cross over the existing vertical line, which is owing to a non-projective tree + continue if ( char_matrix[working_token_index][working_horizontal_position] != SPACE ): + # Draw the arrowhead to the right of what is already there char_matrix[working_token_index][ working_horizontal_position + 1 ] = ARROWHEAD break if working_horizontal_position == 0: + # Draw the arrowhead at the boundary of the diagram char_matrix[working_token_index][ - working_horizontal_position - ] = ARROWHEAD + working_horizontal_position] = ARROWHEAD else: + # Fill in the horizontal line for the governed token char_matrix[working_token_index][ working_horizontal_position ] |= HORIZONTAL_LINE - return [''.join(ROOT_RIGHT_CHARS[char_matrix[vertical_position][horizontal_position]] for horizontal_position in range((max_horizontal_line_length * 2 - 1))) for vertical_position in range(len(doc))] + return [ + "".join( + ROOT_RIGHT_CHARS[char_matrix[vertical_position][horizontal_position]] + for horizontal_position in range((max_horizontal_line_length * 2)) + ) + for vertical_position in range(len(doc)) + ] + import spacy nlp = spacy.load("en_core_web_sm") doc = nlp("I saw a horse yesterday that was injured") -print (Visualizer().render_dependency_trees(doc)) +lines = Visualizer().render_dependency_trees(doc) +for line in lines: + print(line)