More type corrections

This commit is contained in:
Richard Hudson 2021-12-23 17:24:28 +01:00
parent 5c850b2ac3
commit 106fb53509

View File

@ -1,5 +1,5 @@
from os import linesep from os import linesep
from typing import Optional, Union, cast from typing import List, Optional, Union, cast
import wasabi import wasabi
from spacy.tokens import Span, Token, Doc from spacy.tokens import Span, Token, Doc
@ -50,12 +50,12 @@ class AttributeFormat:
*, *,
name: str = "", name: str = "",
aligns: str = "l", aligns: str = "l",
max_width: int = None, max_width: Optional[int] = None,
fg_color: Union[str, int] = None, fg_color: Union[str, int, None] = None,
bg_color: Union[str, int] = None, bg_color: Union[str, int, None] = None,
permitted_values: tuple = None, permitted_values: Optional[tuple] = None,
value_dependent_fg_colors: dict[str, Union[str, int]] = None, value_dependent_fg_colors: Optional[dict[str, Union[str, int]]] = None,
value_dependent_bg_colors: dict[str, Union[str, int]] = None, value_dependent_bg_colors: Optional[dict[str, Union[str, int]]] = None,
): ):
""" """
attribute: the token attribute, e.g. lemma_, ._.holmes.lemma attribute: the token attribute, e.g. lemma_, ._.holmes.lemma
@ -86,11 +86,11 @@ class AttributeFormat:
self, self,
token: Token, token: Token,
*, *,
right_pad_to_length: int = None, right_pad_to_length: Optional[int] = None,
ignore_colors: bool = False, ignore_colors: bool = False,
render_all_colors_within_values: bool = False, render_all_colors_within_values: bool = False,
whole_row_fg_color: Union[int, str] = None, whole_row_fg_color: Union[int, str, None] = None,
whole_row_bg_color: Union[int, str] = None, whole_row_bg_color: Union[int, str, None] = None,
) -> str: ) -> str:
""" """
ignore_colors: no colors should be rendered, typically because the values are required to calculate widths ignore_colors: no colors should be rendered, typically because the values are required to calculate widths
@ -136,7 +136,7 @@ class AttributeFormat:
class Visualizer: class Visualizer:
@staticmethod @staticmethod
def render_dependency_tree(sent: Span, root_right: bool) -> list[str]: def render_dependency_tree(sent: Span, root_right: bool) -> List[str]:
""" """
Returns an ASCII rendering of the document with a dependency tree for each sentence. The Returns an ASCII rendering of the document with a dependency tree for each sentence. The
dependency tree output for a given token has the same index within the output list of dependency tree output for a given token has the same index within the output list of
@ -151,7 +151,7 @@ class Visualizer:
# Check sent is really a sentence # Check sent is really a sentence
assert sent.start == sent[0].sent.start assert sent.start == sent[0].sent.start
assert sent.end == sent[0].sent.end assert sent.end == sent[0].sent.end
heads: list[Optional[int]] = [ heads: List[Optional[int]] = [
None None
if token.dep_.lower() == "root" or token.head.i == token.i if token.dep_.lower() == "root" or token.head.i == token.i
else token.head.i - sent.start else token.head.i - sent.start
@ -168,11 +168,11 @@ class Visualizer:
) )
== 0 == 0
) )
children_lists = [[] for _ in range(sent.end - sent.start)] children_lists: List[List[int]] = [[] for _ in range(sent.end - sent.start)]
for child, head in enumerate(heads): for child, head in enumerate(heads):
if head is not None: if head is not None:
children_lists[head].append(child) children_lists[head].append(child)
all_indices_ordered_by_column = [] all_indices_ordered_by_column: List[int] = []
# start with the root column # start with the root column
indices_in_current_column = [i for i, h in enumerate(heads) if h is None] indices_in_current_column = [i for i, h in enumerate(heads) if h is None]
while len(indices_in_current_column) > 0: while len(indices_in_current_column) > 0:
@ -410,6 +410,7 @@ class Visualizer:
] ]
for token_index, token in enumerate(sent) for token_index, token in enumerate(sent)
] ]
header: Optional[List[str]]
if len([1 for c in columns if len(c.name) > 0]) > 0: if len([1 for c in columns if len(c.name) > 0]) > 0:
header = [c.name for c in columns] header = [c.name for c in columns]
else: else:
@ -432,7 +433,7 @@ class Visualizer:
) )
return return_string return return_string
def render_text(self, doc: Doc, attributes: list[AttributeFormat]) -> str: def render_text(self, doc: Doc, attributes: List[AttributeFormat]) -> str:
"""Renders a text interspersed with attribute labels. """Renders a text interspersed with attribute labels.
TODO: specify a specific portion of the document to display. TODO: specify a specific portion of the document to display.
@ -463,8 +464,8 @@ class Visualizer:
self, self,
doc: Doc, doc: Doc,
*, *,
search_attributes: list[AttributeFormat], search_attributes: List[AttributeFormat],
display_columns: list[AttributeFormat], display_columns: List[AttributeFormat],
group: bool, group: bool,
spacing: int, spacing: int,
surrounding_tokens_height: int, surrounding_tokens_height: int,
@ -580,6 +581,7 @@ class Visualizer:
!= matched_tokens[matched_token_index + 1].i != matched_tokens[matched_token_index + 1].i
): ):
rows.append([]) rows.append([])
header: Optional[List[str]]
if len([1 for c in display_columns if len(c.name) > 0]) > 0: if len([1 for c in display_columns if len(c.name) > 0]) > 0:
header = [c.name for c in display_columns] header = [c.name for c in display_columns]
else: else: