mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 08:12:24 +03:00
Add types to everything
This commit is contained in:
parent
9d3e3e6be2
commit
9d0ae2407b
|
@ -3,7 +3,7 @@ import re
|
||||||
from wasabi import msg
|
from wasabi import msg
|
||||||
import typer
|
import typer
|
||||||
from thinc.api import Config
|
from thinc.api import Config
|
||||||
from typing import Any, Dict, Iterable
|
from typing import Any, Dict, Iterable, List, Union
|
||||||
|
|
||||||
import spacy
|
import spacy
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
|
@ -23,7 +23,9 @@ LISTENER_ARCHS = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _deep_get(obj: Dict[str, Any], key: Iterable[str], default: Any):
|
def _deep_get(
|
||||||
|
obj: Union[Dict[str, Any], Config], key: Iterable[str], default: Any
|
||||||
|
) -> Any:
|
||||||
"""Given a multi-part key, try to get the key. If at any point this isn't
|
"""Given a multi-part key, try to get the key. If at any point this isn't
|
||||||
possible, return the default.
|
possible, return the default.
|
||||||
"""
|
"""
|
||||||
|
@ -36,7 +38,7 @@ def _deep_get(obj: Dict[str, Any], key: Iterable[str], default: Any):
|
||||||
return slot
|
return slot
|
||||||
|
|
||||||
|
|
||||||
def _get_tok2vecs(config):
|
def _get_tok2vecs(config: Config) -> List[str]:
|
||||||
"""Given a pipeline config, return the names of components that are
|
"""Given a pipeline config, return the names of components that are
|
||||||
tok2vecs (or Transformers).
|
tok2vecs (or Transformers).
|
||||||
"""
|
"""
|
||||||
|
@ -52,7 +54,7 @@ def _get_tok2vecs(config):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def _has_listener(nlp, pipe_name):
|
def _has_listener(nlp: Language, pipe_name: str):
|
||||||
"""Given a pipeline and a component name, check if it has a listener."""
|
"""Given a pipeline and a component name, check if it has a listener."""
|
||||||
arch = _deep_get(
|
arch = _deep_get(
|
||||||
nlp.config,
|
nlp.config,
|
||||||
|
@ -65,7 +67,7 @@ def _has_listener(nlp, pipe_name):
|
||||||
return (ns, model) in LISTENER_ARCHS
|
return (ns, model) in LISTENER_ARCHS
|
||||||
|
|
||||||
|
|
||||||
def _get_listeners(nlp):
|
def _get_listeners(nlp: Language) -> List[str]:
|
||||||
"""Get the name of every component that contains a listener.
|
"""Get the name of every component that contains a listener.
|
||||||
|
|
||||||
Does not check that they listen to the same thing; assumes a pipeline has
|
Does not check that they listen to the same thing; assumes a pipeline has
|
||||||
|
@ -78,7 +80,7 @@ def _get_listeners(nlp):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def _increment_suffix(name):
|
def _increment_suffix(name: str) -> str:
|
||||||
"""Given a name, return an incremented version.
|
"""Given a name, return an incremented version.
|
||||||
|
|
||||||
If no numeric suffix is found, return the original with "2" appended.
|
If no numeric suffix is found, return the original with "2" appended.
|
||||||
|
@ -90,12 +92,12 @@ def _increment_suffix(name):
|
||||||
if res is None:
|
if res is None:
|
||||||
return f"{name}2"
|
return f"{name}2"
|
||||||
else:
|
else:
|
||||||
num = res.match
|
num = res.group()
|
||||||
prefix = name[0 : -len(num)]
|
prefix = name[0 : -len(num)]
|
||||||
return f"{prefix}{int(num) + 1}"
|
return f"{prefix}{int(num) + 1}"
|
||||||
|
|
||||||
|
|
||||||
def _check_single_tok2vec(name, config):
|
def _check_single_tok2vec(name: str, config: Config) -> None:
|
||||||
"""Check if there is just one tok2vec in a config.
|
"""Check if there is just one tok2vec in a config.
|
||||||
|
|
||||||
A very simple check, but used in multiple functions.
|
A very simple check, but used in multiple functions.
|
||||||
|
@ -108,7 +110,7 @@ def _check_single_tok2vec(name, config):
|
||||||
msg.fail(fail_msg, exits=1)
|
msg.fail(fail_msg, exits=1)
|
||||||
|
|
||||||
|
|
||||||
def _check_pipeline_names(nlp, nlp2):
|
def _check_pipeline_names(nlp: Language, nlp2: Language) -> Dict[str, str]:
|
||||||
"""Given two pipelines, try to rename any collisions in component names.
|
"""Given two pipelines, try to rename any collisions in component names.
|
||||||
|
|
||||||
If a simple increment of a numeric suffix doesn't work, will give up.
|
If a simple increment of a numeric suffix doesn't work, will give up.
|
||||||
|
@ -139,7 +141,7 @@ def configure_resume_cli(
|
||||||
base_model: Path = Arg(..., help="Path or name of base model to use for config"),
|
base_model: Path = Arg(..., help="Path or name of base model to use for config"),
|
||||||
output_file: Path = Arg(..., help="File to save the config to or - for stdout (will only output config and no additional logging info)", allow_dash=True),
|
output_file: Path = Arg(..., help="File to save the config to or - for stdout (will only output config and no additional logging info)", allow_dash=True),
|
||||||
# fmt: on
|
# fmt: on
|
||||||
):
|
) -> Config:
|
||||||
"""Create a config for resuming training.
|
"""Create a config for resuming training.
|
||||||
|
|
||||||
A config for resuming training is the same as the input config, but with
|
A config for resuming training is the same as the input config, but with
|
||||||
|
@ -247,7 +249,9 @@ def use_tok2vec(base_model: str, output_file: Path) -> Config:
|
||||||
return nlp.config
|
return nlp.config
|
||||||
|
|
||||||
|
|
||||||
def _inner_merge(nlp, nlp2, replace_listeners=False) -> Language:
|
def _inner_merge(
|
||||||
|
nlp: Language, nlp2: Language, replace_listeners: bool = False
|
||||||
|
) -> Language:
|
||||||
"""Actually do the merge.
|
"""Actually do the merge.
|
||||||
|
|
||||||
nlp: Base pipeline to add components to.
|
nlp: Base pipeline to add components to.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user