diff --git a/spacy/visualization.py b/spacy/visualization.py index bde843a71..01d6b2b5e 100644 --- a/spacy/visualization.py +++ b/spacy/visualization.py @@ -18,34 +18,35 @@ class Visualizer: for token in doc ] children_lists = [[] for _ in range(len(doc))] - for child, head in enumerate(h for h in heads): + for child, head in enumerate(heads): if head != -1: children_lists[head].append(child) all_indices_ordered_by_column = [] indices_in_current_column = [i for i, h in enumerate(heads) if h == -1] while len(indices_in_current_column) > 0: - all_indices_ordered_by_column.extend(indices_in_current_column) + all_indices_ordered_by_column = ( + indices_in_current_column + all_indices_ordered_by_column + ) indices_in_current_column = [ i for index_in_current_column in indices_in_current_column for i in children_lists[index_in_current_column] if i not in all_indices_ordered_by_column ] - all_indices_ordered_by_column = reversed(all_indices_ordered_by_column) - # -1: no arrow; None: not yet set - horizontal_arrow_positions = [ + # -1: root token with no arrow; 0: length not yet set + horizontal_line_lengths = [ -1 if heads[i] == -1 - else 0 + else 1 if len(children_lists[i]) == 0 and abs(heads[i] - i) == 1 - else None + else 0 for i in range(len(doc)) ] - while None in horizontal_arrow_positions: + while 0 in horizontal_line_lengths: for working_token_index in ( i for i in all_indices_ordered_by_column - if horizontal_arrow_positions[i] is None + if horizontal_line_lengths[i] == 0 ): first_index_in_relation = min( working_token_index, @@ -55,61 +56,43 @@ class Visualizer: working_token_index, heads[working_token_index], ) - inbetween_indexes = range( - first_index_in_relation + 1, second_index_in_relation - ) - inbetween_horizontal_arrow_positions = horizontal_arrow_positions[ - first_index_in_relation + 1 : second_index_in_relation - ] - if ( - -1 in inbetween_horizontal_arrow_positions - and None in inbetween_horizontal_arrow_positions - ): - continue - if None in ( - horizontal_arrow_positions[i] - for i in children_lists[working_token_index] - ): - continue if len(children_lists[working_token_index]) > 0: - working_horizontal_arrow_position = ( - max( - [ - horizontal_arrow_positions[i] - for i in children_lists[working_token_index] - ] - ) - + 1 + working_horizontal_arrow_position = max( + [ + horizontal_line_lengths[i] + for i in children_lists[working_token_index] + ] ) else: - working_horizontal_arrow_position = -1 + working_horizontal_arrow_position = 0 for inbetween_index in ( i - for i in inbetween_indexes + for i in range( + first_index_in_relation + 1, second_index_in_relation + ) if i not in children_lists[working_token_index] - and horizontal_arrow_positions[i] is not None + and horizontal_line_lengths[i] != 0 ): working_horizontal_arrow_position = max( working_horizontal_arrow_position, - horizontal_arrow_positions[inbetween_index] + horizontal_line_lengths[inbetween_index] - 1 if inbetween_index in children_lists[heads[working_token_index]] - else horizontal_arrow_positions[inbetween_index] + 1, + else horizontal_line_lengths[inbetween_index], ) for child_horizontal_arrow_position in ( - horizontal_arrow_positions[i] + horizontal_line_lengths[i] for i in children_lists[working_token_index] if (i < first_index_in_relation or i > second_index_in_relation) - and horizontal_arrow_positions[i] is not None + and horizontal_line_lengths[i] != 0 ): working_horizontal_arrow_position = max( working_horizontal_arrow_position, - child_horizontal_arrow_position + 1, + child_horizontal_arrow_position, ) - if working_horizontal_arrow_position > -1: - horizontal_arrow_positions[ - working_token_index - ] = working_horizontal_arrow_position - print(horizontal_arrow_positions) + horizontal_line_lengths[working_token_index] = ( + working_horizontal_arrow_position + 1 + ) + print(horizontal_line_lengths) import spacy