Corrected types

This commit is contained in:
Richard Hudson 2021-12-23 17:01:43 +01:00
parent e713aa0938
commit 5c850b2ac3
2 changed files with 46 additions and 15 deletions

View File

@ -531,7 +531,7 @@ def fully_featured_doc_two_sentences(en_vocab):
"She",
"loved",
"it",
"."
".",
]
lemmas = [
"sarah",
@ -547,9 +547,24 @@ def fully_featured_doc_two_sentences(en_vocab):
"she",
"love",
"it",
"."
".",
]
spaces = [
False,
True,
True,
True,
True,
True,
True,
True,
False,
True,
True,
True,
False,
False,
]
spaces = [False, True, True, True, True, True, True, True, False, True, True, True, False, False]
pos = [
"PROPN",
"PART",
@ -564,9 +579,24 @@ def fully_featured_doc_two_sentences(en_vocab):
"PRON",
"VERB",
"PRON",
"PUNCT"
"PUNCT",
]
tags = [
"NNP",
"POS",
"NN",
"VBD",
"IN",
"NNP",
"NNP",
"IN",
"NNP",
".",
"PRP",
"VBD",
"PRP",
".",
]
tags = ["NNP", "POS", "NN", "VBD", "IN", "NNP", "NNP", "IN", "NNP", ".", "PRP", "VBD", "PRP", "."]
morphs = [
"NounType=prop|Number=sing",
"Poss=yes",

View File

@ -1,5 +1,5 @@
from os import linesep
from typing import Union
from typing import Optional, Union, cast
import wasabi
from spacy.tokens import Span, Token, Doc
@ -151,7 +151,7 @@ class Visualizer:
# Check sent is really a sentence
assert sent.start == sent[0].sent.start
assert sent.end == sent[0].sent.end
heads = [
heads: list[Optional[int]] = [
None
if token.dep_.lower() == "root" or token.head.i == token.i
else token.head.i - sent.start
@ -211,7 +211,7 @@ class Visualizer:
horizontal_line_lengths = [
-1 if heads[i] is None else 1
# length == 1: governed by direct neighbour and has no children itself
if len(children_lists[i]) == 0 and abs(heads[i] - i) == 1 else 0
if len(children_lists[i]) == 0 and abs(cast(int, heads[i]) - i) == 1 else 0
for i in range(sent.end - sent.start)
]
while 0 in horizontal_line_lengths:
@ -223,11 +223,11 @@ class Visualizer:
# render relation between this token and its head
first_index_in_relation = min(
working_token_index,
heads[working_token_index],
cast(int, heads[working_token_index]),
)
second_index_in_relation = max(
working_token_index,
heads[working_token_index],
cast(int, heads[working_token_index]),
)
# If this token has children, they will already have been rendered.
# The line needs to be one character longer than the longest of the
@ -254,7 +254,8 @@ class Visualizer:
horizontal_line_lengths[working_token_index] = max(
horizontal_line_lengths[working_token_index],
horizontal_line_lengths[inbetween_index]
if inbetween_index in children_lists[heads[working_token_index]]
if inbetween_index
in children_lists[cast(int, heads[working_token_index])]
and inbetween_index not in children_lists[working_token_index]
else horizontal_line_lengths[inbetween_index] + 1,
)
@ -383,9 +384,9 @@ class Visualizer:
for column in columns:
# get the values without any color codes
if column.attribute == "tree_left":
width = len(tree_left[0])
width = len(tree_left[0]) # type: ignore
elif column.attribute == "tree_right":
width = len(tree_right[0])
width = len(tree_right[0]) # type: ignore
else:
if len(sent) > 0:
width = max(
@ -400,9 +401,9 @@ class Visualizer:
widths.append(width)
data = [
[
tree_right[token_index]
tree_right[token_index] # type: ignore
if column.attribute == "tree_right"
else tree_left[token_index]
else tree_left[token_index] # type: ignore
if column.attribute == "tree_left"
else column.render(token, right_pad_to_length=widths[column_index])
for column_index, column in enumerate(columns)