From 9d0ae2407b82f0b46b19a927b13ff743f9b3051b Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Tue, 31 Jan 2023 13:08:51 +0900 Subject: [PATCH] Add types to everything --- spacy/cli/configure.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/spacy/cli/configure.py b/spacy/cli/configure.py index 488978827..884ba68fe 100644 --- a/spacy/cli/configure.py +++ b/spacy/cli/configure.py @@ -3,7 +3,7 @@ import re from wasabi import msg import typer from thinc.api import Config -from typing import Any, Dict, Iterable +from typing import Any, Dict, Iterable, List, Union import spacy from spacy.language import Language @@ -23,7 +23,9 @@ LISTENER_ARCHS = [ ] -def _deep_get(obj: Dict[str, Any], key: Iterable[str], default: Any): +def _deep_get( + obj: Union[Dict[str, Any], Config], key: Iterable[str], default: Any +) -> Any: """Given a multi-part key, try to get the key. If at any point this isn't possible, return the default. """ @@ -36,7 +38,7 @@ def _deep_get(obj: Dict[str, Any], key: Iterable[str], default: Any): return slot -def _get_tok2vecs(config): +def _get_tok2vecs(config: Config) -> List[str]: """Given a pipeline config, return the names of components that are tok2vecs (or Transformers). """ @@ -52,7 +54,7 @@ def _get_tok2vecs(config): return out -def _has_listener(nlp, pipe_name): +def _has_listener(nlp: Language, pipe_name: str): """Given a pipeline and a component name, check if it has a listener.""" arch = _deep_get( nlp.config, @@ -65,7 +67,7 @@ def _has_listener(nlp, pipe_name): return (ns, model) in LISTENER_ARCHS -def _get_listeners(nlp): +def _get_listeners(nlp: Language) -> List[str]: """Get the name of every component that contains a listener. Does not check that they listen to the same thing; assumes a pipeline has @@ -78,7 +80,7 @@ def _get_listeners(nlp): return out -def _increment_suffix(name): +def _increment_suffix(name: str) -> str: """Given a name, return an incremented version. If no numeric suffix is found, return the original with "2" appended. @@ -90,12 +92,12 @@ def _increment_suffix(name): if res is None: return f"{name}2" else: - num = res.match + num = res.group() prefix = name[0 : -len(num)] return f"{prefix}{int(num) + 1}" -def _check_single_tok2vec(name, config): +def _check_single_tok2vec(name: str, config: Config) -> None: """Check if there is just one tok2vec in a config. A very simple check, but used in multiple functions. @@ -108,7 +110,7 @@ def _check_single_tok2vec(name, config): msg.fail(fail_msg, exits=1) -def _check_pipeline_names(nlp, nlp2): +def _check_pipeline_names(nlp: Language, nlp2: Language) -> Dict[str, str]: """Given two pipelines, try to rename any collisions in component names. If a simple increment of a numeric suffix doesn't work, will give up. @@ -139,7 +141,7 @@ def configure_resume_cli( base_model: Path = Arg(..., help="Path or name of base model to use for config"), output_file: Path = Arg(..., help="File to save the config to or - for stdout (will only output config and no additional logging info)", allow_dash=True), # fmt: on -): +) -> Config: """Create a config for resuming training. A config for resuming training is the same as the input config, but with @@ -247,7 +249,9 @@ def use_tok2vec(base_model: str, output_file: Path) -> Config: return nlp.config -def _inner_merge(nlp, nlp2, replace_listeners=False) -> Language: +def _inner_merge( + nlp: Language, nlp2: Language, replace_listeners: bool = False +) -> Language: """Actually do the merge. nlp: Base pipeline to add components to.