diff --git a/.github/azure-steps.yml b/.github/azure-steps.yml
index 543804b9f..f19236dc8 100644
--- a/.github/azure-steps.yml
+++ b/.github/azure-steps.yml
@@ -25,6 +25,9 @@ steps:
${{ parameters.prefix }} python setup.py sdist --formats=gztar
displayName: "Compile and build sdist"
+ - script: python -m mypy spacy
+ displayName: 'Run mypy'
+
- task: DeleteFiles@1
inputs:
contents: "spacy"
diff --git a/.github/contributors/connorbrinton.md b/.github/contributors/connorbrinton.md
new file mode 100644
index 000000000..25d03b494
--- /dev/null
+++ b/.github/contributors/connorbrinton.md
@@ -0,0 +1,106 @@
+# spaCy contributor agreement
+
+This spaCy Contributor Agreement (**"SCA"**) is based on the
+[Oracle Contributor Agreement](http://www.oracle.com/technetwork/oca-405177.pdf).
+The SCA applies to any contribution that you make to any product or project
+managed by us (the **"project"**), and sets out the intellectual property rights
+you grant to us in the contributed materials. The term **"us"** shall mean
+[ExplosionAI GmbH](https://explosion.ai/legal). The term
+**"you"** shall mean the person or entity identified below.
+
+If you agree to be bound by these terms, fill in the information requested
+below and include the filled-in version with your first pull request, under the
+folder [`.github/contributors/`](/.github/contributors/). The name of the file
+should be your GitHub username, with the extension `.md`. For example, the user
+example_user would create the file `.github/contributors/example_user.md`.
+
+Read this agreement carefully before signing. These terms and conditions
+constitute a binding legal agreement.
+
+## Contributor Agreement
+
+1. The term "contribution" or "contributed materials" means any source code,
+object code, patch, tool, sample, graphic, specification, manual,
+documentation, or any other material posted or submitted by you to the project.
+
+2. With respect to any worldwide copyrights, or copyright applications and
+registrations, in your contribution:
+
+ * you hereby assign to us joint ownership, and to the extent that such
+ assignment is or becomes invalid, ineffective or unenforceable, you hereby
+ grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge,
+ royalty-free, unrestricted license to exercise all rights under those
+ copyrights. This includes, at our option, the right to sublicense these same
+ rights to third parties through multiple levels of sublicensees or other
+ licensing arrangements;
+
+ * you agree that each of us can do all things in relation to your
+ contribution as if each of us were the sole owners, and if one of us makes
+ a derivative work of your contribution, the one who makes the derivative
+ work (or has it made will be the sole owner of that derivative work;
+
+ * you agree that you will not assert any moral rights in your contribution
+ against us, our licensees or transferees;
+
+ * you agree that we may register a copyright in your contribution and
+ exercise all ownership rights associated with it; and
+
+ * you agree that neither of us has any duty to consult with, obtain the
+ consent of, pay or render an accounting to the other for any use or
+ distribution of your contribution.
+
+3. With respect to any patents you own, or that you can license without payment
+to any third party, you hereby grant to us a perpetual, irrevocable,
+non-exclusive, worldwide, no-charge, royalty-free license to:
+
+ * make, have made, use, sell, offer to sell, import, and otherwise transfer
+ your contribution in whole or in part, alone or in combination with or
+ included in any product, work or materials arising out of the project to
+ which your contribution was submitted, and
+
+ * at our option, to sublicense these same rights to third parties through
+ multiple levels of sublicensees or other licensing arrangements.
+
+4. Except as set out above, you keep all right, title, and interest in your
+contribution. The rights that you grant to us under these terms are effective
+on the date you first submitted a contribution to us, even if your submission
+took place before the date you sign these terms.
+
+5. You covenant, represent, warrant and agree that:
+
+ * Each contribution that you submit is and shall be an original work of
+ authorship and you can legally grant the rights set out in this SCA;
+
+ * to the best of your knowledge, each contribution will not violate any
+ third party's copyrights, trademarks, patents, or other intellectual
+ property rights; and
+
+ * each contribution shall be in compliance with U.S. export control laws and
+ other applicable export and import laws. You agree to notify us if you
+ become aware of any circumstance which would make any of the foregoing
+ representations inaccurate in any respect. We may publicly disclose your
+ participation in the project, including the fact that you have signed the SCA.
+
+6. This SCA is governed by the laws of the State of California and applicable
+U.S. Federal law. Any choice of law rules will not apply.
+
+7. Please place an “x” on one of the applicable statement below. Please do NOT
+mark both statements:
+
+ * [x] I am signing on behalf of myself as an individual and no other person
+ or entity, including my employer, has or will have rights with respect to my
+ contributions.
+
+ * [ ] I am signing on behalf of my employer or a legal entity and I have the
+ actual authority to contractually bind that entity.
+
+## Contributor Details
+
+| Field | Entry |
+|------------------------------- | -------------------- |
+| Name | Connor Brinton |
+| Company name (if applicable) | |
+| Title or role (if applicable) | |
+| Date | July 20th, 2021 |
+| GitHub username | connorbrinton |
+| Website (optional) | |
diff --git a/requirements.txt b/requirements.txt
index 12fdf650f..4511be3fc 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -29,3 +29,7 @@ pytest-timeout>=1.3.0,<2.0.0
mock>=2.0.0,<3.0.0
flake8>=3.8.0,<3.10.0
hypothesis>=3.27.0,<7.0.0
+mypy>=0.910
+types-dataclasses>=0.1.3; python_version < "3.7"
+types-mock>=0.1.1
+types-requests
diff --git a/setup.cfg b/setup.cfg
index 2e7be5e12..5dd0227f2 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -129,3 +129,4 @@ markers =
ignore_missing_imports = True
no_implicit_optional = True
plugins = pydantic.mypy, thinc.mypy
+allow_redefinition = True
diff --git a/spacy/cli/_util.py b/spacy/cli/_util.py
index 127bba55a..fb680d888 100644
--- a/spacy/cli/_util.py
+++ b/spacy/cli/_util.py
@@ -1,4 +1,5 @@
-from typing import Dict, Any, Union, List, Optional, Tuple, Iterable, TYPE_CHECKING
+from typing import Dict, Any, Union, List, Optional, Tuple, Iterable
+from typing import TYPE_CHECKING, overload
import sys
import shutil
from pathlib import Path
@@ -15,6 +16,7 @@ from thinc.util import has_cupy, gpu_is_available
from configparser import InterpolationError
import os
+from ..compat import Literal
from ..schemas import ProjectConfigSchema, validate
from ..util import import_file, run_command, make_tempdir, registry, logger
from ..util import is_compatible_version, SimpleFrozenDict, ENV_VARS
@@ -260,15 +262,16 @@ def get_checksum(path: Union[Path, str]) -> str:
RETURNS (str): The checksum.
"""
path = Path(path)
+ if not (path.is_file() or path.is_dir()):
+ msg.fail(f"Can't get checksum for {path}: not a file or directory", exits=1)
if path.is_file():
return hashlib.md5(Path(path).read_bytes()).hexdigest()
- if path.is_dir():
+ else:
# TODO: this is currently pretty slow
dir_checksum = hashlib.md5()
for sub_file in sorted(fp for fp in path.rglob("*") if fp.is_file()):
dir_checksum.update(sub_file.read_bytes())
return dir_checksum.hexdigest()
- msg.fail(f"Can't get checksum for {path}: not a file or directory", exits=1)
@contextmanager
@@ -468,12 +471,15 @@ def get_git_version(
RETURNS (Tuple[int, int]): The version as a (major, minor) tuple. Returns
(0, 0) if the version couldn't be determined.
"""
- ret = run_command("git --version", capture=True)
+ try:
+ ret = run_command("git --version", capture=True)
+ except:
+ raise RuntimeError(error)
stdout = ret.stdout.strip()
if not stdout or not stdout.startswith("git version"):
- return (0, 0)
+ return 0, 0
version = stdout[11:].strip().split(".")
- return (int(version[0]), int(version[1]))
+ return int(version[0]), int(version[1])
def _http_to_git(repo: str) -> str:
@@ -500,6 +506,16 @@ def is_subpath_of(parent, child):
return os.path.commonpath([parent_realpath, child_realpath]) == parent_realpath
+@overload
+def string_to_list(value: str, intify: Literal[False] = ...) -> List[str]:
+ ...
+
+
+@overload
+def string_to_list(value: str, intify: Literal[True]) -> List[int]:
+ ...
+
+
def string_to_list(value: str, intify: bool = False) -> Union[List[str], List[int]]:
"""Parse a comma-separated string to a list and account for various
formatting options. Mostly used to handle CLI arguments that take a list of
@@ -510,7 +526,7 @@ def string_to_list(value: str, intify: bool = False) -> Union[List[str], List[in
RETURNS (Union[List[str], List[int]]): A list of strings or ints.
"""
if not value:
- return []
+ return [] # type: ignore[return-value]
if value.startswith("[") and value.endswith("]"):
value = value[1:-1]
result = []
@@ -522,7 +538,7 @@ def string_to_list(value: str, intify: bool = False) -> Union[List[str], List[in
p = p[1:-1]
p = p.strip()
if intify:
- p = int(p)
+ p = int(p) # type: ignore[assignment]
result.append(p)
return result
diff --git a/spacy/cli/convert.py b/spacy/cli/convert.py
index c84aa6431..95df6bea4 100644
--- a/spacy/cli/convert.py
+++ b/spacy/cli/convert.py
@@ -1,4 +1,4 @@
-from typing import Optional, Any, List, Union
+from typing import Callable, Iterable, Mapping, Optional, Any, List, Union
from enum import Enum
from pathlib import Path
from wasabi import Printer
@@ -9,7 +9,7 @@ import itertools
from ._util import app, Arg, Opt
from ..training import docs_to_json
-from ..tokens import DocBin
+from ..tokens import Doc, DocBin
from ..training.converters import iob_to_docs, conll_ner_to_docs, json_to_docs
from ..training.converters import conllu_to_docs
@@ -19,7 +19,7 @@ from ..training.converters import conllu_to_docs
# entry to this dict with the file extension mapped to the converter function
# imported from /converters.
-CONVERTERS = {
+CONVERTERS: Mapping[str, Callable[..., Iterable[Doc]]] = {
"conllubio": conllu_to_docs,
"conllu": conllu_to_docs,
"conll": conll_ner_to_docs,
@@ -66,19 +66,16 @@ def convert_cli(
DOCS: https://spacy.io/api/cli#convert
"""
- if isinstance(file_type, FileTypes):
- # We get an instance of the FileTypes from the CLI so we need its string value
- file_type = file_type.value
input_path = Path(input_path)
- output_dir = "-" if output_dir == Path("-") else output_dir
+ output_dir: Union[str, Path] = "-" if output_dir == Path("-") else output_dir
silent = output_dir == "-"
msg = Printer(no_print=silent)
- verify_cli_args(msg, input_path, output_dir, file_type, converter, ner_map)
+ verify_cli_args(msg, input_path, output_dir, file_type.value, converter, ner_map)
converter = _get_converter(msg, converter, input_path)
convert(
input_path,
output_dir,
- file_type=file_type,
+ file_type=file_type.value,
n_sents=n_sents,
seg_sents=seg_sents,
model=model,
@@ -94,7 +91,7 @@ def convert_cli(
def convert(
- input_path: Union[str, Path],
+ input_path: Path,
output_dir: Union[str, Path],
*,
file_type: str = "json",
@@ -114,7 +111,7 @@ def convert(
msg = Printer(no_print=silent)
ner_map = srsly.read_json(ner_map) if ner_map is not None else None
doc_files = []
- for input_loc in walk_directory(Path(input_path), converter):
+ for input_loc in walk_directory(input_path, converter):
with input_loc.open("r", encoding="utf-8") as infile:
input_data = infile.read()
# Use converter function to convert data
@@ -141,7 +138,7 @@ def convert(
else:
db = DocBin(docs=docs, store_user_data=True)
len_docs = len(db)
- data = db.to_bytes()
+ data = db.to_bytes() # type: ignore[assignment]
if output_dir == "-":
_print_docs_to_stdout(data, file_type)
else:
@@ -220,13 +217,12 @@ def walk_directory(path: Path, converter: str) -> List[Path]:
def verify_cli_args(
msg: Printer,
- input_path: Union[str, Path],
+ input_path: Path,
output_dir: Union[str, Path],
- file_type: FileTypes,
+ file_type: str,
converter: str,
ner_map: Optional[Path],
):
- input_path = Path(input_path)
if file_type not in FILE_TYPES_STDOUT and output_dir == "-":
msg.fail(
f"Can't write .{file_type} data to stdout. Please specify an output directory.",
@@ -244,13 +240,13 @@ def verify_cli_args(
msg.fail("No input files in directory", input_path, exits=1)
file_types = list(set([loc.suffix[1:] for loc in input_locs]))
if converter == "auto" and len(file_types) >= 2:
- file_types = ",".join(file_types)
- msg.fail("All input files must be same type", file_types, exits=1)
+ file_types_str = ",".join(file_types)
+ msg.fail("All input files must be same type", file_types_str, exits=1)
if converter != "auto" and converter not in CONVERTERS:
msg.fail(f"Can't find converter for {converter}", exits=1)
-def _get_converter(msg, converter, input_path):
+def _get_converter(msg, converter, input_path: Path):
if input_path.is_dir():
input_path = walk_directory(input_path, converter)[0]
if converter == "auto":
diff --git a/spacy/cli/debug_data.py b/spacy/cli/debug_data.py
index 3f368f57d..d53384a25 100644
--- a/spacy/cli/debug_data.py
+++ b/spacy/cli/debug_data.py
@@ -1,4 +1,5 @@
-from typing import List, Sequence, Dict, Any, Tuple, Optional, Set
+from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union
+from typing import cast, overload
from pathlib import Path
from collections import Counter
import sys
@@ -17,6 +18,7 @@ from ..pipeline import Morphologizer
from ..morphology import Morphology
from ..language import Language
from ..util import registry, resolve_dot_names
+from ..compat import Literal
from .. import util
@@ -378,10 +380,11 @@ def debug_data(
if "tagger" in factory_names:
msg.divider("Part-of-speech Tagging")
- labels = [label for label in gold_train_data["tags"]]
+ label_list = [label for label in gold_train_data["tags"]]
model_labels = _get_labels_from_model(nlp, "tagger")
- msg.info(f"{len(labels)} label(s) in train data")
- missing_labels = model_labels - set(labels)
+ msg.info(f"{len(label_list)} label(s) in train data")
+ labels = set(label_list)
+ missing_labels = model_labels - labels
if missing_labels:
msg.warn(
"Some model labels are not present in the train data. The "
@@ -395,10 +398,11 @@ def debug_data(
if "morphologizer" in factory_names:
msg.divider("Morphologizer (POS+Morph)")
- labels = [label for label in gold_train_data["morphs"]]
+ label_list = [label for label in gold_train_data["morphs"]]
model_labels = _get_labels_from_model(nlp, "morphologizer")
- msg.info(f"{len(labels)} label(s) in train data")
- missing_labels = model_labels - set(labels)
+ msg.info(f"{len(label_list)} label(s) in train data")
+ labels = set(label_list)
+ missing_labels = model_labels - labels
if missing_labels:
msg.warn(
"Some model labels are not present in the train data. The "
@@ -565,7 +569,7 @@ def _compile_gold(
nlp: Language,
make_proj: bool,
) -> Dict[str, Any]:
- data = {
+ data: Dict[str, Any] = {
"ner": Counter(),
"cats": Counter(),
"tags": Counter(),
@@ -670,10 +674,28 @@ def _compile_gold(
return data
-def _format_labels(labels: List[Tuple[str, int]], counts: bool = False) -> str:
+@overload
+def _format_labels(labels: Iterable[str], counts: Literal[False] = False) -> str:
+ ...
+
+
+@overload
+def _format_labels(
+ labels: Iterable[Tuple[str, int]],
+ counts: Literal[True],
+) -> str:
+ ...
+
+
+def _format_labels(
+ labels: Union[Iterable[str], Iterable[Tuple[str, int]]],
+ counts: bool = False,
+) -> str:
if counts:
- return ", ".join([f"'{l}' ({c})" for l, c in labels])
- return ", ".join([f"'{l}'" for l in labels])
+ return ", ".join(
+ [f"'{l}' ({c})" for l, c in cast(Iterable[Tuple[str, int]], labels)]
+ )
+ return ", ".join([f"'{l}'" for l in cast(Iterable[str], labels)])
def _get_examples_without_label(data: Sequence[Example], label: str) -> int:
diff --git a/spacy/cli/evaluate.py b/spacy/cli/evaluate.py
index 378911a20..0d08d2c5e 100644
--- a/spacy/cli/evaluate.py
+++ b/spacy/cli/evaluate.py
@@ -136,7 +136,7 @@ def evaluate(
def handle_scores_per_type(
- scores: Union[Scorer, Dict[str, Any]],
+ scores: Dict[str, Any],
data: Dict[str, Any] = {},
*,
spans_key: str = "sc",
diff --git a/spacy/cli/info.py b/spacy/cli/info.py
index 8cc7018ff..e6a1cb616 100644
--- a/spacy/cli/info.py
+++ b/spacy/cli/info.py
@@ -15,7 +15,7 @@ def info_cli(
model: Optional[str] = Arg(None, help="Optional loadable spaCy pipeline"),
markdown: bool = Opt(False, "--markdown", "-md", help="Generate Markdown for GitHub issues"),
silent: bool = Opt(False, "--silent", "-s", "-S", help="Don't print anything (just return)"),
- exclude: Optional[str] = Opt("labels", "--exclude", "-e", help="Comma-separated keys to exclude from the print-out"),
+ exclude: str = Opt("labels", "--exclude", "-e", help="Comma-separated keys to exclude from the print-out"),
# fmt: on
):
"""
@@ -61,7 +61,7 @@ def info(
return raw_data
-def info_spacy() -> Dict[str, any]:
+def info_spacy() -> Dict[str, Any]:
"""Generate info about the current spaCy intallation.
RETURNS (dict): The spaCy info.
diff --git a/spacy/cli/init_config.py b/spacy/cli/init_config.py
index 55622452b..530b38eb3 100644
--- a/spacy/cli/init_config.py
+++ b/spacy/cli/init_config.py
@@ -28,8 +28,8 @@ class Optimizations(str, Enum):
def init_config_cli(
# fmt: off
output_file: Path = Arg(..., help="File to save config.cfg to or - for stdout (will only output config and no additional logging info)", allow_dash=True),
- lang: Optional[str] = Opt("en", "--lang", "-l", help="Two-letter code of the language to use"),
- pipeline: Optional[str] = Opt("tagger,parser,ner", "--pipeline", "-p", help="Comma-separated names of trainable pipeline components to include (without 'tok2vec' or 'transformer')"),
+ lang: str = Opt("en", "--lang", "-l", help="Two-letter code of the language to use"),
+ pipeline: str = Opt("tagger,parser,ner", "--pipeline", "-p", help="Comma-separated names of trainable pipeline components to include (without 'tok2vec' or 'transformer')"),
optimize: Optimizations = Opt(Optimizations.efficiency.value, "--optimize", "-o", help="Whether to optimize for efficiency (faster inference, smaller model, lower memory consumption) or higher accuracy (potentially larger and slower model). This will impact the choice of architecture, pretrained weights and related hyperparameters."),
gpu: bool = Opt(False, "--gpu", "-G", help="Whether the model can run on GPU. This will impact the choice of architecture, pretrained weights and related hyperparameters."),
pretraining: bool = Opt(False, "--pretraining", "-pt", help="Include config for pretraining (with 'spacy pretrain')"),
@@ -44,8 +44,6 @@ def init_config_cli(
DOCS: https://spacy.io/api/cli#init-config
"""
- if isinstance(optimize, Optimizations): # instance of enum from the CLI
- optimize = optimize.value
pipeline = string_to_list(pipeline)
is_stdout = str(output_file) == "-"
if not is_stdout and output_file.exists() and not force_overwrite:
@@ -57,7 +55,7 @@ def init_config_cli(
config = init_config(
lang=lang,
pipeline=pipeline,
- optimize=optimize,
+ optimize=optimize.value,
gpu=gpu,
pretraining=pretraining,
silent=is_stdout,
@@ -175,8 +173,8 @@ def init_config(
"Pipeline": ", ".join(pipeline),
"Optimize for": optimize,
"Hardware": variables["hardware"].upper(),
- "Transformer": template_vars.transformer.get("name")
- if template_vars.use_transformer
+ "Transformer": template_vars.transformer.get("name") # type: ignore[attr-defined]
+ if template_vars.use_transformer # type: ignore[attr-defined]
else None,
}
msg.info("Generated config template specific for your use case")
diff --git a/spacy/cli/package.py b/spacy/cli/package.py
index 332a51bc7..e76343dc3 100644
--- a/spacy/cli/package.py
+++ b/spacy/cli/package.py
@@ -1,4 +1,4 @@
-from typing import Optional, Union, Any, Dict, List, Tuple
+from typing import Optional, Union, Any, Dict, List, Tuple, cast
import shutil
from pathlib import Path
from wasabi import Printer, MarkdownRenderer, get_raw_input
@@ -215,9 +215,9 @@ def get_third_party_dependencies(
for reg_name, func_names in funcs.items():
for func_name in func_names:
func_info = util.registry.find(reg_name, func_name)
- module_name = func_info.get("module")
+ module_name = func_info.get("module") # type: ignore[attr-defined]
if module_name: # the code is part of a module, not a --code file
- modules.add(func_info["module"].split(".")[0])
+ modules.add(func_info["module"].split(".")[0]) # type: ignore[index]
dependencies = []
for module_name in modules:
if module_name in distributions:
@@ -227,7 +227,7 @@ def get_third_party_dependencies(
if pkg in own_packages or pkg in exclude:
continue
version = util.get_package_version(pkg)
- version_range = util.get_minor_version_range(version)
+ version_range = util.get_minor_version_range(version) # type: ignore[arg-type]
dependencies.append(f"{pkg}{version_range}")
return dependencies
@@ -252,7 +252,7 @@ def create_file(file_path: Path, contents: str) -> None:
def get_meta(
model_path: Union[str, Path], existing_meta: Dict[str, Any]
) -> Dict[str, Any]:
- meta = {
+ meta: Dict[str, Any] = {
"lang": "en",
"name": "pipeline",
"version": "0.0.0",
@@ -324,8 +324,8 @@ def generate_readme(meta: Dict[str, Any]) -> str:
license_name = meta.get("license")
sources = _format_sources(meta.get("sources"))
description = meta.get("description")
- label_scheme = _format_label_scheme(meta.get("labels"))
- accuracy = _format_accuracy(meta.get("performance"))
+ label_scheme = _format_label_scheme(cast(Dict[str, Any], meta.get("labels")))
+ accuracy = _format_accuracy(cast(Dict[str, Any], meta.get("performance")))
table_data = [
(md.bold("Name"), md.code(name)),
(md.bold("Version"), md.code(version)),
diff --git a/spacy/cli/profile.py b/spacy/cli/profile.py
index f4f0d3caf..3c282c73d 100644
--- a/spacy/cli/profile.py
+++ b/spacy/cli/profile.py
@@ -32,7 +32,7 @@ def profile_cli(
DOCS: https://spacy.io/api/cli#debug-profile
"""
- if ctx.parent.command.name == NAME: # called as top-level command
+ if ctx.parent.command.name == NAME: # type: ignore[union-attr] # called as top-level command
msg.warn(
"The profile command is now available via the 'debug profile' "
"subcommand. You can run python -m spacy debug --help for an "
@@ -42,9 +42,9 @@ def profile_cli(
def profile(model: str, inputs: Optional[Path] = None, n_texts: int = 10000) -> None:
-
if inputs is not None:
- inputs = _read_inputs(inputs, msg)
+ texts = _read_inputs(inputs, msg)
+ texts = list(itertools.islice(texts, n_texts))
if inputs is None:
try:
import ml_datasets
@@ -56,16 +56,13 @@ def profile(model: str, inputs: Optional[Path] = None, n_texts: int = 10000) ->
exits=1,
)
- n_inputs = 25000
- with msg.loading("Loading IMDB dataset via Thinc..."):
- imdb_train, _ = ml_datasets.imdb()
- inputs, _ = zip(*imdb_train)
- msg.info(f"Loaded IMDB dataset and using {n_inputs} examples")
- inputs = inputs[:n_inputs]
+ with msg.loading("Loading IMDB dataset via ml_datasets..."):
+ imdb_train, _ = ml_datasets.imdb(train_limit=n_texts, dev_limit=0)
+ texts, _ = zip(*imdb_train)
+ msg.info(f"Loaded IMDB dataset and using {n_texts} examples")
with msg.loading(f"Loading pipeline '{model}'..."):
nlp = load_model(model)
msg.good(f"Loaded pipeline '{model}'")
- texts = list(itertools.islice(inputs, n_texts))
cProfile.runctx("parse_texts(nlp, texts)", globals(), locals(), "Profile.prof")
s = pstats.Stats("Profile.prof")
msg.divider("Profile stats")
@@ -87,7 +84,7 @@ def _read_inputs(loc: Union[Path, str], msg: Printer) -> Iterator[str]:
if not input_path.exists() or not input_path.is_file():
msg.fail("Not a valid input data file", loc, exits=1)
msg.info(f"Using data from {input_path.parts[-1]}")
- file_ = input_path.open()
+ file_ = input_path.open() # type: ignore[assignment]
for line in file_:
data = srsly.json_loads(line)
text = data["text"]
diff --git a/spacy/cli/project/assets.py b/spacy/cli/project/assets.py
index efc93efab..b5057e401 100644
--- a/spacy/cli/project/assets.py
+++ b/spacy/cli/project/assets.py
@@ -133,7 +133,6 @@ def fetch_asset(
# If there's already a file, check for checksum
if checksum == get_checksum(dest_path):
msg.good(f"Skipping download with matching checksum: {dest}")
- return dest_path
# We might as well support the user here and create parent directories in
# case the asset dir isn't listed as a dir to create in the project.yml
if not dest_path.parent.exists():
@@ -150,7 +149,6 @@ def fetch_asset(
msg.good(f"Copied local asset {dest}")
else:
msg.fail(f"Download failed: {dest}", e)
- return
if checksum and checksum != get_checksum(dest_path):
msg.fail(f"Checksum doesn't match value defined in {PROJECT_FILE}: {dest}")
diff --git a/spacy/cli/project/clone.py b/spacy/cli/project/clone.py
index 72d4004f8..360ee3428 100644
--- a/spacy/cli/project/clone.py
+++ b/spacy/cli/project/clone.py
@@ -80,9 +80,9 @@ def check_clone(name: str, dest: Path, repo: str) -> None:
repo (str): URL of the repo to clone from.
"""
git_err = (
- f"Cloning spaCy project templates requires Git and the 'git' command. ",
+ f"Cloning spaCy project templates requires Git and the 'git' command. "
f"To clone a project without Git, copy the files from the '{name}' "
- f"directory in the {repo} to {dest} manually.",
+ f"directory in the {repo} to {dest} manually."
)
get_git_version(error=git_err)
if not dest:
diff --git a/spacy/cli/project/dvc.py b/spacy/cli/project/dvc.py
index 7e37712c3..83dc5efbf 100644
--- a/spacy/cli/project/dvc.py
+++ b/spacy/cli/project/dvc.py
@@ -143,8 +143,8 @@ def run_dvc_commands(
easier to pass flags like --quiet that depend on a variable or
command-line setting while avoiding lots of nested conditionals.
"""
- for command in commands:
- command = split_command(command)
+ for c in commands:
+ command = split_command(c)
dvc_command = ["dvc", *command]
# Add the flags if they are set to True
for flag, is_active in flags.items():
diff --git a/spacy/cli/project/remote_storage.py b/spacy/cli/project/remote_storage.py
index 6056458e2..336a4bcb3 100644
--- a/spacy/cli/project/remote_storage.py
+++ b/spacy/cli/project/remote_storage.py
@@ -41,7 +41,7 @@ class RemoteStorage:
raise IOError(f"Cannot push {loc}: does not exist.")
url = self.make_url(path, command_hash, content_hash)
if url.exists():
- return None
+ return url
tmp: Path
with make_tempdir() as tmp:
tar_loc = tmp / self.encode_name(str(path))
@@ -131,8 +131,10 @@ def get_command_hash(
currently installed packages, whatever environment variables have been marked
as relevant, and the command.
"""
- check_commit = check_bool_env_var(ENV_VARS.PROJECT_USE_GIT_VERSION)
- spacy_v = GIT_VERSION if check_commit else get_minor_version(about.__version__)
+ if check_bool_env_var(ENV_VARS.PROJECT_USE_GIT_VERSION):
+ spacy_v = GIT_VERSION
+ else:
+ spacy_v = str(get_minor_version(about.__version__) or "")
dep_checksums = [get_checksum(dep) for dep in sorted(deps)]
hashes = [spacy_v, site_hash, env_hash] + dep_checksums
hashes.extend(cmd)
diff --git a/spacy/cli/project/run.py b/spacy/cli/project/run.py
index 74c8b24b6..734803bc4 100644
--- a/spacy/cli/project/run.py
+++ b/spacy/cli/project/run.py
@@ -70,7 +70,7 @@ def project_run(
config = load_project_config(project_dir, overrides=overrides)
commands = {cmd["name"]: cmd for cmd in config.get("commands", [])}
workflows = config.get("workflows", {})
- validate_subcommand(commands.keys(), workflows.keys(), subcommand)
+ validate_subcommand(list(commands.keys()), list(workflows.keys()), subcommand)
if subcommand in workflows:
msg.info(f"Running workflow '{subcommand}'")
for cmd in workflows[subcommand]:
@@ -116,7 +116,7 @@ def print_run_help(project_dir: Path, subcommand: Optional[str] = None) -> None:
workflows = config.get("workflows", {})
project_loc = "" if is_cwd(project_dir) else project_dir
if subcommand:
- validate_subcommand(commands.keys(), workflows.keys(), subcommand)
+ validate_subcommand(list(commands.keys()), list(workflows.keys()), subcommand)
print(f"Usage: {COMMAND} project run {subcommand} {project_loc}")
if subcommand in commands:
help_text = commands[subcommand].get("help")
@@ -164,8 +164,8 @@ def run_commands(
when you want to turn over execution to the command, and capture=True
when you want to run the command more like a function.
"""
- for command in commands:
- command = split_command(command)
+ for c in commands:
+ command = split_command(c)
# Not sure if this is needed or a good idea. Motivation: users may often
# use commands in their config that reference "python" and we want to
# make sure that it's always executing the same Python that spaCy is
@@ -294,7 +294,7 @@ def get_lock_entry(project_dir: Path, command: Dict[str, Any]) -> Dict[str, Any]
}
-def get_fileinfo(project_dir: Path, paths: List[str]) -> List[Dict[str, str]]:
+def get_fileinfo(project_dir: Path, paths: List[str]) -> List[Dict[str, Optional[str]]]:
"""Generate the file information for a list of paths (dependencies, outputs).
Includes the file path and the file's checksum.
diff --git a/spacy/cli/validate.py b/spacy/cli/validate.py
index a727e380e..a918e9a39 100644
--- a/spacy/cli/validate.py
+++ b/spacy/cli/validate.py
@@ -99,7 +99,7 @@ def get_model_pkgs(silent: bool = False) -> Tuple[dict, dict]:
warnings.filterwarnings("ignore", message="\\[W09[45]")
model_meta = get_model_meta(model_path)
spacy_version = model_meta.get("spacy_version", "n/a")
- is_compat = is_compatible_version(about.__version__, spacy_version)
+ is_compat = is_compatible_version(about.__version__, spacy_version) # type: ignore[assignment]
pkgs[pkg_name] = {
"name": package,
"version": version,
diff --git a/spacy/compat.py b/spacy/compat.py
index 92ed23c0e..f4403f036 100644
--- a/spacy/compat.py
+++ b/spacy/compat.py
@@ -5,12 +5,12 @@ from thinc.util import copy_array
try:
import cPickle as pickle
except ImportError:
- import pickle
+ import pickle # type: ignore[no-redef]
try:
import copy_reg
except ImportError:
- import copyreg as copy_reg
+ import copyreg as copy_reg # type: ignore[no-redef]
try:
from cupy.cuda.stream import Stream as CudaStream
@@ -22,10 +22,10 @@ try:
except ImportError:
cupy = None
-try: # Python 3.8+
+if sys.version_info[:2] >= (3, 8): # Python 3.8+
from typing import Literal
-except ImportError:
- from typing_extensions import Literal # noqa: F401
+else:
+ from typing_extensions import Literal # noqa: F401
# Important note: The importlib_metadata "backport" includes functionality
# that's not part of the built-in importlib.metadata. We should treat this
@@ -33,7 +33,7 @@ except ImportError:
try: # Python 3.8+
import importlib.metadata as importlib_metadata
except ImportError:
- from catalogue import _importlib_metadata as importlib_metadata # noqa: F401
+ from catalogue import _importlib_metadata as importlib_metadata # type: ignore[no-redef] # noqa: F401
from thinc.api import Optimizer # noqa: F401
diff --git a/spacy/displacy/__init__.py b/spacy/displacy/__init__.py
index 78b83f2e5..d9418f675 100644
--- a/spacy/displacy/__init__.py
+++ b/spacy/displacy/__init__.py
@@ -18,7 +18,7 @@ RENDER_WRAPPER = None
def render(
- docs: Union[Iterable[Union[Doc, Span]], Doc, Span],
+ docs: Union[Iterable[Union[Doc, Span, dict]], Doc, Span, dict],
style: str = "dep",
page: bool = False,
minify: bool = False,
@@ -28,7 +28,8 @@ def render(
) -> str:
"""Render displaCy visualisation.
- docs (Union[Iterable[Doc], Doc]): Document(s) to visualise.
+ docs (Union[Iterable[Union[Doc, Span, dict]], Doc, Span, dict]]): Document(s) to visualise.
+ a 'dict' is only allowed here when 'manual' is set to True
style (str): Visualisation style, 'dep' or 'ent'.
page (bool): Render markup as full HTML page.
minify (bool): Minify HTML markup.
@@ -53,8 +54,8 @@ def render(
raise ValueError(Errors.E096)
renderer_func, converter = factories[style]
renderer = renderer_func(options=options)
- parsed = [converter(doc, options) for doc in docs] if not manual else docs
- _html["parsed"] = renderer.render(parsed, page=page, minify=minify).strip()
+ parsed = [converter(doc, options) for doc in docs] if not manual else docs # type: ignore
+ _html["parsed"] = renderer.render(parsed, page=page, minify=minify).strip() # type: ignore
html = _html["parsed"]
if RENDER_WRAPPER is not None:
html = RENDER_WRAPPER(html)
@@ -133,7 +134,7 @@ def parse_deps(orig_doc: Doc, options: Dict[str, Any] = {}) -> Dict[str, Any]:
"lemma": np.root.lemma_,
"ent_type": np.root.ent_type_,
}
- retokenizer.merge(np, attrs=attrs)
+ retokenizer.merge(np, attrs=attrs) # type: ignore[arg-type]
if options.get("collapse_punct", True):
spans = []
for word in doc[:-1]:
@@ -148,7 +149,7 @@ def parse_deps(orig_doc: Doc, options: Dict[str, Any] = {}) -> Dict[str, Any]:
with doc.retokenize() as retokenizer:
for span, tag, lemma, ent_type in spans:
attrs = {"tag": tag, "lemma": lemma, "ent_type": ent_type}
- retokenizer.merge(span, attrs=attrs)
+ retokenizer.merge(span, attrs=attrs) # type: ignore[arg-type]
fine_grained = options.get("fine_grained")
add_lemma = options.get("add_lemma")
words = [
diff --git a/spacy/kb.pyx b/spacy/kb.pyx
index d8514b54c..16d63a4d3 100644
--- a/spacy/kb.pyx
+++ b/spacy/kb.pyx
@@ -1,5 +1,5 @@
# cython: infer_types=True, profile=True
-from typing import Iterator, Iterable
+from typing import Iterator, Iterable, Callable, Dict, Any
import srsly
from cymem.cymem cimport Pool
@@ -446,7 +446,7 @@ cdef class KnowledgeBase:
raise ValueError(Errors.E929.format(loc=path))
if not path.is_dir():
raise ValueError(Errors.E928.format(loc=path))
- deserialize = {}
+ deserialize: Dict[str, Callable[[Any], Any]] = {}
deserialize["contents"] = lambda p: self.read_contents(p)
deserialize["strings.json"] = lambda p: self.vocab.strings.from_disk(p)
util.from_disk(path, deserialize, exclude)
diff --git a/spacy/lang/af/__init__.py b/spacy/lang/af/__init__.py
index 91917daee..553fcbf4c 100644
--- a/spacy/lang/af/__init__.py
+++ b/spacy/lang/af/__init__.py
@@ -1,8 +1,8 @@
from .stop_words import STOP_WORDS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class AfrikaansDefaults(Language.Defaults):
+class AfrikaansDefaults(BaseDefaults):
stop_words = STOP_WORDS
diff --git a/spacy/lang/am/__init__.py b/spacy/lang/am/__init__.py
index ed21b55ee..ddae556d6 100644
--- a/spacy/lang/am/__init__.py
+++ b/spacy/lang/am/__init__.py
@@ -4,12 +4,12 @@ from .punctuation import TOKENIZER_SUFFIXES
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from ..tokenizer_exceptions import BASE_EXCEPTIONS
-from ...language import Language
+from ...language import Language, BaseDefaults
from ...attrs import LANG
from ...util import update_exc
-class AmharicDefaults(Language.Defaults):
+class AmharicDefaults(BaseDefaults):
lex_attr_getters = dict(Language.Defaults.lex_attr_getters)
lex_attr_getters.update(LEX_ATTRS)
lex_attr_getters[LANG] = lambda text: "am"
diff --git a/spacy/lang/ar/__init__.py b/spacy/lang/ar/__init__.py
index 6abb65efb..18c1f90ed 100644
--- a/spacy/lang/ar/__init__.py
+++ b/spacy/lang/ar/__init__.py
@@ -2,10 +2,10 @@ from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
from .punctuation import TOKENIZER_SUFFIXES
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class ArabicDefaults(Language.Defaults):
+class ArabicDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
suffixes = TOKENIZER_SUFFIXES
stop_words = STOP_WORDS
diff --git a/spacy/lang/az/__init__.py b/spacy/lang/az/__init__.py
index 2937e2ecf..476898364 100644
--- a/spacy/lang/az/__init__.py
+++ b/spacy/lang/az/__init__.py
@@ -1,9 +1,9 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class AzerbaijaniDefaults(Language.Defaults):
+class AzerbaijaniDefaults(BaseDefaults):
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/bg/__init__.py b/spacy/lang/bg/__init__.py
index 6fa539a28..559cc34c4 100644
--- a/spacy/lang/bg/__init__.py
+++ b/spacy/lang/bg/__init__.py
@@ -3,12 +3,12 @@ from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .lex_attrs import LEX_ATTRS
from ..tokenizer_exceptions import BASE_EXCEPTIONS
-from ...language import Language
+from ...language import Language, BaseDefaults
from ...attrs import LANG
from ...util import update_exc
-class BulgarianDefaults(Language.Defaults):
+class BulgarianDefaults(BaseDefaults):
lex_attr_getters = dict(Language.Defaults.lex_attr_getters)
lex_attr_getters[LANG] = lambda text: "bg"
diff --git a/spacy/lang/bn/__init__.py b/spacy/lang/bn/__init__.py
index 23c3ff485..4eb9735df 100644
--- a/spacy/lang/bn/__init__.py
+++ b/spacy/lang/bn/__init__.py
@@ -3,11 +3,11 @@ from thinc.api import Model
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES, TOKENIZER_INFIXES
from .stop_words import STOP_WORDS
-from ...language import Language
+from ...language import Language, BaseDefaults
from ...pipeline import Lemmatizer
-class BengaliDefaults(Language.Defaults):
+class BengaliDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
prefixes = TOKENIZER_PREFIXES
suffixes = TOKENIZER_SUFFIXES
diff --git a/spacy/lang/ca/__init__.py b/spacy/lang/ca/__init__.py
index 81f39b13c..250ae9463 100644
--- a/spacy/lang/ca/__init__.py
+++ b/spacy/lang/ca/__init__.py
@@ -7,11 +7,11 @@ from .punctuation import TOKENIZER_INFIXES, TOKENIZER_SUFFIXES
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
from .syntax_iterators import SYNTAX_ITERATORS
-from ...language import Language
+from ...language import Language, BaseDefaults
from .lemmatizer import CatalanLemmatizer
-class CatalanDefaults(Language.Defaults):
+class CatalanDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
infixes = TOKENIZER_INFIXES
suffixes = TOKENIZER_SUFFIXES
diff --git a/spacy/lang/ca/syntax_iterators.py b/spacy/lang/ca/syntax_iterators.py
index c70d53e80..917e07c93 100644
--- a/spacy/lang/ca/syntax_iterators.py
+++ b/spacy/lang/ca/syntax_iterators.py
@@ -1,8 +1,10 @@
+from typing import Union, Iterator, Tuple
+from ...tokens import Doc, Span
from ...symbols import NOUN, PROPN
from ...errors import Errors
-def noun_chunks(doclike):
+def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
"""Detect base noun phrases from a dependency parse. Works on Doc and Span."""
# fmt: off
labels = ["nsubj", "nsubj:pass", "obj", "obl", "iobj", "ROOT", "appos", "nmod", "nmod:poss"]
diff --git a/spacy/lang/cs/__init__.py b/spacy/lang/cs/__init__.py
index 26f5845cc..3e70e4078 100644
--- a/spacy/lang/cs/__init__.py
+++ b/spacy/lang/cs/__init__.py
@@ -1,9 +1,9 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class CzechDefaults(Language.Defaults):
+class CzechDefaults(BaseDefaults):
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/da/__init__.py b/spacy/lang/da/__init__.py
index c5260ccdd..e148a7b4f 100644
--- a/spacy/lang/da/__init__.py
+++ b/spacy/lang/da/__init__.py
@@ -3,10 +3,10 @@ from .punctuation import TOKENIZER_INFIXES, TOKENIZER_SUFFIXES
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
from .syntax_iterators import SYNTAX_ITERATORS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class DanishDefaults(Language.Defaults):
+class DanishDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
infixes = TOKENIZER_INFIXES
suffixes = TOKENIZER_SUFFIXES
diff --git a/spacy/lang/da/syntax_iterators.py b/spacy/lang/da/syntax_iterators.py
index 39181d753..a0b70f004 100644
--- a/spacy/lang/da/syntax_iterators.py
+++ b/spacy/lang/da/syntax_iterators.py
@@ -1,8 +1,10 @@
+from typing import Union, Iterator, Tuple
+from ...tokens import Doc, Span
from ...symbols import NOUN, PROPN, PRON, VERB, AUX
from ...errors import Errors
-def noun_chunks(doclike):
+def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
def is_verb_token(tok):
return tok.pos in [VERB, AUX]
@@ -32,7 +34,7 @@ def noun_chunks(doclike):
def get_bounds(doc, root):
return get_left_bound(doc, root), get_right_bound(doc, root)
- doc = doclike.doc
+ doc = doclike.doc # Ensure works on both Doc and Span.
if not doc.has_annotation("DEP"):
raise ValueError(Errors.E029)
diff --git a/spacy/lang/de/__init__.py b/spacy/lang/de/__init__.py
index b645d3480..65863c098 100644
--- a/spacy/lang/de/__init__.py
+++ b/spacy/lang/de/__init__.py
@@ -2,10 +2,10 @@ from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES, TOKENIZER_INFIXES
from .stop_words import STOP_WORDS
from .syntax_iterators import SYNTAX_ITERATORS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class GermanDefaults(Language.Defaults):
+class GermanDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
prefixes = TOKENIZER_PREFIXES
suffixes = TOKENIZER_SUFFIXES
diff --git a/spacy/lang/de/syntax_iterators.py b/spacy/lang/de/syntax_iterators.py
index aba0e8024..e80504998 100644
--- a/spacy/lang/de/syntax_iterators.py
+++ b/spacy/lang/de/syntax_iterators.py
@@ -1,11 +1,11 @@
-from typing import Union, Iterator
+from typing import Union, Iterator, Tuple
from ...symbols import NOUN, PROPN, PRON
from ...errors import Errors
from ...tokens import Doc, Span
-def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Span]:
+def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
"""Detect base noun phrases from a dependency parse. Works on Doc and Span."""
# this iterator extracts spans headed by NOUNs starting from the left-most
# syntactic dependent until the NOUN itself for close apposition and
diff --git a/spacy/lang/el/__init__.py b/spacy/lang/el/__init__.py
index be59a3500..258b37a8a 100644
--- a/spacy/lang/el/__init__.py
+++ b/spacy/lang/el/__init__.py
@@ -7,10 +7,10 @@ from .lex_attrs import LEX_ATTRS
from .syntax_iterators import SYNTAX_ITERATORS
from .punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES, TOKENIZER_INFIXES
from .lemmatizer import GreekLemmatizer
-from ...language import Language
+from ...language import Language, BaseDefaults
-class GreekDefaults(Language.Defaults):
+class GreekDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
prefixes = TOKENIZER_PREFIXES
suffixes = TOKENIZER_SUFFIXES
diff --git a/spacy/lang/el/syntax_iterators.py b/spacy/lang/el/syntax_iterators.py
index 89cfd8b72..18fa46695 100644
--- a/spacy/lang/el/syntax_iterators.py
+++ b/spacy/lang/el/syntax_iterators.py
@@ -1,11 +1,11 @@
-from typing import Union, Iterator
+from typing import Union, Iterator, Tuple
from ...symbols import NOUN, PROPN, PRON
from ...errors import Errors
from ...tokens import Doc, Span
-def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Span]:
+def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
"""Detect base noun phrases from a dependency parse. Works on Doc and Span."""
# It follows the logic of the noun chunks finder of English language,
# adjusted to some Greek language special characteristics.
diff --git a/spacy/lang/en/__init__.py b/spacy/lang/en/__init__.py
index eea522908..854f59224 100644
--- a/spacy/lang/en/__init__.py
+++ b/spacy/lang/en/__init__.py
@@ -7,10 +7,10 @@ from .lex_attrs import LEX_ATTRS
from .syntax_iterators import SYNTAX_ITERATORS
from .punctuation import TOKENIZER_INFIXES
from .lemmatizer import EnglishLemmatizer
-from ...language import Language
+from ...language import Language, BaseDefaults
-class EnglishDefaults(Language.Defaults):
+class EnglishDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
infixes = TOKENIZER_INFIXES
lex_attr_getters = LEX_ATTRS
diff --git a/spacy/lang/en/lex_attrs.py b/spacy/lang/en/lex_attrs.py
index b630a317d..ab9353919 100644
--- a/spacy/lang/en/lex_attrs.py
+++ b/spacy/lang/en/lex_attrs.py
@@ -19,7 +19,7 @@ _ordinal_words = [
# fmt: on
-def like_num(text: str) -> bool:
+def like_num(text):
if text.startswith(("+", "-", "±", "~")):
text = text[1:]
text = text.replace(",", "").replace(".", "")
diff --git a/spacy/lang/en/syntax_iterators.py b/spacy/lang/en/syntax_iterators.py
index 00a1bac42..7904e5621 100644
--- a/spacy/lang/en/syntax_iterators.py
+++ b/spacy/lang/en/syntax_iterators.py
@@ -1,11 +1,11 @@
-from typing import Union, Iterator
+from typing import Union, Iterator, Tuple
from ...symbols import NOUN, PROPN, PRON
from ...errors import Errors
from ...tokens import Doc, Span
-def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Span]:
+def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
"""
Detect base noun phrases from a dependency parse. Works on both Doc and Span.
"""
diff --git a/spacy/lang/en/tokenizer_exceptions.py b/spacy/lang/en/tokenizer_exceptions.py
index d69508470..55b544e42 100644
--- a/spacy/lang/en/tokenizer_exceptions.py
+++ b/spacy/lang/en/tokenizer_exceptions.py
@@ -1,9 +1,10 @@
+from typing import Dict, List
from ..tokenizer_exceptions import BASE_EXCEPTIONS
from ...symbols import ORTH, NORM
from ...util import update_exc
-_exc = {}
+_exc: Dict[str, List[Dict]] = {}
_exclude = [
"Ill",
"ill",
@@ -294,9 +295,9 @@ for verb_data in [
{ORTH: "has", NORM: "has"},
{ORTH: "dare", NORM: "dare"},
]:
- verb_data_tc = dict(verb_data)
+ verb_data_tc = dict(verb_data) # type: ignore[call-overload]
verb_data_tc[ORTH] = verb_data_tc[ORTH].title()
- for data in [verb_data, verb_data_tc]:
+ for data in [verb_data, verb_data_tc]: # type: ignore[assignment]
_exc[data[ORTH] + "n't"] = [
dict(data),
{ORTH: "n't", NORM: "not"},
diff --git a/spacy/lang/es/__init__.py b/spacy/lang/es/__init__.py
index 4b329b6f7..f5d1eb97a 100644
--- a/spacy/lang/es/__init__.py
+++ b/spacy/lang/es/__init__.py
@@ -6,10 +6,10 @@ from .lex_attrs import LEX_ATTRS
from .lemmatizer import SpanishLemmatizer
from .syntax_iterators import SYNTAX_ITERATORS
from .punctuation import TOKENIZER_INFIXES, TOKENIZER_SUFFIXES
-from ...language import Language
+from ...language import Language, BaseDefaults
-class SpanishDefaults(Language.Defaults):
+class SpanishDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
infixes = TOKENIZER_INFIXES
suffixes = TOKENIZER_SUFFIXES
diff --git a/spacy/lang/es/lemmatizer.py b/spacy/lang/es/lemmatizer.py
index 56f74068d..ca5fc08c8 100644
--- a/spacy/lang/es/lemmatizer.py
+++ b/spacy/lang/es/lemmatizer.py
@@ -52,7 +52,7 @@ class SpanishLemmatizer(Lemmatizer):
rule_pos = "verb"
else:
rule_pos = pos
- rule = self.select_rule(rule_pos, features)
+ rule = self.select_rule(rule_pos, list(features))
index = self.lookups.get_table("lemma_index").get(rule_pos, [])
lemmas = getattr(self, "lemmatize_" + rule_pos)(
string, features, rule, index
@@ -191,6 +191,8 @@ class SpanishLemmatizer(Lemmatizer):
return selected_lemmas
else:
return possible_lemmas
+ else:
+ return []
def lemmatize_noun(
self, word: str, features: List[str], rule: str, index: List[str]
@@ -268,7 +270,7 @@ class SpanishLemmatizer(Lemmatizer):
return [word]
def lemmatize_pron(
- self, word: str, features: List[str], rule: str, index: List[str]
+ self, word: str, features: List[str], rule: Optional[str], index: List[str]
) -> List[str]:
"""
Lemmatize a pronoun.
@@ -319,9 +321,11 @@ class SpanishLemmatizer(Lemmatizer):
return selected_lemmas
else:
return possible_lemmas
+ else:
+ return []
def lemmatize_verb(
- self, word: str, features: List[str], rule: str, index: List[str]
+ self, word: str, features: List[str], rule: Optional[str], index: List[str]
) -> List[str]:
"""
Lemmatize a verb.
@@ -342,6 +346,7 @@ class SpanishLemmatizer(Lemmatizer):
selected_lemmas = []
# Apply lemmatization rules
+ rule = str(rule or "")
for old, new in self.lookups.get_table("lemma_rules").get(rule, []):
possible_lemma = re.sub(old + "$", new, word)
if possible_lemma != word:
@@ -389,11 +394,11 @@ class SpanishLemmatizer(Lemmatizer):
return [word]
def lemmatize_verb_pron(
- self, word: str, features: List[str], rule: str, index: List[str]
+ self, word: str, features: List[str], rule: Optional[str], index: List[str]
) -> List[str]:
# Strip and collect pronouns
pron_patt = "^(.*?)([mts]e|l[aeo]s?|n?os)$"
- prons = []
+ prons: List[str] = []
verb = word
m = re.search(pron_patt, verb)
while m is not None and len(prons) <= 3:
@@ -410,7 +415,7 @@ class SpanishLemmatizer(Lemmatizer):
else:
rule = self.select_rule("verb", features)
verb_lemma = self.lemmatize_verb(
- verb, features - {"PronType=Prs"}, rule, index
+ verb, features - {"PronType=Prs"}, rule, index # type: ignore[operator]
)[0]
pron_lemmas = []
for pron in prons:
diff --git a/spacy/lang/es/syntax_iterators.py b/spacy/lang/es/syntax_iterators.py
index e753a3f98..8b385a1b9 100644
--- a/spacy/lang/es/syntax_iterators.py
+++ b/spacy/lang/es/syntax_iterators.py
@@ -1,11 +1,11 @@
-from typing import Union, Iterator
+from typing import Union, Iterator, Tuple
from ...symbols import NOUN, PROPN, PRON, VERB, AUX
from ...errors import Errors
from ...tokens import Doc, Span, Token
-def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Span]:
+def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
"""Detect base noun phrases from a dependency parse. Works on Doc and Span."""
doc = doclike.doc
if not doc.has_annotation("DEP"):
diff --git a/spacy/lang/et/__init__.py b/spacy/lang/et/__init__.py
index 9f71882d2..274bc1309 100644
--- a/spacy/lang/et/__init__.py
+++ b/spacy/lang/et/__init__.py
@@ -1,8 +1,8 @@
from .stop_words import STOP_WORDS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class EstonianDefaults(Language.Defaults):
+class EstonianDefaults(BaseDefaults):
stop_words = STOP_WORDS
diff --git a/spacy/lang/eu/__init__.py b/spacy/lang/eu/__init__.py
index 89550be96..3346468bd 100644
--- a/spacy/lang/eu/__init__.py
+++ b/spacy/lang/eu/__init__.py
@@ -1,10 +1,10 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
from .punctuation import TOKENIZER_SUFFIXES
-from ...language import Language
+from ...language import Language, BaseDefaults
-class BasqueDefaults(Language.Defaults):
+class BasqueDefaults(BaseDefaults):
suffixes = TOKENIZER_SUFFIXES
stop_words = STOP_WORDS
lex_attr_getters = LEX_ATTRS
diff --git a/spacy/lang/fa/__init__.py b/spacy/lang/fa/__init__.py
index 77a0a28b9..6db64ff62 100644
--- a/spacy/lang/fa/__init__.py
+++ b/spacy/lang/fa/__init__.py
@@ -5,11 +5,11 @@ from .lex_attrs import LEX_ATTRS
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .punctuation import TOKENIZER_SUFFIXES
from .syntax_iterators import SYNTAX_ITERATORS
-from ...language import Language
+from ...language import Language, BaseDefaults
from ...pipeline import Lemmatizer
-class PersianDefaults(Language.Defaults):
+class PersianDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
suffixes = TOKENIZER_SUFFIXES
lex_attr_getters = LEX_ATTRS
diff --git a/spacy/lang/fa/generate_verbs_exc.py b/spacy/lang/fa/generate_verbs_exc.py
index 62094c6de..a6d79a386 100644
--- a/spacy/lang/fa/generate_verbs_exc.py
+++ b/spacy/lang/fa/generate_verbs_exc.py
@@ -639,10 +639,12 @@ for verb_root in verb_roots:
)
if past.startswith("آ"):
- conjugations = set(
- map(
- lambda item: item.replace("بآ", "بیا").replace("نآ", "نیا"),
- conjugations,
+ conjugations = list(
+ set(
+ map(
+ lambda item: item.replace("بآ", "بیا").replace("نآ", "نیا"),
+ conjugations,
+ )
)
)
diff --git a/spacy/lang/fa/syntax_iterators.py b/spacy/lang/fa/syntax_iterators.py
index 0be06e73c..8207884b0 100644
--- a/spacy/lang/fa/syntax_iterators.py
+++ b/spacy/lang/fa/syntax_iterators.py
@@ -1,8 +1,10 @@
+from typing import Union, Iterator, Tuple
+from ...tokens import Doc, Span
from ...symbols import NOUN, PROPN, PRON
from ...errors import Errors
-def noun_chunks(doclike):
+def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
"""
Detect base noun phrases from a dependency parse. Works on both Doc and Span.
"""
diff --git a/spacy/lang/fi/__init__.py b/spacy/lang/fi/__init__.py
index 9233c6547..86a834170 100644
--- a/spacy/lang/fi/__init__.py
+++ b/spacy/lang/fi/__init__.py
@@ -2,10 +2,10 @@ from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
from .punctuation import TOKENIZER_INFIXES, TOKENIZER_SUFFIXES
-from ...language import Language
+from ...language import Language, BaseDefaults
-class FinnishDefaults(Language.Defaults):
+class FinnishDefaults(BaseDefaults):
infixes = TOKENIZER_INFIXES
suffixes = TOKENIZER_SUFFIXES
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
diff --git a/spacy/lang/fr/__init__.py b/spacy/lang/fr/__init__.py
index d69a5a718..e7267dc61 100644
--- a/spacy/lang/fr/__init__.py
+++ b/spacy/lang/fr/__init__.py
@@ -9,10 +9,10 @@ from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
from .syntax_iterators import SYNTAX_ITERATORS
from .lemmatizer import FrenchLemmatizer
-from ...language import Language
+from ...language import Language, BaseDefaults
-class FrenchDefaults(Language.Defaults):
+class FrenchDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
prefixes = TOKENIZER_PREFIXES
infixes = TOKENIZER_INFIXES
diff --git a/spacy/lang/fr/syntax_iterators.py b/spacy/lang/fr/syntax_iterators.py
index 68117a54d..d86662693 100644
--- a/spacy/lang/fr/syntax_iterators.py
+++ b/spacy/lang/fr/syntax_iterators.py
@@ -1,11 +1,11 @@
-from typing import Union, Iterator
+from typing import Union, Iterator, Tuple
from ...symbols import NOUN, PROPN, PRON
from ...errors import Errors
from ...tokens import Doc, Span
-def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Span]:
+def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
"""Detect base noun phrases from a dependency parse. Works on Doc and Span."""
# fmt: off
labels = ["nsubj", "nsubj:pass", "obj", "iobj", "ROOT", "appos", "nmod", "nmod:poss"]
diff --git a/spacy/lang/fr/tokenizer_exceptions.py b/spacy/lang/fr/tokenizer_exceptions.py
index 060f81879..2e88b58cf 100644
--- a/spacy/lang/fr/tokenizer_exceptions.py
+++ b/spacy/lang/fr/tokenizer_exceptions.py
@@ -115,7 +115,7 @@ for s, verb, pronoun in [("s", "est", "il"), ("S", "EST", "IL")]:
]
-_infixes_exc = []
+_infixes_exc = [] # type: ignore[var-annotated]
orig_elision = "'"
orig_hyphen = "-"
diff --git a/spacy/lang/ga/__init__.py b/spacy/lang/ga/__init__.py
index 80131368b..90735d749 100644
--- a/spacy/lang/ga/__init__.py
+++ b/spacy/lang/ga/__init__.py
@@ -1,9 +1,9 @@
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .stop_words import STOP_WORDS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class IrishDefaults(Language.Defaults):
+class IrishDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
stop_words = STOP_WORDS
diff --git a/spacy/lang/grc/__init__.py b/spacy/lang/grc/__init__.py
index e29252da9..e83f0c5a5 100644
--- a/spacy/lang/grc/__init__.py
+++ b/spacy/lang/grc/__init__.py
@@ -1,10 +1,10 @@
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class AncientGreekDefaults(Language.Defaults):
+class AncientGreekDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/grc/tokenizer_exceptions.py b/spacy/lang/grc/tokenizer_exceptions.py
index 230a58fd2..bcee70f32 100644
--- a/spacy/lang/grc/tokenizer_exceptions.py
+++ b/spacy/lang/grc/tokenizer_exceptions.py
@@ -108,8 +108,4 @@ _other_exc = {
_exc.update(_other_exc)
-_exc_data = {}
-
-_exc.update(_exc_data)
-
TOKENIZER_EXCEPTIONS = update_exc(BASE_EXCEPTIONS, _exc)
diff --git a/spacy/lang/gu/__init__.py b/spacy/lang/gu/__init__.py
index 67228ac40..e6fbc9d18 100644
--- a/spacy/lang/gu/__init__.py
+++ b/spacy/lang/gu/__init__.py
@@ -1,8 +1,8 @@
from .stop_words import STOP_WORDS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class GujaratiDefaults(Language.Defaults):
+class GujaratiDefaults(BaseDefaults):
stop_words = STOP_WORDS
diff --git a/spacy/lang/he/__init__.py b/spacy/lang/he/__init__.py
index e0adc3293..dd2ee478d 100644
--- a/spacy/lang/he/__init__.py
+++ b/spacy/lang/he/__init__.py
@@ -1,9 +1,9 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class HebrewDefaults(Language.Defaults):
+class HebrewDefaults(BaseDefaults):
stop_words = STOP_WORDS
lex_attr_getters = LEX_ATTRS
writing_system = {"direction": "rtl", "has_case": False, "has_letters": True}
diff --git a/spacy/lang/hi/__init__.py b/spacy/lang/hi/__init__.py
index 384f040c8..4c8ae446d 100644
--- a/spacy/lang/hi/__init__.py
+++ b/spacy/lang/hi/__init__.py
@@ -1,9 +1,9 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class HindiDefaults(Language.Defaults):
+class HindiDefaults(BaseDefaults):
stop_words = STOP_WORDS
lex_attr_getters = LEX_ATTRS
diff --git a/spacy/lang/hr/__init__.py b/spacy/lang/hr/__init__.py
index 118e0946a..30870b522 100644
--- a/spacy/lang/hr/__init__.py
+++ b/spacy/lang/hr/__init__.py
@@ -1,8 +1,8 @@
from .stop_words import STOP_WORDS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class CroatianDefaults(Language.Defaults):
+class CroatianDefaults(BaseDefaults):
stop_words = STOP_WORDS
diff --git a/spacy/lang/hu/__init__.py b/spacy/lang/hu/__init__.py
index 8962603a6..9426bacea 100644
--- a/spacy/lang/hu/__init__.py
+++ b/spacy/lang/hu/__init__.py
@@ -1,10 +1,10 @@
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS, TOKEN_MATCH
from .punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES, TOKENIZER_INFIXES
from .stop_words import STOP_WORDS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class HungarianDefaults(Language.Defaults):
+class HungarianDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
prefixes = TOKENIZER_PREFIXES
suffixes = TOKENIZER_SUFFIXES
diff --git a/spacy/lang/hy/__init__.py b/spacy/lang/hy/__init__.py
index 4577ab641..481eaae0a 100644
--- a/spacy/lang/hy/__init__.py
+++ b/spacy/lang/hy/__init__.py
@@ -1,9 +1,9 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class ArmenianDefaults(Language.Defaults):
+class ArmenianDefaults(BaseDefaults):
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/id/__init__.py b/spacy/lang/id/__init__.py
index 87373551c..0d72cfa9d 100644
--- a/spacy/lang/id/__init__.py
+++ b/spacy/lang/id/__init__.py
@@ -3,10 +3,10 @@ from .punctuation import TOKENIZER_SUFFIXES, TOKENIZER_PREFIXES, TOKENIZER_INFIX
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .lex_attrs import LEX_ATTRS
from .syntax_iterators import SYNTAX_ITERATORS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class IndonesianDefaults(Language.Defaults):
+class IndonesianDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
prefixes = TOKENIZER_PREFIXES
suffixes = TOKENIZER_SUFFIXES
diff --git a/spacy/lang/id/syntax_iterators.py b/spacy/lang/id/syntax_iterators.py
index 0f29bfe16..fa984d411 100644
--- a/spacy/lang/id/syntax_iterators.py
+++ b/spacy/lang/id/syntax_iterators.py
@@ -1,11 +1,11 @@
-from typing import Union, Iterator
+from typing import Union, Iterator, Tuple
from ...symbols import NOUN, PROPN, PRON
from ...errors import Errors
from ...tokens import Doc, Span
-def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Span]:
+def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
"""
Detect base noun phrases from a dependency parse. Works on both Doc and Span.
"""
diff --git a/spacy/lang/is/__init__.py b/spacy/lang/is/__init__.py
index be5de5981..318363beb 100644
--- a/spacy/lang/is/__init__.py
+++ b/spacy/lang/is/__init__.py
@@ -1,8 +1,8 @@
from .stop_words import STOP_WORDS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class IcelandicDefaults(Language.Defaults):
+class IcelandicDefaults(BaseDefaults):
stop_words = STOP_WORDS
diff --git a/spacy/lang/it/__init__.py b/spacy/lang/it/__init__.py
index 672a8698e..863ed8e2f 100644
--- a/spacy/lang/it/__init__.py
+++ b/spacy/lang/it/__init__.py
@@ -4,11 +4,11 @@ from thinc.api import Model
from .stop_words import STOP_WORDS
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .punctuation import TOKENIZER_PREFIXES, TOKENIZER_INFIXES
-from ...language import Language
+from ...language import Language, BaseDefaults
from .lemmatizer import ItalianLemmatizer
-class ItalianDefaults(Language.Defaults):
+class ItalianDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
stop_words = STOP_WORDS
prefixes = TOKENIZER_PREFIXES
diff --git a/spacy/lang/ja/__init__.py b/spacy/lang/ja/__init__.py
index 4e6bf9d3c..8499fc73e 100644
--- a/spacy/lang/ja/__init__.py
+++ b/spacy/lang/ja/__init__.py
@@ -10,7 +10,7 @@ from .tag_orth_map import TAG_ORTH_MAP
from .tag_bigram_map import TAG_BIGRAM_MAP
from ...compat import copy_reg
from ...errors import Errors
-from ...language import Language
+from ...language import Language, BaseDefaults
from ...scorer import Scorer
from ...symbols import POS
from ...tokens import Doc
@@ -154,7 +154,7 @@ class JapaneseTokenizer(DummyTokenizer):
def to_disk(self, path: Union[str, Path], **kwargs) -> None:
path = util.ensure_path(path)
serializers = {"cfg": lambda p: srsly.write_json(p, self._get_config())}
- return util.to_disk(path, serializers, [])
+ util.to_disk(path, serializers, [])
def from_disk(self, path: Union[str, Path], **kwargs) -> "JapaneseTokenizer":
path = util.ensure_path(path)
@@ -164,7 +164,7 @@ class JapaneseTokenizer(DummyTokenizer):
return self
-class JapaneseDefaults(Language.Defaults):
+class JapaneseDefaults(BaseDefaults):
config = load_config_from_str(DEFAULT_CONFIG)
stop_words = STOP_WORDS
syntax_iterators = SYNTAX_ITERATORS
diff --git a/spacy/lang/ja/syntax_iterators.py b/spacy/lang/ja/syntax_iterators.py
index cca4902ab..588a9ba03 100644
--- a/spacy/lang/ja/syntax_iterators.py
+++ b/spacy/lang/ja/syntax_iterators.py
@@ -1,4 +1,4 @@
-from typing import Union, Iterator
+from typing import Union, Iterator, Tuple, Set
from ...symbols import NOUN, PROPN, PRON, VERB
from ...tokens import Doc, Span
@@ -10,13 +10,13 @@ labels = ["nsubj", "nmod", "ddoclike", "nsubjpass", "pcomp", "pdoclike", "doclik
# fmt: on
-def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Span]:
+def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
"""Detect base noun phrases from a dependency parse. Works on Doc and Span."""
doc = doclike.doc # Ensure works on both Doc and Span.
np_deps = [doc.vocab.strings.add(label) for label in labels]
doc.vocab.strings.add("conj")
np_label = doc.vocab.strings.add("NP")
- seen = set()
+ seen: Set[int] = set()
for i, word in enumerate(doclike):
if word.pos not in (NOUN, PROPN, PRON):
continue
diff --git a/spacy/lang/kn/__init__.py b/spacy/lang/kn/__init__.py
index 8e53989e6..ccd46a394 100644
--- a/spacy/lang/kn/__init__.py
+++ b/spacy/lang/kn/__init__.py
@@ -1,8 +1,8 @@
from .stop_words import STOP_WORDS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class KannadaDefaults(Language.Defaults):
+class KannadaDefaults(BaseDefaults):
stop_words = STOP_WORDS
diff --git a/spacy/lang/ko/__init__.py b/spacy/lang/ko/__init__.py
index 83c9f4962..dfb311136 100644
--- a/spacy/lang/ko/__init__.py
+++ b/spacy/lang/ko/__init__.py
@@ -1,9 +1,9 @@
-from typing import Optional, Any, Dict
+from typing import Iterator, Any, Dict
from .stop_words import STOP_WORDS
from .tag_map import TAG_MAP
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
from ...tokens import Doc
from ...compat import copy_reg
from ...scorer import Scorer
@@ -29,9 +29,9 @@ def create_tokenizer():
class KoreanTokenizer(DummyTokenizer):
- def __init__(self, nlp: Optional[Language] = None):
+ def __init__(self, nlp: Language):
self.vocab = nlp.vocab
- MeCab = try_mecab_import()
+ MeCab = try_mecab_import() # type: ignore[func-returns-value]
self.mecab_tokenizer = MeCab("-F%f[0],%f[7]")
def __del__(self):
@@ -49,7 +49,7 @@ class KoreanTokenizer(DummyTokenizer):
doc.user_data["full_tags"] = [dt["tag"] for dt in dtokens]
return doc
- def detailed_tokens(self, text: str) -> Dict[str, Any]:
+ def detailed_tokens(self, text: str) -> Iterator[Dict[str, Any]]:
# 품사 태그(POS)[0], 의미 부류(semantic class)[1], 종성 유무(jongseong)[2], 읽기(reading)[3],
# 타입(type)[4], 첫번째 품사(start pos)[5], 마지막 품사(end pos)[6], 표현(expression)[7], *
for node in self.mecab_tokenizer.parse(text, as_nodes=True):
@@ -68,7 +68,7 @@ class KoreanTokenizer(DummyTokenizer):
return Scorer.score_tokenization(examples)
-class KoreanDefaults(Language.Defaults):
+class KoreanDefaults(BaseDefaults):
config = load_config_from_str(DEFAULT_CONFIG)
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/ky/__init__.py b/spacy/lang/ky/__init__.py
index a333db035..ccca384bd 100644
--- a/spacy/lang/ky/__init__.py
+++ b/spacy/lang/ky/__init__.py
@@ -2,10 +2,10 @@ from .lex_attrs import LEX_ATTRS
from .punctuation import TOKENIZER_INFIXES
from .stop_words import STOP_WORDS
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class KyrgyzDefaults(Language.Defaults):
+class KyrgyzDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
infixes = TOKENIZER_INFIXES
lex_attr_getters = LEX_ATTRS
diff --git a/spacy/lang/lb/__init__.py b/spacy/lang/lb/__init__.py
index da6fe55d7..7827e7762 100644
--- a/spacy/lang/lb/__init__.py
+++ b/spacy/lang/lb/__init__.py
@@ -2,10 +2,10 @@ from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .punctuation import TOKENIZER_INFIXES
from .lex_attrs import LEX_ATTRS
from .stop_words import STOP_WORDS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class LuxembourgishDefaults(Language.Defaults):
+class LuxembourgishDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
infixes = TOKENIZER_INFIXES
lex_attr_getters = LEX_ATTRS
diff --git a/spacy/lang/lij/__init__.py b/spacy/lang/lij/__init__.py
index 5ae280324..b7e11f77e 100644
--- a/spacy/lang/lij/__init__.py
+++ b/spacy/lang/lij/__init__.py
@@ -1,10 +1,10 @@
from .stop_words import STOP_WORDS
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .punctuation import TOKENIZER_INFIXES
-from ...language import Language
+from ...language import Language, BaseDefaults
-class LigurianDefaults(Language.Defaults):
+class LigurianDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
infixes = TOKENIZER_INFIXES
stop_words = STOP_WORDS
diff --git a/spacy/lang/lt/__init__.py b/spacy/lang/lt/__init__.py
index e395a8f62..3ae000e5f 100644
--- a/spacy/lang/lt/__init__.py
+++ b/spacy/lang/lt/__init__.py
@@ -2,10 +2,10 @@ from .punctuation import TOKENIZER_INFIXES, TOKENIZER_SUFFIXES
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class LithuanianDefaults(Language.Defaults):
+class LithuanianDefaults(BaseDefaults):
infixes = TOKENIZER_INFIXES
suffixes = TOKENIZER_SUFFIXES
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
diff --git a/spacy/lang/lv/__init__.py b/spacy/lang/lv/__init__.py
index 142bc706e..a05e5b939 100644
--- a/spacy/lang/lv/__init__.py
+++ b/spacy/lang/lv/__init__.py
@@ -1,8 +1,8 @@
from .stop_words import STOP_WORDS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class LatvianDefaults(Language.Defaults):
+class LatvianDefaults(BaseDefaults):
stop_words = STOP_WORDS
diff --git a/spacy/lang/mk/__init__.py b/spacy/lang/mk/__init__.py
index 2f6097f16..376afb552 100644
--- a/spacy/lang/mk/__init__.py
+++ b/spacy/lang/mk/__init__.py
@@ -6,13 +6,13 @@ from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .lex_attrs import LEX_ATTRS
from ..tokenizer_exceptions import BASE_EXCEPTIONS
-from ...language import Language
+from ...language import Language, BaseDefaults
from ...attrs import LANG
from ...util import update_exc
from ...lookups import Lookups
-class MacedonianDefaults(Language.Defaults):
+class MacedonianDefaults(BaseDefaults):
lex_attr_getters = dict(Language.Defaults.lex_attr_getters)
lex_attr_getters[LANG] = lambda text: "mk"
diff --git a/spacy/lang/ml/__init__.py b/spacy/lang/ml/__init__.py
index cfad52261..9f90605f0 100644
--- a/spacy/lang/ml/__init__.py
+++ b/spacy/lang/ml/__init__.py
@@ -1,9 +1,9 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class MalayalamDefaults(Language.Defaults):
+class MalayalamDefaults(BaseDefaults):
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/mr/__init__.py b/spacy/lang/mr/__init__.py
index af0c49878..3e172fa60 100644
--- a/spacy/lang/mr/__init__.py
+++ b/spacy/lang/mr/__init__.py
@@ -1,8 +1,8 @@
from .stop_words import STOP_WORDS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class MarathiDefaults(Language.Defaults):
+class MarathiDefaults(BaseDefaults):
stop_words = STOP_WORDS
diff --git a/spacy/lang/nb/__init__.py b/spacy/lang/nb/__init__.py
index 0bfde7d28..e27754e55 100644
--- a/spacy/lang/nb/__init__.py
+++ b/spacy/lang/nb/__init__.py
@@ -5,11 +5,11 @@ from .punctuation import TOKENIZER_PREFIXES, TOKENIZER_INFIXES
from .punctuation import TOKENIZER_SUFFIXES
from .stop_words import STOP_WORDS
from .syntax_iterators import SYNTAX_ITERATORS
-from ...language import Language
+from ...language import Language, BaseDefaults
from ...pipeline import Lemmatizer
-class NorwegianDefaults(Language.Defaults):
+class NorwegianDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
prefixes = TOKENIZER_PREFIXES
infixes = TOKENIZER_INFIXES
diff --git a/spacy/lang/nb/syntax_iterators.py b/spacy/lang/nb/syntax_iterators.py
index 68117a54d..d86662693 100644
--- a/spacy/lang/nb/syntax_iterators.py
+++ b/spacy/lang/nb/syntax_iterators.py
@@ -1,11 +1,11 @@
-from typing import Union, Iterator
+from typing import Union, Iterator, Tuple
from ...symbols import NOUN, PROPN, PRON
from ...errors import Errors
from ...tokens import Doc, Span
-def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Span]:
+def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
"""Detect base noun phrases from a dependency parse. Works on Doc and Span."""
# fmt: off
labels = ["nsubj", "nsubj:pass", "obj", "iobj", "ROOT", "appos", "nmod", "nmod:poss"]
diff --git a/spacy/lang/ne/__init__.py b/spacy/lang/ne/__init__.py
index 68632e9ad..0028d1b0b 100644
--- a/spacy/lang/ne/__init__.py
+++ b/spacy/lang/ne/__init__.py
@@ -1,9 +1,9 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class NepaliDefaults(Language.Defaults):
+class NepaliDefaults(BaseDefaults):
stop_words = STOP_WORDS
lex_attr_getters = LEX_ATTRS
diff --git a/spacy/lang/nl/__init__.py b/spacy/lang/nl/__init__.py
index 5e95b4a8b..8f370eaaf 100644
--- a/spacy/lang/nl/__init__.py
+++ b/spacy/lang/nl/__init__.py
@@ -9,10 +9,10 @@ from .punctuation import TOKENIZER_SUFFIXES
from .stop_words import STOP_WORDS
from .syntax_iterators import SYNTAX_ITERATORS
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class DutchDefaults(Language.Defaults):
+class DutchDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
prefixes = TOKENIZER_PREFIXES
infixes = TOKENIZER_INFIXES
diff --git a/spacy/lang/nl/syntax_iterators.py b/spacy/lang/nl/syntax_iterators.py
index 1959ba6e2..1ab5e7cff 100644
--- a/spacy/lang/nl/syntax_iterators.py
+++ b/spacy/lang/nl/syntax_iterators.py
@@ -1,11 +1,11 @@
-from typing import Union, Iterator
+from typing import Union, Iterator, Tuple
from ...symbols import NOUN, PRON
from ...errors import Errors
from ...tokens import Doc, Span
-def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Span]:
+def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
"""
Detect base noun phrases from a dependency parse. Works on Doc and Span.
The definition is inspired by https://www.nltk.org/book/ch07.html
diff --git a/spacy/lang/pl/__init__.py b/spacy/lang/pl/__init__.py
index 585e08c60..4b8c88bd7 100644
--- a/spacy/lang/pl/__init__.py
+++ b/spacy/lang/pl/__init__.py
@@ -8,7 +8,7 @@ from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
from .lemmatizer import PolishLemmatizer
from ..tokenizer_exceptions import BASE_EXCEPTIONS
-from ...language import Language
+from ...language import Language, BaseDefaults
TOKENIZER_EXCEPTIONS = {
@@ -16,7 +16,7 @@ TOKENIZER_EXCEPTIONS = {
}
-class PolishDefaults(Language.Defaults):
+class PolishDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
prefixes = TOKENIZER_PREFIXES
infixes = TOKENIZER_INFIXES
diff --git a/spacy/lang/pt/__init__.py b/spacy/lang/pt/__init__.py
index 0447099f0..9ae6501fb 100644
--- a/spacy/lang/pt/__init__.py
+++ b/spacy/lang/pt/__init__.py
@@ -2,10 +2,10 @@ from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
from .punctuation import TOKENIZER_INFIXES, TOKENIZER_PREFIXES
-from ...language import Language
+from ...language import Language, BaseDefaults
-class PortugueseDefaults(Language.Defaults):
+class PortugueseDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
infixes = TOKENIZER_INFIXES
prefixes = TOKENIZER_PREFIXES
diff --git a/spacy/lang/ro/__init__.py b/spacy/lang/ro/__init__.py
index f0d8d8d31..50027ffd2 100644
--- a/spacy/lang/ro/__init__.py
+++ b/spacy/lang/ro/__init__.py
@@ -3,14 +3,14 @@ from .stop_words import STOP_WORDS
from .punctuation import TOKENIZER_PREFIXES, TOKENIZER_INFIXES
from .punctuation import TOKENIZER_SUFFIXES
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
# Lemma data note:
# Original pairs downloaded from http://www.lexiconista.com/datasets/lemmatization/
# Replaced characters using cedillas with the correct ones (ș and ț)
-class RomanianDefaults(Language.Defaults):
+class RomanianDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
prefixes = TOKENIZER_PREFIXES
suffixes = TOKENIZER_SUFFIXES
diff --git a/spacy/lang/ru/__init__.py b/spacy/lang/ru/__init__.py
index 4287cc288..16ae5eef5 100644
--- a/spacy/lang/ru/__init__.py
+++ b/spacy/lang/ru/__init__.py
@@ -5,10 +5,10 @@ from .stop_words import STOP_WORDS
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .lex_attrs import LEX_ATTRS
from .lemmatizer import RussianLemmatizer
-from ...language import Language
+from ...language import Language, BaseDefaults
-class RussianDefaults(Language.Defaults):
+class RussianDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/sa/__init__.py b/spacy/lang/sa/__init__.py
index 345137817..61398af6c 100644
--- a/spacy/lang/sa/__init__.py
+++ b/spacy/lang/sa/__init__.py
@@ -1,9 +1,9 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class SanskritDefaults(Language.Defaults):
+class SanskritDefaults(BaseDefaults):
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/si/__init__.py b/spacy/lang/si/__init__.py
index d77e3bb8b..971cee3c6 100644
--- a/spacy/lang/si/__init__.py
+++ b/spacy/lang/si/__init__.py
@@ -1,9 +1,9 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class SinhalaDefaults(Language.Defaults):
+class SinhalaDefaults(BaseDefaults):
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/sk/__init__.py b/spacy/lang/sk/__init__.py
index 4003c7340..da6e3048e 100644
--- a/spacy/lang/sk/__init__.py
+++ b/spacy/lang/sk/__init__.py
@@ -1,9 +1,9 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class SlovakDefaults(Language.Defaults):
+class SlovakDefaults(BaseDefaults):
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/sl/__init__.py b/spacy/lang/sl/__init__.py
index 0330cc4d0..9ddd676bf 100644
--- a/spacy/lang/sl/__init__.py
+++ b/spacy/lang/sl/__init__.py
@@ -1,8 +1,8 @@
from .stop_words import STOP_WORDS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class SlovenianDefaults(Language.Defaults):
+class SlovenianDefaults(BaseDefaults):
stop_words = STOP_WORDS
diff --git a/spacy/lang/sq/__init__.py b/spacy/lang/sq/__init__.py
index a4bacfa49..5e32a0cbe 100644
--- a/spacy/lang/sq/__init__.py
+++ b/spacy/lang/sq/__init__.py
@@ -1,8 +1,8 @@
from .stop_words import STOP_WORDS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class AlbanianDefaults(Language.Defaults):
+class AlbanianDefaults(BaseDefaults):
stop_words = STOP_WORDS
diff --git a/spacy/lang/sr/__init__.py b/spacy/lang/sr/__init__.py
index 165e54975..fd0c8c832 100644
--- a/spacy/lang/sr/__init__.py
+++ b/spacy/lang/sr/__init__.py
@@ -1,10 +1,10 @@
from .stop_words import STOP_WORDS
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class SerbianDefaults(Language.Defaults):
+class SerbianDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/sv/__init__.py b/spacy/lang/sv/__init__.py
index 1b1b69fac..518ee0db7 100644
--- a/spacy/lang/sv/__init__.py
+++ b/spacy/lang/sv/__init__.py
@@ -4,7 +4,7 @@ from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
from .syntax_iterators import SYNTAX_ITERATORS
-from ...language import Language
+from ...language import Language, BaseDefaults
from ...pipeline import Lemmatizer
@@ -12,7 +12,7 @@ from ...pipeline import Lemmatizer
from ..da.punctuation import TOKENIZER_INFIXES, TOKENIZER_SUFFIXES
-class SwedishDefaults(Language.Defaults):
+class SwedishDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
infixes = TOKENIZER_INFIXES
suffixes = TOKENIZER_SUFFIXES
diff --git a/spacy/lang/sv/syntax_iterators.py b/spacy/lang/sv/syntax_iterators.py
index d5ae47853..06ad016ac 100644
--- a/spacy/lang/sv/syntax_iterators.py
+++ b/spacy/lang/sv/syntax_iterators.py
@@ -1,11 +1,11 @@
-from typing import Union, Iterator
+from typing import Union, Iterator, Tuple
from ...symbols import NOUN, PROPN, PRON
from ...errors import Errors
from ...tokens import Doc, Span
-def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Span]:
+def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
"""Detect base noun phrases from a dependency parse. Works on Doc and Span."""
# fmt: off
labels = ["nsubj", "nsubj:pass", "dobj", "obj", "iobj", "ROOT", "appos", "nmod", "nmod:poss"]
diff --git a/spacy/lang/ta/__init__.py b/spacy/lang/ta/__init__.py
index ac5fc7124..4929a4b97 100644
--- a/spacy/lang/ta/__init__.py
+++ b/spacy/lang/ta/__init__.py
@@ -1,9 +1,9 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class TamilDefaults(Language.Defaults):
+class TamilDefaults(BaseDefaults):
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/te/__init__.py b/spacy/lang/te/__init__.py
index e6dc80e28..77cc2fe9b 100644
--- a/spacy/lang/te/__init__.py
+++ b/spacy/lang/te/__init__.py
@@ -1,9 +1,9 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class TeluguDefaults(Language.Defaults):
+class TeluguDefaults(BaseDefaults):
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/th/__init__.py b/spacy/lang/th/__init__.py
index 219c50c1a..10d466bd3 100644
--- a/spacy/lang/th/__init__.py
+++ b/spacy/lang/th/__init__.py
@@ -1,6 +1,6 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
from ...tokens import Doc
from ...util import DummyTokenizer, registry, load_config_from_str
@@ -39,7 +39,7 @@ class ThaiTokenizer(DummyTokenizer):
return Doc(self.vocab, words=words, spaces=spaces)
-class ThaiDefaults(Language.Defaults):
+class ThaiDefaults(BaseDefaults):
config = load_config_from_str(DEFAULT_CONFIG)
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/ti/__init__.py b/spacy/lang/ti/__init__.py
index 709fb21cb..c74c081b5 100644
--- a/spacy/lang/ti/__init__.py
+++ b/spacy/lang/ti/__init__.py
@@ -4,12 +4,12 @@ from .punctuation import TOKENIZER_SUFFIXES
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from ..tokenizer_exceptions import BASE_EXCEPTIONS
-from ...language import Language
+from ...language import Language, BaseDefaults
from ...attrs import LANG
from ...util import update_exc
-class TigrinyaDefaults(Language.Defaults):
+class TigrinyaDefaults(BaseDefaults):
lex_attr_getters = dict(Language.Defaults.lex_attr_getters)
lex_attr_getters.update(LEX_ATTRS)
lex_attr_getters[LANG] = lambda text: "ti"
diff --git a/spacy/lang/tl/__init__.py b/spacy/lang/tl/__init__.py
index 61530dc30..30838890a 100644
--- a/spacy/lang/tl/__init__.py
+++ b/spacy/lang/tl/__init__.py
@@ -1,10 +1,10 @@
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class TagalogDefaults(Language.Defaults):
+class TagalogDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/tn/__init__.py b/spacy/lang/tn/__init__.py
index 99907c28a..28e887eea 100644
--- a/spacy/lang/tn/__init__.py
+++ b/spacy/lang/tn/__init__.py
@@ -1,10 +1,10 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
from .punctuation import TOKENIZER_INFIXES
-from ...language import Language
+from ...language import Language, BaseDefaults
-class SetswanaDefaults(Language.Defaults):
+class SetswanaDefaults(BaseDefaults):
infixes = TOKENIZER_INFIXES
stop_words = STOP_WORDS
lex_attr_getters = LEX_ATTRS
diff --git a/spacy/lang/tr/__init__.py b/spacy/lang/tr/__init__.py
index 679411acf..02b5c7bf4 100644
--- a/spacy/lang/tr/__init__.py
+++ b/spacy/lang/tr/__init__.py
@@ -2,10 +2,10 @@ from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS, TOKEN_MATCH
from .stop_words import STOP_WORDS
from .syntax_iterators import SYNTAX_ITERATORS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class TurkishDefaults(Language.Defaults):
+class TurkishDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/tr/syntax_iterators.py b/spacy/lang/tr/syntax_iterators.py
index 3fd726fb5..769af1223 100644
--- a/spacy/lang/tr/syntax_iterators.py
+++ b/spacy/lang/tr/syntax_iterators.py
@@ -1,8 +1,10 @@
+from typing import Union, Iterator, Tuple
+from ...tokens import Doc, Span
from ...symbols import NOUN, PROPN, PRON
from ...errors import Errors
-def noun_chunks(doclike):
+def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
"""
Detect base noun phrases from a dependency parse. Works on both Doc and Span.
"""
diff --git a/spacy/lang/tt/__init__.py b/spacy/lang/tt/__init__.py
index c8e293f29..d5e1e87ef 100644
--- a/spacy/lang/tt/__init__.py
+++ b/spacy/lang/tt/__init__.py
@@ -2,10 +2,10 @@ from .lex_attrs import LEX_ATTRS
from .punctuation import TOKENIZER_INFIXES
from .stop_words import STOP_WORDS
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class TatarDefaults(Language.Defaults):
+class TatarDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
infixes = TOKENIZER_INFIXES
lex_attr_getters = LEX_ATTRS
diff --git a/spacy/lang/uk/__init__.py b/spacy/lang/uk/__init__.py
index 677281ec6..1fa568292 100644
--- a/spacy/lang/uk/__init__.py
+++ b/spacy/lang/uk/__init__.py
@@ -6,10 +6,10 @@ from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
from .lemmatizer import UkrainianLemmatizer
-from ...language import Language
+from ...language import Language, BaseDefaults
-class UkrainianDefaults(Language.Defaults):
+class UkrainianDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/ur/__init__.py b/spacy/lang/ur/__init__.py
index e3dee5805..266c5a73d 100644
--- a/spacy/lang/ur/__init__.py
+++ b/spacy/lang/ur/__init__.py
@@ -1,10 +1,10 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
from .punctuation import TOKENIZER_SUFFIXES
-from ...language import Language
+from ...language import Language, BaseDefaults
-class UrduDefaults(Language.Defaults):
+class UrduDefaults(BaseDefaults):
suffixes = TOKENIZER_SUFFIXES
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/vi/__init__.py b/spacy/lang/vi/__init__.py
index b6d873a13..9d5fd8d9d 100644
--- a/spacy/lang/vi/__init__.py
+++ b/spacy/lang/vi/__init__.py
@@ -6,7 +6,7 @@ import string
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
from ...tokens import Doc
from ...util import DummyTokenizer, registry, load_config_from_str
from ... import util
@@ -141,7 +141,7 @@ class VietnameseTokenizer(DummyTokenizer):
def to_disk(self, path: Union[str, Path], **kwargs) -> None:
path = util.ensure_path(path)
serializers = {"cfg": lambda p: srsly.write_json(p, self._get_config())}
- return util.to_disk(path, serializers, [])
+ util.to_disk(path, serializers, [])
def from_disk(self, path: Union[str, Path], **kwargs) -> "VietnameseTokenizer":
path = util.ensure_path(path)
@@ -150,7 +150,7 @@ class VietnameseTokenizer(DummyTokenizer):
return self
-class VietnameseDefaults(Language.Defaults):
+class VietnameseDefaults(BaseDefaults):
config = load_config_from_str(DEFAULT_CONFIG)
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/yo/__init__.py b/spacy/lang/yo/__init__.py
index df6bb7d4a..6c38ec8af 100644
--- a/spacy/lang/yo/__init__.py
+++ b/spacy/lang/yo/__init__.py
@@ -1,9 +1,9 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class YorubaDefaults(Language.Defaults):
+class YorubaDefaults(BaseDefaults):
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/zh/__init__.py b/spacy/lang/zh/__init__.py
index 9a8a21a63..755a294e2 100644
--- a/spacy/lang/zh/__init__.py
+++ b/spacy/lang/zh/__init__.py
@@ -6,7 +6,7 @@ import warnings
from pathlib import Path
from ...errors import Warnings, Errors
-from ...language import Language
+from ...language import Language, BaseDefaults
from ...scorer import Scorer
from ...tokens import Doc
from ...training import validate_examples, Example
@@ -56,21 +56,21 @@ def create_chinese_tokenizer(segmenter: Segmenter = Segmenter.char):
class ChineseTokenizer(DummyTokenizer):
def __init__(self, nlp: Language, segmenter: Segmenter = Segmenter.char):
self.vocab = nlp.vocab
- if isinstance(segmenter, Segmenter):
- segmenter = segmenter.value
- self.segmenter = segmenter
+ self.segmenter = (
+ segmenter.value if isinstance(segmenter, Segmenter) else segmenter
+ )
self.pkuseg_seg = None
self.jieba_seg = None
- if segmenter not in Segmenter.values():
+ if self.segmenter not in Segmenter.values():
warn_msg = Warnings.W103.format(
lang="Chinese",
- segmenter=segmenter,
+ segmenter=self.segmenter,
supported=", ".join(Segmenter.values()),
default="'char' (character segmentation)",
)
warnings.warn(warn_msg)
self.segmenter = Segmenter.char
- if segmenter == Segmenter.jieba:
+ if self.segmenter == Segmenter.jieba:
self.jieba_seg = try_jieba_import()
def initialize(
@@ -90,7 +90,7 @@ class ChineseTokenizer(DummyTokenizer):
def __call__(self, text: str) -> Doc:
if self.segmenter == Segmenter.jieba:
- words = list([x for x in self.jieba_seg.cut(text, cut_all=False) if x])
+ words = list([x for x in self.jieba_seg.cut(text, cut_all=False) if x]) # type: ignore[union-attr]
(words, spaces) = util.get_words_and_spaces(words, text)
return Doc(self.vocab, words=words, spaces=spaces)
elif self.segmenter == Segmenter.pkuseg:
@@ -121,7 +121,7 @@ class ChineseTokenizer(DummyTokenizer):
try:
import spacy_pkuseg
- self.pkuseg_seg.preprocesser = spacy_pkuseg.Preprocesser(None)
+ self.pkuseg_seg.preprocesser = spacy_pkuseg.Preprocesser(None) # type: ignore[attr-defined]
except ImportError:
msg = (
"spacy_pkuseg not installed: unable to reset pkuseg "
@@ -129,7 +129,7 @@ class ChineseTokenizer(DummyTokenizer):
)
raise ImportError(msg) from None
for word in words:
- self.pkuseg_seg.preprocesser.insert(word.strip(), "")
+ self.pkuseg_seg.preprocesser.insert(word.strip(), "") # type: ignore[attr-defined]
else:
warn_msg = Warnings.W104.format(target="pkuseg", current=self.segmenter)
warnings.warn(warn_msg)
@@ -282,7 +282,7 @@ class ChineseTokenizer(DummyTokenizer):
util.from_disk(path, serializers, [])
-class ChineseDefaults(Language.Defaults):
+class ChineseDefaults(BaseDefaults):
config = load_config_from_str(DEFAULT_CONFIG)
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
@@ -294,7 +294,7 @@ class Chinese(Language):
Defaults = ChineseDefaults
-def try_jieba_import() -> None:
+def try_jieba_import():
try:
import jieba
@@ -310,7 +310,7 @@ def try_jieba_import() -> None:
raise ImportError(msg) from None
-def try_pkuseg_import(pkuseg_model: str, pkuseg_user_dict: str) -> None:
+def try_pkuseg_import(pkuseg_model: Optional[str], pkuseg_user_dict: Optional[str]):
try:
import spacy_pkuseg
@@ -318,9 +318,9 @@ def try_pkuseg_import(pkuseg_model: str, pkuseg_user_dict: str) -> None:
msg = "spacy-pkuseg not installed. To use pkuseg, " + _PKUSEG_INSTALL_MSG
raise ImportError(msg) from None
try:
- return spacy_pkuseg.pkuseg(pkuseg_model, pkuseg_user_dict)
+ return spacy_pkuseg.pkuseg(pkuseg_model, user_dict=pkuseg_user_dict)
except FileNotFoundError:
- msg = "Unable to load pkuseg model from: " + pkuseg_model
+ msg = "Unable to load pkuseg model from: " + str(pkuseg_model or "")
raise FileNotFoundError(msg) from None
diff --git a/spacy/language.py b/spacy/language.py
index 81d740d74..37fdf9e0d 100644
--- a/spacy/language.py
+++ b/spacy/language.py
@@ -1,6 +1,7 @@
-from typing import Iterator, Optional, Any, Dict, Callable, Iterable, TypeVar
-from typing import Union, List, Pattern, overload
-from typing import Tuple
+from typing import Iterator, Optional, Any, Dict, Callable, Iterable
+from typing import Union, Tuple, List, Set, Pattern, Sequence
+from typing import NoReturn, TYPE_CHECKING, TypeVar, cast, overload
+
from dataclasses import dataclass
import random
import itertools
@@ -37,6 +38,11 @@ from .git_info import GIT_VERSION
from . import util
from . import about
from .lookups import load_lookups
+from .compat import Literal
+
+
+if TYPE_CHECKING:
+ from .pipeline import Pipe # noqa: F401
# This is the base config will all settings (training etc.)
@@ -46,6 +52,9 @@ DEFAULT_CONFIG = util.load_config(DEFAULT_CONFIG_PATH)
# in the main config and only added via the 'init fill-config' command
DEFAULT_CONFIG_PRETRAIN_PATH = Path(__file__).parent / "default_config_pretraining.cfg"
+# Type variable for contexts piped with documents
+_AnyContext = TypeVar("_AnyContext")
+
class BaseDefaults:
"""Language data defaults, available via Language.Defaults. Can be
@@ -55,14 +64,14 @@ class BaseDefaults:
config: Config = Config(section_order=CONFIG_SECTION_ORDER)
tokenizer_exceptions: Dict[str, List[dict]] = BASE_EXCEPTIONS
- prefixes: Optional[List[Union[str, Pattern]]] = TOKENIZER_PREFIXES
- suffixes: Optional[List[Union[str, Pattern]]] = TOKENIZER_SUFFIXES
- infixes: Optional[List[Union[str, Pattern]]] = TOKENIZER_INFIXES
- token_match: Optional[Pattern] = None
- url_match: Optional[Pattern] = URL_MATCH
+ prefixes: Optional[Sequence[Union[str, Pattern]]] = TOKENIZER_PREFIXES
+ suffixes: Optional[Sequence[Union[str, Pattern]]] = TOKENIZER_SUFFIXES
+ infixes: Optional[Sequence[Union[str, Pattern]]] = TOKENIZER_INFIXES
+ token_match: Optional[Callable] = None
+ url_match: Optional[Callable] = URL_MATCH
syntax_iterators: Dict[str, Callable] = {}
lex_attr_getters: Dict[int, Callable[[str], Any]] = {}
- stop_words = set()
+ stop_words: Set[str] = set()
writing_system = {"direction": "ltr", "has_case": True, "has_letters": True}
@@ -111,7 +120,7 @@ class Language:
"""
Defaults = BaseDefaults
- lang: str = None
+ lang: Optional[str] = None
default_config = DEFAULT_CONFIG
factories = SimpleFrozenDict(error=Errors.E957)
@@ -154,7 +163,7 @@ class Language:
self._config = DEFAULT_CONFIG.merge(self.default_config)
self._meta = dict(meta)
self._path = None
- self._optimizer = None
+ self._optimizer: Optional[Optimizer] = None
# Component meta and configs are only needed on the instance
self._pipe_meta: Dict[str, "FactoryMeta"] = {} # meta by component
self._pipe_configs: Dict[str, Config] = {} # config by component
@@ -170,8 +179,8 @@ class Language:
self.vocab: Vocab = vocab
if self.lang is None:
self.lang = self.vocab.lang
- self._components = []
- self._disabled = set()
+ self._components: List[Tuple[str, "Pipe"]] = []
+ self._disabled: Set[str] = set()
self.max_length = max_length
# Create the default tokenizer from the default config
if not create_tokenizer:
@@ -291,7 +300,7 @@ class Language:
return SimpleFrozenList(names)
@property
- def components(self) -> List[Tuple[str, Callable[[Doc], Doc]]]:
+ def components(self) -> List[Tuple[str, "Pipe"]]:
"""Get all (name, component) tuples in the pipeline, including the
currently disabled components.
"""
@@ -310,12 +319,12 @@ class Language:
return SimpleFrozenList(names, error=Errors.E926.format(attr="component_names"))
@property
- def pipeline(self) -> List[Tuple[str, Callable[[Doc], Doc]]]:
+ def pipeline(self) -> List[Tuple[str, "Pipe"]]:
"""The processing pipeline consisting of (name, component) tuples. The
components are called on the Doc in order as it passes through the
pipeline.
- RETURNS (List[Tuple[str, Callable[[Doc], Doc]]]): The pipeline.
+ RETURNS (List[Tuple[str, Pipe]]): The pipeline.
"""
pipes = [(n, p) for n, p in self._components if n not in self._disabled]
return SimpleFrozenList(pipes, error=Errors.E926.format(attr="pipeline"))
@@ -423,7 +432,7 @@ class Language:
assigns: Iterable[str] = SimpleFrozenList(),
requires: Iterable[str] = SimpleFrozenList(),
retokenizes: bool = False,
- default_score_weights: Dict[str, float] = SimpleFrozenDict(),
+ default_score_weights: Dict[str, Optional[float]] = SimpleFrozenDict(),
func: Optional[Callable] = None,
) -> Callable:
"""Register a new pipeline component factory. Can be used as a decorator
@@ -440,7 +449,7 @@ class Language:
e.g. "token.ent_id". Used for pipeline analysis.
retokenizes (bool): Whether the component changes the tokenization.
Used for pipeline analysis.
- default_score_weights (Dict[str, float]): The scores to report during
+ default_score_weights (Dict[str, Optional[float]]): The scores to report during
training, and their default weight towards the final score used to
select the best model. Weights should sum to 1.0 per component and
will be combined and normalized for the whole pipeline. If None,
@@ -505,12 +514,12 @@ class Language:
@classmethod
def component(
cls,
- name: Optional[str] = None,
+ name: str,
*,
assigns: Iterable[str] = SimpleFrozenList(),
requires: Iterable[str] = SimpleFrozenList(),
retokenizes: bool = False,
- func: Optional[Callable[[Doc], Doc]] = None,
+ func: Optional["Pipe"] = None,
) -> Callable:
"""Register a new pipeline component. Can be used for stateless function
components that don't require a separate factory. Can be used as a
@@ -533,11 +542,11 @@ class Language:
raise ValueError(Errors.E963.format(decorator="component"))
component_name = name if name is not None else util.get_object_name(func)
- def add_component(component_func: Callable[[Doc], Doc]) -> Callable:
+ def add_component(component_func: "Pipe") -> Callable:
if isinstance(func, type): # function is a class
raise ValueError(Errors.E965.format(name=component_name))
- def factory_func(nlp: cls, name: str) -> Callable[[Doc], Doc]:
+ def factory_func(nlp, name: str) -> "Pipe":
return component_func
internal_name = cls.get_factory_name(name)
@@ -587,7 +596,7 @@ class Language:
print_pipe_analysis(analysis, keys=keys)
return analysis
- def get_pipe(self, name: str) -> Callable[[Doc], Doc]:
+ def get_pipe(self, name: str) -> "Pipe":
"""Get a pipeline component for a given component name.
name (str): Name of pipeline component to get.
@@ -608,7 +617,7 @@ class Language:
config: Dict[str, Any] = SimpleFrozenDict(),
raw_config: Optional[Config] = None,
validate: bool = True,
- ) -> Callable[[Doc], Doc]:
+ ) -> "Pipe":
"""Create a pipeline component. Mostly used internally. To create and
add a component to the pipeline, you can use nlp.add_pipe.
@@ -620,7 +629,7 @@ class Language:
raw_config (Optional[Config]): Internals: the non-interpolated config.
validate (bool): Whether to validate the component config against the
arguments and types expected by the factory.
- RETURNS (Callable[[Doc], Doc]): The pipeline component.
+ RETURNS (Pipe): The pipeline component.
DOCS: https://spacy.io/api/language#create_pipe
"""
@@ -675,7 +684,7 @@ class Language:
def create_pipe_from_source(
self, source_name: str, source: "Language", *, name: str
- ) -> Tuple[Callable[[Doc], Doc], str]:
+ ) -> Tuple["Pipe", str]:
"""Create a pipeline component by copying it from an existing model.
source_name (str): Name of the component in the source pipeline.
@@ -725,7 +734,7 @@ class Language:
config: Dict[str, Any] = SimpleFrozenDict(),
raw_config: Optional[Config] = None,
validate: bool = True,
- ) -> Callable[[Doc], Doc]:
+ ) -> "Pipe":
"""Add a component to the processing pipeline. Valid components are
callables that take a `Doc` object, modify it and return it. Only one
of before/after/first/last can be set. Default behaviour is "last".
@@ -748,7 +757,7 @@ class Language:
raw_config (Optional[Config]): Internals: the non-interpolated config.
validate (bool): Whether to validate the component config against the
arguments and types expected by the factory.
- RETURNS (Callable[[Doc], Doc]): The pipeline component.
+ RETURNS (Pipe): The pipeline component.
DOCS: https://spacy.io/api/language#add_pipe
"""
@@ -859,7 +868,7 @@ class Language:
*,
config: Dict[str, Any] = SimpleFrozenDict(),
validate: bool = True,
- ) -> Callable[[Doc], Doc]:
+ ) -> "Pipe":
"""Replace a component in the pipeline.
name (str): Name of the component to replace.
@@ -868,7 +877,7 @@ class Language:
component. Will be merged with default config, if available.
validate (bool): Whether to validate the component config against the
arguments and types expected by the factory.
- RETURNS (Callable[[Doc], Doc]): The new pipeline component.
+ RETURNS (Pipe): The new pipeline component.
DOCS: https://spacy.io/api/language#replace_pipe
"""
@@ -920,7 +929,7 @@ class Language:
init_cfg = self._config["initialize"]["components"].pop(old_name)
self._config["initialize"]["components"][new_name] = init_cfg
- def remove_pipe(self, name: str) -> Tuple[str, Callable[[Doc], Doc]]:
+ def remove_pipe(self, name: str) -> Tuple[str, "Pipe"]:
"""Remove a component from the pipeline.
name (str): Name of the component to remove.
@@ -978,7 +987,7 @@ class Language:
is preserved.
text (str): The text to be processed.
- disable (list): Names of the pipeline components to disable.
+ disable (List[str]): Names of the pipeline components to disable.
component_cfg (Dict[str, dict]): An optional dictionary with extra
keyword arguments for specific components.
RETURNS (Doc): A container for accessing the annotations.
@@ -997,7 +1006,7 @@ class Language:
if hasattr(proc, "get_error_handler"):
error_handler = proc.get_error_handler()
try:
- doc = proc(doc, **component_cfg.get(name, {}))
+ doc = proc(doc, **component_cfg.get(name, {})) # type: ignore[call-arg]
except KeyError as e:
# This typically happens if a component is not initialized
raise ValueError(Errors.E109.format(name=name)) from e
@@ -1017,7 +1026,7 @@ class Language:
"""
warnings.warn(Warnings.W096, DeprecationWarning)
if len(names) == 1 and isinstance(names[0], (list, tuple)):
- names = names[0] # support list of names instead of spread
+ names = names[0] # type: ignore[assignment] # support list of names instead of spread
return self.select_pipes(disable=names)
def select_pipes(
@@ -1052,6 +1061,7 @@ class Language:
)
)
disable = to_disable
+ assert disable is not None
# DisabledPipes will restore the pipes in 'disable' when it's done, so we need to exclude
# those pipes that were already disabled.
disable = [d for d in disable if d not in self._disabled]
@@ -1102,7 +1112,7 @@ class Language:
raise ValueError(Errors.E989)
if losses is None:
losses = {}
- if len(examples) == 0:
+ if isinstance(examples, list) and len(examples) == 0:
return losses
validate_examples(examples, "Language.update")
examples = _copy_examples(examples)
@@ -1119,16 +1129,17 @@ class Language:
component_cfg[name].setdefault("drop", drop)
pipe_kwargs[name].setdefault("batch_size", self.batch_size)
for name, proc in self.pipeline:
+ # ignore statements are used here because mypy ignores hasattr
if name not in exclude and hasattr(proc, "update"):
- proc.update(examples, sgd=None, losses=losses, **component_cfg[name])
+ proc.update(examples, sgd=None, losses=losses, **component_cfg[name]) # type: ignore
if sgd not in (None, False):
if (
name not in exclude
and hasattr(proc, "is_trainable")
and proc.is_trainable
- and proc.model not in (True, False, None)
+ and proc.model not in (True, False, None) # type: ignore
):
- proc.finish_update(sgd)
+ proc.finish_update(sgd) # type: ignore
if name in annotates:
for doc, eg in zip(
_pipe(
@@ -1174,8 +1185,10 @@ class Language:
DOCS: https://spacy.io/api/language#rehearse
"""
- if len(examples) == 0:
- return
+ if losses is None:
+ losses = {}
+ if isinstance(examples, list) and len(examples) == 0:
+ return losses
validate_examples(examples, "Language.rehearse")
if sgd is None:
if self._optimizer is None:
@@ -1190,18 +1203,18 @@ class Language:
def get_grads(W, dW, key=None):
grads[key] = (W, dW)
- get_grads.learn_rate = sgd.learn_rate
- get_grads.b1 = sgd.b1
- get_grads.b2 = sgd.b2
+ get_grads.learn_rate = sgd.learn_rate # type: ignore[attr-defined, union-attr]
+ get_grads.b1 = sgd.b1 # type: ignore[attr-defined, union-attr]
+ get_grads.b2 = sgd.b2 # type: ignore[attr-defined, union-attr]
for name, proc in pipes:
if name in exclude or not hasattr(proc, "rehearse"):
continue
grads = {}
- proc.rehearse(
+ proc.rehearse( # type: ignore[attr-defined]
examples, sgd=get_grads, losses=losses, **component_cfg.get(name, {})
)
for key, (W, dW) in grads.items():
- sgd(W, dW, key=key)
+ sgd(W, dW, key=key) # type: ignore[call-arg, misc]
return losses
def begin_training(
@@ -1258,19 +1271,19 @@ class Language:
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
if hasattr(self.tokenizer, "initialize"):
tok_settings = validate_init_settings(
- self.tokenizer.initialize,
+ self.tokenizer.initialize, # type: ignore[union-attr]
I["tokenizer"],
section="tokenizer",
name="tokenizer",
)
- self.tokenizer.initialize(get_examples, nlp=self, **tok_settings)
+ self.tokenizer.initialize(get_examples, nlp=self, **tok_settings) # type: ignore[union-attr]
for name, proc in self.pipeline:
if hasattr(proc, "initialize"):
p_settings = I["components"].get(name, {})
p_settings = validate_init_settings(
proc.initialize, p_settings, section="components", name=name
)
- proc.initialize(get_examples, nlp=self, **p_settings)
+ proc.initialize(get_examples, nlp=self, **p_settings) # type: ignore[call-arg]
pretrain_cfg = config.get("pretraining")
if pretrain_cfg:
P = registry.resolve(pretrain_cfg, schema=ConfigSchemaPretrain)
@@ -1304,7 +1317,7 @@ class Language:
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
for name, proc in self.pipeline:
if hasattr(proc, "_rehearsal_model"):
- proc._rehearsal_model = deepcopy(proc.model)
+ proc._rehearsal_model = deepcopy(proc.model) # type: ignore[attr-defined]
if sgd is not None:
self._optimizer = sgd
elif self._optimizer is None:
@@ -1313,14 +1326,12 @@ class Language:
def set_error_handler(
self,
- error_handler: Callable[
- [str, Callable[[Doc], Doc], List[Doc], Exception], None
- ],
+ error_handler: Callable[[str, "Pipe", List[Doc], Exception], NoReturn],
):
"""Set an error handler object for all the components in the pipeline that implement
a set_error_handler function.
- error_handler (Callable[[str, Callable[[Doc], Doc], List[Doc], Exception], None]):
+ error_handler (Callable[[str, Pipe, List[Doc], Exception], NoReturn]):
Function that deals with a failing batch of documents. This callable function should take in
the component's name, the component itself, the offending batch of documents, and the exception
that was thrown.
@@ -1339,7 +1350,7 @@ class Language:
scorer: Optional[Scorer] = None,
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
scorer_cfg: Optional[Dict[str, Any]] = None,
- ) -> Dict[str, Union[float, dict]]:
+ ) -> Dict[str, Any]:
"""Evaluate a model's pipeline components.
examples (Iterable[Example]): `Example` objects.
@@ -1417,7 +1428,7 @@ class Language:
yield
else:
contexts = [
- pipe.use_params(params)
+ pipe.use_params(params) # type: ignore[attr-defined]
for name, pipe in self.pipeline
if hasattr(pipe, "use_params") and hasattr(pipe, "model")
]
@@ -1435,14 +1446,25 @@ class Language:
except StopIteration:
pass
- _AnyContext = TypeVar("_AnyContext")
-
@overload
def pipe(
+ self,
+ texts: Iterable[str],
+ *,
+ as_tuples: Literal[False] = ...,
+ batch_size: Optional[int] = ...,
+ disable: Iterable[str] = ...,
+ component_cfg: Optional[Dict[str, Dict[str, Any]]] = ...,
+ n_process: int = ...,
+ ) -> Iterator[Doc]:
+ ...
+
+ @overload
+ def pipe( # noqa: F811
self,
texts: Iterable[Tuple[str, _AnyContext]],
*,
- as_tuples: bool = ...,
+ as_tuples: Literal[True] = ...,
batch_size: Optional[int] = ...,
disable: Iterable[str] = ...,
component_cfg: Optional[Dict[str, Dict[str, Any]]] = ...,
@@ -1452,14 +1474,14 @@ class Language:
def pipe( # noqa: F811
self,
- texts: Iterable[str],
+ texts: Union[Iterable[str], Iterable[Tuple[str, _AnyContext]]],
*,
as_tuples: bool = False,
batch_size: Optional[int] = None,
disable: Iterable[str] = SimpleFrozenList(),
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
n_process: int = 1,
- ) -> Iterator[Doc]:
+ ) -> Union[Iterator[Doc], Iterator[Tuple[Doc, _AnyContext]]]:
"""Process texts as a stream, and yield `Doc` objects in order.
texts (Iterable[str]): A sequence of texts to process.
@@ -1475,9 +1497,9 @@ class Language:
DOCS: https://spacy.io/api/language#pipe
"""
- if n_process == -1:
- n_process = mp.cpu_count()
+ # Handle texts with context as tuples
if as_tuples:
+ texts = cast(Iterable[Tuple[str, _AnyContext]], texts)
text_context1, text_context2 = itertools.tee(texts)
texts = (tc[0] for tc in text_context1)
contexts = (tc[1] for tc in text_context2)
@@ -1491,6 +1513,13 @@ class Language:
for doc, context in zip(docs, contexts):
yield (doc, context)
return
+
+ # At this point, we know that we're dealing with an iterable of plain texts
+ texts = cast(Iterable[str], texts)
+
+ # Set argument defaults
+ if n_process == -1:
+ n_process = mp.cpu_count()
if component_cfg is None:
component_cfg = {}
if batch_size is None:
@@ -1527,14 +1556,14 @@ class Language:
def _multiprocessing_pipe(
self,
texts: Iterable[str],
- pipes: Iterable[Callable[[Doc], Doc]],
+ pipes: Iterable[Callable[..., Iterator[Doc]]],
n_process: int,
batch_size: int,
- ) -> None:
+ ) -> Iterator[Doc]:
# raw_texts is used later to stop iteration.
texts, raw_texts = itertools.tee(texts)
# for sending texts to worker
- texts_q = [mp.Queue() for _ in range(n_process)]
+ texts_q: List[mp.Queue] = [mp.Queue() for _ in range(n_process)]
# for receiving byte-encoded docs from worker
bytedocs_recv_ch, bytedocs_send_ch = zip(
*[mp.Pipe(False) for _ in range(n_process)]
@@ -1595,7 +1624,7 @@ class Language:
for i, (name1, proc1) in enumerate(self.pipeline):
if hasattr(proc1, "find_listeners"):
for name2, proc2 in self.pipeline[i + 1 :]:
- proc1.find_listeners(proc2)
+ proc1.find_listeners(proc2) # type: ignore[attr-defined]
@classmethod
def from_config(
@@ -1787,8 +1816,8 @@ class Language:
ll for ll in listener_names if ll not in nlp.pipe_names
]
for listener_name in unused_listener_names:
- for listener in proc.listener_map.get(listener_name, []):
- proc.remove_listener(listener, listener_name)
+ for listener in proc.listener_map.get(listener_name, []): # type: ignore[attr-defined]
+ proc.remove_listener(listener, listener_name) # type: ignore[attr-defined]
for listener in getattr(
proc, "listening_components", []
@@ -1849,7 +1878,6 @@ class Language:
raise ValueError(err)
tok2vec = self.get_pipe(tok2vec_name)
tok2vec_cfg = self.get_pipe_config(tok2vec_name)
- tok2vec_model = tok2vec.model
if (
not hasattr(tok2vec, "model")
or not hasattr(tok2vec, "listener_map")
@@ -1857,12 +1885,13 @@ class Language:
or "model" not in tok2vec_cfg
):
raise ValueError(Errors.E888.format(name=tok2vec_name, pipe=type(tok2vec)))
- pipe_listeners = tok2vec.listener_map.get(pipe_name, [])
+ tok2vec_model = tok2vec.model # type: ignore[attr-defined]
+ pipe_listeners = tok2vec.listener_map.get(pipe_name, []) # type: ignore[attr-defined]
pipe = self.get_pipe(pipe_name)
pipe_cfg = self._pipe_configs[pipe_name]
if listeners:
util.logger.debug(f"Replacing listeners of component '{pipe_name}'")
- if len(listeners) != len(pipe_listeners):
+ if len(list(listeners)) != len(pipe_listeners):
# The number of listeners defined in the component model doesn't
# match the listeners to replace, so we won't be able to update
# the nodes and generate a matching config
@@ -1896,8 +1925,8 @@ class Language:
new_model = tok2vec_model.copy()
if "replace_listener" in tok2vec_model.attrs:
new_model = tok2vec_model.attrs["replace_listener"](new_model)
- util.replace_model_node(pipe.model, listener, new_model)
- tok2vec.remove_listener(listener, pipe_name)
+ util.replace_model_node(pipe.model, listener, new_model) # type: ignore[attr-defined]
+ tok2vec.remove_listener(listener, pipe_name) # type: ignore[attr-defined]
def to_disk(
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
@@ -1907,13 +1936,13 @@ class Language:
path (str / Path): Path to a directory, which will be created if
it doesn't exist.
- exclude (list): Names of components or serialization fields to exclude.
+ exclude (Iterable[str]): Names of components or serialization fields to exclude.
DOCS: https://spacy.io/api/language#to_disk
"""
path = util.ensure_path(path)
serializers = {}
- serializers["tokenizer"] = lambda p: self.tokenizer.to_disk(
+ serializers["tokenizer"] = lambda p: self.tokenizer.to_disk( # type: ignore[union-attr]
p, exclude=["vocab"]
)
serializers["meta.json"] = lambda p: srsly.write_json(p, self.meta)
@@ -1923,7 +1952,7 @@ class Language:
continue
if not hasattr(proc, "to_disk"):
continue
- serializers[name] = lambda p, proc=proc: proc.to_disk(p, exclude=["vocab"])
+ serializers[name] = lambda p, proc=proc: proc.to_disk(p, exclude=["vocab"]) # type: ignore[misc]
serializers["vocab"] = lambda p: self.vocab.to_disk(p, exclude=exclude)
util.to_disk(path, serializers, exclude)
@@ -1939,7 +1968,7 @@ class Language:
model will be loaded.
path (str / Path): A path to a directory.
- exclude (list): Names of components or serialization fields to exclude.
+ exclude (Iterable[str]): Names of components or serialization fields to exclude.
RETURNS (Language): The modified `Language` object.
DOCS: https://spacy.io/api/language#from_disk
@@ -1959,13 +1988,13 @@ class Language:
path = util.ensure_path(path)
deserializers = {}
- if Path(path / "config.cfg").exists():
+ if Path(path / "config.cfg").exists(): # type: ignore[operator]
deserializers["config.cfg"] = lambda p: self.config.from_disk(
p, interpolate=False, overrides=overrides
)
- deserializers["meta.json"] = deserialize_meta
- deserializers["vocab"] = deserialize_vocab
- deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk(
+ deserializers["meta.json"] = deserialize_meta # type: ignore[assignment]
+ deserializers["vocab"] = deserialize_vocab # type: ignore[assignment]
+ deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk( # type: ignore[union-attr]
p, exclude=["vocab"]
)
for name, proc in self._components:
@@ -1973,28 +2002,28 @@ class Language:
continue
if not hasattr(proc, "from_disk"):
continue
- deserializers[name] = lambda p, proc=proc: proc.from_disk(
+ deserializers[name] = lambda p, proc=proc: proc.from_disk( # type: ignore[misc]
p, exclude=["vocab"]
)
- if not (path / "vocab").exists() and "vocab" not in exclude:
+ if not (path / "vocab").exists() and "vocab" not in exclude: # type: ignore[operator]
# Convert to list here in case exclude is (default) tuple
exclude = list(exclude) + ["vocab"]
- util.from_disk(path, deserializers, exclude)
- self._path = path
+ util.from_disk(path, deserializers, exclude) # type: ignore[arg-type]
+ self._path = path # type: ignore[assignment]
self._link_components()
return self
def to_bytes(self, *, exclude: Iterable[str] = SimpleFrozenList()) -> bytes:
"""Serialize the current state to a binary string.
- exclude (list): Names of components or serialization fields to exclude.
+ exclude (Iterable[str]): Names of components or serialization fields to exclude.
RETURNS (bytes): The serialized form of the `Language` object.
DOCS: https://spacy.io/api/language#to_bytes
"""
- serializers = {}
+ serializers: Dict[str, Callable[[], bytes]] = {}
serializers["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude)
- serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"])
+ serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"]) # type: ignore[union-attr]
serializers["meta.json"] = lambda: srsly.json_dumps(self.meta)
serializers["config.cfg"] = lambda: self.config.to_bytes()
for name, proc in self._components:
@@ -2002,7 +2031,7 @@ class Language:
continue
if not hasattr(proc, "to_bytes"):
continue
- serializers[name] = lambda proc=proc: proc.to_bytes(exclude=["vocab"])
+ serializers[name] = lambda proc=proc: proc.to_bytes(exclude=["vocab"]) # type: ignore[misc]
return util.to_bytes(serializers, exclude)
def from_bytes(
@@ -2011,7 +2040,7 @@ class Language:
"""Load state from a binary string.
bytes_data (bytes): The data to load from.
- exclude (list): Names of components or serialization fields to exclude.
+ exclude (Iterable[str]): Names of components or serialization fields to exclude.
RETURNS (Language): The `Language` object.
DOCS: https://spacy.io/api/language#from_bytes
@@ -2024,13 +2053,13 @@ class Language:
# from self.vocab.vectors, so set the name directly
self.vocab.vectors.name = data.get("vectors", {}).get("name")
- deserializers = {}
+ deserializers: Dict[str, Callable[[bytes], Any]] = {}
deserializers["config.cfg"] = lambda b: self.config.from_bytes(
b, interpolate=False
)
deserializers["meta.json"] = deserialize_meta
deserializers["vocab"] = lambda b: self.vocab.from_bytes(b, exclude=exclude)
- deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes(
+ deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes( # type: ignore[union-attr]
b, exclude=["vocab"]
)
for name, proc in self._components:
@@ -2038,7 +2067,7 @@ class Language:
continue
if not hasattr(proc, "from_bytes"):
continue
- deserializers[name] = lambda b, proc=proc: proc.from_bytes(
+ deserializers[name] = lambda b, proc=proc: proc.from_bytes( # type: ignore[misc]
b, exclude=["vocab"]
)
util.from_bytes(bytes_data, deserializers, exclude)
@@ -2060,7 +2089,7 @@ class FactoryMeta:
requires: Iterable[str] = tuple()
retokenizes: bool = False
scores: Iterable[str] = tuple()
- default_score_weights: Optional[Dict[str, float]] = None # noqa: E704
+ default_score_weights: Optional[Dict[str, Optional[float]]] = None # noqa: E704
class DisabledPipes(list):
@@ -2100,7 +2129,7 @@ def _copy_examples(examples: Iterable[Example]) -> List[Example]:
def _apply_pipes(
make_doc: Callable[[str], Doc],
- pipes: Iterable[Callable[[Doc], Doc]],
+ pipes: Iterable[Callable[..., Iterator[Doc]]],
receiver,
sender,
underscore_state: Tuple[dict, dict, dict],
@@ -2108,7 +2137,7 @@ def _apply_pipes(
"""Worker for Language.pipe
make_doc (Callable[[str,] Doc]): Function to create Doc from text.
- pipes (Iterable[Callable[[Doc], Doc]]): The components to apply.
+ pipes (Iterable[Pipe]): The components to apply.
receiver (multiprocessing.Connection): Pipe to receive text. Usually
created by `multiprocessing.Pipe()`
sender (multiprocessing.Connection): Pipe to send doc. Usually created by
@@ -2122,11 +2151,11 @@ def _apply_pipes(
texts = receiver.get()
docs = (make_doc(text) for text in texts)
for pipe in pipes:
- docs = pipe(docs)
+ docs = pipe(docs) # type: ignore[arg-type, assignment]
# Connection does not accept unpickable objects, so send list.
byte_docs = [(doc.to_bytes(), None) for doc in docs]
padding = [(None, None)] * (len(texts) - len(byte_docs))
- sender.send(byte_docs + padding)
+ sender.send(byte_docs + padding) # type: ignore[operator]
except Exception:
error_msg = [(None, srsly.msgpack_dumps(traceback.format_exc()))]
padding = [(None, None)] * (len(texts) - 1)
diff --git a/spacy/lookups.py b/spacy/lookups.py
index 025afa04b..b2f3dc15e 100644
--- a/spacy/lookups.py
+++ b/spacy/lookups.py
@@ -1,4 +1,4 @@
-from typing import Any, List, Union, Optional
+from typing import Any, List, Union, Optional, Dict
from pathlib import Path
import srsly
from preshed.bloom import BloomFilter
@@ -34,9 +34,9 @@ def load_lookups(lang: str, tables: List[str], strict: bool = True) -> "Lookups"
if table not in data:
if strict:
raise ValueError(Errors.E955.format(table=table, lang=lang))
- language_data = {}
+ language_data = {} # type: ignore[var-annotated]
else:
- language_data = load_language_data(data[table])
+ language_data = load_language_data(data[table]) # type: ignore[assignment]
lookups.add_table(table, language_data)
return lookups
@@ -116,7 +116,7 @@ class Table(OrderedDict):
key = get_string_id(key)
return OrderedDict.get(self, key, default)
- def __contains__(self, key: Union[str, int]) -> bool:
+ def __contains__(self, key: Union[str, int]) -> bool: # type: ignore[override]
"""Check whether a key is in the table. String keys will be hashed.
key (str / int): The key to check.
@@ -172,7 +172,7 @@ class Lookups:
DOCS: https://spacy.io/api/lookups#init
"""
- self._tables = {}
+ self._tables: Dict[str, Table] = {}
def __contains__(self, name: str) -> bool:
"""Check if the lookups contain a table of a given name. Delegates to
diff --git a/spacy/matcher/matcher.pyi b/spacy/matcher/matcher.pyi
index 3be065bcd..ec4a88eaf 100644
--- a/spacy/matcher/matcher.pyi
+++ b/spacy/matcher/matcher.pyi
@@ -9,7 +9,7 @@ class Matcher:
def __contains__(self, key: str) -> bool: ...
def add(
self,
- key: str,
+ key: Union[str, int],
patterns: List[List[Dict[str, Any]]],
*,
on_match: Optional[
@@ -39,3 +39,4 @@ class Matcher:
allow_missing: bool = ...,
with_alignments: bool = ...
) -> Union[List[Tuple[int, int, int]], List[Span]]: ...
+ def _normalize_key(self, key: Any) -> Any: ...
diff --git a/spacy/matcher/matcher.pyx b/spacy/matcher/matcher.pyx
index 05c55c9a7..718349ad6 100644
--- a/spacy/matcher/matcher.pyx
+++ b/spacy/matcher/matcher.pyx
@@ -101,7 +101,7 @@ cdef class Matcher:
number of arguments). The on_match callback becomes an optional keyword
argument.
- key (str): The match ID.
+ key (Union[str, int]): The match ID.
patterns (list): The patterns to add for the given key.
on_match (callable): Optional callback executed on match.
greedy (str): Optional filter: "FIRST" or "LONGEST".
diff --git a/spacy/matcher/phrasematcher.pyi b/spacy/matcher/phrasematcher.pyi
new file mode 100644
index 000000000..d73633ec0
--- /dev/null
+++ b/spacy/matcher/phrasematcher.pyi
@@ -0,0 +1,25 @@
+from typing import List, Tuple, Union, Optional, Callable, Any, Dict
+
+from . import Matcher
+from ..vocab import Vocab
+from ..tokens import Doc, Span
+
+class PhraseMatcher:
+ def __init__(
+ self, vocab: Vocab, attr: Optional[Union[int, str]], validate: bool = ...
+ ) -> None: ...
+ def __call__(
+ self,
+ doclike: Union[Doc, Span],
+ *,
+ as_spans: bool = ...,
+ ) -> Union[List[Tuple[int, int, int]], List[Span]]: ...
+ def add(
+ self,
+ key: str,
+ docs: List[List[Dict[str, Any]]],
+ *,
+ on_match: Optional[
+ Callable[[Matcher, Doc, int, List[Tuple[Any, ...]]], Any]
+ ] = ...,
+ ) -> None: ...
diff --git a/spacy/ml/_character_embed.py b/spacy/ml/_character_embed.py
index 0ed28b859..e46735102 100644
--- a/spacy/ml/_character_embed.py
+++ b/spacy/ml/_character_embed.py
@@ -44,7 +44,7 @@ def forward(model: Model, docs: List[Doc], is_train: bool):
# Let's say I have a 2d array of indices, and a 3d table of data. What numpy
# incantation do I chant to get
# output[i, j, k] == data[j, ids[i, j], k]?
- doc_vectors[:, nCv] = E[nCv, doc_ids[:, nCv]]
+ doc_vectors[:, nCv] = E[nCv, doc_ids[:, nCv]] # type: ignore[call-overload, index]
output.append(doc_vectors.reshape((len(doc), nO)))
ids.append(doc_ids)
diff --git a/spacy/ml/extract_ngrams.py b/spacy/ml/extract_ngrams.py
index c1c2929fd..c9c82f369 100644
--- a/spacy/ml/extract_ngrams.py
+++ b/spacy/ml/extract_ngrams.py
@@ -6,7 +6,7 @@ from ..attrs import LOWER
@registry.layers("spacy.extract_ngrams.v1")
def extract_ngrams(ngram_size: int, attr: int = LOWER) -> Model:
- model = Model("extract_ngrams", forward)
+ model: Model = Model("extract_ngrams", forward)
model.attrs["ngram_size"] = ngram_size
model.attrs["attr"] = attr
return model
@@ -19,7 +19,7 @@ def forward(model: Model, docs, is_train: bool):
unigrams = model.ops.asarray(doc.to_array([model.attrs["attr"]]))
ngrams = [unigrams]
for n in range(2, model.attrs["ngram_size"] + 1):
- ngrams.append(model.ops.ngrams(n, unigrams))
+ ngrams.append(model.ops.ngrams(n, unigrams)) # type: ignore[arg-type]
keys = model.ops.xp.concatenate(ngrams)
keys, vals = model.ops.xp.unique(keys, return_counts=True)
batch_keys.append(keys)
diff --git a/spacy/ml/extract_spans.py b/spacy/ml/extract_spans.py
index 8afd1a3cc..9bc972032 100644
--- a/spacy/ml/extract_spans.py
+++ b/spacy/ml/extract_spans.py
@@ -28,13 +28,13 @@ def forward(
X, spans = source_spans
assert spans.dataXd.ndim == 2
indices = _get_span_indices(ops, spans, X.lengths)
- Y = Ragged(X.dataXd[indices], spans.dataXd[:, 1] - spans.dataXd[:, 0])
+ Y = Ragged(X.dataXd[indices], spans.dataXd[:, 1] - spans.dataXd[:, 0]) # type: ignore[arg-type, index]
x_shape = X.dataXd.shape
x_lengths = X.lengths
def backprop_windows(dY: Ragged) -> Tuple[Ragged, Ragged]:
dX = Ragged(ops.alloc2f(*x_shape), x_lengths)
- ops.scatter_add(dX.dataXd, indices, dY.dataXd)
+ ops.scatter_add(dX.dataXd, indices, dY.dataXd) # type: ignore[arg-type]
return (dX, spans)
return Y, backprop_windows
@@ -51,7 +51,7 @@ def _get_span_indices(ops, spans: Ragged, lengths: Ints1d) -> Ints1d:
for i, length in enumerate(lengths):
spans_i = spans[i].dataXd + offset
for j in range(spans_i.shape[0]):
- indices.append(ops.xp.arange(spans_i[j, 0], spans_i[j, 1]))
+ indices.append(ops.xp.arange(spans_i[j, 0], spans_i[j, 1])) # type: ignore[call-overload, index]
offset += length
return ops.flatten(indices)
diff --git a/spacy/ml/models/entity_linker.py b/spacy/ml/models/entity_linker.py
index 645b67c62..831fee90f 100644
--- a/spacy/ml/models/entity_linker.py
+++ b/spacy/ml/models/entity_linker.py
@@ -1,16 +1,19 @@
from pathlib import Path
-from typing import Optional, Callable, Iterable
+from typing import Optional, Callable, Iterable, List
+from thinc.types import Floats2d
from thinc.api import chain, clone, list2ragged, reduce_mean, residual
from thinc.api import Model, Maxout, Linear
from ...util import registry
from ...kb import KnowledgeBase, Candidate, get_candidates
from ...vocab import Vocab
-from ...tokens import Span
+from ...tokens import Span, Doc
@registry.architectures("spacy.EntityLinker.v1")
-def build_nel_encoder(tok2vec: Model, nO: Optional[int] = None) -> Model:
+def build_nel_encoder(
+ tok2vec: Model, nO: Optional[int] = None
+) -> Model[List[Doc], Floats2d]:
with Model.define_operators({">>": chain, "**": clone}):
token_width = tok2vec.maybe_get_dim("nO")
output_layer = Linear(nO=nO, nI=token_width)
@@ -18,7 +21,7 @@ def build_nel_encoder(tok2vec: Model, nO: Optional[int] = None) -> Model:
tok2vec
>> list2ragged()
>> reduce_mean()
- >> residual(Maxout(nO=token_width, nI=token_width, nP=2, dropout=0.0))
+ >> residual(Maxout(nO=token_width, nI=token_width, nP=2, dropout=0.0)) # type: ignore[arg-type]
>> output_layer
)
model.set_ref("output_layer", output_layer)
diff --git a/spacy/ml/models/multi_task.py b/spacy/ml/models/multi_task.py
index 97bef2d0e..37473b7f4 100644
--- a/spacy/ml/models/multi_task.py
+++ b/spacy/ml/models/multi_task.py
@@ -1,7 +1,9 @@
-from typing import Optional, Iterable, Tuple, List, Callable, TYPE_CHECKING
+from typing import Any, Optional, Iterable, Tuple, List, Callable, TYPE_CHECKING, cast
+from thinc.types import Floats2d
from thinc.api import chain, Maxout, LayerNorm, Softmax, Linear, zero_init, Model
from thinc.api import MultiSoftmax, list2array
from thinc.api import to_categorical, CosineDistance, L2Distance
+from thinc.loss import Loss
from ...util import registry, OOV_RANK
from ...errors import Errors
@@ -30,6 +32,7 @@ def create_pretrain_vectors(
return model
def create_vectors_loss() -> Callable:
+ distance: Loss
if loss == "cosine":
distance = CosineDistance(normalize=True, ignore_zeros=True)
return partial(get_vectors_loss, distance=distance)
@@ -115,7 +118,7 @@ def build_cloze_multi_task_model(
) -> Model:
nO = vocab.vectors.data.shape[1]
output_layer = chain(
- list2array(),
+ cast(Model[List["Floats2d"], Floats2d], list2array()),
Maxout(
nO=hidden_size,
nI=tok2vec.get_dim("nO"),
@@ -136,10 +139,10 @@ def build_cloze_characters_multi_task_model(
vocab: "Vocab", tok2vec: Model, maxout_pieces: int, hidden_size: int, nr_char: int
) -> Model:
output_layer = chain(
- list2array(),
+ cast(Model[List["Floats2d"], Floats2d], list2array()),
Maxout(nO=hidden_size, nP=maxout_pieces),
LayerNorm(nI=hidden_size),
- MultiSoftmax([256] * nr_char, nI=hidden_size),
+ MultiSoftmax([256] * nr_char, nI=hidden_size), # type: ignore[arg-type]
)
model = build_masked_language_model(vocab, chain(tok2vec, output_layer))
model.set_ref("tok2vec", tok2vec)
@@ -171,7 +174,7 @@ def build_masked_language_model(
if wrapped.has_dim(dim):
model.set_dim(dim, wrapped.get_dim(dim))
- mlm_model = Model(
+ mlm_model: Model = Model(
"masked-language-model",
mlm_forward,
layers=[wrapped_model],
@@ -185,13 +188,19 @@ def build_masked_language_model(
class _RandomWords:
def __init__(self, vocab: "Vocab") -> None:
+ # Extract lexeme representations
self.words = [lex.text for lex in vocab if lex.prob != 0.0]
- self.probs = [lex.prob for lex in vocab if lex.prob != 0.0]
self.words = self.words[:10000]
- self.probs = self.probs[:10000]
- self.probs = numpy.exp(numpy.array(self.probs, dtype="f"))
- self.probs /= self.probs.sum()
- self._cache = []
+
+ # Compute normalized lexeme probabilities
+ probs = [lex.prob for lex in vocab if lex.prob != 0.0]
+ probs = probs[:10000]
+ probs: numpy.ndarray = numpy.exp(numpy.array(probs, dtype="f"))
+ probs /= probs.sum()
+ self.probs = probs
+
+ # Initialize cache
+ self._cache: List[int] = []
def next(self) -> str:
if not self._cache:
diff --git a/spacy/ml/models/parser.py b/spacy/ml/models/parser.py
index 97137313d..aaa6b6b81 100644
--- a/spacy/ml/models/parser.py
+++ b/spacy/ml/models/parser.py
@@ -1,4 +1,4 @@
-from typing import Optional, List
+from typing import Optional, List, cast
from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops
from thinc.types import Floats2d
@@ -70,7 +70,11 @@ def build_tb_parser_model(
else:
raise ValueError(Errors.E917.format(value=state_type))
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
- tok2vec = chain(tok2vec, list2array(), Linear(hidden_width, t2v_width))
+ tok2vec = chain(
+ tok2vec,
+ cast(Model[List["Floats2d"], Floats2d], list2array()),
+ Linear(hidden_width, t2v_width),
+ )
tok2vec.set_dim("nO", hidden_width)
lower = _define_lower(
nO=hidden_width if use_upper else nO,
diff --git a/spacy/ml/models/spancat.py b/spacy/ml/models/spancat.py
index 5c49fef40..893db2e6d 100644
--- a/spacy/ml/models/spancat.py
+++ b/spacy/ml/models/spancat.py
@@ -1,4 +1,4 @@
-from typing import List, Tuple
+from typing import List, Tuple, cast
from thinc.api import Model, with_getitem, chain, list2ragged, Logistic
from thinc.api import Maxout, Linear, concatenate, glorot_uniform_init
from thinc.api import reduce_mean, reduce_max, reduce_first, reduce_last
@@ -9,7 +9,7 @@ from ...tokens import Doc
from ..extract_spans import extract_spans
-@registry.layers.register("spacy.LinearLogistic.v1")
+@registry.layers("spacy.LinearLogistic.v1")
def build_linear_logistic(nO=None, nI=None) -> Model[Floats2d, Floats2d]:
"""An output layer for multi-label classification. It uses a linear layer
followed by a logistic activation.
@@ -17,18 +17,23 @@ def build_linear_logistic(nO=None, nI=None) -> Model[Floats2d, Floats2d]:
return chain(Linear(nO=nO, nI=nI, init_W=glorot_uniform_init), Logistic())
-@registry.layers.register("spacy.mean_max_reducer.v1")
+@registry.layers("spacy.mean_max_reducer.v1")
def build_mean_max_reducer(hidden_size: int) -> Model[Ragged, Floats2d]:
"""Reduce sequences by concatenating their mean and max pooled vectors,
and then combine the concatenated vectors with a hidden layer.
"""
return chain(
- concatenate(reduce_last(), reduce_first(), reduce_mean(), reduce_max()),
+ concatenate(
+ cast(Model[Ragged, Floats2d], reduce_last()),
+ cast(Model[Ragged, Floats2d], reduce_first()),
+ reduce_mean(),
+ reduce_max(),
+ ),
Maxout(nO=hidden_size, normalize=True, dropout=0.0),
)
-@registry.architectures.register("spacy.SpanCategorizer.v1")
+@registry.architectures("spacy.SpanCategorizer.v1")
def build_spancat_model(
tok2vec: Model[List[Doc], List[Floats2d]],
reducer: Model[Ragged, Floats2d],
@@ -43,7 +48,12 @@ def build_spancat_model(
scorer (Model[Floats2d, Floats2d]): The scorer model.
"""
model = chain(
- with_getitem(0, chain(tok2vec, list2ragged())),
+ cast(
+ Model[Tuple[List[Doc], Ragged], Tuple[Ragged, Ragged]],
+ with_getitem(
+ 0, chain(tok2vec, cast(Model[List[Floats2d], Ragged], list2ragged()))
+ ),
+ ),
extract_spans(),
reducer,
scorer,
diff --git a/spacy/ml/models/tagger.py b/spacy/ml/models/tagger.py
index 87944e305..9c7fe042d 100644
--- a/spacy/ml/models/tagger.py
+++ b/spacy/ml/models/tagger.py
@@ -20,7 +20,7 @@ def build_tagger_model(
# TODO: glorot_uniform_init seems to work a bit better than zero_init here?!
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
output_layer = Softmax(nO, t2v_width, init_W=zero_init)
- softmax = with_array(output_layer)
+ softmax = with_array(output_layer) # type: ignore
model = chain(tok2vec, softmax)
model.set_ref("tok2vec", tok2vec)
model.set_ref("softmax", output_layer)
diff --git a/spacy/ml/models/textcat.py b/spacy/ml/models/textcat.py
index e3f6e944a..c8c146f02 100644
--- a/spacy/ml/models/textcat.py
+++ b/spacy/ml/models/textcat.py
@@ -37,7 +37,7 @@ def build_simple_cnn_text_classifier(
if exclusive_classes:
output_layer = Softmax(nO=nO, nI=nI)
fill_defaults["b"] = NEG_VALUE
- resizable_layer = resizable(
+ resizable_layer: Model = resizable(
output_layer,
resize_layer=partial(
resize_linear_weighted, fill_defaults=fill_defaults
@@ -59,7 +59,7 @@ def build_simple_cnn_text_classifier(
resizable_layer=resizable_layer,
)
model.set_ref("tok2vec", tok2vec)
- model.set_dim("nO", nO)
+ model.set_dim("nO", nO) # type: ignore # TODO: remove type ignore once Thinc has been updated
model.attrs["multi_label"] = not exclusive_classes
return model
@@ -85,7 +85,7 @@ def build_bow_text_classifier(
if not no_output_layer:
fill_defaults["b"] = NEG_VALUE
output_layer = softmax_activation() if exclusive_classes else Logistic()
- resizable_layer = resizable(
+ resizable_layer = resizable( # type: ignore[var-annotated]
sparse_linear,
resize_layer=partial(resize_linear_weighted, fill_defaults=fill_defaults),
)
@@ -93,7 +93,7 @@ def build_bow_text_classifier(
model = with_cpu(model, model.ops)
if output_layer:
model = model >> with_cpu(output_layer, output_layer.ops)
- model.set_dim("nO", nO)
+ model.set_dim("nO", nO) # type: ignore[arg-type]
model.set_ref("output_layer", sparse_linear)
model.attrs["multi_label"] = not exclusive_classes
model.attrs["resize_output"] = partial(
@@ -130,14 +130,14 @@ def build_text_classifier_v2(
model = (linear_model | cnn_model) >> output_layer
model.set_ref("tok2vec", tok2vec)
if model.has_dim("nO") is not False:
- model.set_dim("nO", nO)
+ model.set_dim("nO", nO) # type: ignore[arg-type]
model.set_ref("output_layer", linear_model.get_ref("output_layer"))
model.set_ref("attention_layer", attention_layer)
model.set_ref("maxout_layer", maxout_layer)
model.set_ref("norm_layer", norm_layer)
model.attrs["multi_label"] = not exclusive_classes
- model.init = init_ensemble_textcat
+ model.init = init_ensemble_textcat # type: ignore[assignment]
return model
@@ -164,7 +164,7 @@ def build_text_classifier_lowdata(
>> list2ragged()
>> ParametricAttention(width)
>> reduce_sum()
- >> residual(Relu(width, width)) ** 2
+ >> residual(Relu(width, width)) ** 2 # type: ignore[arg-type]
>> Linear(nO, width)
)
if dropout:
diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py
index 76ec87054..8d78e418f 100644
--- a/spacy/ml/models/tok2vec.py
+++ b/spacy/ml/models/tok2vec.py
@@ -1,5 +1,5 @@
-from typing import Optional, List, Union
-from thinc.types import Floats2d
+from typing import Optional, List, Union, cast
+from thinc.types import Floats2d, Ints2d, Ragged
from thinc.api import chain, clone, concatenate, with_array, with_padded
from thinc.api import Model, noop, list2ragged, ragged2list, HashEmbed
from thinc.api import expand_window, residual, Maxout, Mish, PyTorchLSTM
@@ -158,26 +158,30 @@ def MultiHashEmbed(
embeddings = [make_hash_embed(i) for i in range(len(attrs))]
concat_size = width * (len(embeddings) + include_static_vectors)
+ max_out: Model[Ragged, Ragged] = with_array(
+ Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True) # type: ignore
+ )
if include_static_vectors:
+ feature_extractor: Model[List[Doc], Ragged] = chain(
+ FeatureExtractor(attrs),
+ cast(Model[List[Ints2d], Ragged], list2ragged()),
+ with_array(concatenate(*embeddings)),
+ )
model = chain(
concatenate(
- chain(
- FeatureExtractor(attrs),
- list2ragged(),
- with_array(concatenate(*embeddings)),
- ),
+ feature_extractor,
StaticVectors(width, dropout=0.0),
),
- with_array(Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True)),
- ragged2list(),
+ max_out,
+ cast(Model[Ragged, List[Floats2d]], ragged2list()),
)
else:
model = chain(
FeatureExtractor(list(attrs)),
- list2ragged(),
+ cast(Model[List[Ints2d], Ragged], list2ragged()),
with_array(concatenate(*embeddings)),
- with_array(Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True)),
- ragged2list(),
+ max_out,
+ cast(Model[Ragged, List[Floats2d]], ragged2list()),
)
return model
@@ -220,37 +224,41 @@ def CharacterEmbed(
"""
feature = intify_attr(feature)
if feature is None:
- raise ValueError(Errors.E911(feat=feature))
+ raise ValueError(Errors.E911.format(feat=feature))
+ char_embed = chain(
+ _character_embed.CharacterEmbed(nM=nM, nC=nC),
+ cast(Model[List[Floats2d], Ragged], list2ragged()),
+ )
+ feature_extractor: Model[List[Doc], Ragged] = chain(
+ FeatureExtractor([feature]),
+ cast(Model[List[Ints2d], Ragged], list2ragged()),
+ with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)), # type: ignore
+ )
+ max_out: Model[Ragged, Ragged]
if include_static_vectors:
+ max_out = with_array(
+ Maxout(width, nM * nC + (2 * width), nP=3, normalize=True, dropout=0.0) # type: ignore
+ )
model = chain(
concatenate(
- chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()),
- chain(
- FeatureExtractor([feature]),
- list2ragged(),
- with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)),
- ),
+ char_embed,
+ feature_extractor,
StaticVectors(width, dropout=0.0),
),
- with_array(
- Maxout(width, nM * nC + (2 * width), nP=3, normalize=True, dropout=0.0)
- ),
- ragged2list(),
+ max_out,
+ cast(Model[Ragged, List[Floats2d]], ragged2list()),
)
else:
+ max_out = with_array(
+ Maxout(width, nM * nC + width, nP=3, normalize=True, dropout=0.0) # type: ignore
+ )
model = chain(
concatenate(
- chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()),
- chain(
- FeatureExtractor([feature]),
- list2ragged(),
- with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)),
- ),
+ char_embed,
+ feature_extractor,
),
- with_array(
- Maxout(width, nM * nC + width, nP=3, normalize=True, dropout=0.0)
- ),
- ragged2list(),
+ max_out,
+ cast(Model[Ragged, List[Floats2d]], ragged2list()),
)
return model
@@ -281,10 +289,10 @@ def MaxoutWindowEncoder(
normalize=True,
),
)
- model = clone(residual(cnn), depth)
+ model = clone(residual(cnn), depth) # type: ignore[arg-type]
model.set_dim("nO", width)
receptive_field = window_size * depth
- return with_array(model, pad=receptive_field)
+ return with_array(model, pad=receptive_field) # type: ignore[arg-type]
@registry.architectures("spacy.MishWindowEncoder.v2")
@@ -305,9 +313,9 @@ def MishWindowEncoder(
expand_window(window_size=window_size),
Mish(nO=width, nI=width * ((window_size * 2) + 1), dropout=0.0, normalize=True),
)
- model = clone(residual(cnn), depth)
+ model = clone(residual(cnn), depth) # type: ignore[arg-type]
model.set_dim("nO", width)
- return with_array(model)
+ return with_array(model) # type: ignore[arg-type]
@registry.architectures("spacy.TorchBiLSTMEncoder.v1")
diff --git a/spacy/ml/staticvectors.py b/spacy/ml/staticvectors.py
index 4e7262e7d..bc0189873 100644
--- a/spacy/ml/staticvectors.py
+++ b/spacy/ml/staticvectors.py
@@ -49,7 +49,7 @@ def forward(
# Convert negative indices to 0-vectors (TODO: more options for UNK tokens)
vectors_data[rows < 0] = 0
output = Ragged(
- vectors_data, model.ops.asarray([len(doc) for doc in docs], dtype="i")
+ vectors_data, model.ops.asarray([len(doc) for doc in docs], dtype="i") # type: ignore
)
mask = None
if is_train:
@@ -62,7 +62,9 @@ def forward(
d_output.data *= mask
model.inc_grad(
"W",
- model.ops.gemm(d_output.data, model.ops.as_contig(V[rows]), trans1=True),
+ model.ops.gemm(
+ cast(Floats2d, d_output.data), model.ops.as_contig(V[rows]), trans1=True
+ ),
)
return []
@@ -97,4 +99,8 @@ def _handle_empty(ops: Ops, nO: int):
def _get_drop_mask(ops: Ops, nO: int, rate: Optional[float]) -> Optional[Floats1d]:
- return ops.get_dropout_mask((nO,), rate) if rate is not None else None
+ if rate is not None:
+ mask = ops.get_dropout_mask((nO,), rate)
+ assert isinstance(mask, Floats1d)
+ return mask
+ return None
diff --git a/spacy/pipe_analysis.py b/spacy/pipe_analysis.py
index d0362e7e1..245747061 100644
--- a/spacy/pipe_analysis.py
+++ b/spacy/pipe_analysis.py
@@ -1,4 +1,4 @@
-from typing import List, Dict, Iterable, Optional, Union, TYPE_CHECKING
+from typing import List, Set, Dict, Iterable, ItemsView, Union, TYPE_CHECKING
from wasabi import msg
from .tokens import Doc, Token, Span
@@ -67,7 +67,7 @@ def get_attr_info(nlp: "Language", attr: str) -> Dict[str, List[str]]:
RETURNS (Dict[str, List[str]]): A dict keyed by "assigns" and "requires",
mapped to a list of component names.
"""
- result = {"assigns": [], "requires": []}
+ result: Dict[str, List[str]] = {"assigns": [], "requires": []}
for pipe_name in nlp.pipe_names:
meta = nlp.get_pipe_meta(pipe_name)
if attr in meta.assigns:
@@ -79,7 +79,7 @@ def get_attr_info(nlp: "Language", attr: str) -> Dict[str, List[str]]:
def analyze_pipes(
nlp: "Language", *, keys: List[str] = DEFAULT_KEYS
-) -> Dict[str, Union[List[str], Dict[str, List[str]]]]:
+) -> Dict[str, Dict[str, Union[List[str], Dict]]]:
"""Print a formatted summary for the current nlp object's pipeline. Shows
a table with the pipeline components and why they assign and require, as
well as any problems if available.
@@ -88,8 +88,11 @@ def analyze_pipes(
keys (List[str]): The meta keys to show in the table.
RETURNS (dict): A dict with "summary" and "problems".
"""
- result = {"summary": {}, "problems": {}}
- all_attrs = set()
+ result: Dict[str, Dict[str, Union[List[str], Dict]]] = {
+ "summary": {},
+ "problems": {},
+ }
+ all_attrs: Set[str] = set()
for i, name in enumerate(nlp.pipe_names):
meta = nlp.get_pipe_meta(name)
all_attrs.update(meta.assigns)
@@ -102,19 +105,18 @@ def analyze_pipes(
prev_meta = nlp.get_pipe_meta(prev_name)
for annot in prev_meta.assigns:
requires[annot] = True
- result["problems"][name] = []
- for annot, fulfilled in requires.items():
- if not fulfilled:
- result["problems"][name].append(annot)
+ result["problems"][name] = [
+ annot for annot, fulfilled in requires.items() if not fulfilled
+ ]
result["attrs"] = {attr: get_attr_info(nlp, attr) for attr in all_attrs}
return result
def print_pipe_analysis(
- analysis: Dict[str, Union[List[str], Dict[str, List[str]]]],
+ analysis: Dict[str, Dict[str, Union[List[str], Dict]]],
*,
keys: List[str] = DEFAULT_KEYS,
-) -> Optional[Dict[str, Union[List[str], Dict[str, List[str]]]]]:
+) -> None:
"""Print a formatted version of the pipe analysis produced by analyze_pipes.
analysis (Dict[str, Union[List[str], Dict[str, List[str]]]]): The analysis.
@@ -122,7 +124,7 @@ def print_pipe_analysis(
"""
msg.divider("Pipeline Overview")
header = ["#", "Component", *[key.capitalize() for key in keys]]
- summary = analysis["summary"].items()
+ summary: ItemsView = analysis["summary"].items()
body = [[i, n, *[v for v in m.values()]] for i, (n, m) in enumerate(summary)]
msg.table(body, header=header, divider=True, multiline=True)
n_problems = sum(len(p) for p in analysis["problems"].values())
diff --git a/spacy/pipeline/attributeruler.py b/spacy/pipeline/attributeruler.py
index f95a5a48c..331eaa4d8 100644
--- a/spacy/pipeline/attributeruler.py
+++ b/spacy/pipeline/attributeruler.py
@@ -54,9 +54,9 @@ class AttributeRuler(Pipe):
self.vocab = vocab
self.matcher = Matcher(self.vocab, validate=validate)
self.validate = validate
- self.attrs = []
- self._attrs_unnormed = [] # store for reference
- self.indices = []
+ self.attrs: List[Dict] = []
+ self._attrs_unnormed: List[Dict] = [] # store for reference
+ self.indices: List[int] = []
def clear(self) -> None:
"""Reset all patterns."""
@@ -102,13 +102,13 @@ class AttributeRuler(Pipe):
self.set_annotations(doc, matches)
return doc
except Exception as e:
- error_handler(self.name, self, [doc], e)
+ return error_handler(self.name, self, [doc], e)
def match(self, doc: Doc):
- matches = self.matcher(doc, allow_missing=True)
+ matches = self.matcher(doc, allow_missing=True, as_spans=False)
# Sort by the attribute ID, so that later rules have precedence
matches = [
- (int(self.vocab.strings[m_id]), m_id, s, e) for m_id, s, e in matches
+ (int(self.vocab.strings[m_id]), m_id, s, e) for m_id, s, e in matches # type: ignore
]
matches.sort()
return matches
@@ -154,7 +154,7 @@ class AttributeRuler(Pipe):
else:
morph = self.vocab.morphology.add(attrs["MORPH"])
attrs["MORPH"] = self.vocab.strings[morph]
- self.add([pattern], attrs)
+ self.add([pattern], attrs) # type: ignore[list-item]
def load_from_morph_rules(
self, morph_rules: Dict[str, Dict[str, Dict[Union[int, str], Union[int, str]]]]
@@ -178,7 +178,7 @@ class AttributeRuler(Pipe):
elif morph_attrs:
morph = self.vocab.morphology.add(morph_attrs)
attrs["MORPH"] = self.vocab.strings[morph]
- self.add([pattern], attrs)
+ self.add([pattern], attrs) # type: ignore[list-item]
def add(
self, patterns: Iterable[MatcherPatternType], attrs: Dict, index: int = 0
@@ -198,7 +198,7 @@ class AttributeRuler(Pipe):
# We need to make a string here, because otherwise the ID we pass back
# will be interpreted as the hash of a string, rather than an ordinal.
key = str(len(self.attrs))
- self.matcher.add(self.vocab.strings.add(key), patterns)
+ self.matcher.add(self.vocab.strings.add(key), patterns) # type: ignore[arg-type]
self._attrs_unnormed.append(attrs)
attrs = normalize_token_attrs(self.vocab, attrs)
self.attrs.append(attrs)
@@ -214,7 +214,7 @@ class AttributeRuler(Pipe):
DOCS: https://spacy.io/api/attributeruler#add_patterns
"""
for p in patterns:
- self.add(**p)
+ self.add(**p) # type: ignore[arg-type]
@property
def patterns(self) -> List[AttributeRulerPatternType]:
@@ -223,10 +223,10 @@ class AttributeRuler(Pipe):
for i in range(len(self.attrs)):
p = {}
p["patterns"] = self.matcher.get(str(i))[1]
- p["attrs"] = self._attrs_unnormed[i]
- p["index"] = self.indices[i]
+ p["attrs"] = self._attrs_unnormed[i] # type: ignore
+ p["index"] = self.indices[i] # type: ignore
all_patterns.append(p)
- return all_patterns
+ return all_patterns # type: ignore[return-value]
def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
"""Score a batch of examples.
@@ -244,7 +244,7 @@ class AttributeRuler(Pipe):
validate_examples(examples, "AttributeRuler.score")
results = {}
- attrs = set()
+ attrs = set() # type: ignore
for token_attrs in self.attrs:
attrs.update(token_attrs)
for attr in attrs:
diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py
index 7b52025bc..4a0902444 100644
--- a/spacy/pipeline/entity_linker.py
+++ b/spacy/pipeline/entity_linker.py
@@ -1,4 +1,5 @@
-from typing import Optional, Iterable, Callable, Dict, Union, List
+from typing import Optional, Iterable, Callable, Dict, Union, List, Any
+from thinc.types import Floats2d
from pathlib import Path
from itertools import islice
import srsly
@@ -140,7 +141,7 @@ class EntityLinker(TrainablePipe):
self.incl_prior = incl_prior
self.incl_context = incl_context
self.get_candidates = get_candidates
- self.cfg = {}
+ self.cfg: Dict[str, Any] = {}
self.distance = CosineDistance(normalize=False)
# how many neighbour sentences to take into account
# create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'.
@@ -166,7 +167,7 @@ class EntityLinker(TrainablePipe):
get_examples: Callable[[], Iterable[Example]],
*,
nlp: Optional[Language] = None,
- kb_loader: Callable[[Vocab], KnowledgeBase] = None,
+ kb_loader: Optional[Callable[[Vocab], KnowledgeBase]] = None,
):
"""Initialize the pipe for training, using a representative set
of data examples.
@@ -261,7 +262,7 @@ class EntityLinker(TrainablePipe):
losses[self.name] += loss
return losses
- def get_loss(self, examples: Iterable[Example], sentence_encodings):
+ def get_loss(self, examples: Iterable[Example], sentence_encodings: Floats2d):
validate_examples(examples, "EntityLinker.get_loss")
entity_encodings = []
for eg in examples:
@@ -277,8 +278,9 @@ class EntityLinker(TrainablePipe):
method="get_loss", msg="gold entities do not match up"
)
raise RuntimeError(err)
- gradients = self.distance.get_grad(sentence_encodings, entity_encodings)
- loss = self.distance.get_loss(sentence_encodings, entity_encodings)
+ # TODO: fix typing issue here
+ gradients = self.distance.get_grad(sentence_encodings, entity_encodings) # type: ignore
+ loss = self.distance.get_loss(sentence_encodings, entity_encodings) # type: ignore
loss = loss / len(entity_encodings)
return float(loss), gradients
@@ -288,13 +290,13 @@ class EntityLinker(TrainablePipe):
no prediction.
docs (Iterable[Doc]): The documents to predict.
- RETURNS (List[int]): The models prediction for each document.
+ RETURNS (List[str]): The models prediction for each document.
DOCS: https://spacy.io/api/entitylinker#predict
"""
self.validate_kb()
entity_count = 0
- final_kb_ids = []
+ final_kb_ids: List[str] = []
if not docs:
return final_kb_ids
if isinstance(docs, Doc):
@@ -324,7 +326,7 @@ class EntityLinker(TrainablePipe):
# ignoring this entity - setting to NIL
final_kb_ids.append(self.NIL)
else:
- candidates = self.get_candidates(self.kb, ent)
+ candidates = list(self.get_candidates(self.kb, ent))
if not candidates:
# no prediction possible for this entity - setting to NIL
final_kb_ids.append(self.NIL)
@@ -478,7 +480,7 @@ class EntityLinker(TrainablePipe):
except AttributeError:
raise ValueError(Errors.E149) from None
- deserialize = {}
+ deserialize: Dict[str, Callable[[Any], Any]] = {}
deserialize["cfg"] = lambda p: self.cfg.update(deserialize_config(p))
deserialize["vocab"] = lambda p: self.vocab.from_disk(p, exclude=exclude)
deserialize["kb"] = lambda p: self.kb.from_disk(p)
diff --git a/spacy/pipeline/entityruler.py b/spacy/pipeline/entityruler.py
index 1dea8fba0..b8f32b4d3 100644
--- a/spacy/pipeline/entityruler.py
+++ b/spacy/pipeline/entityruler.py
@@ -1,5 +1,6 @@
import warnings
from typing import Optional, Union, List, Dict, Tuple, Iterable, Any, Callable, Sequence
+from typing import cast
from collections import defaultdict
from pathlib import Path
import srsly
@@ -100,8 +101,8 @@ class EntityRuler(Pipe):
self.nlp = nlp
self.name = name
self.overwrite = overwrite_ents
- self.token_patterns = defaultdict(list)
- self.phrase_patterns = defaultdict(list)
+ self.token_patterns = defaultdict(list) # type: ignore
+ self.phrase_patterns = defaultdict(list) # type: ignore
self._validate = validate
self.matcher = Matcher(nlp.vocab, validate=validate)
self.phrase_matcher_attr = phrase_matcher_attr
@@ -109,7 +110,7 @@ class EntityRuler(Pipe):
nlp.vocab, attr=self.phrase_matcher_attr, validate=validate
)
self.ent_id_sep = ent_id_sep
- self._ent_ids = defaultdict(dict)
+ self._ent_ids = defaultdict(tuple) # type: ignore
if patterns is not None:
self.add_patterns(patterns)
@@ -137,19 +138,22 @@ class EntityRuler(Pipe):
self.set_annotations(doc, matches)
return doc
except Exception as e:
- error_handler(self.name, self, [doc], e)
+ return error_handler(self.name, self, [doc], e)
def match(self, doc: Doc):
self._require_patterns()
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="\\[W036")
- matches = list(self.matcher(doc)) + list(self.phrase_matcher(doc))
- matches = set(
+ matches = cast(
+ List[Tuple[int, int, int]],
+ list(self.matcher(doc)) + list(self.phrase_matcher(doc)),
+ )
+ final_matches = set(
[(m_id, start, end) for m_id, start, end in matches if start != end]
)
get_sort_key = lambda m: (m[2] - m[1], -m[1])
- matches = sorted(matches, key=get_sort_key, reverse=True)
- return matches
+ final_matches = sorted(final_matches, key=get_sort_key, reverse=True)
+ return final_matches
def set_annotations(self, doc, matches):
"""Modify the document in place"""
@@ -214,10 +218,10 @@ class EntityRuler(Pipe):
"""
self.clear()
if patterns:
- self.add_patterns(patterns)
+ self.add_patterns(patterns) # type: ignore[arg-type]
@property
- def ent_ids(self) -> Tuple[str, ...]:
+ def ent_ids(self) -> Tuple[Optional[str], ...]:
"""All entity ids present in the match patterns `id` properties
RETURNS (set): The string entity ids.
@@ -302,17 +306,17 @@ class EntityRuler(Pipe):
if ent_id:
phrase_pattern["id"] = ent_id
phrase_patterns.append(phrase_pattern)
- for entry in token_patterns + phrase_patterns:
+ for entry in token_patterns + phrase_patterns: # type: ignore[operator]
label = entry["label"]
if "id" in entry:
ent_label = label
label = self._create_label(label, entry["id"])
key = self.matcher._normalize_key(label)
self._ent_ids[key] = (ent_label, entry["id"])
- pattern = entry["pattern"]
+ pattern = entry["pattern"] # type: ignore
if isinstance(pattern, Doc):
self.phrase_patterns[label].append(pattern)
- self.phrase_matcher.add(label, [pattern])
+ self.phrase_matcher.add(label, [pattern]) # type: ignore
elif isinstance(pattern, list):
self.token_patterns[label].append(pattern)
self.matcher.add(label, [pattern])
@@ -323,7 +327,7 @@ class EntityRuler(Pipe):
"""Reset all patterns."""
self.token_patterns = defaultdict(list)
self.phrase_patterns = defaultdict(list)
- self._ent_ids = defaultdict(dict)
+ self._ent_ids = defaultdict(tuple)
self.matcher = Matcher(self.nlp.vocab, validate=self._validate)
self.phrase_matcher = PhraseMatcher(
self.nlp.vocab, attr=self.phrase_matcher_attr, validate=self._validate
@@ -334,7 +338,7 @@ class EntityRuler(Pipe):
if len(self) == 0:
warnings.warn(Warnings.W036.format(name=self.name))
- def _split_label(self, label: str) -> Tuple[str, str]:
+ def _split_label(self, label: str) -> Tuple[str, Optional[str]]:
"""Split Entity label into ent_label and ent_id if it contains self.ent_id_sep
label (str): The value of label in a pattern entry
@@ -344,11 +348,12 @@ class EntityRuler(Pipe):
ent_label, ent_id = label.rsplit(self.ent_id_sep, 1)
else:
ent_label = label
- ent_id = None
+ ent_id = None # type: ignore
return ent_label, ent_id
- def _create_label(self, label: str, ent_id: str) -> str:
+ def _create_label(self, label: Any, ent_id: Any) -> str:
"""Join Entity label with ent_id if the pattern has an `id` attribute
+ If ent_id is not a string, the label is returned as is.
label (str): The label to set for ent.label_
ent_id (str): The label
diff --git a/spacy/pipeline/functions.py b/spacy/pipeline/functions.py
index 03c7db422..f0a75dc2c 100644
--- a/spacy/pipeline/functions.py
+++ b/spacy/pipeline/functions.py
@@ -25,7 +25,7 @@ def merge_noun_chunks(doc: Doc) -> Doc:
with doc.retokenize() as retokenizer:
for np in doc.noun_chunks:
attrs = {"tag": np.root.tag, "dep": np.root.dep}
- retokenizer.merge(np, attrs=attrs)
+ retokenizer.merge(np, attrs=attrs) # type: ignore[arg-type]
return doc
@@ -45,7 +45,7 @@ def merge_entities(doc: Doc):
with doc.retokenize() as retokenizer:
for ent in doc.ents:
attrs = {"tag": ent.root.tag, "dep": ent.root.dep, "ent_type": ent.label}
- retokenizer.merge(ent, attrs=attrs)
+ retokenizer.merge(ent, attrs=attrs) # type: ignore[arg-type]
return doc
@@ -63,7 +63,7 @@ def merge_subtokens(doc: Doc, label: str = "subtok") -> Doc:
merger = Matcher(doc.vocab)
merger.add("SUBTOK", [[{"DEP": label, "op": "+"}]])
matches = merger(doc)
- spans = util.filter_spans([doc[start : end + 1] for _, start, end in matches])
+ spans = util.filter_spans([doc[start : end + 1] for _, start, end in matches]) # type: ignore[misc, operator]
with doc.retokenize() as retokenizer:
for span in spans:
retokenizer.merge(span)
@@ -93,11 +93,11 @@ class TokenSplitter:
if len(t.text) >= self.min_length:
orths = []
heads = []
- attrs = {}
+ attrs = {} # type: ignore[var-annotated]
for i in range(0, len(t.text), self.split_length):
orths.append(t.text[i : i + self.split_length])
heads.append((t, i / self.split_length))
- retokenizer.split(t, orths, heads, attrs)
+ retokenizer.split(t, orths, heads, attrs) # type: ignore[arg-type]
return doc
def _get_config(self) -> Dict[str, Any]:
diff --git a/spacy/pipeline/lemmatizer.py b/spacy/pipeline/lemmatizer.py
index b2338724d..ad227d240 100644
--- a/spacy/pipeline/lemmatizer.py
+++ b/spacy/pipeline/lemmatizer.py
@@ -88,7 +88,7 @@ class Lemmatizer(Pipe):
if not hasattr(self, mode_attr):
raise ValueError(Errors.E1003.format(mode=mode))
self.lemmatize = getattr(self, mode_attr)
- self.cache = {}
+ self.cache = {} # type: ignore[var-annotated]
@property
def mode(self):
@@ -177,7 +177,7 @@ class Lemmatizer(Pipe):
DOCS: https://spacy.io/api/lemmatizer#rule_lemmatize
"""
- cache_key = (token.orth, token.pos, token.morph.key)
+ cache_key = (token.orth, token.pos, token.morph.key) # type: ignore[attr-defined]
if cache_key in self.cache:
return self.cache[cache_key]
string = token.text
@@ -284,7 +284,7 @@ class Lemmatizer(Pipe):
DOCS: https://spacy.io/api/lemmatizer#from_disk
"""
- deserialize = {}
+ deserialize: Dict[str, Callable[[Any], Any]] = {}
deserialize["vocab"] = lambda p: self.vocab.from_disk(p, exclude=exclude)
deserialize["lookups"] = lambda p: self.lookups.from_disk(p)
util.from_disk(path, deserialize, exclude)
@@ -315,7 +315,7 @@ class Lemmatizer(Pipe):
DOCS: https://spacy.io/api/lemmatizer#from_bytes
"""
- deserialize = {}
+ deserialize: Dict[str, Callable[[Any], Any]] = {}
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b, exclude=exclude)
deserialize["lookups"] = lambda b: self.lookups.from_bytes(b)
util.from_bytes(bytes_data, deserialize, exclude)
diff --git a/spacy/pipeline/pipe.pyi b/spacy/pipeline/pipe.pyi
new file mode 100644
index 000000000..a8d4e40ce
--- /dev/null
+++ b/spacy/pipeline/pipe.pyi
@@ -0,0 +1,32 @@
+from pathlib import Path
+from typing import Any, Callable, Dict, Iterable, Iterator, List, NoReturn, Optional, Tuple, Union
+
+from ..tokens.doc import Doc
+
+from ..training import Example
+from ..language import Language
+
+class Pipe:
+ def __call__(self, doc: Doc) -> Doc: ...
+ def pipe(self, stream: Iterable[Doc], *, batch_size: int = ...) -> Iterator[Doc]: ...
+ def initialize(
+ self, get_examples: Callable[[], Iterable[Example]], *, nlp: Language = ...,
+ ) -> None: ...
+ def score(
+ self, examples: Iterable[Example], **kwargs: Any
+ ) -> Dict[str, Union[float, Dict[str, float]]]: ...
+ @property
+ def is_trainable(self) -> bool: ...
+ @property
+ def labels(self) -> Tuple[str, ...]: ...
+ @property
+ def label_data(self) -> Any: ...
+ def _require_labels(self) -> None: ...
+ def set_error_handler(
+ self, error_handler: Callable[[str, "Pipe", List[Doc], Exception], NoReturn]
+ ) -> None: ...
+ def get_error_handler(
+ self
+ ) -> Callable[[str, "Pipe", List[Doc], Exception], NoReturn]: ...
+
+def deserialize_config(path: Path) -> Any: ...
diff --git a/spacy/pipeline/pipe.pyx b/spacy/pipeline/pipe.pyx
index 0d298ce4f..4372645af 100644
--- a/spacy/pipeline/pipe.pyx
+++ b/spacy/pipeline/pipe.pyx
@@ -88,7 +88,7 @@ cdef class Pipe:
return False
@property
- def labels(self) -> Optional[Tuple[str]]:
+ def labels(self) -> Tuple[str, ...]:
return tuple()
@property
@@ -115,7 +115,7 @@ cdef class Pipe:
"""
self.error_handler = error_handler
- def get_error_handler(self) -> Optional[Callable]:
+ def get_error_handler(self) -> Callable:
"""Retrieve the error handler function.
RETURNS (Callable): The error handler, or if it's not set a default function that just reraises.
diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py
index 052bd2874..4e9a82423 100644
--- a/spacy/pipeline/spancat.py
+++ b/spacy/pipeline/spancat.py
@@ -1,8 +1,9 @@
import numpy
-from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any
+from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast
+from typing_extensions import Protocol, runtime_checkable
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
from thinc.api import Optimizer
-from thinc.types import Ragged, Ints2d, Floats2d
+from thinc.types import Ragged, Ints2d, Floats2d, Ints1d
from ..scorer import Scorer
from ..language import Language
@@ -44,13 +45,19 @@ depth = 4
DEFAULT_SPANCAT_MODEL = Config().from_str(spancat_default_config)["model"]
+@runtime_checkable
+class Suggester(Protocol):
+ def __call__(self, docs: Iterable[Doc], *, ops: Optional[Ops] = None) -> Ragged:
+ ...
+
+
@registry.misc("spacy.ngram_suggester.v1")
-def build_ngram_suggester(sizes: List[int]) -> Callable[[List[Doc]], Ragged]:
+def build_ngram_suggester(sizes: List[int]) -> Suggester:
"""Suggest all spans of the given lengths. Spans are returned as a ragged
array of integers. The array has two columns, indicating the start and end
position."""
- def ngram_suggester(docs: List[Doc], *, ops: Optional[Ops] = None) -> Ragged:
+ def ngram_suggester(docs: Iterable[Doc], *, ops: Optional[Ops] = None) -> Ragged:
if ops is None:
ops = get_current_ops()
spans = []
@@ -67,10 +74,11 @@ def build_ngram_suggester(sizes: List[int]) -> Callable[[List[Doc]], Ragged]:
if spans:
assert spans[-1].ndim == 2, spans[-1].shape
lengths.append(length)
+ lengths_array = cast(Ints1d, ops.asarray(lengths, dtype="i"))
if len(spans) > 0:
- output = Ragged(ops.xp.vstack(spans), ops.asarray(lengths, dtype="i"))
+ output = Ragged(ops.xp.vstack(spans), lengths_array)
else:
- output = Ragged(ops.xp.zeros((0, 0)), ops.asarray(lengths, dtype="i"))
+ output = Ragged(ops.xp.zeros((0, 0)), lengths_array)
assert output.dataXd.ndim == 2
return output
@@ -79,13 +87,11 @@ def build_ngram_suggester(sizes: List[int]) -> Callable[[List[Doc]], Ragged]:
@registry.misc("spacy.ngram_range_suggester.v1")
-def build_ngram_range_suggester(
- min_size: int, max_size: int
-) -> Callable[[List[Doc]], Ragged]:
+def build_ngram_range_suggester(min_size: int, max_size: int) -> Suggester:
"""Suggest all spans of the given lengths between a given min and max value - both inclusive.
Spans are returned as a ragged array of integers. The array has two columns,
indicating the start and end position."""
- sizes = range(min_size, max_size + 1)
+ sizes = list(range(min_size, max_size + 1))
return build_ngram_suggester(sizes)
@@ -104,7 +110,7 @@ def build_ngram_range_suggester(
def make_spancat(
nlp: Language,
name: str,
- suggester: Callable[[List[Doc]], Ragged],
+ suggester: Suggester,
model: Model[Tuple[List[Doc], Ragged], Floats2d],
spans_key: str,
threshold: float = 0.5,
@@ -114,7 +120,7 @@ def make_spancat(
parts: a suggester function that proposes candidate spans, and a labeller
model that predicts one or more labels for each span.
- suggester (Callable[List[Doc], Ragged]): A function that suggests spans.
+ suggester (Callable[[Iterable[Doc], Optional[Ops]], Ragged]): A function that suggests spans.
Spans are returned as a ragged array with two integer columns, for the
start and end positions.
model (Model[Tuple[List[Doc], Ragged], Floats2d]): A model instance that
@@ -151,7 +157,7 @@ class SpanCategorizer(TrainablePipe):
self,
vocab: Vocab,
model: Model[Tuple[List[Doc], Ragged], Floats2d],
- suggester: Callable[[List[Doc]], Ragged],
+ suggester: Suggester,
name: str = "spancat",
*,
spans_key: str = "spans",
@@ -179,7 +185,7 @@ class SpanCategorizer(TrainablePipe):
initialization and training, the component will look for spans on the
reference document under the same key.
"""
- return self.cfg["spans_key"]
+ return str(self.cfg["spans_key"])
def add_label(self, label: str) -> int:
"""Add a new label to the pipe.
@@ -194,7 +200,7 @@ class SpanCategorizer(TrainablePipe):
if label in self.labels:
return 0
self._allow_extra_label()
- self.cfg["labels"].append(label)
+ self.cfg["labels"].append(label) # type: ignore
self.vocab.strings.add(label)
return 1
@@ -204,7 +210,7 @@ class SpanCategorizer(TrainablePipe):
DOCS: https://spacy.io/api/spancategorizer#labels
"""
- return tuple(self.cfg["labels"])
+ return tuple(self.cfg["labels"]) # type: ignore
@property
def label_data(self) -> List[str]:
@@ -223,8 +229,8 @@ class SpanCategorizer(TrainablePipe):
DOCS: https://spacy.io/api/spancategorizer#predict
"""
indices = self.suggester(docs, ops=self.model.ops)
- scores = self.model.predict((docs, indices))
- return (indices, scores)
+ scores = self.model.predict((docs, indices)) # type: ignore
+ return indices, scores
def set_annotations(self, docs: Iterable[Doc], indices_scores) -> None:
"""Modify a batch of Doc objects, using pre-computed scores.
@@ -240,7 +246,7 @@ class SpanCategorizer(TrainablePipe):
for i, doc in enumerate(docs):
indices_i = indices[i].dataXd
doc.spans[self.key] = self._make_span_group(
- doc, indices_i, scores[offset : offset + indices.lengths[i]], labels
+ doc, indices_i, scores[offset : offset + indices.lengths[i]], labels # type: ignore[arg-type]
)
offset += indices.lengths[i]
@@ -279,14 +285,14 @@ class SpanCategorizer(TrainablePipe):
set_dropout_rate(self.model, drop)
scores, backprop_scores = self.model.begin_update((docs, spans))
loss, d_scores = self.get_loss(examples, (spans, scores))
- backprop_scores(d_scores)
+ backprop_scores(d_scores) # type: ignore
if sgd is not None:
self.finish_update(sgd)
losses[self.name] += loss
return losses
def get_loss(
- self, examples: Iterable[Example], spans_scores: Tuple[Ragged, Ragged]
+ self, examples: Iterable[Example], spans_scores: Tuple[Ragged, Floats2d]
) -> Tuple[float, float]:
"""Find the loss and gradient of loss for the batch of documents and
their predicted scores.
@@ -311,8 +317,8 @@ class SpanCategorizer(TrainablePipe):
spans_index = {}
spans_i = spans[i].dataXd
for j in range(spans.lengths[i]):
- start = int(spans_i[j, 0])
- end = int(spans_i[j, 1])
+ start = int(spans_i[j, 0]) # type: ignore
+ end = int(spans_i[j, 1]) # type: ignore
spans_index[(start, end)] = offset + j
for gold_span in self._get_aligned_spans(eg):
key = (gold_span.start, gold_span.end)
@@ -323,7 +329,7 @@ class SpanCategorizer(TrainablePipe):
# The target is a flat array for all docs. Track the position
# we're at within the flat array.
offset += spans.lengths[i]
- target = self.model.ops.asarray(target, dtype="f")
+ target = self.model.ops.asarray(target, dtype="f") # type: ignore
# The target will have the values 0 (for untrue predictions) or 1
# (for true predictions).
# The scores should be in the range [0, 1].
@@ -339,7 +345,7 @@ class SpanCategorizer(TrainablePipe):
self,
get_examples: Callable[[], Iterable[Example]],
*,
- nlp: Language = None,
+ nlp: Optional[Language] = None,
labels: Optional[List[str]] = None,
) -> None:
"""Initialize the pipe for training, using a representative set
@@ -347,14 +353,14 @@ class SpanCategorizer(TrainablePipe):
get_examples (Callable[[], Iterable[Example]]): Function that
returns a representative sample of gold-standard Example objects.
- nlp (Language): The current nlp object the component is part of.
- labels: The labels to add to the component, typically generated by the
+ nlp (Optional[Language]): The current nlp object the component is part of.
+ labels (Optional[List[str]]): The labels to add to the component, typically generated by the
`init labels` command. If no labels are provided, the get_examples
callback is used to extract the labels from the data.
DOCS: https://spacy.io/api/spancategorizer#initialize
"""
- subbatch = []
+ subbatch: List[Example] = []
if labels is not None:
for label in labels:
self.add_label(label)
@@ -393,7 +399,7 @@ class SpanCategorizer(TrainablePipe):
kwargs.setdefault("has_annotation", lambda doc: self.key in doc.spans)
return Scorer.score_spans(examples, **kwargs)
- def _validate_categories(self, examples):
+ def _validate_categories(self, examples: Iterable[Example]):
# TODO
pass
@@ -410,10 +416,11 @@ class SpanCategorizer(TrainablePipe):
threshold = self.cfg["threshold"]
keeps = scores >= threshold
- ranked = (scores * -1).argsort()
+ ranked = (scores * -1).argsort() # type: ignore
if max_positive is not None:
- filter = ranked[:, max_positive:]
- for i, row in enumerate(filter):
+ assert isinstance(max_positive, int)
+ span_filter = ranked[:, max_positive:]
+ for i, row in enumerate(span_filter):
keeps[i, row] = False
spans.attrs["scores"] = scores[keeps].flatten()
diff --git a/spacy/pipeline/textcat.py b/spacy/pipeline/textcat.py
index 0dde5de82..085b949cc 100644
--- a/spacy/pipeline/textcat.py
+++ b/spacy/pipeline/textcat.py
@@ -131,7 +131,7 @@ class TextCategorizer(TrainablePipe):
DOCS: https://spacy.io/api/textcategorizer#labels
"""
- return tuple(self.cfg["labels"])
+ return tuple(self.cfg["labels"]) # type: ignore[arg-type, return-value]
@property
def label_data(self) -> List[str]:
@@ -139,7 +139,7 @@ class TextCategorizer(TrainablePipe):
DOCS: https://spacy.io/api/textcategorizer#label_data
"""
- return self.labels
+ return self.labels # type: ignore[return-value]
def predict(self, docs: Iterable[Doc]):
"""Apply the pipeline's model to a batch of docs, without modifying them.
@@ -153,7 +153,7 @@ class TextCategorizer(TrainablePipe):
# Handle cases where there are no tokens in any docs.
tensors = [doc.tensor for doc in docs]
xp = get_array_module(tensors)
- scores = xp.zeros((len(docs), len(self.labels)))
+ scores = xp.zeros((len(list(docs)), len(self.labels)))
return scores
scores = self.model.predict(docs)
scores = self.model.ops.asarray(scores)
@@ -230,8 +230,9 @@ class TextCategorizer(TrainablePipe):
DOCS: https://spacy.io/api/textcategorizer#rehearse
"""
- if losses is not None:
- losses.setdefault(self.name, 0.0)
+ if losses is None:
+ losses = {}
+ losses.setdefault(self.name, 0.0)
if self._rehearsal_model is None:
return losses
validate_examples(examples, "TextCategorizer.rehearse")
@@ -247,23 +248,23 @@ class TextCategorizer(TrainablePipe):
bp_scores(gradient)
if sgd is not None:
self.finish_update(sgd)
- if losses is not None:
- losses[self.name] += (gradient ** 2).sum()
+ losses[self.name] += (gradient ** 2).sum()
return losses
def _examples_to_truth(
- self, examples: List[Example]
+ self, examples: Iterable[Example]
) -> Tuple[numpy.ndarray, numpy.ndarray]:
- truths = numpy.zeros((len(examples), len(self.labels)), dtype="f")
- not_missing = numpy.ones((len(examples), len(self.labels)), dtype="f")
+ nr_examples = len(list(examples))
+ truths = numpy.zeros((nr_examples, len(self.labels)), dtype="f")
+ not_missing = numpy.ones((nr_examples, len(self.labels)), dtype="f")
for i, eg in enumerate(examples):
for j, label in enumerate(self.labels):
if label in eg.reference.cats:
truths[i, j] = eg.reference.cats[label]
else:
not_missing[i, j] = 0.0
- truths = self.model.ops.asarray(truths)
- return truths, not_missing
+ truths = self.model.ops.asarray(truths) # type: ignore
+ return truths, not_missing # type: ignore
def get_loss(self, examples: Iterable[Example], scores) -> Tuple[float, float]:
"""Find the loss and gradient of loss for the batch of documents and
@@ -278,7 +279,7 @@ class TextCategorizer(TrainablePipe):
validate_examples(examples, "TextCategorizer.get_loss")
self._validate_categories(examples)
truths, not_missing = self._examples_to_truth(examples)
- not_missing = self.model.ops.asarray(not_missing)
+ not_missing = self.model.ops.asarray(not_missing) # type: ignore
d_scores = (scores - truths) / scores.shape[0]
d_scores *= not_missing
mean_square_error = (d_scores ** 2).sum(axis=1).mean()
@@ -297,11 +298,9 @@ class TextCategorizer(TrainablePipe):
if label in self.labels:
return 0
self._allow_extra_label()
- self.cfg["labels"].append(label)
+ self.cfg["labels"].append(label) # type: ignore[attr-defined]
if self.model and "resize_output" in self.model.attrs:
- self.model = self.model.attrs["resize_output"](
- self.model, len(self.cfg["labels"])
- )
+ self.model = self.model.attrs["resize_output"](self.model, len(self.labels))
self.vocab.strings.add(label)
return 1
@@ -374,7 +373,7 @@ class TextCategorizer(TrainablePipe):
**kwargs,
)
- def _validate_categories(self, examples: List[Example]):
+ def _validate_categories(self, examples: Iterable[Example]):
"""Check whether the provided examples all have single-label cats annotations."""
for ex in examples:
if list(ex.reference.cats.values()).count(1.0) > 1:
diff --git a/spacy/pipeline/textcat_multilabel.py b/spacy/pipeline/textcat_multilabel.py
index ba36881af..65961a38c 100644
--- a/spacy/pipeline/textcat_multilabel.py
+++ b/spacy/pipeline/textcat_multilabel.py
@@ -131,7 +131,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
cfg = {"labels": [], "threshold": threshold}
self.cfg = dict(cfg)
- def initialize(
+ def initialize( # type: ignore[override]
self,
get_examples: Callable[[], Iterable[Example]],
*,
@@ -184,7 +184,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
**kwargs,
)
- def _validate_categories(self, examples: List[Example]):
+ def _validate_categories(self, examples: Iterable[Example]):
"""This component allows any type of single- or multi-label annotations.
This method overwrites the more strict one from 'textcat'."""
pass
diff --git a/spacy/pipeline/tok2vec.py b/spacy/pipeline/tok2vec.py
index 00d9548a4..cb601e5dc 100644
--- a/spacy/pipeline/tok2vec.py
+++ b/spacy/pipeline/tok2vec.py
@@ -1,4 +1,4 @@
-from typing import Sequence, Iterable, Optional, Dict, Callable, List
+from typing import Sequence, Iterable, Optional, Dict, Callable, List, Any
from thinc.api import Model, set_dropout_rate, Optimizer, Config
from itertools import islice
@@ -60,8 +60,8 @@ class Tok2Vec(TrainablePipe):
self.vocab = vocab
self.model = model
self.name = name
- self.listener_map = {}
- self.cfg = {}
+ self.listener_map: Dict[str, List["Tok2VecListener"]] = {}
+ self.cfg: Dict[str, Any] = {}
@property
def listeners(self) -> List["Tok2VecListener"]:
@@ -245,12 +245,12 @@ class Tok2VecListener(Model):
"""
Model.__init__(self, name=self.name, forward=forward, dims={"nO": width})
self.upstream_name = upstream_name
- self._batch_id = None
+ self._batch_id: Optional[int] = None
self._outputs = None
self._backprop = None
@classmethod
- def get_batch_id(cls, inputs: List[Doc]) -> int:
+ def get_batch_id(cls, inputs: Iterable[Doc]) -> int:
"""Calculate a content-sensitive hash of the batch of documents, to check
whether the next batch of documents is unexpected.
"""
diff --git a/spacy/schemas.py b/spacy/schemas.py
index 83623b104..73ddc45b1 100644
--- a/spacy/schemas.py
+++ b/spacy/schemas.py
@@ -44,7 +44,7 @@ def validate(schema: Type[BaseModel], obj: Dict[str, Any]) -> List[str]:
for error in errors:
err_loc = " -> ".join([str(p) for p in error.get("loc", [])])
data[err_loc].append(error.get("msg"))
- return [f"[{loc}] {', '.join(msg)}" for loc, msg in data.items()]
+ return [f"[{loc}] {', '.join(msg)}" for loc, msg in data.items()] # type: ignore[arg-type]
# Initialization
@@ -82,7 +82,7 @@ def get_arg_model(
except ValueError:
# Typically happens if the method is part of a Cython module without
# binding=True. Here we just use an empty model that allows everything.
- return create_model(name, __config__=ArgSchemaConfigExtra)
+ return create_model(name, __config__=ArgSchemaConfigExtra) # type: ignore[arg-type, return-value]
has_variable = False
for param in sig.parameters.values():
if param.name in exclude:
@@ -102,8 +102,8 @@ def get_arg_model(
default = param.default if param.default != param.empty else default_empty
sig_args[param.name] = (annotation, default)
is_strict = strict and not has_variable
- sig_args["__config__"] = ArgSchemaConfig if is_strict else ArgSchemaConfigExtra
- return create_model(name, **sig_args)
+ sig_args["__config__"] = ArgSchemaConfig if is_strict else ArgSchemaConfigExtra # type: ignore[assignment]
+ return create_model(name, **sig_args) # type: ignore[arg-type, return-value]
def validate_init_settings(
@@ -198,10 +198,10 @@ class TokenPatternNumber(BaseModel):
class TokenPatternOperator(str, Enum):
- plus: StrictStr = "+"
- start: StrictStr = "*"
- question: StrictStr = "?"
- exclamation: StrictStr = "!"
+ plus: StrictStr = StrictStr("+")
+ start: StrictStr = StrictStr("*")
+ question: StrictStr = StrictStr("?")
+ exclamation: StrictStr = StrictStr("!")
StringValue = Union[TokenPatternString, StrictStr]
@@ -385,7 +385,7 @@ class ConfigSchemaInit(BaseModel):
class ConfigSchema(BaseModel):
training: ConfigSchemaTraining
nlp: ConfigSchemaNlp
- pretraining: Union[ConfigSchemaPretrain, ConfigSchemaPretrainEmpty] = {}
+ pretraining: Union[ConfigSchemaPretrain, ConfigSchemaPretrainEmpty] = {} # type: ignore[assignment]
components: Dict[str, Dict[str, Any]]
corpora: Dict[str, Reader]
initialize: ConfigSchemaInit
diff --git a/spacy/scorer.py b/spacy/scorer.py
index f4ccb2269..ebab2382d 100644
--- a/spacy/scorer.py
+++ b/spacy/scorer.py
@@ -1,4 +1,5 @@
-from typing import Optional, Iterable, Dict, Set, Any, Callable, TYPE_CHECKING
+from typing import Optional, Iterable, Dict, Set, List, Any, Callable, Tuple
+from typing import TYPE_CHECKING
import numpy as np
from collections import defaultdict
@@ -74,8 +75,8 @@ class ROCAUCScore:
may throw an error."""
def __init__(self) -> None:
- self.golds = []
- self.cands = []
+ self.golds: List[Any] = []
+ self.cands: List[Any] = []
self.saved_score = 0.0
self.saved_score_at_len = 0
@@ -111,9 +112,10 @@ class Scorer:
DOCS: https://spacy.io/api/scorer#init
"""
- self.nlp = nlp
self.cfg = cfg
- if not nlp:
+ if nlp:
+ self.nlp = nlp
+ else:
nlp = get_lang_class(default_lang)()
for pipe in default_pipeline:
nlp.add_pipe(pipe)
@@ -129,7 +131,7 @@ class Scorer:
"""
scores = {}
if hasattr(self.nlp.tokenizer, "score"):
- scores.update(self.nlp.tokenizer.score(examples, **self.cfg))
+ scores.update(self.nlp.tokenizer.score(examples, **self.cfg)) # type: ignore
for name, component in self.nlp.pipeline:
if hasattr(component, "score"):
scores.update(component.score(examples, **self.cfg))
@@ -191,7 +193,7 @@ class Scorer:
attr: str,
*,
getter: Callable[[Token, str], Any] = getattr,
- missing_values: Set[Any] = MISSING_VALUES,
+ missing_values: Set[Any] = MISSING_VALUES, # type: ignore[assignment]
**cfg,
) -> Dict[str, Any]:
"""Returns an accuracy score for a token-level attribute.
@@ -201,6 +203,8 @@ class Scorer:
getter (Callable[[Token, str], Any]): Defaults to getattr. If provided,
getter(token, attr) should return the value of the attribute for an
individual token.
+ missing_values (Set[Any]): Attribute values to treat as missing annotation
+ in the reference annotation.
RETURNS (Dict[str, Any]): A dictionary containing the accuracy score
under the key attr_acc.
@@ -240,7 +244,7 @@ class Scorer:
attr: str,
*,
getter: Callable[[Token, str], Any] = getattr,
- missing_values: Set[Any] = MISSING_VALUES,
+ missing_values: Set[Any] = MISSING_VALUES, # type: ignore[assignment]
**cfg,
) -> Dict[str, Any]:
"""Return PRF scores per feat for a token attribute in UFEATS format.
@@ -250,6 +254,8 @@ class Scorer:
getter (Callable[[Token, str], Any]): Defaults to getattr. If provided,
getter(token, attr) should return the value of the attribute for an
individual token.
+ missing_values (Set[Any]): Attribute values to treat as missing annotation
+ in the reference annotation.
RETURNS (dict): A dictionary containing the per-feat PRF scores under
the key attr_per_feat.
"""
@@ -258,7 +264,7 @@ class Scorer:
pred_doc = example.predicted
gold_doc = example.reference
align = example.alignment
- gold_per_feat = {}
+ gold_per_feat: Dict[str, Set] = {}
missing_indices = set()
for gold_i, token in enumerate(gold_doc):
value = getter(token, attr)
@@ -273,7 +279,7 @@ class Scorer:
gold_per_feat[field].add((gold_i, feat))
else:
missing_indices.add(gold_i)
- pred_per_feat = {}
+ pred_per_feat: Dict[str, Set] = {}
for token in pred_doc:
if token.orth_.isspace():
continue
@@ -350,7 +356,7 @@ class Scorer:
+ [k.label_ for k in getter(pred_doc, attr)]
)
# Set up all labels for per type scoring and prepare gold per type
- gold_per_type = {label: set() for label in labels}
+ gold_per_type: Dict[str, Set] = {label: set() for label in labels}
for label in labels:
if label not in score_per_type:
score_per_type[label] = PRFScore()
@@ -358,16 +364,18 @@ class Scorer:
gold_spans = set()
pred_spans = set()
for span in getter(gold_doc, attr):
+ gold_span: Tuple
if labeled:
gold_span = (span.label_, span.start, span.end - 1)
else:
gold_span = (span.start, span.end - 1)
gold_spans.add(gold_span)
gold_per_type[span.label_].add(gold_span)
- pred_per_type = {label: set() for label in labels}
+ pred_per_type: Dict[str, Set] = {label: set() for label in labels}
for span in example.get_aligned_spans_x2y(
getter(pred_doc, attr), allow_overlap
):
+ pred_span: Tuple
if labeled:
pred_span = (span.label_, span.start, span.end - 1)
else:
@@ -382,7 +390,7 @@ class Scorer:
# Score for all labels
score.score_set(pred_spans, gold_spans)
# Assemble final result
- final_scores = {
+ final_scores: Dict[str, Any] = {
f"{attr}_p": None,
f"{attr}_r": None,
f"{attr}_f": None,
@@ -508,7 +516,7 @@ class Scorer:
sum(auc.score if auc.is_binary() else 0.0 for auc in auc_per_type.values())
/ n_cats
)
- results = {
+ results: Dict[str, Any] = {
f"{attr}_score": None,
f"{attr}_score_desc": None,
f"{attr}_micro_p": micro_prf.precision,
@@ -613,7 +621,7 @@ class Scorer:
head_attr: str = "head",
head_getter: Callable[[Token, str], Token] = getattr,
ignore_labels: Iterable[str] = SimpleFrozenList(),
- missing_values: Set[Any] = MISSING_VALUES,
+ missing_values: Set[Any] = MISSING_VALUES, # type: ignore[assignment]
**cfg,
) -> Dict[str, Any]:
"""Returns the UAS, LAS, and LAS per type scores for dependency
@@ -630,6 +638,8 @@ class Scorer:
head_getter(token, attr) should return the value of the head for an
individual token.
ignore_labels (Tuple): Labels to ignore while scoring (e.g., punct).
+ missing_values (Set[Any]): Attribute values to treat as missing annotation
+ in the reference annotation.
RETURNS (Dict[str, Any]): A dictionary containing the scores:
attr_uas, attr_las, and attr_las_per_type.
@@ -644,7 +654,7 @@ class Scorer:
pred_doc = example.predicted
align = example.alignment
gold_deps = set()
- gold_deps_per_dep = {}
+ gold_deps_per_dep: Dict[str, Set] = {}
for gold_i, token in enumerate(gold_doc):
dep = getter(token, attr)
head = head_getter(token, head_attr)
@@ -659,12 +669,12 @@ class Scorer:
else:
missing_indices.add(gold_i)
pred_deps = set()
- pred_deps_per_dep = {}
+ pred_deps_per_dep: Dict[str, Set] = {}
for token in pred_doc:
if token.orth_.isspace():
continue
if align.x2y.lengths[token.i] != 1:
- gold_i = None
+ gold_i = None # type: ignore
else:
gold_i = align.x2y[token.i].dataXd[0, 0]
if gold_i not in missing_indices:
diff --git a/spacy/strings.pyi b/spacy/strings.pyi
index 57bf71b93..5b4147e12 100644
--- a/spacy/strings.pyi
+++ b/spacy/strings.pyi
@@ -1,7 +1,7 @@
from typing import Optional, Iterable, Iterator, Union, Any
from pathlib import Path
-def get_string_id(key: str) -> int: ...
+def get_string_id(key: Union[str, int]) -> int: ...
class StringStore:
def __init__(
diff --git a/spacy/tests/lang/fr/test_prefix_suffix_infix.py b/spacy/tests/lang/fr/test_prefix_suffix_infix.py
index 2ead34069..7770f807b 100644
--- a/spacy/tests/lang/fr/test_prefix_suffix_infix.py
+++ b/spacy/tests/lang/fr/test_prefix_suffix_infix.py
@@ -1,5 +1,5 @@
import pytest
-from spacy.language import Language
+from spacy.language import Language, BaseDefaults
from spacy.lang.punctuation import TOKENIZER_INFIXES
from spacy.lang.char_classes import ALPHA
@@ -12,7 +12,7 @@ def test_issue768(text, expected_tokens):
SPLIT_INFIX = r"(?<=[{a}]\')(?=[{a}])".format(a=ALPHA)
class FrenchTest(Language):
- class Defaults(Language.Defaults):
+ class Defaults(BaseDefaults):
infixes = TOKENIZER_INFIXES + [SPLIT_INFIX]
fr_tokenizer_w_infix = FrenchTest().tokenizer
diff --git a/spacy/tests/lang/hu/test_tokenizer.py b/spacy/tests/lang/hu/test_tokenizer.py
index fd3acd0a0..0488474ae 100644
--- a/spacy/tests/lang/hu/test_tokenizer.py
+++ b/spacy/tests/lang/hu/test_tokenizer.py
@@ -294,7 +294,7 @@ WIKI_TESTS = [
]
EXTRA_TESTS = (
- DOT_TESTS + QUOTE_TESTS + NUMBER_TESTS + HYPHEN_TESTS + WIKI_TESTS + TYPO_TESTS
+ DOT_TESTS + QUOTE_TESTS + NUMBER_TESTS + HYPHEN_TESTS + WIKI_TESTS + TYPO_TESTS # type: ignore[operator]
)
# normal: default tests + 10% of extra tests
diff --git a/spacy/tests/package/test_requirements.py b/spacy/tests/package/test_requirements.py
index 8e042c9cf..1d51bd609 100644
--- a/spacy/tests/package/test_requirements.py
+++ b/spacy/tests/package/test_requirements.py
@@ -12,6 +12,10 @@ def test_build_dependencies():
"flake8",
"hypothesis",
"pre-commit",
+ "mypy",
+ "types-dataclasses",
+ "types-mock",
+ "types-requests",
]
# ignore language-specific packages that shouldn't be installed by all
libs_ignore_setup = [
diff --git a/spacy/tests/parser/test_ner.py b/spacy/tests/parser/test_ner.py
index a30001b27..21094bcb1 100644
--- a/spacy/tests/parser/test_ner.py
+++ b/spacy/tests/parser/test_ner.py
@@ -2,14 +2,14 @@ import pytest
from numpy.testing import assert_equal
from spacy.attrs import ENT_IOB
-from spacy import util
+from spacy import util, registry
from spacy.lang.en import English
from spacy.language import Language
from spacy.lookups import Lookups
from spacy.pipeline._parser_internals.ner import BiluoPushDown
from spacy.training import Example
from spacy.tokens import Doc, Span
-from spacy.vocab import Vocab, registry
+from spacy.vocab import Vocab
import logging
from ..util import make_tempdir
diff --git a/spacy/tests/pipeline/test_pipe_factories.py b/spacy/tests/pipeline/test_pipe_factories.py
index f1f0c8a6e..631b7c162 100644
--- a/spacy/tests/pipeline/test_pipe_factories.py
+++ b/spacy/tests/pipeline/test_pipe_factories.py
@@ -135,8 +135,8 @@ def test_pipe_class_component_defaults():
self,
nlp: Language,
name: str,
- value1: StrictInt = 10,
- value2: StrictStr = "hello",
+ value1: StrictInt = StrictInt(10),
+ value2: StrictStr = StrictStr("hello"),
):
self.nlp = nlp
self.value1 = value1
@@ -196,7 +196,7 @@ def test_pipe_class_component_model_custom():
@Language.factory(name, default_config=default_config)
class Component:
def __init__(
- self, nlp: Language, model: Model, name: str, value1: StrictInt = 10
+ self, nlp: Language, model: Model, name: str, value1: StrictInt = StrictInt(10)
):
self.nlp = nlp
self.model = model
diff --git a/spacy/tests/pipeline/test_spancat.py b/spacy/tests/pipeline/test_spancat.py
index d4d0617d7..5c3a9d27d 100644
--- a/spacy/tests/pipeline/test_spancat.py
+++ b/spacy/tests/pipeline/test_spancat.py
@@ -6,8 +6,8 @@ from thinc.api import get_current_ops
from spacy import util
from spacy.lang.en import English
from spacy.language import Language
-from spacy.tokens.doc import SpanGroups
from spacy.tokens import SpanGroup
+from spacy.tokens._dict_proxies import SpanGroups
from spacy.training import Example
from spacy.util import fix_random_seed, registry, make_tempdir
diff --git a/spacy/tests/serialize/test_serialize_doc.py b/spacy/tests/serialize/test_serialize_doc.py
index e51c7f45b..23afaf26c 100644
--- a/spacy/tests/serialize/test_serialize_doc.py
+++ b/spacy/tests/serialize/test_serialize_doc.py
@@ -1,5 +1,5 @@
import pytest
-from spacy.tokens.doc import Underscore
+from spacy.tokens.underscore import Underscore
import spacy
from spacy.lang.en import English
diff --git a/spacy/tests/training/test_readers.py b/spacy/tests/training/test_readers.py
index 1f262c011..c0c51b287 100644
--- a/spacy/tests/training/test_readers.py
+++ b/spacy/tests/training/test_readers.py
@@ -28,7 +28,7 @@ def test_readers():
"""
@registry.readers("myreader.v1")
- def myreader() -> Dict[str, Callable[[Language, str], Iterable[Example]]]:
+ def myreader() -> Dict[str, Callable[[Language], Iterable[Example]]]:
annots = {"cats": {"POS": 1.0, "NEG": 0.0}}
def reader(nlp: Language):
diff --git a/spacy/tests/vocab_vectors/test_vectors.py b/spacy/tests/vocab_vectors/test_vectors.py
index 8a7dd22c3..23597455f 100644
--- a/spacy/tests/vocab_vectors/test_vectors.py
+++ b/spacy/tests/vocab_vectors/test_vectors.py
@@ -5,7 +5,7 @@ from thinc.api import get_current_ops
from spacy.vocab import Vocab
from spacy.vectors import Vectors
from spacy.tokenizer import Tokenizer
-from spacy.strings import hash_string
+from spacy.strings import hash_string # type: ignore
from spacy.tokens import Doc
from ..util import add_vecs_to_vocab, get_cosine, make_tempdir
diff --git a/spacy/tokens/_dict_proxies.py b/spacy/tokens/_dict_proxies.py
index 9ee1ad02f..470d3430f 100644
--- a/spacy/tokens/_dict_proxies.py
+++ b/spacy/tokens/_dict_proxies.py
@@ -1,9 +1,10 @@
-from typing import Iterable, Tuple, Union, TYPE_CHECKING
+from typing import Iterable, Tuple, Union, Optional, TYPE_CHECKING
import weakref
from collections import UserDict
import srsly
from .span_group import SpanGroup
+from ..errors import Errors
if TYPE_CHECKING:
# This lets us add type hints for mypy etc. without causing circular imports
@@ -13,7 +14,7 @@ if TYPE_CHECKING:
# Why inherit from UserDict instead of dict here?
# Well, the 'dict' class doesn't necessarily delegate everything nicely,
-# for performance reasons. The UserDict is slower by better behaved.
+# for performance reasons. The UserDict is slower but better behaved.
# See https://treyhunner.com/2019/04/why-you-shouldnt-inherit-from-list-and-dict-in-python/0ww
class SpanGroups(UserDict):
"""A dict-like proxy held by the Doc, to control access to span groups."""
@@ -22,7 +23,7 @@ class SpanGroups(UserDict):
self, doc: "Doc", items: Iterable[Tuple[str, SpanGroup]] = tuple()
) -> None:
self.doc_ref = weakref.ref(doc)
- UserDict.__init__(self, items)
+ UserDict.__init__(self, items) # type: ignore[arg-type]
def __setitem__(self, key: str, value: Union[SpanGroup, Iterable["Span"]]) -> None:
if not isinstance(value, SpanGroup):
@@ -31,11 +32,12 @@ class SpanGroups(UserDict):
UserDict.__setitem__(self, key, value)
def _make_span_group(self, name: str, spans: Iterable["Span"]) -> SpanGroup:
- return SpanGroup(self.doc_ref(), name=name, spans=spans)
+ doc = self._ensure_doc()
+ return SpanGroup(doc, name=name, spans=spans)
- def copy(self, doc: "Doc" = None) -> "SpanGroups":
+ def copy(self, doc: Optional["Doc"] = None) -> "SpanGroups":
if doc is None:
- doc = self.doc_ref()
+ doc = self._ensure_doc()
return SpanGroups(doc).from_bytes(self.to_bytes())
def to_bytes(self) -> bytes:
@@ -47,8 +49,14 @@ class SpanGroups(UserDict):
def from_bytes(self, bytes_data: bytes) -> "SpanGroups":
msg = srsly.msgpack_loads(bytes_data)
self.clear()
- doc = self.doc_ref()
+ doc = self._ensure_doc()
for value_bytes in msg:
group = SpanGroup(doc).from_bytes(value_bytes)
self[group.name] = group
return self
+
+ def _ensure_doc(self) -> "Doc":
+ doc = self.doc_ref()
+ if doc is None:
+ raise ValueError(Errors.E866)
+ return doc
diff --git a/spacy/tokens/_retokenize.pyi b/spacy/tokens/_retokenize.pyi
index b829b71a3..8834d38c0 100644
--- a/spacy/tokens/_retokenize.pyi
+++ b/spacy/tokens/_retokenize.pyi
@@ -2,6 +2,7 @@ from typing import Dict, Any, Union, List, Tuple
from .doc import Doc
from .span import Span
from .token import Token
+from .. import Vocab
class Retokenizer:
def __init__(self, doc: Doc) -> None: ...
@@ -15,3 +16,6 @@ class Retokenizer:
) -> None: ...
def __enter__(self) -> Retokenizer: ...
def __exit__(self, *args: Any) -> None: ...
+
+def normalize_token_attrs(vocab: Vocab, attrs: Dict): ...
+def set_token_attrs(py_token: Token, attrs: Dict): ...
diff --git a/spacy/tokens/_serialize.py b/spacy/tokens/_serialize.py
index e7799d230..510a2ea71 100644
--- a/spacy/tokens/_serialize.py
+++ b/spacy/tokens/_serialize.py
@@ -1,6 +1,7 @@
-from typing import Iterable, Iterator, Union
+from typing import List, Dict, Set, Iterable, Iterator, Union, Optional
from pathlib import Path
import numpy
+from numpy import ndarray
import zlib
import srsly
from thinc.api import NumpyOps
@@ -74,13 +75,13 @@ class DocBin:
self.version = "0.1"
self.attrs = [attr for attr in attrs if attr != ORTH and attr != SPACY]
self.attrs.insert(0, ORTH) # Ensure ORTH is always attrs[0]
- self.tokens = []
- self.spaces = []
- self.cats = []
- self.span_groups = []
- self.user_data = []
- self.flags = []
- self.strings = set()
+ self.tokens: List[ndarray] = []
+ self.spaces: List[ndarray] = []
+ self.cats: List[Dict] = []
+ self.span_groups: List[bytes] = []
+ self.user_data: List[Optional[bytes]] = []
+ self.flags: List[Dict] = []
+ self.strings: Set[str] = set()
self.store_user_data = store_user_data
for doc in docs:
self.add(doc)
@@ -138,11 +139,11 @@ class DocBin:
for i in range(len(self.tokens)):
flags = self.flags[i]
tokens = self.tokens[i]
- spaces = self.spaces[i]
+ spaces: Optional[ndarray] = self.spaces[i]
if flags.get("has_unknown_spaces"):
spaces = None
- doc = Doc(vocab, words=tokens[:, orth_col], spaces=spaces)
- doc = doc.from_array(self.attrs, tokens)
+ doc = Doc(vocab, words=tokens[:, orth_col], spaces=spaces) # type: ignore
+ doc = doc.from_array(self.attrs, tokens) # type: ignore
doc.cats = self.cats[i]
if self.span_groups[i]:
doc.spans.from_bytes(self.span_groups[i])
diff --git a/spacy/tokens/doc.pyi b/spacy/tokens/doc.pyi
index 8688fb91f..2b18cee7a 100644
--- a/spacy/tokens/doc.pyi
+++ b/spacy/tokens/doc.pyi
@@ -1,16 +1,5 @@
-from typing import (
- Callable,
- Protocol,
- Iterable,
- Iterator,
- Optional,
- Union,
- Tuple,
- List,
- Dict,
- Any,
- overload,
-)
+from typing import Callable, Protocol, Iterable, Iterator, Optional
+from typing import Union, Tuple, List, Dict, Any, overload
from cymem.cymem import Pool
from thinc.types import Floats1d, Floats2d, Ints2d
from .span import Span
@@ -24,7 +13,7 @@ from pathlib import Path
import numpy
class DocMethod(Protocol):
- def __call__(self: Doc, *args: Any, **kwargs: Any) -> Any: ...
+ def __call__(self: Doc, *args: Any, **kwargs: Any) -> Any: ... # type: ignore[misc]
class Doc:
vocab: Vocab
@@ -150,6 +139,7 @@ class Doc:
self, attr_id: int, exclude: Optional[Any] = ..., counts: Optional[Any] = ...
) -> Dict[Any, int]: ...
def from_array(self, attrs: List[int], array: Ints2d) -> Doc: ...
+ def to_array(self, py_attr_ids: List[int]) -> numpy.ndarray: ...
@staticmethod
def from_docs(
docs: List[Doc],
diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx
index b3eda26e1..5ea3e1e3b 100644
--- a/spacy/tokens/doc.pyx
+++ b/spacy/tokens/doc.pyx
@@ -914,7 +914,7 @@ cdef class Doc:
can specify attributes by integer ID (e.g. spacy.attrs.LEMMA) or
string name (e.g. 'LEMMA' or 'lemma').
- attr_ids (list[]): A list of attributes (int IDs or string names).
+ py_attr_ids (list[]): A list of attributes (int IDs or string names).
RETURNS (numpy.ndarray[long, ndim=2]): A feature matrix, with one row
per word, and one column per attribute indicated in the input
`attr_ids`.
diff --git a/spacy/tokens/morphanalysis.pyi b/spacy/tokens/morphanalysis.pyi
index c7e05e58f..b86203cc4 100644
--- a/spacy/tokens/morphanalysis.pyi
+++ b/spacy/tokens/morphanalysis.pyi
@@ -11,8 +11,8 @@ class MorphAnalysis:
def __iter__(self) -> Iterator[str]: ...
def __len__(self) -> int: ...
def __hash__(self) -> int: ...
- def __eq__(self, other: MorphAnalysis) -> bool: ...
- def __ne__(self, other: MorphAnalysis) -> bool: ...
+ def __eq__(self, other: MorphAnalysis) -> bool: ... # type: ignore[override]
+ def __ne__(self, other: MorphAnalysis) -> bool: ... # type: ignore[override]
def get(self, field: Any) -> List[str]: ...
def to_json(self) -> str: ...
def to_dict(self) -> Dict[str, str]: ...
diff --git a/spacy/tokens/span.pyi b/spacy/tokens/span.pyi
index 4f65abace..697051e81 100644
--- a/spacy/tokens/span.pyi
+++ b/spacy/tokens/span.pyi
@@ -7,7 +7,7 @@ from ..lexeme import Lexeme
from ..vocab import Vocab
class SpanMethod(Protocol):
- def __call__(self: Span, *args: Any, **kwargs: Any) -> Any: ...
+ def __call__(self: Span, *args: Any, **kwargs: Any) -> Any: ... # type: ignore[misc]
class Span:
@classmethod
@@ -45,7 +45,7 @@ class Span:
doc: Doc,
start: int,
end: int,
- label: int = ...,
+ label: Union[str, int] = ...,
vector: Optional[Floats1d] = ...,
vector_norm: Optional[float] = ...,
kb_id: Optional[int] = ...,
@@ -65,6 +65,8 @@ class Span:
def get_lca_matrix(self) -> Ints2d: ...
def similarity(self, other: Union[Doc, Span, Token, Lexeme]) -> float: ...
@property
+ def doc(self) -> Doc: ...
+ @property
def vocab(self) -> Vocab: ...
@property
def sent(self) -> Span: ...
diff --git a/spacy/tokens/span.pyx b/spacy/tokens/span.pyx
index 050a70d02..c9c807d7d 100644
--- a/spacy/tokens/span.pyx
+++ b/spacy/tokens/span.pyx
@@ -88,7 +88,7 @@ cdef class Span:
doc (Doc): The parent document.
start (int): The index of the first token of the span.
end (int): The index of the first token after the span.
- label (uint64): A label to attach to the Span, e.g. for named entities.
+ label (int or str): A label to attach to the Span, e.g. for named entities.
vector (ndarray[ndim=1, dtype='float32']): A meaning representation
of the span.
vector_norm (float): The L2 norm of the span's vector representation.
diff --git a/spacy/tokens/span_group.pyi b/spacy/tokens/span_group.pyi
index 4bd6bec27..26efc3ba0 100644
--- a/spacy/tokens/span_group.pyi
+++ b/spacy/tokens/span_group.pyi
@@ -3,6 +3,8 @@ from .doc import Doc
from .span import Span
class SpanGroup:
+ name: str
+ attrs: Dict[str, Any]
def __init__(
self,
doc: Doc,
diff --git a/spacy/tokens/token.pyi b/spacy/tokens/token.pyi
index 23d028ffd..bd585d034 100644
--- a/spacy/tokens/token.pyi
+++ b/spacy/tokens/token.pyi
@@ -16,7 +16,7 @@ from ..vocab import Vocab
from .underscore import Underscore
class TokenMethod(Protocol):
- def __call__(self: Token, *args: Any, **kwargs: Any) -> Any: ...
+ def __call__(self: Token, *args: Any, **kwargs: Any) -> Any: ... # type: ignore[misc]
class Token:
i: int
diff --git a/spacy/tokens/token.pyx b/spacy/tokens/token.pyx
index 9277eb6fa..c5baae510 100644
--- a/spacy/tokens/token.pyx
+++ b/spacy/tokens/token.pyx
@@ -600,7 +600,7 @@ cdef class Token:
yield from word.subtree
@property
- def left_edge(self):
+ def left_edge(self) -> int:
"""The leftmost token of this token's syntactic descendents.
RETURNS (Token): The first token such that `self.is_ancestor(token)`.
@@ -608,7 +608,7 @@ cdef class Token:
return self.doc[self.c.l_edge]
@property
- def right_edge(self):
+ def right_edge(self) -> int:
"""The rightmost token of this token's syntactic descendents.
RETURNS (Token): The last token such that `self.is_ancestor(token)`.
diff --git a/spacy/tokens/underscore.py b/spacy/tokens/underscore.py
index b7966fd6e..7fa7bf095 100644
--- a/spacy/tokens/underscore.py
+++ b/spacy/tokens/underscore.py
@@ -1,3 +1,4 @@
+from typing import Dict, Any
import functools
import copy
@@ -6,9 +7,9 @@ from ..errors import Errors
class Underscore:
mutable_types = (dict, list, set)
- doc_extensions = {}
- span_extensions = {}
- token_extensions = {}
+ doc_extensions: Dict[Any, Any] = {}
+ span_extensions: Dict[Any, Any] = {}
+ token_extensions: Dict[Any, Any] = {}
def __init__(self, extensions, obj, start=None, end=None):
object.__setattr__(self, "_extensions", extensions)
diff --git a/spacy/training/__init__.py b/spacy/training/__init__.py
index 34cde0ba9..99fe7c19f 100644
--- a/spacy/training/__init__.py
+++ b/spacy/training/__init__.py
@@ -7,5 +7,5 @@ from .iob_utils import offsets_to_biluo_tags, biluo_tags_to_offsets # noqa: F40
from .iob_utils import biluo_tags_to_spans, tags_to_entities # noqa: F401
from .gold_io import docs_to_json, read_json_file # noqa: F401
from .batchers import minibatch_by_padded_size, minibatch_by_words # noqa: F401
-from .loggers import console_logger, wandb_logger # noqa: F401
+from .loggers import console_logger, wandb_logger_v3 as wandb_logger # noqa: F401
from .callbacks import create_copy_from_base_model # noqa: F401
diff --git a/spacy/training/augment.py b/spacy/training/augment.py
index 0dae92143..63b54034c 100644
--- a/spacy/training/augment.py
+++ b/spacy/training/augment.py
@@ -22,8 +22,8 @@ class OrthVariantsPaired(BaseModel):
class OrthVariants(BaseModel):
- paired: List[OrthVariantsPaired] = {}
- single: List[OrthVariantsSingle] = {}
+ paired: List[OrthVariantsPaired] = []
+ single: List[OrthVariantsSingle] = []
@registry.augmenters("spacy.orth_variants.v1")
@@ -76,7 +76,7 @@ def lower_casing_augmenter(
def orth_variants_augmenter(
nlp: "Language",
example: Example,
- orth_variants: dict,
+ orth_variants: Dict,
*,
level: float = 0.0,
lower: float = 0.0,
diff --git a/spacy/training/batchers.py b/spacy/training/batchers.py
index e79ba79b0..f0b6c3123 100644
--- a/spacy/training/batchers.py
+++ b/spacy/training/batchers.py
@@ -1,4 +1,4 @@
-from typing import Union, Iterable, Sequence, TypeVar, List, Callable
+from typing import Union, Iterable, Sequence, TypeVar, List, Callable, Iterator
from typing import Optional, Any
from functools import partial
import itertools
@@ -6,7 +6,7 @@ import itertools
from ..util import registry, minibatch
-Sizing = Union[Iterable[int], int]
+Sizing = Union[Sequence[int], int]
ItemT = TypeVar("ItemT")
BatcherT = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]]
@@ -24,7 +24,7 @@ def configure_minibatch_by_padded_size(
The padded size is defined as the maximum length of sequences within the
batch multiplied by the number of sequences in the batch.
- size (int or Iterable[int]): The largest padded size to batch sequences into.
+ size (int or Sequence[int]): The largest padded size to batch sequences into.
Can be a single integer, or a sequence, allowing for variable batch sizes.
buffer (int): The number of sequences to accumulate before sorting by length.
A larger buffer will result in more even sizing, but if the buffer is
@@ -56,7 +56,7 @@ def configure_minibatch_by_words(
) -> BatcherT:
"""Create a batcher that uses the "minibatch by words" strategy.
- size (int or Iterable[int]): The target number of words per batch.
+ size (int or Sequence[int]): The target number of words per batch.
Can be a single integer, or a sequence, allowing for variable batch sizes.
tolerance (float): What percentage of the size to allow batches to exceed.
discard_oversize (bool): Whether to discard sequences that by themselves
@@ -80,7 +80,7 @@ def configure_minibatch(
) -> BatcherT:
"""Create a batcher that creates batches of the specified size.
- size (int or Iterable[int]): The target number of items per batch.
+ size (int or Sequence[int]): The target number of items per batch.
Can be a single integer, or a sequence, allowing for variable batch sizes.
"""
optionals = {"get_length": get_length} if get_length is not None else {}
@@ -100,7 +100,7 @@ def minibatch_by_padded_size(
The padded size is defined as the maximum length of sequences within the
batch multiplied by the number of sequences in the batch.
- size (int): The largest padded size to batch sequences into.
+ size (int or Sequence[int]): The largest padded size to batch sequences into.
buffer (int): The number of sequences to accumulate before sorting by length.
A larger buffer will result in more even sizing, but if the buffer is
very large, the iteration order will be less random, which can result
@@ -111,9 +111,9 @@ def minibatch_by_padded_size(
The `len` function is used by default.
"""
if isinstance(size, int):
- size_ = itertools.repeat(size)
+ size_ = itertools.repeat(size) # type: Iterator[int]
else:
- size_ = size
+ size_ = iter(size)
for outer_batch in minibatch(seqs, size=buffer):
outer_batch = list(outer_batch)
target_size = next(size_)
@@ -138,7 +138,7 @@ def minibatch_by_words(
themselves, or be discarded if discard_oversize=True.
seqs (Iterable[Sequence]): The sequences to minibatch.
- size (int or Iterable[int]): The target number of words per batch.
+ size (int or Sequence[int]): The target number of words per batch.
Can be a single integer, or a sequence, allowing for variable batch sizes.
tolerance (float): What percentage of the size to allow batches to exceed.
discard_oversize (bool): Whether to discard sequences that by themselves
@@ -147,11 +147,9 @@ def minibatch_by_words(
item. The `len` function is used by default.
"""
if isinstance(size, int):
- size_ = itertools.repeat(size)
- elif isinstance(size, List):
- size_ = iter(size)
+ size_ = itertools.repeat(size) # type: Iterator[int]
else:
- size_ = size
+ size_ = iter(size)
target_size = next(size_)
tol_size = target_size * tolerance
batch = []
@@ -216,7 +214,7 @@ def _batch_by_length(
lengths_indices = [(get_length(seq), i) for i, seq in enumerate(seqs)]
lengths_indices.sort()
batches = []
- batch = []
+ batch: List[int] = []
for length, i in lengths_indices:
if not batch:
batch.append(i)
diff --git a/spacy/training/callbacks.py b/spacy/training/callbacks.py
index 2a21be98c..426fddf90 100644
--- a/spacy/training/callbacks.py
+++ b/spacy/training/callbacks.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Callable, Optional
from ..errors import Errors
from ..language import Language
from ..util import load_model, registry, logger
@@ -8,7 +8,7 @@ from ..util import load_model, registry, logger
def create_copy_from_base_model(
tokenizer: Optional[str] = None,
vocab: Optional[str] = None,
-) -> Language:
+) -> Callable[[Language], Language]:
def copy_from_base_model(nlp):
if tokenizer:
logger.info(f"Copying tokenizer from: {tokenizer}")
diff --git a/spacy/training/corpus.py b/spacy/training/corpus.py
index 606dbfb4a..b30d918fd 100644
--- a/spacy/training/corpus.py
+++ b/spacy/training/corpus.py
@@ -41,8 +41,8 @@ def create_docbin_reader(
@util.registry.readers("spacy.JsonlCorpus.v1")
def create_jsonl_reader(
- path: Optional[Path], min_length: int = 0, max_length: int = 0, limit: int = 0
-) -> Callable[["Language"], Iterable[Doc]]:
+ path: Union[str, Path], min_length: int = 0, max_length: int = 0, limit: int = 0
+) -> Callable[["Language"], Iterable[Example]]:
return JsonlCorpus(path, min_length=min_length, max_length=max_length, limit=limit)
@@ -129,15 +129,15 @@ class Corpus:
"""
ref_docs = self.read_docbin(nlp.vocab, walk_corpus(self.path, FILE_TYPE))
if self.shuffle:
- ref_docs = list(ref_docs)
- random.shuffle(ref_docs)
+ ref_docs = list(ref_docs) # type: ignore
+ random.shuffle(ref_docs) # type: ignore
if self.gold_preproc:
examples = self.make_examples_gold_preproc(nlp, ref_docs)
else:
examples = self.make_examples(nlp, ref_docs)
for real_eg in examples:
- for augmented_eg in self.augmenter(nlp, real_eg):
+ for augmented_eg in self.augmenter(nlp, real_eg): # type: ignore[operator]
yield augmented_eg
def _make_example(
@@ -190,7 +190,7 @@ class Corpus:
i = 0
for loc in locs:
loc = util.ensure_path(loc)
- if loc.parts[-1].endswith(FILE_TYPE):
+ if loc.parts[-1].endswith(FILE_TYPE): # type: ignore[union-attr]
doc_bin = DocBin().from_disk(loc)
docs = doc_bin.get_docs(vocab)
for doc in docs:
@@ -202,7 +202,7 @@ class Corpus:
class JsonlCorpus:
- """Iterate Doc objects from a file or directory of jsonl
+ """Iterate Example objects from a file or directory of jsonl
formatted raw text files.
path (Path): The directory or filename to read from.
diff --git a/spacy/training/initialize.py b/spacy/training/initialize.py
index 4eb8ea276..96abcc7cd 100644
--- a/spacy/training/initialize.py
+++ b/spacy/training/initialize.py
@@ -106,7 +106,7 @@ def init_vocab(
data: Optional[Path] = None,
lookups: Optional[Lookups] = None,
vectors: Optional[str] = None,
-) -> "Language":
+) -> None:
if lookups:
nlp.vocab.lookups = lookups
logger.info(f"Added vocab lookups: {', '.join(lookups.tables)}")
@@ -164,7 +164,7 @@ def load_vectors_into_model(
logger.warning(Warnings.W112.format(name=name))
for lex in nlp.vocab:
- lex.rank = nlp.vocab.vectors.key2row.get(lex.orth, OOV_RANK)
+ lex.rank = nlp.vocab.vectors.key2row.get(lex.orth, OOV_RANK) # type: ignore[attr-defined]
def init_tok2vec(
@@ -203,7 +203,7 @@ def convert_vectors(
nlp.vocab.vectors = Vectors(data=numpy.load(vectors_loc.open("rb")))
for lex in nlp.vocab:
if lex.rank and lex.rank != OOV_RANK:
- nlp.vocab.vectors.add(lex.orth, row=lex.rank)
+ nlp.vocab.vectors.add(lex.orth, row=lex.rank) # type: ignore[attr-defined]
else:
if vectors_loc:
logger.info(f"Reading vectors from {vectors_loc}")
@@ -251,14 +251,14 @@ def open_file(loc: Union[str, Path]) -> IO:
"""Handle .gz, .tar.gz or unzipped files"""
loc = ensure_path(loc)
if tarfile.is_tarfile(str(loc)):
- return tarfile.open(str(loc), "r:gz")
+ return tarfile.open(str(loc), "r:gz") # type: ignore[return-value]
elif loc.parts[-1].endswith("gz"):
- return (line.decode("utf8") for line in gzip.open(str(loc), "r"))
+ return (line.decode("utf8") for line in gzip.open(str(loc), "r")) # type: ignore[return-value]
elif loc.parts[-1].endswith("zip"):
zip_file = zipfile.ZipFile(str(loc))
names = zip_file.namelist()
file_ = zip_file.open(names[0])
- return (line.decode("utf8") for line in file_)
+ return (line.decode("utf8") for line in file_) # type: ignore[return-value]
else:
return loc.open("r", encoding="utf8")
diff --git a/spacy/training/iob_utils.py b/spacy/training/iob_utils.py
index 42dae8fc4..64492c2bc 100644
--- a/spacy/training/iob_utils.py
+++ b/spacy/training/iob_utils.py
@@ -1,4 +1,4 @@
-from typing import List, Tuple, Iterable, Union, Iterator
+from typing import List, Dict, Tuple, Iterable, Union, Iterator
import warnings
from ..errors import Errors, Warnings
@@ -6,7 +6,7 @@ from ..tokens import Span, Doc
def iob_to_biluo(tags: Iterable[str]) -> List[str]:
- out = []
+ out: List[str] = []
tags = list(tags)
while tags:
out.extend(_consume_os(tags))
@@ -90,7 +90,7 @@ def offsets_to_biluo_tags(
>>> assert tags == ["O", "O", 'U-LOC', "O"]
"""
# Ensure no overlapping entity labels exist
- tokens_in_ents = {}
+ tokens_in_ents: Dict[int, Tuple[int, int, Union[str, int]]] = {}
starts = {token.idx: token.i for token in doc}
ends = {token.idx + len(token): token.i for token in doc}
biluo = ["-" for _ in doc]
@@ -199,14 +199,18 @@ def tags_to_entities(tags: Iterable[str]) -> List[Tuple[str, int, int]]:
pass
elif tag.startswith("I"):
if start is None:
- raise ValueError(Errors.E067.format(start="I", tags=tags[: i + 1]))
+ raise ValueError(
+ Errors.E067.format(start="I", tags=list(tags)[: i + 1])
+ )
elif tag.startswith("U"):
entities.append((tag[2:], i, i))
elif tag.startswith("B"):
start = i
elif tag.startswith("L"):
if start is None:
- raise ValueError(Errors.E067.format(start="L", tags=tags[: i + 1]))
+ raise ValueError(
+ Errors.E067.format(start="L", tags=list(tags)[: i + 1])
+ )
entities.append((tag[2:], start, i))
start = None
else:
diff --git a/spacy/training/loggers.py b/spacy/training/loggers.py
index 137e89e56..602e0ff3e 100644
--- a/spacy/training/loggers.py
+++ b/spacy/training/loggers.py
@@ -102,7 +102,7 @@ def console_logger(progress_bar: bool = False):
@registry.loggers("spacy.WandbLogger.v2")
-def wandb_logger(
+def wandb_logger_v2(
project_name: str,
remove_config_values: List[str] = [],
model_log_interval: Optional[int] = None,
@@ -180,7 +180,7 @@ def wandb_logger(
@registry.loggers("spacy.WandbLogger.v3")
-def wandb_logger(
+def wandb_logger_v3(
project_name: str,
remove_config_values: List[str] = [],
model_log_interval: Optional[int] = None,
diff --git a/spacy/training/loop.py b/spacy/training/loop.py
index 09c54fc9f..06372cbb0 100644
--- a/spacy/training/loop.py
+++ b/spacy/training/loop.py
@@ -32,7 +32,7 @@ def train(
"""Train a pipeline.
nlp (Language): The initialized nlp object with the full config.
- output_path (Path): Optional output path to save trained model to.
+ output_path (Optional[Path]): Optional output path to save trained model to.
use_gpu (int): Whether to train on GPU. Make sure to call require_gpu
before calling this function.
stdout (file): A file-like object to write output messages. To disable
@@ -194,17 +194,17 @@ def train_while_improving(
else:
dropouts = dropout
results = []
- losses = {}
+ losses: Dict[str, float] = {}
words_seen = 0
start_time = timer()
for step, (epoch, batch) in enumerate(train_data):
- dropout = next(dropouts)
+ dropout = next(dropouts) # type: ignore
for subbatch in subdivide_batch(batch, accumulate_gradient):
nlp.update(
subbatch,
drop=dropout,
losses=losses,
- sgd=False,
+ sgd=False, # type: ignore[arg-type]
exclude=exclude,
annotates=annotating_components,
)
@@ -214,9 +214,9 @@ def train_while_improving(
name not in exclude
and hasattr(proc, "is_trainable")
and proc.is_trainable
- and proc.model not in (True, False, None)
+ and proc.model not in (True, False, None) # type: ignore[attr-defined]
):
- proc.finish_update(optimizer)
+ proc.finish_update(optimizer) # type: ignore[attr-defined]
optimizer.step_schedules()
if not (step % eval_frequency):
if optimizer.averages:
@@ -310,13 +310,13 @@ def create_train_batches(
):
epoch = 0
if max_epochs >= 0:
- examples = list(corpus(nlp))
+ examples = list(corpus(nlp)) # type: Iterable[Example]
if not examples:
# Raise error if no data
raise ValueError(Errors.E986)
while max_epochs < 1 or epoch != max_epochs:
if max_epochs >= 0:
- random.shuffle(examples)
+ random.shuffle(examples) # type: ignore
else:
examples = corpus(nlp)
for batch in batcher(examples):
@@ -353,7 +353,7 @@ def create_before_to_disk_callback(
return before_to_disk
-def clean_output_dir(path: Union[str, Path]) -> None:
+def clean_output_dir(path: Optional[Path]) -> None:
"""Remove an existing output directory. Typically used to ensure that that
a directory like model-best and its contents aren't just being overwritten
by nlp.to_disk, which could preserve existing subdirectories (e.g.
diff --git a/spacy/training/pretrain.py b/spacy/training/pretrain.py
index 0228f2947..2328ebbc7 100644
--- a/spacy/training/pretrain.py
+++ b/spacy/training/pretrain.py
@@ -93,7 +93,7 @@ def ensure_docs(examples_or_docs: Iterable[Union[Doc, Example]]) -> List[Doc]:
def _resume_model(
- model: Model, resume_path: Path, epoch_resume: int, silent: bool = True
+ model: Model, resume_path: Path, epoch_resume: Optional[int], silent: bool = True
) -> int:
msg = Printer(no_print=silent)
msg.info(f"Resume training tok2vec from: {resume_path}")
diff --git a/spacy/util.py b/spacy/util.py
index 0aa7c4c17..cf62a4ecd 100644
--- a/spacy/util.py
+++ b/spacy/util.py
@@ -1,4 +1,5 @@
-from typing import List, Union, Dict, Any, Optional, Iterable, Callable, Tuple
+from typing import List, Mapping, NoReturn, Union, Dict, Any, Set
+from typing import Optional, Iterable, Callable, Tuple, Type
from typing import Iterator, Type, Pattern, Generator, TYPE_CHECKING
from types import ModuleType
import os
@@ -50,6 +51,7 @@ from . import about
if TYPE_CHECKING:
# This lets us add type hints for mypy etc. without causing circular imports
from .language import Language # noqa: F401
+ from .pipeline import Pipe # noqa: F401
from .tokens import Doc, Span # noqa: F401
from .vocab import Vocab # noqa: F401
@@ -255,7 +257,7 @@ def lang_class_is_loaded(lang: str) -> bool:
return lang in registry.languages
-def get_lang_class(lang: str) -> "Language":
+def get_lang_class(lang: str) -> Type["Language"]:
"""Import and load a Language class.
lang (str): Two-letter language code, e.g. 'en'.
@@ -269,7 +271,7 @@ def get_lang_class(lang: str) -> "Language":
module = importlib.import_module(f".lang.{lang}", "spacy")
except ImportError as err:
raise ImportError(Errors.E048.format(lang=lang, err=err)) from err
- set_lang_class(lang, getattr(module, module.__all__[0]))
+ set_lang_class(lang, getattr(module, module.__all__[0])) # type: ignore[attr-defined]
return registry.languages.get(lang)
@@ -344,13 +346,13 @@ def load_model(
if name.startswith("blank:"): # shortcut for blank model
return get_lang_class(name.replace("blank:", ""))()
if is_package(name): # installed as package
- return load_model_from_package(name, **kwargs)
+ return load_model_from_package(name, **kwargs) # type: ignore[arg-type]
if Path(name).exists(): # path to model data directory
- return load_model_from_path(Path(name), **kwargs)
+ return load_model_from_path(Path(name), **kwargs) # type: ignore[arg-type]
elif hasattr(name, "exists"): # Path or Path-like to model data
- return load_model_from_path(name, **kwargs)
+ return load_model_from_path(name, **kwargs) # type: ignore[arg-type]
if name in OLD_MODEL_SHORTCUTS:
- raise IOError(Errors.E941.format(name=name, full=OLD_MODEL_SHORTCUTS[name]))
+ raise IOError(Errors.E941.format(name=name, full=OLD_MODEL_SHORTCUTS[name])) # type: ignore[index]
raise IOError(Errors.E050.format(name=name))
@@ -377,11 +379,11 @@ def load_model_from_package(
RETURNS (Language): The loaded nlp object.
"""
cls = importlib.import_module(name)
- return cls.load(vocab=vocab, disable=disable, exclude=exclude, config=config)
+ return cls.load(vocab=vocab, disable=disable, exclude=exclude, config=config) # type: ignore[attr-defined]
def load_model_from_path(
- model_path: Union[str, Path],
+ model_path: Path,
*,
meta: Optional[Dict[str, Any]] = None,
vocab: Union["Vocab", bool] = True,
@@ -392,7 +394,7 @@ def load_model_from_path(
"""Load a model from a data directory path. Creates Language class with
pipeline from config.cfg and then calls from_disk() with path.
- name (str): Package name or model path.
+ model_path (Path): Mmodel path.
meta (Dict[str, Any]): Optional model meta.
vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
a new Vocab object will be created.
@@ -474,7 +476,9 @@ def get_sourced_components(
}
-def resolve_dot_names(config: Config, dot_names: List[Optional[str]]) -> Tuple[Any]:
+def resolve_dot_names(
+ config: Config, dot_names: List[Optional[str]]
+) -> Tuple[Any, ...]:
"""Resolve one or more "dot notation" names, e.g. corpora.train.
The paths could point anywhere into the config, so we don't know which
top-level section we'll be looking within.
@@ -484,7 +488,7 @@ def resolve_dot_names(config: Config, dot_names: List[Optional[str]]) -> Tuple[A
"""
# TODO: include schema?
resolved = {}
- output = []
+ output: List[Any] = []
errors = []
for name in dot_names:
if name is None:
@@ -500,7 +504,7 @@ def resolve_dot_names(config: Config, dot_names: List[Optional[str]]) -> Tuple[A
result = registry.resolve(config[section])
resolved[section] = result
try:
- output.append(dot_to_object(resolved, name))
+ output.append(dot_to_object(resolved, name)) # type: ignore[arg-type]
except KeyError:
msg = f"not a valid section reference: {name}"
errors.append({"loc": name.split("."), "msg": msg})
@@ -604,8 +608,8 @@ def get_package_version(name: str) -> Optional[str]:
RETURNS (str / None): The version or None if package not installed.
"""
try:
- return importlib_metadata.version(name)
- except importlib_metadata.PackageNotFoundError:
+ return importlib_metadata.version(name) # type: ignore[attr-defined]
+ except importlib_metadata.PackageNotFoundError: # type: ignore[attr-defined]
return None
@@ -628,7 +632,7 @@ def is_compatible_version(
constraint = f"=={constraint}"
try:
spec = SpecifierSet(constraint)
- version = Version(version)
+ version = Version(version) # type: ignore[assignment]
except (InvalidSpecifier, InvalidVersion):
return None
spec.prereleases = prereleases
@@ -742,7 +746,7 @@ def load_meta(path: Union[str, Path]) -> Dict[str, Any]:
if "spacy_version" in meta:
if not is_compatible_version(about.__version__, meta["spacy_version"]):
lower_version = get_model_lower_version(meta["spacy_version"])
- lower_version = get_minor_version(lower_version)
+ lower_version = get_minor_version(lower_version) # type: ignore[arg-type]
if lower_version is not None:
lower_version = "v" + lower_version
elif "spacy_git_version" in meta:
@@ -784,7 +788,7 @@ def is_package(name: str) -> bool:
RETURNS (bool): True if installed package, False if not.
"""
try:
- importlib_metadata.distribution(name)
+ importlib_metadata.distribution(name) # type: ignore[attr-defined]
return True
except: # noqa: E722
return False
@@ -845,7 +849,7 @@ def run_command(
*,
stdin: Optional[Any] = None,
capture: bool = False,
-) -> Optional[subprocess.CompletedProcess]:
+) -> subprocess.CompletedProcess:
"""Run a command on the command line as a subprocess. If the subprocess
returns a non-zero exit code, a system exit is performed.
@@ -888,8 +892,8 @@ def run_command(
message += f"\n\nProcess log (stdout and stderr):\n\n"
message += ret.stdout
error = subprocess.SubprocessError(message)
- error.ret = ret
- error.command = cmd_str
+ error.ret = ret # type: ignore[attr-defined]
+ error.command = cmd_str # type: ignore[attr-defined]
raise error
elif ret.returncode != 0:
sys.exit(ret.returncode)
@@ -897,7 +901,7 @@ def run_command(
@contextmanager
-def working_dir(path: Union[str, Path]) -> None:
+def working_dir(path: Union[str, Path]) -> Iterator[Path]:
"""Change current working directory and returns to previous on exit.
path (str / Path): The directory to navigate to.
@@ -945,7 +949,7 @@ def is_in_jupyter() -> bool:
"""
# https://stackoverflow.com/a/39662359/6400719
try:
- shell = get_ipython().__class__.__name__
+ shell = get_ipython().__class__.__name__ # type: ignore[name-defined]
if shell == "ZMQInteractiveShell":
return True # Jupyter notebook or qtconsole
except NameError:
@@ -1027,7 +1031,7 @@ def compile_prefix_regex(entries: Iterable[Union[str, Pattern]]) -> Pattern:
spacy.lang.punctuation.TOKENIZER_PREFIXES.
RETURNS (Pattern): The regex object. to be used for Tokenizer.prefix_search.
"""
- expression = "|".join(["^" + piece for piece in entries if piece.strip()])
+ expression = "|".join(["^" + piece for piece in entries if piece.strip()]) # type: ignore[operator, union-attr]
return re.compile(expression)
@@ -1038,7 +1042,7 @@ def compile_suffix_regex(entries: Iterable[Union[str, Pattern]]) -> Pattern:
spacy.lang.punctuation.TOKENIZER_SUFFIXES.
RETURNS (Pattern): The regex object. to be used for Tokenizer.suffix_search.
"""
- expression = "|".join([piece + "$" for piece in entries if piece.strip()])
+ expression = "|".join([piece + "$" for piece in entries if piece.strip()]) # type: ignore[operator, union-attr]
return re.compile(expression)
@@ -1049,7 +1053,7 @@ def compile_infix_regex(entries: Iterable[Union[str, Pattern]]) -> Pattern:
spacy.lang.punctuation.TOKENIZER_INFIXES.
RETURNS (regex object): The regex object. to be used for Tokenizer.infix_finditer.
"""
- expression = "|".join([piece for piece in entries if piece.strip()])
+ expression = "|".join([piece for piece in entries if piece.strip()]) # type: ignore[misc, union-attr]
return re.compile(expression)
@@ -1071,7 +1075,7 @@ def _get_attr_unless_lookup(
) -> Any:
for lookup in lookups:
if string in lookup:
- return lookup[string]
+ return lookup[string] # type: ignore[index]
return default_func(string)
@@ -1153,7 +1157,7 @@ def filter_spans(spans: Iterable["Span"]) -> List["Span"]:
get_sort_key = lambda span: (span.end - span.start, -span.start)
sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
result = []
- seen_tokens = set()
+ seen_tokens: Set[int] = set()
for span in sorted_spans:
# Check for end - 1 here because boundaries are inclusive
if span.start not in seen_tokens and span.end - 1 not in seen_tokens:
@@ -1172,7 +1176,7 @@ def from_bytes(
setters: Dict[str, Callable[[bytes], Any]],
exclude: Iterable[str],
) -> None:
- return from_dict(srsly.msgpack_loads(bytes_data), setters, exclude)
+ return from_dict(srsly.msgpack_loads(bytes_data), setters, exclude) # type: ignore[return-value]
def to_dict(
@@ -1234,8 +1238,8 @@ def import_file(name: str, loc: Union[str, Path]) -> ModuleType:
RETURNS: The loaded module.
"""
spec = importlib.util.spec_from_file_location(name, str(loc))
- module = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(module)
+ module = importlib.util.module_from_spec(spec) # type: ignore[arg-type]
+ spec.loader.exec_module(module) # type: ignore[union-attr]
return module
@@ -1325,7 +1329,7 @@ def dot_to_dict(values: Dict[str, Any]) -> Dict[str, dict]:
values (Dict[str, Any]): The key/value pairs to convert.
RETURNS (Dict[str, dict]): The converted values.
"""
- result = {}
+ result: Dict[str, dict] = {}
for key, value in values.items():
path = result
parts = key.lower().split(".")
@@ -1407,9 +1411,9 @@ def get_arg_names(func: Callable) -> List[str]:
def combine_score_weights(
- weights: List[Dict[str, float]],
- overrides: Dict[str, Optional[Union[float, int]]] = SimpleFrozenDict(),
-) -> Dict[str, float]:
+ weights: List[Dict[str, Optional[float]]],
+ overrides: Dict[str, Optional[float]] = SimpleFrozenDict(),
+) -> Dict[str, Optional[float]]:
"""Combine and normalize score weights defined by components, e.g.
{"ents_r": 0.2, "ents_p": 0.3, "ents_f": 0.5} and {"some_other_score": 1.0}.
@@ -1421,7 +1425,9 @@ def combine_score_weights(
# We divide each weight by the total weight sum.
# We first need to extract all None/null values for score weights that
# shouldn't be shown in the table *or* be weighted
- result = {key: value for w_dict in weights for (key, value) in w_dict.items()}
+ result: Dict[str, Optional[float]] = {
+ key: value for w_dict in weights for (key, value) in w_dict.items()
+ }
result.update(overrides)
weight_sum = sum([v if v else 0.0 for v in result.values()])
for key, value in result.items():
@@ -1443,13 +1449,13 @@ class DummyTokenizer:
def to_bytes(self, **kwargs):
return b""
- def from_bytes(self, _bytes_data, **kwargs):
+ def from_bytes(self, data: bytes, **kwargs) -> "DummyTokenizer":
return self
- def to_disk(self, _path, **kwargs):
+ def to_disk(self, path: Union[str, Path], **kwargs) -> None:
return None
- def from_disk(self, _path, **kwargs):
+ def from_disk(self, path: Union[str, Path], **kwargs) -> "DummyTokenizer":
return self
@@ -1511,7 +1517,13 @@ def check_bool_env_var(env_var: str) -> bool:
return bool(value)
-def _pipe(docs, proc, name, default_error_handler, kwargs):
+def _pipe(
+ docs: Iterable["Doc"],
+ proc: "Pipe",
+ name: str,
+ default_error_handler: Callable[[str, "Pipe", List["Doc"], Exception], NoReturn],
+ kwargs: Mapping[str, Any],
+) -> Iterator["Doc"]:
if hasattr(proc, "pipe"):
yield from proc.pipe(docs, **kwargs)
else:
@@ -1525,7 +1537,7 @@ def _pipe(docs, proc, name, default_error_handler, kwargs):
kwargs.pop(arg)
for doc in docs:
try:
- doc = proc(doc, **kwargs)
+ doc = proc(doc, **kwargs) # type: ignore[call-arg]
yield doc
except Exception as e:
error_handler(name, proc, [doc], e)
@@ -1589,7 +1601,7 @@ def packages_distributions() -> Dict[str, List[str]]:
it's not available in the builtin importlib.metadata.
"""
pkg_to_dist = defaultdict(list)
- for dist in importlib_metadata.distributions():
+ for dist in importlib_metadata.distributions(): # type: ignore[attr-defined]
for pkg in (dist.read_text("top_level.txt") or "").split():
pkg_to_dist[pkg].append(dist.metadata["Name"])
return dict(pkg_to_dist)
diff --git a/spacy/vocab.pyi b/spacy/vocab.pyi
index 0a8ef6198..304ac62df 100644
--- a/spacy/vocab.pyi
+++ b/spacy/vocab.pyi
@@ -1,26 +1,28 @@
-from typing import (
- Callable,
- Iterator,
- Optional,
- Union,
- Tuple,
- List,
- Dict,
- Any,
-)
+from typing import Callable, Iterator, Optional, Union, List, Dict
+from typing import Any, Iterable
from thinc.types import Floats1d, FloatsXd
from . import Language
from .strings import StringStore
from .lexeme import Lexeme
from .lookups import Lookups
+from .morphology import Morphology
from .tokens import Doc, Span
+from .vectors import Vectors
from pathlib import Path
def create_vocab(
- lang: Language, defaults: Any, vectors_name: Optional[str] = ...
+ lang: Optional[str], defaults: Any, vectors_name: Optional[str] = ...
) -> Vocab: ...
class Vocab:
+ cfg: Dict[str, Any]
+ get_noun_chunks: Optional[Callable[[Union[Doc, Span]], Iterator[Span]]]
+ lookups: Lookups
+ morphology: Morphology
+ strings: StringStore
+ vectors: Vectors
+ writing_system: Dict[str, Any]
+
def __init__(
self,
lex_attr_getters: Optional[Dict[str, Callable[[str], Any]]] = ...,
@@ -32,7 +34,7 @@ class Vocab:
get_noun_chunks: Optional[Callable[[Union[Doc, Span]], Iterator[Span]]] = ...,
) -> None: ...
@property
- def lang(self) -> Language: ...
+ def lang(self) -> str: ...
def __len__(self) -> int: ...
def add_flag(
self, flag_getter: Callable[[str], bool], flag_id: int = ...
@@ -54,16 +56,15 @@ class Vocab:
) -> FloatsXd: ...
def set_vector(self, orth: Union[int, str], vector: Floats1d) -> None: ...
def has_vector(self, orth: Union[int, str]) -> bool: ...
- lookups: Lookups
def to_disk(
- self, path: Union[str, Path], *, exclude: Union[List[str], Tuple[str]] = ...
+ self, path: Union[str, Path], *, exclude: Iterable[str] = ...
) -> None: ...
def from_disk(
- self, path: Union[str, Path], *, exclude: Union[List[str], Tuple[str]] = ...
+ self, path: Union[str, Path], *, exclude: Iterable[str] = ...
) -> Vocab: ...
- def to_bytes(self, *, exclude: Union[List[str], Tuple[str]] = ...) -> bytes: ...
+ def to_bytes(self, *, exclude: Iterable[str] = ...) -> bytes: ...
def from_bytes(
- self, bytes_data: bytes, *, exclude: Union[List[str], Tuple[str]] = ...
+ self, bytes_data: bytes, *, exclude: Iterable[str] = ...
) -> Vocab: ...
def pickle_vocab(vocab: Vocab) -> Any: ...
diff --git a/spacy/vocab.pyx b/spacy/vocab.pyx
index 7af780457..b4dfd22f5 100644
--- a/spacy/vocab.pyx
+++ b/spacy/vocab.pyx
@@ -61,7 +61,7 @@ cdef class Vocab:
lookups (Lookups): Container for large lookup tables and dictionaries.
oov_prob (float): Default OOV probability.
vectors_name (unicode): Optional name to identify the vectors table.
- get_noun_chunks (Optional[Callable[[Union[Doc, Span], Iterator[Span]]]]):
+ get_noun_chunks (Optional[Callable[[Union[Doc, Span], Iterator[Tuple[int, int, int]]]]]):
A function that yields base noun phrases used for Doc.noun_chunks.
"""
lex_attr_getters = lex_attr_getters if lex_attr_getters is not None else {}
@@ -450,7 +450,7 @@ cdef class Vocab:
path (unicode or Path): A path to a directory, which will be created if
it doesn't exist.
- exclude (list): String names of serialization fields to exclude.
+ exclude (Iterable[str]): String names of serialization fields to exclude.
DOCS: https://spacy.io/api/vocab#to_disk
"""
@@ -470,7 +470,7 @@ cdef class Vocab:
returns it.
path (unicode or Path): A path to a directory.
- exclude (list): String names of serialization fields to exclude.
+ exclude (Iterable[str]): String names of serialization fields to exclude.
RETURNS (Vocab): The modified `Vocab` object.
DOCS: https://spacy.io/api/vocab#to_disk
@@ -495,7 +495,7 @@ cdef class Vocab:
def to_bytes(self, *, exclude=tuple()):
"""Serialize the current state to a binary string.
- exclude (list): String names of serialization fields to exclude.
+ exclude (Iterable[str]): String names of serialization fields to exclude.
RETURNS (bytes): The serialized form of the `Vocab` object.
DOCS: https://spacy.io/api/vocab#to_bytes
@@ -517,7 +517,7 @@ cdef class Vocab:
"""Load state from a binary string.
bytes_data (bytes): The data to load from.
- exclude (list): String names of serialization fields to exclude.
+ exclude (Iterable[str]): String names of serialization fields to exclude.
RETURNS (Vocab): The `Vocab` object.
DOCS: https://spacy.io/api/vocab#from_bytes
diff --git a/website/docs/api/entityruler.md b/website/docs/api/entityruler.md
index 48c279914..c9c3ec365 100644
--- a/website/docs/api/entityruler.md
+++ b/website/docs/api/entityruler.md
@@ -288,7 +288,7 @@ All labels present in the match patterns.
| ----------- | -------------------------------------- |
| **RETURNS** | The string labels. ~~Tuple[str, ...]~~ |
-## EntityRuler.ent_ids {#labels tag="property" new="2.2.2"}
+## EntityRuler.ent_ids {#ent_ids tag="property" new="2.2.2"}
All entity IDs present in the `id` properties of the match patterns.
diff --git a/website/docs/api/language.md b/website/docs/api/language.md
index 0aa33b281..d0d6b9514 100644
--- a/website/docs/api/language.md
+++ b/website/docs/api/language.md
@@ -1077,9 +1077,9 @@ customize the default language data:
| --------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `stop_words` | List of stop words, used for `Token.is_stop`.
**Example:** [`stop_words.py`](%%GITHUB_SPACY/spacy/lang/en/stop_words.py) ~~Set[str]~~ |
| `tokenizer_exceptions` | Tokenizer exception rules, string mapped to list of token attributes.
**Example:** [`de/tokenizer_exceptions.py`](%%GITHUB_SPACY/spacy/lang/de/tokenizer_exceptions.py) ~~Dict[str, List[dict]]~~ |
-| `prefixes`, `suffixes`, `infixes` | Prefix, suffix and infix rules for the default tokenizer.
**Example:** [`puncutation.py`](%%GITHUB_SPACY/spacy/lang/punctuation.py) ~~Optional[List[Union[str, Pattern]]]~~ |
-| `token_match` | Optional regex for matching strings that should never be split, overriding the infix rules.
**Example:** [`fr/tokenizer_exceptions.py`](%%GITHUB_SPACY/spacy/lang/fr/tokenizer_exceptions.py) ~~Optional[Pattern]~~ |
-| `url_match` | Regular expression for matching URLs. Prefixes and suffixes are removed before applying the match.
**Example:** [`tokenizer_exceptions.py`](%%GITHUB_SPACY/spacy/lang/tokenizer_exceptions.py) ~~Optional[Pattern]~~ |
+| `prefixes`, `suffixes`, `infixes` | Prefix, suffix and infix rules for the default tokenizer.
**Example:** [`puncutation.py`](%%GITHUB_SPACY/spacy/lang/punctuation.py) ~~Optional[Sequence[Union[str, Pattern]]]~~ |
+| `token_match` | Optional regex for matching strings that should never be split, overriding the infix rules.
**Example:** [`fr/tokenizer_exceptions.py`](%%GITHUB_SPACY/spacy/lang/fr/tokenizer_exceptions.py) ~~Optional[Callable]~~ |
+| `url_match` | Regular expression for matching URLs. Prefixes and suffixes are removed before applying the match.
**Example:** [`tokenizer_exceptions.py`](%%GITHUB_SPACY/spacy/lang/tokenizer_exceptions.py) ~~Optional[Callable]~~ |
| `lex_attr_getters` | Custom functions for setting lexical attributes on tokens, e.g. `like_num`.
**Example:** [`lex_attrs.py`](%%GITHUB_SPACY/spacy/lang/en/lex_attrs.py) ~~Dict[int, Callable[[str], Any]]~~ |
| `syntax_iterators` | Functions that compute views of a `Doc` object based on its syntax. At the moment, only used for [noun chunks](/usage/linguistic-features#noun-chunks).
**Example:** [`syntax_iterators.py`](%%GITHUB_SPACY/spacy/lang/en/syntax_iterators.py). ~~Dict[str, Callable[[Union[Doc, Span]], Iterator[Span]]]~~ |
| `writing_system` | Information about the language's writing system, available via `Vocab.writing_system`. Defaults to: `{"direction": "ltr", "has_case": True, "has_letters": True}.`.
**Example:** [`zh/__init__.py`](%%GITHUB_SPACY/spacy/lang/zh/__init__.py) ~~Dict[str, Any]~~ |
diff --git a/website/docs/api/phrasematcher.md b/website/docs/api/phrasematcher.md
index 71ee4b7d1..2cef9ac2a 100644
--- a/website/docs/api/phrasematcher.md
+++ b/website/docs/api/phrasematcher.md
@@ -150,7 +150,7 @@ patterns = [nlp("health care reform"), nlp("healthcare reform")]
| Name | Description |
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| `match_id` | An ID for the thing you're matching. ~~str~~ | |
+| `key` | An ID for the thing you're matching. ~~str~~ |
| `docs` | `Doc` objects of the phrases to match. ~~List[Doc]~~ |
| _keyword-only_ | |
| `on_match` | Callback function to act on matches. Takes the arguments `matcher`, `doc`, `i` and `matches`. ~~Optional[Callable[[Matcher, Doc, int, List[tuple], Any]]~~ |
diff --git a/website/docs/api/spancategorizer.md b/website/docs/api/spancategorizer.md
index d37b2f23a..4edc6fb5b 100644
--- a/website/docs/api/spancategorizer.md
+++ b/website/docs/api/spancategorizer.md
@@ -54,7 +54,7 @@ architectures and their arguments and hyperparameters.
| Setting | Description |
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| `suggester` | A function that [suggests spans](#suggesters). Spans are returned as a ragged array with two integer columns, for the start and end positions. Defaults to [`ngram_suggester`](#ngram_suggester). ~~Callable[List[Doc], Ragged]~~ |
+| `suggester` | A function that [suggests spans](#suggesters). Spans are returned as a ragged array with two integer columns, for the start and end positions. Defaults to [`ngram_suggester`](#ngram_suggester). ~~Callable[[Iterable[Doc], Optional[Ops]], Ragged]~~ |
| `model` | A model instance that is given a a list of documents and `(start, end)` indices representing candidate span offsets. The model predicts a probability for each category for each span. Defaults to [SpanCategorizer](/api/architectures#SpanCategorizer). ~~Model[Tuple[List[Doc], Ragged], Floats2d]~~ |
| `spans_key` | Key of the [`Doc.spans`](/api/doc#spans) dict to save the spans under. During initialization and training, the component will look for spans on the reference document under the same key. Defaults to `"spans"`. ~~str~~ |
| `threshold` | Minimum probability to consider a prediction positive. Spans with a positive prediction will be saved on the Doc. Defaults to `0.5`. ~~float~~ |
@@ -89,7 +89,7 @@ shortcut for this and instantiate the component using its string name and
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `vocab` | The shared vocabulary. ~~Vocab~~ |
| `model` | A model instance that is given a a list of documents and `(start, end)` indices representing candidate span offsets. The model predicts a probability for each category for each span. ~~Model[Tuple[List[Doc], Ragged], Floats2d]~~ |
-| `suggester` | A function that [suggests spans](#suggesters). Spans are returned as a ragged array with two integer columns, for the start and end positions. ~~Callable[List[Doc], Ragged]~~ |
+| `suggester` | A function that [suggests spans](#suggesters). Spans are returned as a ragged array with two integer columns, for the start and end positions. ~~Callable[[Iterable[Doc], Optional[Ops]], Ragged]~~ |
| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
| _keyword-only_ | |
| `spans_key` | Key of the [`Doc.spans`](/api/doc#sans) dict to save the spans under. During initialization and training, the component will look for spans on the reference document under the same key. Defaults to `"spans"`. ~~str~~ |
@@ -251,11 +251,11 @@ predicted scores.
> loss, d_loss = spancat.get_loss(examples, scores)
> ```
-| Name | Description |
-| ----------- | --------------------------------------------------------------------------- |
-| `examples` | The batch of examples. ~~Iterable[Example]~~ |
-| `scores` | Scores representing the model's predictions. |
-| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
+| Name | Description |
+| -------------- | --------------------------------------------------------------------------- |
+| `examples` | The batch of examples. ~~Iterable[Example]~~ |
+| `spans_scores` | Scores representing the model's predictions. ~~Tuple[Ragged, Floats2d]~~ |
+| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
## SpanCategorizer.score {#score tag="method"}
@@ -466,7 +466,7 @@ integers. The array has two columns, indicating the start and end position.
| Name | Description |
| ----------- | -------------------------------------------------------------------------------------------------------------------- |
| `sizes` | The phrase lengths to suggest. For example, `[1, 2]` will suggest phrases consisting of 1 or 2 tokens. ~~List[int]~~ |
-| **CREATES** | The suggester function. ~~Callable[[List[Doc]], Ragged]~~ |
+| **CREATES** | The suggester function. ~~Callable[[Iterable[Doc], Optional[Ops]], Ragged]~~ |
### spacy.ngram_range_suggester.v1 {#ngram_range_suggester}
@@ -483,8 +483,8 @@ Suggest all spans of at least length `min_size` and at most length `max_size`
(both inclusive). Spans are returned as a ragged array of integers. The array
has two columns, indicating the start and end position.
-| Name | Description |
-| ----------- | ------------------------------------------------------------ |
-| `min_size` | The minimal phrase lengths to suggest (inclusive). ~~[int]~~ |
-| `max_size` | The maximal phrase lengths to suggest (exclusive). ~~[int]~~ |
-| **CREATES** | The suggester function. ~~Callable[[List[Doc]], Ragged]~~ |
+| Name | Description |
+| ----------- | ---------------------------------------------------------------------------- |
+| `min_size` | The minimal phrase lengths to suggest (inclusive). ~~[int]~~ |
+| `max_size` | The maximal phrase lengths to suggest (exclusive). ~~[int]~~ |
+| **CREATES** | The suggester function. ~~Callable[[Iterable[Doc], Optional[Ops]], Ragged]~~ |
diff --git a/website/docs/api/vocab.md b/website/docs/api/vocab.md
index c37b27a0e..c0a269d95 100644
--- a/website/docs/api/vocab.md
+++ b/website/docs/api/vocab.md
@@ -21,15 +21,15 @@ Create the vocabulary.
> vocab = Vocab(strings=["hello", "world"])
> ```
-| Name | Description |
-| ------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| `lex_attr_getters` | A dictionary mapping attribute IDs to functions to compute them. Defaults to `None`. ~~Optional[Dict[str, Callable[[str], Any]]]~~ |
-| `strings` | A [`StringStore`](/api/stringstore) that maps strings to hash values, and vice versa, or a list of strings. ~~Union[List[str], StringStore]~~ |
-| `lookups` | A [`Lookups`](/api/lookups) that stores the `lexeme_norm` and other large lookup tables. Defaults to `None`. ~~Optional[Lookups]~~ |
-| `oov_prob` | The default OOV probability. Defaults to `-20.0`. ~~float~~ |
-| `vectors_name` 2.2 | A name to identify the vectors table. ~~str~~ |
-| `writing_system` | A dictionary describing the language's writing system. Typically provided by [`Language.Defaults`](/api/language#defaults). ~~Dict[str, Any]~~ |
-| `get_noun_chunks` | A function that yields base noun phrases used for [`Doc.noun_chunks`](/api/doc#noun_chunks). ~~Optional[Callable[[Union[Doc, Span], Iterator[Span]]]]~~ |
+| Name | Description |
+| ------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `lex_attr_getters` | A dictionary mapping attribute IDs to functions to compute them. Defaults to `None`. ~~Optional[Dict[str, Callable[[str], Any]]]~~ |
+| `strings` | A [`StringStore`](/api/stringstore) that maps strings to hash values, and vice versa, or a list of strings. ~~Union[List[str], StringStore]~~ |
+| `lookups` | A [`Lookups`](/api/lookups) that stores the `lexeme_norm` and other large lookup tables. Defaults to `None`. ~~Optional[Lookups]~~ |
+| `oov_prob` | The default OOV probability. Defaults to `-20.0`. ~~float~~ |
+| `vectors_name` 2.2 | A name to identify the vectors table. ~~str~~ |
+| `writing_system` | A dictionary describing the language's writing system. Typically provided by [`Language.Defaults`](/api/language#defaults). ~~Dict[str, Any]~~ |
+| `get_noun_chunks` | A function that yields base noun phrases used for [`Doc.noun_chunks`](/api/doc#noun_chunks). ~~Optional[Callable[[Union[Doc, Span], Iterator[Tuple[int, int, int]]]]]~~ |
## Vocab.\_\_len\_\_ {#len tag="method"}
@@ -300,14 +300,14 @@ Load state from a binary string.
> assert type(PERSON) == int
> ```
-| Name | Description |
-| ---------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ |
-| `strings` | A table managing the string-to-int mapping. ~~StringStore~~ |
-| `vectors` 2 | A table associating word IDs to word vectors. ~~Vectors~~ |
-| `vectors_length` | Number of dimensions for each word vector. ~~int~~ |
-| `lookups` | The available lookup tables in this vocab. ~~Lookups~~ |
-| `writing_system` 2.1 | A dict with information about the language's writing system. ~~Dict[str, Any]~~ |
-| `get_noun_chunks` 3.0 | A function that yields base noun phrases used for [`Doc.noun_chunks`](/ap/doc#noun_chunks). ~~Optional[Callable[[Union[Doc, Span], Iterator[Span]]]]~~ |
+| Name | Description |
+| ---------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `strings` | A table managing the string-to-int mapping. ~~StringStore~~ |
+| `vectors` 2 | A table associating word IDs to word vectors. ~~Vectors~~ |
+| `vectors_length` | Number of dimensions for each word vector. ~~int~~ |
+| `lookups` | The available lookup tables in this vocab. ~~Lookups~~ |
+| `writing_system` 2.1 | A dict with information about the language's writing system. ~~Dict[str, Any]~~ |
+| `get_noun_chunks` 3.0 | A function that yields base noun phrases used for [`Doc.noun_chunks`](/ap/doc#noun_chunks). ~~Optional[Callable[[Union[Doc, Span], Iterator[Tuple[int, int, int]]]]]~~ |
## Serialization fields {#serialization-fields}
diff --git a/website/docs/usage/layers-architectures.md b/website/docs/usage/layers-architectures.md
index 17043d599..2e23b3684 100644
--- a/website/docs/usage/layers-architectures.md
+++ b/website/docs/usage/layers-architectures.md
@@ -833,7 +833,7 @@ retrieve and add to them.
self.cfg = {"labels": []}
@property
- def labels(self) -> Tuple[str]:
+ def labels(self) -> Tuple[str, ...]:
"""Returns the labels currently added to the component."""
return tuple(self.cfg["labels"])