diff --git a/spacy/tests/conftest.py b/spacy/tests/conftest.py index c021788b9..0bc2604bb 100644 --- a/spacy/tests/conftest.py +++ b/spacy/tests/conftest.py @@ -531,7 +531,7 @@ def fully_featured_doc_two_sentences(en_vocab): "She", "loved", "it", - "." + ".", ] lemmas = [ "sarah", @@ -547,9 +547,24 @@ def fully_featured_doc_two_sentences(en_vocab): "she", "love", "it", - "." + ".", + ] + spaces = [ + False, + True, + True, + True, + True, + True, + True, + True, + False, + True, + True, + True, + False, + False, ] - spaces = [False, True, True, True, True, True, True, True, False, True, True, True, False, False] pos = [ "PROPN", "PART", @@ -564,9 +579,24 @@ def fully_featured_doc_two_sentences(en_vocab): "PRON", "VERB", "PRON", - "PUNCT" + "PUNCT", + ] + tags = [ + "NNP", + "POS", + "NN", + "VBD", + "IN", + "NNP", + "NNP", + "IN", + "NNP", + ".", + "PRP", + "VBD", + "PRP", + ".", ] - tags = ["NNP", "POS", "NN", "VBD", "IN", "NNP", "NNP", "IN", "NNP", ".", "PRP", "VBD", "PRP", "."] morphs = [ "NounType=prop|Number=sing", "Poss=yes", diff --git a/spacy/visualization.py b/spacy/visualization.py index 0075d9f41..7903d1c4b 100644 --- a/spacy/visualization.py +++ b/spacy/visualization.py @@ -1,5 +1,5 @@ from os import linesep -from typing import Union +from typing import Optional, Union, cast import wasabi from spacy.tokens import Span, Token, Doc @@ -151,7 +151,7 @@ class Visualizer: # Check sent is really a sentence assert sent.start == sent[0].sent.start assert sent.end == sent[0].sent.end - heads = [ + heads: list[Optional[int]] = [ None if token.dep_.lower() == "root" or token.head.i == token.i else token.head.i - sent.start @@ -211,7 +211,7 @@ class Visualizer: horizontal_line_lengths = [ -1 if heads[i] is None else 1 # length == 1: governed by direct neighbour and has no children itself - if len(children_lists[i]) == 0 and abs(heads[i] - i) == 1 else 0 + if len(children_lists[i]) == 0 and abs(cast(int, heads[i]) - i) == 1 else 0 for i in range(sent.end - sent.start) ] while 0 in horizontal_line_lengths: @@ -223,11 +223,11 @@ class Visualizer: # render relation between this token and its head first_index_in_relation = min( working_token_index, - heads[working_token_index], + cast(int, heads[working_token_index]), ) second_index_in_relation = max( working_token_index, - heads[working_token_index], + cast(int, heads[working_token_index]), ) # If this token has children, they will already have been rendered. # The line needs to be one character longer than the longest of the @@ -254,7 +254,8 @@ class Visualizer: horizontal_line_lengths[working_token_index] = max( horizontal_line_lengths[working_token_index], horizontal_line_lengths[inbetween_index] - if inbetween_index in children_lists[heads[working_token_index]] + if inbetween_index + in children_lists[cast(int, heads[working_token_index])] and inbetween_index not in children_lists[working_token_index] else horizontal_line_lengths[inbetween_index] + 1, ) @@ -383,9 +384,9 @@ class Visualizer: for column in columns: # get the values without any color codes if column.attribute == "tree_left": - width = len(tree_left[0]) + width = len(tree_left[0]) # type: ignore elif column.attribute == "tree_right": - width = len(tree_right[0]) + width = len(tree_right[0]) # type: ignore else: if len(sent) > 0: width = max( @@ -400,9 +401,9 @@ class Visualizer: widths.append(width) data = [ [ - tree_right[token_index] + tree_right[token_index] # type: ignore if column.attribute == "tree_right" - else tree_left[token_index] + else tree_left[token_index] # type: ignore if column.attribute == "tree_left" else column.render(token, right_pad_to_length=widths[column_index]) for column_index, column in enumerate(columns)