mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 08:12:24 +03:00
Add types to everything
This commit is contained in:
parent
9d3e3e6be2
commit
9d0ae2407b
|
@ -3,7 +3,7 @@ import re
|
|||
from wasabi import msg
|
||||
import typer
|
||||
from thinc.api import Config
|
||||
from typing import Any, Dict, Iterable
|
||||
from typing import Any, Dict, Iterable, List, Union
|
||||
|
||||
import spacy
|
||||
from spacy.language import Language
|
||||
|
@ -23,7 +23,9 @@ LISTENER_ARCHS = [
|
|||
]
|
||||
|
||||
|
||||
def _deep_get(obj: Dict[str, Any], key: Iterable[str], default: Any):
|
||||
def _deep_get(
|
||||
obj: Union[Dict[str, Any], Config], key: Iterable[str], default: Any
|
||||
) -> Any:
|
||||
"""Given a multi-part key, try to get the key. If at any point this isn't
|
||||
possible, return the default.
|
||||
"""
|
||||
|
@ -36,7 +38,7 @@ def _deep_get(obj: Dict[str, Any], key: Iterable[str], default: Any):
|
|||
return slot
|
||||
|
||||
|
||||
def _get_tok2vecs(config):
|
||||
def _get_tok2vecs(config: Config) -> List[str]:
|
||||
"""Given a pipeline config, return the names of components that are
|
||||
tok2vecs (or Transformers).
|
||||
"""
|
||||
|
@ -52,7 +54,7 @@ def _get_tok2vecs(config):
|
|||
return out
|
||||
|
||||
|
||||
def _has_listener(nlp, pipe_name):
|
||||
def _has_listener(nlp: Language, pipe_name: str):
|
||||
"""Given a pipeline and a component name, check if it has a listener."""
|
||||
arch = _deep_get(
|
||||
nlp.config,
|
||||
|
@ -65,7 +67,7 @@ def _has_listener(nlp, pipe_name):
|
|||
return (ns, model) in LISTENER_ARCHS
|
||||
|
||||
|
||||
def _get_listeners(nlp):
|
||||
def _get_listeners(nlp: Language) -> List[str]:
|
||||
"""Get the name of every component that contains a listener.
|
||||
|
||||
Does not check that they listen to the same thing; assumes a pipeline has
|
||||
|
@ -78,7 +80,7 @@ def _get_listeners(nlp):
|
|||
return out
|
||||
|
||||
|
||||
def _increment_suffix(name):
|
||||
def _increment_suffix(name: str) -> str:
|
||||
"""Given a name, return an incremented version.
|
||||
|
||||
If no numeric suffix is found, return the original with "2" appended.
|
||||
|
@ -90,12 +92,12 @@ def _increment_suffix(name):
|
|||
if res is None:
|
||||
return f"{name}2"
|
||||
else:
|
||||
num = res.match
|
||||
num = res.group()
|
||||
prefix = name[0 : -len(num)]
|
||||
return f"{prefix}{int(num) + 1}"
|
||||
|
||||
|
||||
def _check_single_tok2vec(name, config):
|
||||
def _check_single_tok2vec(name: str, config: Config) -> None:
|
||||
"""Check if there is just one tok2vec in a config.
|
||||
|
||||
A very simple check, but used in multiple functions.
|
||||
|
@ -108,7 +110,7 @@ def _check_single_tok2vec(name, config):
|
|||
msg.fail(fail_msg, exits=1)
|
||||
|
||||
|
||||
def _check_pipeline_names(nlp, nlp2):
|
||||
def _check_pipeline_names(nlp: Language, nlp2: Language) -> Dict[str, str]:
|
||||
"""Given two pipelines, try to rename any collisions in component names.
|
||||
|
||||
If a simple increment of a numeric suffix doesn't work, will give up.
|
||||
|
@ -139,7 +141,7 @@ def configure_resume_cli(
|
|||
base_model: Path = Arg(..., help="Path or name of base model to use for config"),
|
||||
output_file: Path = Arg(..., help="File to save the config to or - for stdout (will only output config and no additional logging info)", allow_dash=True),
|
||||
# fmt: on
|
||||
):
|
||||
) -> Config:
|
||||
"""Create a config for resuming training.
|
||||
|
||||
A config for resuming training is the same as the input config, but with
|
||||
|
@ -247,7 +249,9 @@ def use_tok2vec(base_model: str, output_file: Path) -> Config:
|
|||
return nlp.config
|
||||
|
||||
|
||||
def _inner_merge(nlp, nlp2, replace_listeners=False) -> Language:
|
||||
def _inner_merge(
|
||||
nlp: Language, nlp2: Language, replace_listeners: bool = False
|
||||
) -> Language:
|
||||
"""Actually do the merge.
|
||||
|
||||
nlp: Base pipeline to add components to.
|
||||
|
|
Loading…
Reference in New Issue
Block a user