Corrected types

This commit is contained in:
Richard Hudson 2021-12-23 17:01:43 +01:00
parent e713aa0938
commit 5c850b2ac3
2 changed files with 46 additions and 15 deletions

View File

@ -531,7 +531,7 @@ def fully_featured_doc_two_sentences(en_vocab):
"She", "She",
"loved", "loved",
"it", "it",
"." ".",
] ]
lemmas = [ lemmas = [
"sarah", "sarah",
@ -547,9 +547,24 @@ def fully_featured_doc_two_sentences(en_vocab):
"she", "she",
"love", "love",
"it", "it",
"." ".",
]
spaces = [
False,
True,
True,
True,
True,
True,
True,
True,
False,
True,
True,
True,
False,
False,
] ]
spaces = [False, True, True, True, True, True, True, True, False, True, True, True, False, False]
pos = [ pos = [
"PROPN", "PROPN",
"PART", "PART",
@ -564,9 +579,24 @@ def fully_featured_doc_two_sentences(en_vocab):
"PRON", "PRON",
"VERB", "VERB",
"PRON", "PRON",
"PUNCT" "PUNCT",
]
tags = [
"NNP",
"POS",
"NN",
"VBD",
"IN",
"NNP",
"NNP",
"IN",
"NNP",
".",
"PRP",
"VBD",
"PRP",
".",
] ]
tags = ["NNP", "POS", "NN", "VBD", "IN", "NNP", "NNP", "IN", "NNP", ".", "PRP", "VBD", "PRP", "."]
morphs = [ morphs = [
"NounType=prop|Number=sing", "NounType=prop|Number=sing",
"Poss=yes", "Poss=yes",

View File

@ -1,5 +1,5 @@
from os import linesep from os import linesep
from typing import Union from typing import Optional, Union, cast
import wasabi import wasabi
from spacy.tokens import Span, Token, Doc from spacy.tokens import Span, Token, Doc
@ -151,7 +151,7 @@ class Visualizer:
# Check sent is really a sentence # Check sent is really a sentence
assert sent.start == sent[0].sent.start assert sent.start == sent[0].sent.start
assert sent.end == sent[0].sent.end assert sent.end == sent[0].sent.end
heads = [ heads: list[Optional[int]] = [
None None
if token.dep_.lower() == "root" or token.head.i == token.i if token.dep_.lower() == "root" or token.head.i == token.i
else token.head.i - sent.start else token.head.i - sent.start
@ -211,7 +211,7 @@ class Visualizer:
horizontal_line_lengths = [ horizontal_line_lengths = [
-1 if heads[i] is None else 1 -1 if heads[i] is None else 1
# length == 1: governed by direct neighbour and has no children itself # length == 1: governed by direct neighbour and has no children itself
if len(children_lists[i]) == 0 and abs(heads[i] - i) == 1 else 0 if len(children_lists[i]) == 0 and abs(cast(int, heads[i]) - i) == 1 else 0
for i in range(sent.end - sent.start) for i in range(sent.end - sent.start)
] ]
while 0 in horizontal_line_lengths: while 0 in horizontal_line_lengths:
@ -223,11 +223,11 @@ class Visualizer:
# render relation between this token and its head # render relation between this token and its head
first_index_in_relation = min( first_index_in_relation = min(
working_token_index, working_token_index,
heads[working_token_index], cast(int, heads[working_token_index]),
) )
second_index_in_relation = max( second_index_in_relation = max(
working_token_index, working_token_index,
heads[working_token_index], cast(int, heads[working_token_index]),
) )
# If this token has children, they will already have been rendered. # If this token has children, they will already have been rendered.
# The line needs to be one character longer than the longest of the # The line needs to be one character longer than the longest of the
@ -254,7 +254,8 @@ class Visualizer:
horizontal_line_lengths[working_token_index] = max( horizontal_line_lengths[working_token_index] = max(
horizontal_line_lengths[working_token_index], horizontal_line_lengths[working_token_index],
horizontal_line_lengths[inbetween_index] horizontal_line_lengths[inbetween_index]
if inbetween_index in children_lists[heads[working_token_index]] if inbetween_index
in children_lists[cast(int, heads[working_token_index])]
and inbetween_index not in children_lists[working_token_index] and inbetween_index not in children_lists[working_token_index]
else horizontal_line_lengths[inbetween_index] + 1, else horizontal_line_lengths[inbetween_index] + 1,
) )
@ -383,9 +384,9 @@ class Visualizer:
for column in columns: for column in columns:
# get the values without any color codes # get the values without any color codes
if column.attribute == "tree_left": if column.attribute == "tree_left":
width = len(tree_left[0]) width = len(tree_left[0]) # type: ignore
elif column.attribute == "tree_right": elif column.attribute == "tree_right":
width = len(tree_right[0]) width = len(tree_right[0]) # type: ignore
else: else:
if len(sent) > 0: if len(sent) > 0:
width = max( width = max(
@ -400,9 +401,9 @@ class Visualizer:
widths.append(width) widths.append(width)
data = [ data = [
[ [
tree_right[token_index] tree_right[token_index] # type: ignore
if column.attribute == "tree_right" if column.attribute == "tree_right"
else tree_left[token_index] else tree_left[token_index] # type: ignore
if column.attribute == "tree_left" if column.attribute == "tree_left"
else column.render(token, right_pad_to_length=widths[column_index]) else column.render(token, right_pad_to_length=widths[column_index])
for column_index, column in enumerate(columns) for column_index, column in enumerate(columns)