Add types to everything

This commit is contained in:
Paul O'Leary McCann 2023-01-31 13:08:51 +09:00
parent 9d3e3e6be2
commit 9d0ae2407b

View File

@ -3,7 +3,7 @@ import re
from wasabi import msg
import typer
from thinc.api import Config
from typing import Any, Dict, Iterable
from typing import Any, Dict, Iterable, List, Union
import spacy
from spacy.language import Language
@ -23,7 +23,9 @@ LISTENER_ARCHS = [
]
def _deep_get(obj: Dict[str, Any], key: Iterable[str], default: Any):
def _deep_get(
obj: Union[Dict[str, Any], Config], key: Iterable[str], default: Any
) -> Any:
"""Given a multi-part key, try to get the key. If at any point this isn't
possible, return the default.
"""
@ -36,7 +38,7 @@ def _deep_get(obj: Dict[str, Any], key: Iterable[str], default: Any):
return slot
def _get_tok2vecs(config):
def _get_tok2vecs(config: Config) -> List[str]:
"""Given a pipeline config, return the names of components that are
tok2vecs (or Transformers).
"""
@ -52,7 +54,7 @@ def _get_tok2vecs(config):
return out
def _has_listener(nlp, pipe_name):
def _has_listener(nlp: Language, pipe_name: str):
"""Given a pipeline and a component name, check if it has a listener."""
arch = _deep_get(
nlp.config,
@ -65,7 +67,7 @@ def _has_listener(nlp, pipe_name):
return (ns, model) in LISTENER_ARCHS
def _get_listeners(nlp):
def _get_listeners(nlp: Language) -> List[str]:
"""Get the name of every component that contains a listener.
Does not check that they listen to the same thing; assumes a pipeline has
@ -78,7 +80,7 @@ def _get_listeners(nlp):
return out
def _increment_suffix(name):
def _increment_suffix(name: str) -> str:
"""Given a name, return an incremented version.
If no numeric suffix is found, return the original with "2" appended.
@ -90,12 +92,12 @@ def _increment_suffix(name):
if res is None:
return f"{name}2"
else:
num = res.match
num = res.group()
prefix = name[0 : -len(num)]
return f"{prefix}{int(num) + 1}"
def _check_single_tok2vec(name, config):
def _check_single_tok2vec(name: str, config: Config) -> None:
"""Check if there is just one tok2vec in a config.
A very simple check, but used in multiple functions.
@ -108,7 +110,7 @@ def _check_single_tok2vec(name, config):
msg.fail(fail_msg, exits=1)
def _check_pipeline_names(nlp, nlp2):
def _check_pipeline_names(nlp: Language, nlp2: Language) -> Dict[str, str]:
"""Given two pipelines, try to rename any collisions in component names.
If a simple increment of a numeric suffix doesn't work, will give up.
@ -139,7 +141,7 @@ def configure_resume_cli(
base_model: Path = Arg(..., help="Path or name of base model to use for config"),
output_file: Path = Arg(..., help="File to save the config to or - for stdout (will only output config and no additional logging info)", allow_dash=True),
# fmt: on
):
) -> Config:
"""Create a config for resuming training.
A config for resuming training is the same as the input config, but with
@ -247,7 +249,9 @@ def use_tok2vec(base_model: str, output_file: Path) -> Config:
return nlp.config
def _inner_merge(nlp, nlp2, replace_listeners=False) -> Language:
def _inner_merge(
nlp: Language, nlp2: Language, replace_listeners: bool = False
) -> Language:
"""Actually do the merge.
nlp: Base pipeline to add components to.