diff --git a/.github/azure-steps.yml b/.github/azure-steps.yml
index 823509888..cdcd1fa6f 100644
--- a/.github/azure-steps.yml
+++ b/.github/azure-steps.yml
@@ -25,6 +25,9 @@ steps:
${{ parameters.prefix }} python setup.py sdist --formats=gztar
displayName: "Compile and build sdist"
+ - script: python -m mypy spacy
+ displayName: 'Run mypy'
+
- task: DeleteFiles@1
inputs:
contents: "spacy"
@@ -109,3 +112,9 @@ steps:
python .github/validate_universe_json.py website/meta/universe.json
displayName: 'Test website/meta/universe.json'
condition: eq(variables['python_version'], '3.8')
+
+ - script: |
+ ${{ parameters.prefix }} python -m pip install thinc-apple-ops
+ ${{ parameters.prefix }} python -m pytest --pyargs spacy
+ displayName: "Run CPU tests with thinc-apple-ops"
+ condition: and(startsWith(variables['imageName'], 'macos'), eq(variables['python.version'], '3.9'))
diff --git a/.github/contributors/connorbrinton.md b/.github/contributors/connorbrinton.md
new file mode 100644
index 000000000..25d03b494
--- /dev/null
+++ b/.github/contributors/connorbrinton.md
@@ -0,0 +1,106 @@
+# spaCy contributor agreement
+
+This spaCy Contributor Agreement (**"SCA"**) is based on the
+[Oracle Contributor Agreement](http://www.oracle.com/technetwork/oca-405177.pdf).
+The SCA applies to any contribution that you make to any product or project
+managed by us (the **"project"**), and sets out the intellectual property rights
+you grant to us in the contributed materials. The term **"us"** shall mean
+[ExplosionAI GmbH](https://explosion.ai/legal). The term
+**"you"** shall mean the person or entity identified below.
+
+If you agree to be bound by these terms, fill in the information requested
+below and include the filled-in version with your first pull request, under the
+folder [`.github/contributors/`](/.github/contributors/). The name of the file
+should be your GitHub username, with the extension `.md`. For example, the user
+example_user would create the file `.github/contributors/example_user.md`.
+
+Read this agreement carefully before signing. These terms and conditions
+constitute a binding legal agreement.
+
+## Contributor Agreement
+
+1. The term "contribution" or "contributed materials" means any source code,
+object code, patch, tool, sample, graphic, specification, manual,
+documentation, or any other material posted or submitted by you to the project.
+
+2. With respect to any worldwide copyrights, or copyright applications and
+registrations, in your contribution:
+
+ * you hereby assign to us joint ownership, and to the extent that such
+ assignment is or becomes invalid, ineffective or unenforceable, you hereby
+ grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge,
+ royalty-free, unrestricted license to exercise all rights under those
+ copyrights. This includes, at our option, the right to sublicense these same
+ rights to third parties through multiple levels of sublicensees or other
+ licensing arrangements;
+
+ * you agree that each of us can do all things in relation to your
+ contribution as if each of us were the sole owners, and if one of us makes
+ a derivative work of your contribution, the one who makes the derivative
+ work (or has it made will be the sole owner of that derivative work;
+
+ * you agree that you will not assert any moral rights in your contribution
+ against us, our licensees or transferees;
+
+ * you agree that we may register a copyright in your contribution and
+ exercise all ownership rights associated with it; and
+
+ * you agree that neither of us has any duty to consult with, obtain the
+ consent of, pay or render an accounting to the other for any use or
+ distribution of your contribution.
+
+3. With respect to any patents you own, or that you can license without payment
+to any third party, you hereby grant to us a perpetual, irrevocable,
+non-exclusive, worldwide, no-charge, royalty-free license to:
+
+ * make, have made, use, sell, offer to sell, import, and otherwise transfer
+ your contribution in whole or in part, alone or in combination with or
+ included in any product, work or materials arising out of the project to
+ which your contribution was submitted, and
+
+ * at our option, to sublicense these same rights to third parties through
+ multiple levels of sublicensees or other licensing arrangements.
+
+4. Except as set out above, you keep all right, title, and interest in your
+contribution. The rights that you grant to us under these terms are effective
+on the date you first submitted a contribution to us, even if your submission
+took place before the date you sign these terms.
+
+5. You covenant, represent, warrant and agree that:
+
+ * Each contribution that you submit is and shall be an original work of
+ authorship and you can legally grant the rights set out in this SCA;
+
+ * to the best of your knowledge, each contribution will not violate any
+ third party's copyrights, trademarks, patents, or other intellectual
+ property rights; and
+
+ * each contribution shall be in compliance with U.S. export control laws and
+ other applicable export and import laws. You agree to notify us if you
+ become aware of any circumstance which would make any of the foregoing
+ representations inaccurate in any respect. We may publicly disclose your
+ participation in the project, including the fact that you have signed the SCA.
+
+6. This SCA is governed by the laws of the State of California and applicable
+U.S. Federal law. Any choice of law rules will not apply.
+
+7. Please place an “x” on one of the applicable statement below. Please do NOT
+mark both statements:
+
+ * [x] I am signing on behalf of myself as an individual and no other person
+ or entity, including my employer, has or will have rights with respect to my
+ contributions.
+
+ * [ ] I am signing on behalf of my employer or a legal entity and I have the
+ actual authority to contractually bind that entity.
+
+## Contributor Details
+
+| Field | Entry |
+|------------------------------- | -------------------- |
+| Name | Connor Brinton |
+| Company name (if applicable) | |
+| Title or role (if applicable) | |
+| Date | July 20th, 2021 |
+| GitHub username | connorbrinton |
+| Website (optional) | |
diff --git a/.github/lock.yml b/.github/lock.yml
deleted file mode 100644
index 593e88397..000000000
--- a/.github/lock.yml
+++ /dev/null
@@ -1,19 +0,0 @@
-# Configuration for lock-threads - https://github.com/dessant/lock-threads
-
-# Number of days of inactivity before a closed issue or pull request is locked
-daysUntilLock: 30
-
-# Issues and pull requests with these labels will not be locked. Set to `[]` to disable
-exemptLabels: []
-
-# Label to add before locking, such as `outdated`. Set to `false` to disable
-lockLabel: false
-
-# Comment to post before locking. Set to `false` to disable
-lockComment: >
- This thread has been automatically locked since there has not been
- any recent activity after it was closed. Please open a new issue for
- related bugs.
-
-# Limit to only `issues` or `pulls`
-only: issues
diff --git a/.github/workflows/explosionbot.yml b/.github/workflows/explosionbot.yml
index 7d9ee45e9..e29ce8fe8 100644
--- a/.github/workflows/explosionbot.yml
+++ b/.github/workflows/explosionbot.yml
@@ -23,5 +23,5 @@ jobs:
env:
INPUT_TOKEN: ${{ secrets.EXPLOSIONBOT_TOKEN }}
INPUT_BK_TOKEN: ${{ secrets.BUILDKITE_SECRET }}
- ENABLED_COMMANDS: "test_gpu"
- ALLOWED_TEAMS: "spaCy"
\ No newline at end of file
+ ENABLED_COMMANDS: "test_gpu,test_slow"
+ ALLOWED_TEAMS: "spaCy"
diff --git a/.github/workflows/lock.yml b/.github/workflows/lock.yml
new file mode 100644
index 000000000..c9833cdba
--- /dev/null
+++ b/.github/workflows/lock.yml
@@ -0,0 +1,25 @@
+name: 'Lock Threads'
+
+on:
+ schedule:
+ - cron: '0 0 * * *' # check every day
+ workflow_dispatch:
+
+permissions:
+ issues: write
+
+concurrency:
+ group: lock
+
+jobs:
+ action:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: dessant/lock-threads@v3
+ with:
+ process-only: 'issues'
+ issue-inactive-days: '30'
+ issue-comment: >
+ This thread has been automatically locked since there
+ has not been any recent activity after it was closed.
+ Please open a new issue for related bugs.
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 3e2b3927b..a4d321aa3 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -419,7 +419,7 @@ simply click on the "Suggest edits" button at the bottom of a page.
## Publishing spaCy extensions and plugins
We're very excited about all the new possibilities for **community extensions**
-and plugins in spaCy v2.0, and we can't wait to see what you build with it!
+and plugins in spaCy v3.0, and we can't wait to see what you build with it!
- An extension or plugin should add substantial functionality, be
**well-documented** and **open-source**. It should be available for users to download
diff --git a/MANIFEST.in b/MANIFEST.in
index d022223cd..c1524d460 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,5 +1,5 @@
recursive-include include *.h
-recursive-include spacy *.pyx *.pxd *.txt *.cfg *.jinja *.toml
+recursive-include spacy *.pyi *.pyx *.pxd *.txt *.cfg *.jinja *.toml
include LICENSE
include README.md
include pyproject.toml
diff --git a/azure-pipelines.yml b/azure-pipelines.yml
index 245407189..6bf591bee 100644
--- a/azure-pipelines.yml
+++ b/azure-pipelines.yml
@@ -12,15 +12,11 @@ trigger:
- "website/*"
- "*.md"
pr:
- paths:
- include:
- - "*.cfg"
- - "*.py"
- - "*.toml"
- - "*.yml"
- - ".github/azure-steps.yml"
- - "spacy/*"
- - "website/meta/universe.json"
+ paths:
+ exclude:
+ - "*.md"
+ - "website/docs/*"
+ - "website/src/*"
jobs:
# Perform basic checks for most important errors (syntax etc.) Uses the config
diff --git a/pyproject.toml b/pyproject.toml
index 7328cd6c2..cb103de0a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -5,7 +5,7 @@ requires = [
"cymem>=2.0.2,<2.1.0",
"preshed>=3.0.2,<3.1.0",
"murmurhash>=0.28.0,<1.1.0",
- "thinc>=8.0.10,<8.1.0",
+ "thinc>=8.0.11,<8.1.0",
"blis>=0.4.0,<0.8.0",
"pathy",
"numpy>=1.15.0",
diff --git a/requirements.txt b/requirements.txt
index 85de453b7..c800e5ea9 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,7 +3,7 @@ spacy-legacy>=3.0.8,<3.1.0
spacy-loggers>=1.0.0,<2.0.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
-thinc>=8.0.10,<8.1.0
+thinc>=8.0.11,<8.1.0
blis>=0.4.0,<0.8.0
ml_datasets>=0.2.0,<0.3.0
murmurhash>=0.28.0,<1.1.0
@@ -18,6 +18,7 @@ requests>=2.13.0,<3.0.0
tqdm>=4.38.0,<5.0.0
pydantic>=1.7.4,!=1.8,!=1.8.1,<1.9.0
jinja2
+langcodes>=3.2.0,<4.0.0
# Official Python utilities
setuptools
packaging>=20.0
@@ -30,4 +31,7 @@ pytest-timeout>=1.3.0,<2.0.0
mock>=2.0.0,<3.0.0
flake8>=3.8.0,<3.10.0
hypothesis>=3.27.0,<7.0.0
-langcodes>=3.2.0,<4.0.0
+mypy>=0.910
+types-dataclasses>=0.1.3; python_version < "3.7"
+types-mock>=0.1.1
+types-requests
diff --git a/setup.cfg b/setup.cfg
index e3a9af5c1..d007fb160 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -37,7 +37,7 @@ setup_requires =
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
murmurhash>=0.28.0,<1.1.0
- thinc>=8.0.10,<8.1.0
+ thinc>=8.0.11,<8.1.0
install_requires =
# Our libraries
spacy-legacy>=3.0.8,<3.1.0
@@ -45,7 +45,7 @@ install_requires =
murmurhash>=0.28.0,<1.1.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
- thinc>=8.0.10,<8.1.0
+ thinc>=8.0.11,<8.1.0
blis>=0.4.0,<0.8.0
wasabi>=0.8.1,<1.1.0
srsly>=2.4.1,<3.0.0
@@ -72,7 +72,7 @@ console_scripts =
lookups =
spacy_lookups_data>=1.0.3,<1.1.0
transformers =
- spacy_transformers>=1.0.1,<1.1.0
+ spacy_transformers>=1.0.1,<1.2.0
ray =
spacy_ray>=0.1.0,<1.0.0
cuda =
@@ -131,3 +131,4 @@ markers =
ignore_missing_imports = True
no_implicit_optional = True
plugins = pydantic.mypy, thinc.mypy
+allow_redefinition = True
diff --git a/spacy/cli/_util.py b/spacy/cli/_util.py
index 127bba55a..fb680d888 100644
--- a/spacy/cli/_util.py
+++ b/spacy/cli/_util.py
@@ -1,4 +1,5 @@
-from typing import Dict, Any, Union, List, Optional, Tuple, Iterable, TYPE_CHECKING
+from typing import Dict, Any, Union, List, Optional, Tuple, Iterable
+from typing import TYPE_CHECKING, overload
import sys
import shutil
from pathlib import Path
@@ -15,6 +16,7 @@ from thinc.util import has_cupy, gpu_is_available
from configparser import InterpolationError
import os
+from ..compat import Literal
from ..schemas import ProjectConfigSchema, validate
from ..util import import_file, run_command, make_tempdir, registry, logger
from ..util import is_compatible_version, SimpleFrozenDict, ENV_VARS
@@ -260,15 +262,16 @@ def get_checksum(path: Union[Path, str]) -> str:
RETURNS (str): The checksum.
"""
path = Path(path)
+ if not (path.is_file() or path.is_dir()):
+ msg.fail(f"Can't get checksum for {path}: not a file or directory", exits=1)
if path.is_file():
return hashlib.md5(Path(path).read_bytes()).hexdigest()
- if path.is_dir():
+ else:
# TODO: this is currently pretty slow
dir_checksum = hashlib.md5()
for sub_file in sorted(fp for fp in path.rglob("*") if fp.is_file()):
dir_checksum.update(sub_file.read_bytes())
return dir_checksum.hexdigest()
- msg.fail(f"Can't get checksum for {path}: not a file or directory", exits=1)
@contextmanager
@@ -468,12 +471,15 @@ def get_git_version(
RETURNS (Tuple[int, int]): The version as a (major, minor) tuple. Returns
(0, 0) if the version couldn't be determined.
"""
- ret = run_command("git --version", capture=True)
+ try:
+ ret = run_command("git --version", capture=True)
+ except:
+ raise RuntimeError(error)
stdout = ret.stdout.strip()
if not stdout or not stdout.startswith("git version"):
- return (0, 0)
+ return 0, 0
version = stdout[11:].strip().split(".")
- return (int(version[0]), int(version[1]))
+ return int(version[0]), int(version[1])
def _http_to_git(repo: str) -> str:
@@ -500,6 +506,16 @@ def is_subpath_of(parent, child):
return os.path.commonpath([parent_realpath, child_realpath]) == parent_realpath
+@overload
+def string_to_list(value: str, intify: Literal[False] = ...) -> List[str]:
+ ...
+
+
+@overload
+def string_to_list(value: str, intify: Literal[True]) -> List[int]:
+ ...
+
+
def string_to_list(value: str, intify: bool = False) -> Union[List[str], List[int]]:
"""Parse a comma-separated string to a list and account for various
formatting options. Mostly used to handle CLI arguments that take a list of
@@ -510,7 +526,7 @@ def string_to_list(value: str, intify: bool = False) -> Union[List[str], List[in
RETURNS (Union[List[str], List[int]]): A list of strings or ints.
"""
if not value:
- return []
+ return [] # type: ignore[return-value]
if value.startswith("[") and value.endswith("]"):
value = value[1:-1]
result = []
@@ -522,7 +538,7 @@ def string_to_list(value: str, intify: bool = False) -> Union[List[str], List[in
p = p[1:-1]
p = p.strip()
if intify:
- p = int(p)
+ p = int(p) # type: ignore[assignment]
result.append(p)
return result
diff --git a/spacy/cli/convert.py b/spacy/cli/convert.py
index c84aa6431..04eb7078f 100644
--- a/spacy/cli/convert.py
+++ b/spacy/cli/convert.py
@@ -1,4 +1,4 @@
-from typing import Optional, Any, List, Union
+from typing import Callable, Iterable, Mapping, Optional, Any, List, Union
from enum import Enum
from pathlib import Path
from wasabi import Printer
@@ -9,7 +9,7 @@ import itertools
from ._util import app, Arg, Opt
from ..training import docs_to_json
-from ..tokens import DocBin
+from ..tokens import Doc, DocBin
from ..training.converters import iob_to_docs, conll_ner_to_docs, json_to_docs
from ..training.converters import conllu_to_docs
@@ -19,7 +19,7 @@ from ..training.converters import conllu_to_docs
# entry to this dict with the file extension mapped to the converter function
# imported from /converters.
-CONVERTERS = {
+CONVERTERS: Mapping[str, Callable[..., Iterable[Doc]]] = {
"conllubio": conllu_to_docs,
"conllu": conllu_to_docs,
"conll": conll_ner_to_docs,
@@ -66,19 +66,16 @@ def convert_cli(
DOCS: https://spacy.io/api/cli#convert
"""
- if isinstance(file_type, FileTypes):
- # We get an instance of the FileTypes from the CLI so we need its string value
- file_type = file_type.value
input_path = Path(input_path)
- output_dir = "-" if output_dir == Path("-") else output_dir
+ output_dir: Union[str, Path] = "-" if output_dir == Path("-") else output_dir
silent = output_dir == "-"
msg = Printer(no_print=silent)
- verify_cli_args(msg, input_path, output_dir, file_type, converter, ner_map)
+ verify_cli_args(msg, input_path, output_dir, file_type.value, converter, ner_map)
converter = _get_converter(msg, converter, input_path)
convert(
input_path,
output_dir,
- file_type=file_type,
+ file_type=file_type.value,
n_sents=n_sents,
seg_sents=seg_sents,
model=model,
@@ -94,7 +91,7 @@ def convert_cli(
def convert(
- input_path: Union[str, Path],
+ input_path: Path,
output_dir: Union[str, Path],
*,
file_type: str = "json",
@@ -108,13 +105,14 @@ def convert(
lang: Optional[str] = None,
concatenate: bool = False,
silent: bool = True,
- msg: Optional[Printer],
+ msg: Optional[Printer] = None,
) -> None:
+ input_path = Path(input_path)
if not msg:
msg = Printer(no_print=silent)
ner_map = srsly.read_json(ner_map) if ner_map is not None else None
doc_files = []
- for input_loc in walk_directory(Path(input_path), converter):
+ for input_loc in walk_directory(input_path, converter):
with input_loc.open("r", encoding="utf-8") as infile:
input_data = infile.read()
# Use converter function to convert data
@@ -141,7 +139,7 @@ def convert(
else:
db = DocBin(docs=docs, store_user_data=True)
len_docs = len(db)
- data = db.to_bytes()
+ data = db.to_bytes() # type: ignore[assignment]
if output_dir == "-":
_print_docs_to_stdout(data, file_type)
else:
@@ -220,13 +218,12 @@ def walk_directory(path: Path, converter: str) -> List[Path]:
def verify_cli_args(
msg: Printer,
- input_path: Union[str, Path],
+ input_path: Path,
output_dir: Union[str, Path],
- file_type: FileTypes,
+ file_type: str,
converter: str,
ner_map: Optional[Path],
):
- input_path = Path(input_path)
if file_type not in FILE_TYPES_STDOUT and output_dir == "-":
msg.fail(
f"Can't write .{file_type} data to stdout. Please specify an output directory.",
@@ -244,13 +241,13 @@ def verify_cli_args(
msg.fail("No input files in directory", input_path, exits=1)
file_types = list(set([loc.suffix[1:] for loc in input_locs]))
if converter == "auto" and len(file_types) >= 2:
- file_types = ",".join(file_types)
- msg.fail("All input files must be same type", file_types, exits=1)
+ file_types_str = ",".join(file_types)
+ msg.fail("All input files must be same type", file_types_str, exits=1)
if converter != "auto" and converter not in CONVERTERS:
msg.fail(f"Can't find converter for {converter}", exits=1)
-def _get_converter(msg, converter, input_path):
+def _get_converter(msg, converter, input_path: Path):
if input_path.is_dir():
input_path = walk_directory(input_path, converter)[0]
if converter == "auto":
diff --git a/spacy/cli/debug_data.py b/spacy/cli/debug_data.py
index 3f368f57d..3143e2c62 100644
--- a/spacy/cli/debug_data.py
+++ b/spacy/cli/debug_data.py
@@ -1,4 +1,5 @@
-from typing import List, Sequence, Dict, Any, Tuple, Optional, Set
+from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union
+from typing import cast, overload
from pathlib import Path
from collections import Counter
import sys
@@ -17,6 +18,7 @@ from ..pipeline import Morphologizer
from ..morphology import Morphology
from ..language import Language
from ..util import registry, resolve_dot_names
+from ..compat import Literal
from .. import util
@@ -201,7 +203,6 @@ def debug_data(
has_low_data_warning = False
has_no_neg_warning = False
has_ws_ents_error = False
- has_punct_ents_warning = False
msg.divider("Named Entity Recognition")
msg.info(f"{len(model_labels)} label(s)")
@@ -228,10 +229,6 @@ def debug_data(
msg.fail(f"{gold_train_data['ws_ents']} invalid whitespace entity spans")
has_ws_ents_error = True
- if gold_train_data["punct_ents"]:
- msg.warn(f"{gold_train_data['punct_ents']} entity span(s) with punctuation")
- has_punct_ents_warning = True
-
for label in labels:
if label_counts[label] <= NEW_LABEL_THRESHOLD:
msg.warn(
@@ -251,8 +248,6 @@ def debug_data(
msg.good("Examples without occurrences available for all labels")
if not has_ws_ents_error:
msg.good("No entities consisting of or starting/ending with whitespace")
- if not has_punct_ents_warning:
- msg.good("No entities consisting of or starting/ending with punctuation")
if has_low_data_warning:
msg.text(
@@ -268,15 +263,9 @@ def debug_data(
show=verbose,
)
if has_ws_ents_error:
- msg.text(
- "As of spaCy v2.1.0, entity spans consisting of or starting/ending "
- "with whitespace characters are considered invalid."
- )
-
- if has_punct_ents_warning:
msg.text(
"Entity spans consisting of or starting/ending "
- "with punctuation can not be trained with a noise level > 0."
+ "with whitespace characters are considered invalid."
)
if "textcat" in factory_names:
@@ -378,10 +367,11 @@ def debug_data(
if "tagger" in factory_names:
msg.divider("Part-of-speech Tagging")
- labels = [label for label in gold_train_data["tags"]]
+ label_list = [label for label in gold_train_data["tags"]]
model_labels = _get_labels_from_model(nlp, "tagger")
- msg.info(f"{len(labels)} label(s) in train data")
- missing_labels = model_labels - set(labels)
+ msg.info(f"{len(label_list)} label(s) in train data")
+ labels = set(label_list)
+ missing_labels = model_labels - labels
if missing_labels:
msg.warn(
"Some model labels are not present in the train data. The "
@@ -395,10 +385,11 @@ def debug_data(
if "morphologizer" in factory_names:
msg.divider("Morphologizer (POS+Morph)")
- labels = [label for label in gold_train_data["morphs"]]
+ label_list = [label for label in gold_train_data["morphs"]]
model_labels = _get_labels_from_model(nlp, "morphologizer")
- msg.info(f"{len(labels)} label(s) in train data")
- missing_labels = model_labels - set(labels)
+ msg.info(f"{len(label_list)} label(s) in train data")
+ labels = set(label_list)
+ missing_labels = model_labels - labels
if missing_labels:
msg.warn(
"Some model labels are not present in the train data. The "
@@ -565,7 +556,7 @@ def _compile_gold(
nlp: Language,
make_proj: bool,
) -> Dict[str, Any]:
- data = {
+ data: Dict[str, Any] = {
"ner": Counter(),
"cats": Counter(),
"tags": Counter(),
@@ -574,7 +565,6 @@ def _compile_gold(
"words": Counter(),
"roots": Counter(),
"ws_ents": 0,
- "punct_ents": 0,
"n_words": 0,
"n_misaligned_words": 0,
"words_missing_vectors": Counter(),
@@ -609,16 +599,6 @@ def _compile_gold(
if label.startswith(("B-", "U-", "L-")) and doc[i].is_space:
# "Illegal" whitespace entity
data["ws_ents"] += 1
- if label.startswith(("B-", "U-", "L-")) and doc[i].text in [
- ".",
- "'",
- "!",
- "?",
- ",",
- ]:
- # punctuation entity: could be replaced by whitespace when training with noise,
- # so add a warning to alert the user to this unexpected side effect.
- data["punct_ents"] += 1
if label.startswith(("B-", "U-")):
combined_label = label.split("-")[1]
data["ner"][combined_label] += 1
@@ -670,10 +650,28 @@ def _compile_gold(
return data
-def _format_labels(labels: List[Tuple[str, int]], counts: bool = False) -> str:
+@overload
+def _format_labels(labels: Iterable[str], counts: Literal[False] = False) -> str:
+ ...
+
+
+@overload
+def _format_labels(
+ labels: Iterable[Tuple[str, int]],
+ counts: Literal[True],
+) -> str:
+ ...
+
+
+def _format_labels(
+ labels: Union[Iterable[str], Iterable[Tuple[str, int]]],
+ counts: bool = False,
+) -> str:
if counts:
- return ", ".join([f"'{l}' ({c})" for l, c in labels])
- return ", ".join([f"'{l}'" for l in labels])
+ return ", ".join(
+ [f"'{l}' ({c})" for l, c in cast(Iterable[Tuple[str, int]], labels)]
+ )
+ return ", ".join([f"'{l}'" for l in cast(Iterable[str], labels)])
def _get_examples_without_label(data: Sequence[Example], label: str) -> int:
diff --git a/spacy/cli/evaluate.py b/spacy/cli/evaluate.py
index 378911a20..0d08d2c5e 100644
--- a/spacy/cli/evaluate.py
+++ b/spacy/cli/evaluate.py
@@ -136,7 +136,7 @@ def evaluate(
def handle_scores_per_type(
- scores: Union[Scorer, Dict[str, Any]],
+ scores: Dict[str, Any],
data: Dict[str, Any] = {},
*,
spans_key: str = "sc",
diff --git a/spacy/cli/info.py b/spacy/cli/info.py
index 8cc7018ff..e6a1cb616 100644
--- a/spacy/cli/info.py
+++ b/spacy/cli/info.py
@@ -15,7 +15,7 @@ def info_cli(
model: Optional[str] = Arg(None, help="Optional loadable spaCy pipeline"),
markdown: bool = Opt(False, "--markdown", "-md", help="Generate Markdown for GitHub issues"),
silent: bool = Opt(False, "--silent", "-s", "-S", help="Don't print anything (just return)"),
- exclude: Optional[str] = Opt("labels", "--exclude", "-e", help="Comma-separated keys to exclude from the print-out"),
+ exclude: str = Opt("labels", "--exclude", "-e", help="Comma-separated keys to exclude from the print-out"),
# fmt: on
):
"""
@@ -61,7 +61,7 @@ def info(
return raw_data
-def info_spacy() -> Dict[str, any]:
+def info_spacy() -> Dict[str, Any]:
"""Generate info about the current spaCy intallation.
RETURNS (dict): The spaCy info.
diff --git a/spacy/cli/init_config.py b/spacy/cli/init_config.py
index 55622452b..530b38eb3 100644
--- a/spacy/cli/init_config.py
+++ b/spacy/cli/init_config.py
@@ -28,8 +28,8 @@ class Optimizations(str, Enum):
def init_config_cli(
# fmt: off
output_file: Path = Arg(..., help="File to save config.cfg to or - for stdout (will only output config and no additional logging info)", allow_dash=True),
- lang: Optional[str] = Opt("en", "--lang", "-l", help="Two-letter code of the language to use"),
- pipeline: Optional[str] = Opt("tagger,parser,ner", "--pipeline", "-p", help="Comma-separated names of trainable pipeline components to include (without 'tok2vec' or 'transformer')"),
+ lang: str = Opt("en", "--lang", "-l", help="Two-letter code of the language to use"),
+ pipeline: str = Opt("tagger,parser,ner", "--pipeline", "-p", help="Comma-separated names of trainable pipeline components to include (without 'tok2vec' or 'transformer')"),
optimize: Optimizations = Opt(Optimizations.efficiency.value, "--optimize", "-o", help="Whether to optimize for efficiency (faster inference, smaller model, lower memory consumption) or higher accuracy (potentially larger and slower model). This will impact the choice of architecture, pretrained weights and related hyperparameters."),
gpu: bool = Opt(False, "--gpu", "-G", help="Whether the model can run on GPU. This will impact the choice of architecture, pretrained weights and related hyperparameters."),
pretraining: bool = Opt(False, "--pretraining", "-pt", help="Include config for pretraining (with 'spacy pretrain')"),
@@ -44,8 +44,6 @@ def init_config_cli(
DOCS: https://spacy.io/api/cli#init-config
"""
- if isinstance(optimize, Optimizations): # instance of enum from the CLI
- optimize = optimize.value
pipeline = string_to_list(pipeline)
is_stdout = str(output_file) == "-"
if not is_stdout and output_file.exists() and not force_overwrite:
@@ -57,7 +55,7 @@ def init_config_cli(
config = init_config(
lang=lang,
pipeline=pipeline,
- optimize=optimize,
+ optimize=optimize.value,
gpu=gpu,
pretraining=pretraining,
silent=is_stdout,
@@ -175,8 +173,8 @@ def init_config(
"Pipeline": ", ".join(pipeline),
"Optimize for": optimize,
"Hardware": variables["hardware"].upper(),
- "Transformer": template_vars.transformer.get("name")
- if template_vars.use_transformer
+ "Transformer": template_vars.transformer.get("name") # type: ignore[attr-defined]
+ if template_vars.use_transformer # type: ignore[attr-defined]
else None,
}
msg.info("Generated config template specific for your use case")
diff --git a/spacy/cli/package.py b/spacy/cli/package.py
index 332a51bc7..e76343dc3 100644
--- a/spacy/cli/package.py
+++ b/spacy/cli/package.py
@@ -1,4 +1,4 @@
-from typing import Optional, Union, Any, Dict, List, Tuple
+from typing import Optional, Union, Any, Dict, List, Tuple, cast
import shutil
from pathlib import Path
from wasabi import Printer, MarkdownRenderer, get_raw_input
@@ -215,9 +215,9 @@ def get_third_party_dependencies(
for reg_name, func_names in funcs.items():
for func_name in func_names:
func_info = util.registry.find(reg_name, func_name)
- module_name = func_info.get("module")
+ module_name = func_info.get("module") # type: ignore[attr-defined]
if module_name: # the code is part of a module, not a --code file
- modules.add(func_info["module"].split(".")[0])
+ modules.add(func_info["module"].split(".")[0]) # type: ignore[index]
dependencies = []
for module_name in modules:
if module_name in distributions:
@@ -227,7 +227,7 @@ def get_third_party_dependencies(
if pkg in own_packages or pkg in exclude:
continue
version = util.get_package_version(pkg)
- version_range = util.get_minor_version_range(version)
+ version_range = util.get_minor_version_range(version) # type: ignore[arg-type]
dependencies.append(f"{pkg}{version_range}")
return dependencies
@@ -252,7 +252,7 @@ def create_file(file_path: Path, contents: str) -> None:
def get_meta(
model_path: Union[str, Path], existing_meta: Dict[str, Any]
) -> Dict[str, Any]:
- meta = {
+ meta: Dict[str, Any] = {
"lang": "en",
"name": "pipeline",
"version": "0.0.0",
@@ -324,8 +324,8 @@ def generate_readme(meta: Dict[str, Any]) -> str:
license_name = meta.get("license")
sources = _format_sources(meta.get("sources"))
description = meta.get("description")
- label_scheme = _format_label_scheme(meta.get("labels"))
- accuracy = _format_accuracy(meta.get("performance"))
+ label_scheme = _format_label_scheme(cast(Dict[str, Any], meta.get("labels")))
+ accuracy = _format_accuracy(cast(Dict[str, Any], meta.get("performance")))
table_data = [
(md.bold("Name"), md.code(name)),
(md.bold("Version"), md.code(version)),
diff --git a/spacy/cli/profile.py b/spacy/cli/profile.py
index f4f0d3caf..3c282c73d 100644
--- a/spacy/cli/profile.py
+++ b/spacy/cli/profile.py
@@ -32,7 +32,7 @@ def profile_cli(
DOCS: https://spacy.io/api/cli#debug-profile
"""
- if ctx.parent.command.name == NAME: # called as top-level command
+ if ctx.parent.command.name == NAME: # type: ignore[union-attr] # called as top-level command
msg.warn(
"The profile command is now available via the 'debug profile' "
"subcommand. You can run python -m spacy debug --help for an "
@@ -42,9 +42,9 @@ def profile_cli(
def profile(model: str, inputs: Optional[Path] = None, n_texts: int = 10000) -> None:
-
if inputs is not None:
- inputs = _read_inputs(inputs, msg)
+ texts = _read_inputs(inputs, msg)
+ texts = list(itertools.islice(texts, n_texts))
if inputs is None:
try:
import ml_datasets
@@ -56,16 +56,13 @@ def profile(model: str, inputs: Optional[Path] = None, n_texts: int = 10000) ->
exits=1,
)
- n_inputs = 25000
- with msg.loading("Loading IMDB dataset via Thinc..."):
- imdb_train, _ = ml_datasets.imdb()
- inputs, _ = zip(*imdb_train)
- msg.info(f"Loaded IMDB dataset and using {n_inputs} examples")
- inputs = inputs[:n_inputs]
+ with msg.loading("Loading IMDB dataset via ml_datasets..."):
+ imdb_train, _ = ml_datasets.imdb(train_limit=n_texts, dev_limit=0)
+ texts, _ = zip(*imdb_train)
+ msg.info(f"Loaded IMDB dataset and using {n_texts} examples")
with msg.loading(f"Loading pipeline '{model}'..."):
nlp = load_model(model)
msg.good(f"Loaded pipeline '{model}'")
- texts = list(itertools.islice(inputs, n_texts))
cProfile.runctx("parse_texts(nlp, texts)", globals(), locals(), "Profile.prof")
s = pstats.Stats("Profile.prof")
msg.divider("Profile stats")
@@ -87,7 +84,7 @@ def _read_inputs(loc: Union[Path, str], msg: Printer) -> Iterator[str]:
if not input_path.exists() or not input_path.is_file():
msg.fail("Not a valid input data file", loc, exits=1)
msg.info(f"Using data from {input_path.parts[-1]}")
- file_ = input_path.open()
+ file_ = input_path.open() # type: ignore[assignment]
for line in file_:
data = srsly.json_loads(line)
text = data["text"]
diff --git a/spacy/cli/project/assets.py b/spacy/cli/project/assets.py
index efc93efab..b5057e401 100644
--- a/spacy/cli/project/assets.py
+++ b/spacy/cli/project/assets.py
@@ -133,7 +133,6 @@ def fetch_asset(
# If there's already a file, check for checksum
if checksum == get_checksum(dest_path):
msg.good(f"Skipping download with matching checksum: {dest}")
- return dest_path
# We might as well support the user here and create parent directories in
# case the asset dir isn't listed as a dir to create in the project.yml
if not dest_path.parent.exists():
@@ -150,7 +149,6 @@ def fetch_asset(
msg.good(f"Copied local asset {dest}")
else:
msg.fail(f"Download failed: {dest}", e)
- return
if checksum and checksum != get_checksum(dest_path):
msg.fail(f"Checksum doesn't match value defined in {PROJECT_FILE}: {dest}")
diff --git a/spacy/cli/project/clone.py b/spacy/cli/project/clone.py
index 72d4004f8..360ee3428 100644
--- a/spacy/cli/project/clone.py
+++ b/spacy/cli/project/clone.py
@@ -80,9 +80,9 @@ def check_clone(name: str, dest: Path, repo: str) -> None:
repo (str): URL of the repo to clone from.
"""
git_err = (
- f"Cloning spaCy project templates requires Git and the 'git' command. ",
+ f"Cloning spaCy project templates requires Git and the 'git' command. "
f"To clone a project without Git, copy the files from the '{name}' "
- f"directory in the {repo} to {dest} manually.",
+ f"directory in the {repo} to {dest} manually."
)
get_git_version(error=git_err)
if not dest:
diff --git a/spacy/cli/project/dvc.py b/spacy/cli/project/dvc.py
index 7e37712c3..83dc5efbf 100644
--- a/spacy/cli/project/dvc.py
+++ b/spacy/cli/project/dvc.py
@@ -143,8 +143,8 @@ def run_dvc_commands(
easier to pass flags like --quiet that depend on a variable or
command-line setting while avoiding lots of nested conditionals.
"""
- for command in commands:
- command = split_command(command)
+ for c in commands:
+ command = split_command(c)
dvc_command = ["dvc", *command]
# Add the flags if they are set to True
for flag, is_active in flags.items():
diff --git a/spacy/cli/project/remote_storage.py b/spacy/cli/project/remote_storage.py
index 6056458e2..336a4bcb3 100644
--- a/spacy/cli/project/remote_storage.py
+++ b/spacy/cli/project/remote_storage.py
@@ -41,7 +41,7 @@ class RemoteStorage:
raise IOError(f"Cannot push {loc}: does not exist.")
url = self.make_url(path, command_hash, content_hash)
if url.exists():
- return None
+ return url
tmp: Path
with make_tempdir() as tmp:
tar_loc = tmp / self.encode_name(str(path))
@@ -131,8 +131,10 @@ def get_command_hash(
currently installed packages, whatever environment variables have been marked
as relevant, and the command.
"""
- check_commit = check_bool_env_var(ENV_VARS.PROJECT_USE_GIT_VERSION)
- spacy_v = GIT_VERSION if check_commit else get_minor_version(about.__version__)
+ if check_bool_env_var(ENV_VARS.PROJECT_USE_GIT_VERSION):
+ spacy_v = GIT_VERSION
+ else:
+ spacy_v = str(get_minor_version(about.__version__) or "")
dep_checksums = [get_checksum(dep) for dep in sorted(deps)]
hashes = [spacy_v, site_hash, env_hash] + dep_checksums
hashes.extend(cmd)
diff --git a/spacy/cli/project/run.py b/spacy/cli/project/run.py
index 74c8b24b6..734803bc4 100644
--- a/spacy/cli/project/run.py
+++ b/spacy/cli/project/run.py
@@ -70,7 +70,7 @@ def project_run(
config = load_project_config(project_dir, overrides=overrides)
commands = {cmd["name"]: cmd for cmd in config.get("commands", [])}
workflows = config.get("workflows", {})
- validate_subcommand(commands.keys(), workflows.keys(), subcommand)
+ validate_subcommand(list(commands.keys()), list(workflows.keys()), subcommand)
if subcommand in workflows:
msg.info(f"Running workflow '{subcommand}'")
for cmd in workflows[subcommand]:
@@ -116,7 +116,7 @@ def print_run_help(project_dir: Path, subcommand: Optional[str] = None) -> None:
workflows = config.get("workflows", {})
project_loc = "" if is_cwd(project_dir) else project_dir
if subcommand:
- validate_subcommand(commands.keys(), workflows.keys(), subcommand)
+ validate_subcommand(list(commands.keys()), list(workflows.keys()), subcommand)
print(f"Usage: {COMMAND} project run {subcommand} {project_loc}")
if subcommand in commands:
help_text = commands[subcommand].get("help")
@@ -164,8 +164,8 @@ def run_commands(
when you want to turn over execution to the command, and capture=True
when you want to run the command more like a function.
"""
- for command in commands:
- command = split_command(command)
+ for c in commands:
+ command = split_command(c)
# Not sure if this is needed or a good idea. Motivation: users may often
# use commands in their config that reference "python" and we want to
# make sure that it's always executing the same Python that spaCy is
@@ -294,7 +294,7 @@ def get_lock_entry(project_dir: Path, command: Dict[str, Any]) -> Dict[str, Any]
}
-def get_fileinfo(project_dir: Path, paths: List[str]) -> List[Dict[str, str]]:
+def get_fileinfo(project_dir: Path, paths: List[str]) -> List[Dict[str, Optional[str]]]:
"""Generate the file information for a list of paths (dependencies, outputs).
Includes the file path and the file's checksum.
diff --git a/spacy/cli/templates/quickstart_training.jinja b/spacy/cli/templates/quickstart_training.jinja
index 339fb1e96..50dbc6e42 100644
--- a/spacy/cli/templates/quickstart_training.jinja
+++ b/spacy/cli/templates/quickstart_training.jinja
@@ -16,7 +16,8 @@ gpu_allocator = null
[nlp]
lang = "{{ lang }}"
-{%- if "tagger" in components or "morphologizer" in components or "parser" in components or "ner" in components or "entity_linker" in components or (("textcat" in components or "textcat_multilabel" in components) and optimize == "accuracy") -%}
+{%- set no_tok2vec = components|length == 1 and (("textcat" in components or "textcat_multilabel" in components) and optimize == "efficiency")-%}
+{%- if not no_tok2vec and ("tagger" in components or "morphologizer" in components or "parser" in components or "ner" in components or "entity_linker" in components or "textcat" in components or "textcat_multilabel" in components) -%}
{%- set full_pipeline = ["transformer" if use_transformer else "tok2vec"] + components %}
{%- else -%}
{%- set full_pipeline = components %}
@@ -32,7 +33,7 @@ batch_size = {{ 128 if hardware == "gpu" else 1000 }}
factory = "transformer"
[components.transformer.model]
-@architectures = "spacy-transformers.TransformerModel.v1"
+@architectures = "spacy-transformers.TransformerModel.v3"
name = "{{ transformer["name"] }}"
tokenizer_config = {"use_fast": true}
@@ -198,7 +199,7 @@ no_output_layer = false
{# NON-TRANSFORMER PIPELINE #}
{% else -%}
-
+{% if not no_tok2vec-%}
[components.tok2vec]
factory = "tok2vec"
@@ -223,6 +224,7 @@ width = {{ 96 if optimize == "efficiency" else 256 }}
depth = {{ 4 if optimize == "efficiency" else 8 }}
window_size = 1
maxout_pieces = 3
+{% endif -%}
{% if "morphologizer" in components %}
[components.morphologizer]
diff --git a/spacy/cli/validate.py b/spacy/cli/validate.py
index a727e380e..a918e9a39 100644
--- a/spacy/cli/validate.py
+++ b/spacy/cli/validate.py
@@ -99,7 +99,7 @@ def get_model_pkgs(silent: bool = False) -> Tuple[dict, dict]:
warnings.filterwarnings("ignore", message="\\[W09[45]")
model_meta = get_model_meta(model_path)
spacy_version = model_meta.get("spacy_version", "n/a")
- is_compat = is_compatible_version(about.__version__, spacy_version)
+ is_compat = is_compatible_version(about.__version__, spacy_version) # type: ignore[assignment]
pkgs[pkg_name] = {
"name": package,
"version": version,
diff --git a/spacy/compat.py b/spacy/compat.py
index 92ed23c0e..89132735d 100644
--- a/spacy/compat.py
+++ b/spacy/compat.py
@@ -5,12 +5,12 @@ from thinc.util import copy_array
try:
import cPickle as pickle
except ImportError:
- import pickle
+ import pickle # type: ignore[no-redef]
try:
import copy_reg
except ImportError:
- import copyreg as copy_reg
+ import copyreg as copy_reg # type: ignore[no-redef]
try:
from cupy.cuda.stream import Stream as CudaStream
@@ -22,10 +22,10 @@ try:
except ImportError:
cupy = None
-try: # Python 3.8+
- from typing import Literal
-except ImportError:
- from typing_extensions import Literal # noqa: F401
+if sys.version_info[:2] >= (3, 8): # Python 3.8+
+ from typing import Literal, Protocol, runtime_checkable
+else:
+ from typing_extensions import Literal, Protocol, runtime_checkable # noqa: F401
# Important note: The importlib_metadata "backport" includes functionality
# that's not part of the built-in importlib.metadata. We should treat this
@@ -33,7 +33,7 @@ except ImportError:
try: # Python 3.8+
import importlib.metadata as importlib_metadata
except ImportError:
- from catalogue import _importlib_metadata as importlib_metadata # noqa: F401
+ from catalogue import _importlib_metadata as importlib_metadata # type: ignore[no-redef] # noqa: F401
from thinc.api import Optimizer # noqa: F401
diff --git a/spacy/displacy/__init__.py b/spacy/displacy/__init__.py
index 78b83f2e5..d9418f675 100644
--- a/spacy/displacy/__init__.py
+++ b/spacy/displacy/__init__.py
@@ -18,7 +18,7 @@ RENDER_WRAPPER = None
def render(
- docs: Union[Iterable[Union[Doc, Span]], Doc, Span],
+ docs: Union[Iterable[Union[Doc, Span, dict]], Doc, Span, dict],
style: str = "dep",
page: bool = False,
minify: bool = False,
@@ -28,7 +28,8 @@ def render(
) -> str:
"""Render displaCy visualisation.
- docs (Union[Iterable[Doc], Doc]): Document(s) to visualise.
+ docs (Union[Iterable[Union[Doc, Span, dict]], Doc, Span, dict]]): Document(s) to visualise.
+ a 'dict' is only allowed here when 'manual' is set to True
style (str): Visualisation style, 'dep' or 'ent'.
page (bool): Render markup as full HTML page.
minify (bool): Minify HTML markup.
@@ -53,8 +54,8 @@ def render(
raise ValueError(Errors.E096)
renderer_func, converter = factories[style]
renderer = renderer_func(options=options)
- parsed = [converter(doc, options) for doc in docs] if not manual else docs
- _html["parsed"] = renderer.render(parsed, page=page, minify=minify).strip()
+ parsed = [converter(doc, options) for doc in docs] if not manual else docs # type: ignore
+ _html["parsed"] = renderer.render(parsed, page=page, minify=minify).strip() # type: ignore
html = _html["parsed"]
if RENDER_WRAPPER is not None:
html = RENDER_WRAPPER(html)
@@ -133,7 +134,7 @@ def parse_deps(orig_doc: Doc, options: Dict[str, Any] = {}) -> Dict[str, Any]:
"lemma": np.root.lemma_,
"ent_type": np.root.ent_type_,
}
- retokenizer.merge(np, attrs=attrs)
+ retokenizer.merge(np, attrs=attrs) # type: ignore[arg-type]
if options.get("collapse_punct", True):
spans = []
for word in doc[:-1]:
@@ -148,7 +149,7 @@ def parse_deps(orig_doc: Doc, options: Dict[str, Any] = {}) -> Dict[str, Any]:
with doc.retokenize() as retokenizer:
for span, tag, lemma, ent_type in spans:
attrs = {"tag": tag, "lemma": lemma, "ent_type": ent_type}
- retokenizer.merge(span, attrs=attrs)
+ retokenizer.merge(span, attrs=attrs) # type: ignore[arg-type]
fine_grained = options.get("fine_grained")
add_lemma = options.get("add_lemma")
words = [
diff --git a/spacy/errors.py b/spacy/errors.py
index 4b617ecf3..4fe3e9003 100644
--- a/spacy/errors.py
+++ b/spacy/errors.py
@@ -190,6 +190,8 @@ class Warnings:
"vectors. This is almost certainly a mistake.")
W113 = ("Sourced component '{name}' may not work as expected: source "
"vectors are not identical to current pipeline vectors.")
+ W114 = ("Using multiprocessing with GPU models is not recommended and may "
+ "lead to errors.")
@add_codes
diff --git a/spacy/kb.pyx b/spacy/kb.pyx
index 421a8241a..9a765c8e4 100644
--- a/spacy/kb.pyx
+++ b/spacy/kb.pyx
@@ -1,5 +1,5 @@
# cython: infer_types=True, profile=True
-from typing import Iterator, Iterable
+from typing import Iterator, Iterable, Callable, Dict, Any
import srsly
from cymem.cymem cimport Pool
@@ -96,6 +96,8 @@ cdef class KnowledgeBase:
def initialize_entities(self, int64_t nr_entities):
self._entry_index = PreshMap(nr_entities + 1)
self._entries = entry_vec(nr_entities + 1)
+
+ def initialize_vectors(self, int64_t nr_entities):
self._vectors_table = float_matrix(nr_entities + 1)
def initialize_aliases(self, int64_t nr_aliases):
@@ -154,6 +156,7 @@ cdef class KnowledgeBase:
nr_entities = len(set(entity_list))
self.initialize_entities(nr_entities)
+ self.initialize_vectors(nr_entities)
i = 0
cdef KBEntryC entry
@@ -172,8 +175,8 @@ cdef class KnowledgeBase:
entry.entity_hash = entity_hash
entry.freq = freq_list[i]
- vector_index = self.c_add_vector(entity_vector=vector_list[i])
- entry.vector_index = vector_index
+ self._vectors_table[i] = entity_vector
+ entry.vector_index = i
entry.feats_row = -1 # Features table currently not implemented
@@ -386,6 +389,7 @@ cdef class KnowledgeBase:
nr_aliases = header[1]
entity_vector_length = header[2]
self.initialize_entities(nr_entities)
+ self.initialize_vectors(nr_entities)
self.initialize_aliases(nr_aliases)
self.entity_vector_length = entity_vector_length
@@ -446,7 +450,7 @@ cdef class KnowledgeBase:
raise ValueError(Errors.E929.format(loc=path))
if not path.is_dir():
raise ValueError(Errors.E928.format(loc=path))
- deserialize = {}
+ deserialize: Dict[str, Callable[[Any], Any]] = {}
deserialize["contents"] = lambda p: self.read_contents(p)
deserialize["strings.json"] = lambda p: self.vocab.strings.from_disk(p)
util.from_disk(path, deserialize, exclude)
@@ -509,6 +513,7 @@ cdef class KnowledgeBase:
reader.read_header(&nr_entities, &entity_vector_length)
self.initialize_entities(nr_entities)
+ self.initialize_vectors(nr_entities)
self.entity_vector_length = entity_vector_length
# STEP 1: load entity vectors
diff --git a/spacy/lang/af/__init__.py b/spacy/lang/af/__init__.py
index 91917daee..553fcbf4c 100644
--- a/spacy/lang/af/__init__.py
+++ b/spacy/lang/af/__init__.py
@@ -1,8 +1,8 @@
from .stop_words import STOP_WORDS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class AfrikaansDefaults(Language.Defaults):
+class AfrikaansDefaults(BaseDefaults):
stop_words = STOP_WORDS
diff --git a/spacy/lang/am/__init__.py b/spacy/lang/am/__init__.py
index ed21b55ee..ddae556d6 100644
--- a/spacy/lang/am/__init__.py
+++ b/spacy/lang/am/__init__.py
@@ -4,12 +4,12 @@ from .punctuation import TOKENIZER_SUFFIXES
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from ..tokenizer_exceptions import BASE_EXCEPTIONS
-from ...language import Language
+from ...language import Language, BaseDefaults
from ...attrs import LANG
from ...util import update_exc
-class AmharicDefaults(Language.Defaults):
+class AmharicDefaults(BaseDefaults):
lex_attr_getters = dict(Language.Defaults.lex_attr_getters)
lex_attr_getters.update(LEX_ATTRS)
lex_attr_getters[LANG] = lambda text: "am"
diff --git a/spacy/lang/ar/__init__.py b/spacy/lang/ar/__init__.py
index 6abb65efb..18c1f90ed 100644
--- a/spacy/lang/ar/__init__.py
+++ b/spacy/lang/ar/__init__.py
@@ -2,10 +2,10 @@ from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
from .punctuation import TOKENIZER_SUFFIXES
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class ArabicDefaults(Language.Defaults):
+class ArabicDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
suffixes = TOKENIZER_SUFFIXES
stop_words = STOP_WORDS
diff --git a/spacy/lang/az/__init__.py b/spacy/lang/az/__init__.py
index 2937e2ecf..476898364 100644
--- a/spacy/lang/az/__init__.py
+++ b/spacy/lang/az/__init__.py
@@ -1,9 +1,9 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class AzerbaijaniDefaults(Language.Defaults):
+class AzerbaijaniDefaults(BaseDefaults):
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/bg/__init__.py b/spacy/lang/bg/__init__.py
index 6fa539a28..559cc34c4 100644
--- a/spacy/lang/bg/__init__.py
+++ b/spacy/lang/bg/__init__.py
@@ -3,12 +3,12 @@ from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .lex_attrs import LEX_ATTRS
from ..tokenizer_exceptions import BASE_EXCEPTIONS
-from ...language import Language
+from ...language import Language, BaseDefaults
from ...attrs import LANG
from ...util import update_exc
-class BulgarianDefaults(Language.Defaults):
+class BulgarianDefaults(BaseDefaults):
lex_attr_getters = dict(Language.Defaults.lex_attr_getters)
lex_attr_getters[LANG] = lambda text: "bg"
diff --git a/spacy/lang/bn/__init__.py b/spacy/lang/bn/__init__.py
index 945560aac..6d0331e00 100644
--- a/spacy/lang/bn/__init__.py
+++ b/spacy/lang/bn/__init__.py
@@ -3,11 +3,11 @@ from thinc.api import Model
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES, TOKENIZER_INFIXES
from .stop_words import STOP_WORDS
-from ...language import Language
+from ...language import Language, BaseDefaults
from ...pipeline import Lemmatizer
-class BengaliDefaults(Language.Defaults):
+class BengaliDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
prefixes = TOKENIZER_PREFIXES
suffixes = TOKENIZER_SUFFIXES
diff --git a/spacy/lang/ca/__init__.py b/spacy/lang/ca/__init__.py
index 802c7e4cc..a3def660d 100755
--- a/spacy/lang/ca/__init__.py
+++ b/spacy/lang/ca/__init__.py
@@ -7,11 +7,11 @@ from .punctuation import TOKENIZER_INFIXES, TOKENIZER_SUFFIXES, TOKENIZER_PREFIX
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
from .syntax_iterators import SYNTAX_ITERATORS
-from ...language import Language
+from ...language import Language, BaseDefaults
from .lemmatizer import CatalanLemmatizer
-class CatalanDefaults(Language.Defaults):
+class CatalanDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
infixes = TOKENIZER_INFIXES
suffixes = TOKENIZER_SUFFIXES
diff --git a/spacy/lang/ca/syntax_iterators.py b/spacy/lang/ca/syntax_iterators.py
index c70d53e80..917e07c93 100644
--- a/spacy/lang/ca/syntax_iterators.py
+++ b/spacy/lang/ca/syntax_iterators.py
@@ -1,8 +1,10 @@
+from typing import Union, Iterator, Tuple
+from ...tokens import Doc, Span
from ...symbols import NOUN, PROPN
from ...errors import Errors
-def noun_chunks(doclike):
+def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
"""Detect base noun phrases from a dependency parse. Works on Doc and Span."""
# fmt: off
labels = ["nsubj", "nsubj:pass", "obj", "obl", "iobj", "ROOT", "appos", "nmod", "nmod:poss"]
diff --git a/spacy/lang/cs/__init__.py b/spacy/lang/cs/__init__.py
index 26f5845cc..3e70e4078 100644
--- a/spacy/lang/cs/__init__.py
+++ b/spacy/lang/cs/__init__.py
@@ -1,9 +1,9 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class CzechDefaults(Language.Defaults):
+class CzechDefaults(BaseDefaults):
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/da/__init__.py b/spacy/lang/da/__init__.py
index c5260ccdd..e148a7b4f 100644
--- a/spacy/lang/da/__init__.py
+++ b/spacy/lang/da/__init__.py
@@ -3,10 +3,10 @@ from .punctuation import TOKENIZER_INFIXES, TOKENIZER_SUFFIXES
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
from .syntax_iterators import SYNTAX_ITERATORS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class DanishDefaults(Language.Defaults):
+class DanishDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
infixes = TOKENIZER_INFIXES
suffixes = TOKENIZER_SUFFIXES
diff --git a/spacy/lang/da/syntax_iterators.py b/spacy/lang/da/syntax_iterators.py
index 39181d753..a0b70f004 100644
--- a/spacy/lang/da/syntax_iterators.py
+++ b/spacy/lang/da/syntax_iterators.py
@@ -1,8 +1,10 @@
+from typing import Union, Iterator, Tuple
+from ...tokens import Doc, Span
from ...symbols import NOUN, PROPN, PRON, VERB, AUX
from ...errors import Errors
-def noun_chunks(doclike):
+def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
def is_verb_token(tok):
return tok.pos in [VERB, AUX]
@@ -32,7 +34,7 @@ def noun_chunks(doclike):
def get_bounds(doc, root):
return get_left_bound(doc, root), get_right_bound(doc, root)
- doc = doclike.doc
+ doc = doclike.doc # Ensure works on both Doc and Span.
if not doc.has_annotation("DEP"):
raise ValueError(Errors.E029)
diff --git a/spacy/lang/de/__init__.py b/spacy/lang/de/__init__.py
index b645d3480..65863c098 100644
--- a/spacy/lang/de/__init__.py
+++ b/spacy/lang/de/__init__.py
@@ -2,10 +2,10 @@ from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES, TOKENIZER_INFIXES
from .stop_words import STOP_WORDS
from .syntax_iterators import SYNTAX_ITERATORS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class GermanDefaults(Language.Defaults):
+class GermanDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
prefixes = TOKENIZER_PREFIXES
suffixes = TOKENIZER_SUFFIXES
diff --git a/spacy/lang/de/syntax_iterators.py b/spacy/lang/de/syntax_iterators.py
index aba0e8024..e80504998 100644
--- a/spacy/lang/de/syntax_iterators.py
+++ b/spacy/lang/de/syntax_iterators.py
@@ -1,11 +1,11 @@
-from typing import Union, Iterator
+from typing import Union, Iterator, Tuple
from ...symbols import NOUN, PROPN, PRON
from ...errors import Errors
from ...tokens import Doc, Span
-def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Span]:
+def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
"""Detect base noun phrases from a dependency parse. Works on Doc and Span."""
# this iterator extracts spans headed by NOUNs starting from the left-most
# syntactic dependent until the NOUN itself for close apposition and
diff --git a/spacy/lang/el/__init__.py b/spacy/lang/el/__init__.py
index e843114fc..53dd9be8e 100644
--- a/spacy/lang/el/__init__.py
+++ b/spacy/lang/el/__init__.py
@@ -7,10 +7,10 @@ from .lex_attrs import LEX_ATTRS
from .syntax_iterators import SYNTAX_ITERATORS
from .punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES, TOKENIZER_INFIXES
from .lemmatizer import GreekLemmatizer
-from ...language import Language
+from ...language import Language, BaseDefaults
-class GreekDefaults(Language.Defaults):
+class GreekDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
prefixes = TOKENIZER_PREFIXES
suffixes = TOKENIZER_SUFFIXES
diff --git a/spacy/lang/el/syntax_iterators.py b/spacy/lang/el/syntax_iterators.py
index 89cfd8b72..18fa46695 100644
--- a/spacy/lang/el/syntax_iterators.py
+++ b/spacy/lang/el/syntax_iterators.py
@@ -1,11 +1,11 @@
-from typing import Union, Iterator
+from typing import Union, Iterator, Tuple
from ...symbols import NOUN, PROPN, PRON
from ...errors import Errors
from ...tokens import Doc, Span
-def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Span]:
+def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
"""Detect base noun phrases from a dependency parse. Works on Doc and Span."""
# It follows the logic of the noun chunks finder of English language,
# adjusted to some Greek language special characteristics.
diff --git a/spacy/lang/en/__init__.py b/spacy/lang/en/__init__.py
index a84b50476..876186979 100644
--- a/spacy/lang/en/__init__.py
+++ b/spacy/lang/en/__init__.py
@@ -7,10 +7,10 @@ from .lex_attrs import LEX_ATTRS
from .syntax_iterators import SYNTAX_ITERATORS
from .punctuation import TOKENIZER_INFIXES
from .lemmatizer import EnglishLemmatizer
-from ...language import Language
+from ...language import Language, BaseDefaults
-class EnglishDefaults(Language.Defaults):
+class EnglishDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
infixes = TOKENIZER_INFIXES
lex_attr_getters = LEX_ATTRS
diff --git a/spacy/lang/en/lex_attrs.py b/spacy/lang/en/lex_attrs.py
index b630a317d..ab9353919 100644
--- a/spacy/lang/en/lex_attrs.py
+++ b/spacy/lang/en/lex_attrs.py
@@ -19,7 +19,7 @@ _ordinal_words = [
# fmt: on
-def like_num(text: str) -> bool:
+def like_num(text):
if text.startswith(("+", "-", "±", "~")):
text = text[1:]
text = text.replace(",", "").replace(".", "")
diff --git a/spacy/lang/en/syntax_iterators.py b/spacy/lang/en/syntax_iterators.py
index 00a1bac42..7904e5621 100644
--- a/spacy/lang/en/syntax_iterators.py
+++ b/spacy/lang/en/syntax_iterators.py
@@ -1,11 +1,11 @@
-from typing import Union, Iterator
+from typing import Union, Iterator, Tuple
from ...symbols import NOUN, PROPN, PRON
from ...errors import Errors
from ...tokens import Doc, Span
-def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Span]:
+def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
"""
Detect base noun phrases from a dependency parse. Works on both Doc and Span.
"""
diff --git a/spacy/lang/en/tokenizer_exceptions.py b/spacy/lang/en/tokenizer_exceptions.py
index d69508470..55b544e42 100644
--- a/spacy/lang/en/tokenizer_exceptions.py
+++ b/spacy/lang/en/tokenizer_exceptions.py
@@ -1,9 +1,10 @@
+from typing import Dict, List
from ..tokenizer_exceptions import BASE_EXCEPTIONS
from ...symbols import ORTH, NORM
from ...util import update_exc
-_exc = {}
+_exc: Dict[str, List[Dict]] = {}
_exclude = [
"Ill",
"ill",
@@ -294,9 +295,9 @@ for verb_data in [
{ORTH: "has", NORM: "has"},
{ORTH: "dare", NORM: "dare"},
]:
- verb_data_tc = dict(verb_data)
+ verb_data_tc = dict(verb_data) # type: ignore[call-overload]
verb_data_tc[ORTH] = verb_data_tc[ORTH].title()
- for data in [verb_data, verb_data_tc]:
+ for data in [verb_data, verb_data_tc]: # type: ignore[assignment]
_exc[data[ORTH] + "n't"] = [
dict(data),
{ORTH: "n't", NORM: "not"},
diff --git a/spacy/lang/es/__init__.py b/spacy/lang/es/__init__.py
index 2f246a678..e75955202 100644
--- a/spacy/lang/es/__init__.py
+++ b/spacy/lang/es/__init__.py
@@ -6,10 +6,10 @@ from .lex_attrs import LEX_ATTRS
from .lemmatizer import SpanishLemmatizer
from .syntax_iterators import SYNTAX_ITERATORS
from .punctuation import TOKENIZER_INFIXES, TOKENIZER_SUFFIXES
-from ...language import Language
+from ...language import Language, BaseDefaults
-class SpanishDefaults(Language.Defaults):
+class SpanishDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
infixes = TOKENIZER_INFIXES
suffixes = TOKENIZER_SUFFIXES
diff --git a/spacy/lang/es/lemmatizer.py b/spacy/lang/es/lemmatizer.py
index 56f74068d..ca5fc08c8 100644
--- a/spacy/lang/es/lemmatizer.py
+++ b/spacy/lang/es/lemmatizer.py
@@ -52,7 +52,7 @@ class SpanishLemmatizer(Lemmatizer):
rule_pos = "verb"
else:
rule_pos = pos
- rule = self.select_rule(rule_pos, features)
+ rule = self.select_rule(rule_pos, list(features))
index = self.lookups.get_table("lemma_index").get(rule_pos, [])
lemmas = getattr(self, "lemmatize_" + rule_pos)(
string, features, rule, index
@@ -191,6 +191,8 @@ class SpanishLemmatizer(Lemmatizer):
return selected_lemmas
else:
return possible_lemmas
+ else:
+ return []
def lemmatize_noun(
self, word: str, features: List[str], rule: str, index: List[str]
@@ -268,7 +270,7 @@ class SpanishLemmatizer(Lemmatizer):
return [word]
def lemmatize_pron(
- self, word: str, features: List[str], rule: str, index: List[str]
+ self, word: str, features: List[str], rule: Optional[str], index: List[str]
) -> List[str]:
"""
Lemmatize a pronoun.
@@ -319,9 +321,11 @@ class SpanishLemmatizer(Lemmatizer):
return selected_lemmas
else:
return possible_lemmas
+ else:
+ return []
def lemmatize_verb(
- self, word: str, features: List[str], rule: str, index: List[str]
+ self, word: str, features: List[str], rule: Optional[str], index: List[str]
) -> List[str]:
"""
Lemmatize a verb.
@@ -342,6 +346,7 @@ class SpanishLemmatizer(Lemmatizer):
selected_lemmas = []
# Apply lemmatization rules
+ rule = str(rule or "")
for old, new in self.lookups.get_table("lemma_rules").get(rule, []):
possible_lemma = re.sub(old + "$", new, word)
if possible_lemma != word:
@@ -389,11 +394,11 @@ class SpanishLemmatizer(Lemmatizer):
return [word]
def lemmatize_verb_pron(
- self, word: str, features: List[str], rule: str, index: List[str]
+ self, word: str, features: List[str], rule: Optional[str], index: List[str]
) -> List[str]:
# Strip and collect pronouns
pron_patt = "^(.*?)([mts]e|l[aeo]s?|n?os)$"
- prons = []
+ prons: List[str] = []
verb = word
m = re.search(pron_patt, verb)
while m is not None and len(prons) <= 3:
@@ -410,7 +415,7 @@ class SpanishLemmatizer(Lemmatizer):
else:
rule = self.select_rule("verb", features)
verb_lemma = self.lemmatize_verb(
- verb, features - {"PronType=Prs"}, rule, index
+ verb, features - {"PronType=Prs"}, rule, index # type: ignore[operator]
)[0]
pron_lemmas = []
for pron in prons:
diff --git a/spacy/lang/es/syntax_iterators.py b/spacy/lang/es/syntax_iterators.py
index e753a3f98..8b385a1b9 100644
--- a/spacy/lang/es/syntax_iterators.py
+++ b/spacy/lang/es/syntax_iterators.py
@@ -1,11 +1,11 @@
-from typing import Union, Iterator
+from typing import Union, Iterator, Tuple
from ...symbols import NOUN, PROPN, PRON, VERB, AUX
from ...errors import Errors
from ...tokens import Doc, Span, Token
-def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Span]:
+def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
"""Detect base noun phrases from a dependency parse. Works on Doc and Span."""
doc = doclike.doc
if not doc.has_annotation("DEP"):
diff --git a/spacy/lang/et/__init__.py b/spacy/lang/et/__init__.py
index 9f71882d2..274bc1309 100644
--- a/spacy/lang/et/__init__.py
+++ b/spacy/lang/et/__init__.py
@@ -1,8 +1,8 @@
from .stop_words import STOP_WORDS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class EstonianDefaults(Language.Defaults):
+class EstonianDefaults(BaseDefaults):
stop_words = STOP_WORDS
diff --git a/spacy/lang/eu/__init__.py b/spacy/lang/eu/__init__.py
index 89550be96..3346468bd 100644
--- a/spacy/lang/eu/__init__.py
+++ b/spacy/lang/eu/__init__.py
@@ -1,10 +1,10 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
from .punctuation import TOKENIZER_SUFFIXES
-from ...language import Language
+from ...language import Language, BaseDefaults
-class BasqueDefaults(Language.Defaults):
+class BasqueDefaults(BaseDefaults):
suffixes = TOKENIZER_SUFFIXES
stop_words = STOP_WORDS
lex_attr_getters = LEX_ATTRS
diff --git a/spacy/lang/fa/__init__.py b/spacy/lang/fa/__init__.py
index 0c3100f2b..914e4c27d 100644
--- a/spacy/lang/fa/__init__.py
+++ b/spacy/lang/fa/__init__.py
@@ -5,11 +5,11 @@ from .lex_attrs import LEX_ATTRS
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .punctuation import TOKENIZER_SUFFIXES
from .syntax_iterators import SYNTAX_ITERATORS
-from ...language import Language
+from ...language import Language, BaseDefaults
from ...pipeline import Lemmatizer
-class PersianDefaults(Language.Defaults):
+class PersianDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
suffixes = TOKENIZER_SUFFIXES
lex_attr_getters = LEX_ATTRS
diff --git a/spacy/lang/fa/generate_verbs_exc.py b/spacy/lang/fa/generate_verbs_exc.py
index 62094c6de..a6d79a386 100644
--- a/spacy/lang/fa/generate_verbs_exc.py
+++ b/spacy/lang/fa/generate_verbs_exc.py
@@ -639,10 +639,12 @@ for verb_root in verb_roots:
)
if past.startswith("آ"):
- conjugations = set(
- map(
- lambda item: item.replace("بآ", "بیا").replace("نآ", "نیا"),
- conjugations,
+ conjugations = list(
+ set(
+ map(
+ lambda item: item.replace("بآ", "بیا").replace("نآ", "نیا"),
+ conjugations,
+ )
)
)
diff --git a/spacy/lang/fa/syntax_iterators.py b/spacy/lang/fa/syntax_iterators.py
index 0be06e73c..8207884b0 100644
--- a/spacy/lang/fa/syntax_iterators.py
+++ b/spacy/lang/fa/syntax_iterators.py
@@ -1,8 +1,10 @@
+from typing import Union, Iterator, Tuple
+from ...tokens import Doc, Span
from ...symbols import NOUN, PROPN, PRON
from ...errors import Errors
-def noun_chunks(doclike):
+def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
"""
Detect base noun phrases from a dependency parse. Works on both Doc and Span.
"""
diff --git a/spacy/lang/fi/__init__.py b/spacy/lang/fi/__init__.py
index 9233c6547..86a834170 100644
--- a/spacy/lang/fi/__init__.py
+++ b/spacy/lang/fi/__init__.py
@@ -2,10 +2,10 @@ from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
from .punctuation import TOKENIZER_INFIXES, TOKENIZER_SUFFIXES
-from ...language import Language
+from ...language import Language, BaseDefaults
-class FinnishDefaults(Language.Defaults):
+class FinnishDefaults(BaseDefaults):
infixes = TOKENIZER_INFIXES
suffixes = TOKENIZER_SUFFIXES
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
diff --git a/spacy/lang/fr/__init__.py b/spacy/lang/fr/__init__.py
index 254e1651b..27d2a915e 100644
--- a/spacy/lang/fr/__init__.py
+++ b/spacy/lang/fr/__init__.py
@@ -9,10 +9,10 @@ from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
from .syntax_iterators import SYNTAX_ITERATORS
from .lemmatizer import FrenchLemmatizer
-from ...language import Language
+from ...language import Language, BaseDefaults
-class FrenchDefaults(Language.Defaults):
+class FrenchDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
prefixes = TOKENIZER_PREFIXES
infixes = TOKENIZER_INFIXES
diff --git a/spacy/lang/fr/syntax_iterators.py b/spacy/lang/fr/syntax_iterators.py
index 68117a54d..d86662693 100644
--- a/spacy/lang/fr/syntax_iterators.py
+++ b/spacy/lang/fr/syntax_iterators.py
@@ -1,11 +1,11 @@
-from typing import Union, Iterator
+from typing import Union, Iterator, Tuple
from ...symbols import NOUN, PROPN, PRON
from ...errors import Errors
from ...tokens import Doc, Span
-def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Span]:
+def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
"""Detect base noun phrases from a dependency parse. Works on Doc and Span."""
# fmt: off
labels = ["nsubj", "nsubj:pass", "obj", "iobj", "ROOT", "appos", "nmod", "nmod:poss"]
diff --git a/spacy/lang/fr/tokenizer_exceptions.py b/spacy/lang/fr/tokenizer_exceptions.py
index 060f81879..2e88b58cf 100644
--- a/spacy/lang/fr/tokenizer_exceptions.py
+++ b/spacy/lang/fr/tokenizer_exceptions.py
@@ -115,7 +115,7 @@ for s, verb, pronoun in [("s", "est", "il"), ("S", "EST", "IL")]:
]
-_infixes_exc = []
+_infixes_exc = [] # type: ignore[var-annotated]
orig_elision = "'"
orig_hyphen = "-"
diff --git a/spacy/lang/ga/__init__.py b/spacy/lang/ga/__init__.py
index 167edf939..3be53bc7a 100644
--- a/spacy/lang/ga/__init__.py
+++ b/spacy/lang/ga/__init__.py
@@ -4,11 +4,11 @@ from thinc.api import Model
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .stop_words import STOP_WORDS
-from ...language import Language
+from ...language import Language, BaseDefaults
from .lemmatizer import IrishLemmatizer
-class IrishDefaults(Language.Defaults):
+class IrishDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
stop_words = STOP_WORDS
diff --git a/spacy/lang/grc/__init__.py b/spacy/lang/grc/__init__.py
index e29252da9..e83f0c5a5 100644
--- a/spacy/lang/grc/__init__.py
+++ b/spacy/lang/grc/__init__.py
@@ -1,10 +1,10 @@
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class AncientGreekDefaults(Language.Defaults):
+class AncientGreekDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/grc/tokenizer_exceptions.py b/spacy/lang/grc/tokenizer_exceptions.py
index 230a58fd2..bcee70f32 100644
--- a/spacy/lang/grc/tokenizer_exceptions.py
+++ b/spacy/lang/grc/tokenizer_exceptions.py
@@ -108,8 +108,4 @@ _other_exc = {
_exc.update(_other_exc)
-_exc_data = {}
-
-_exc.update(_exc_data)
-
TOKENIZER_EXCEPTIONS = update_exc(BASE_EXCEPTIONS, _exc)
diff --git a/spacy/lang/gu/__init__.py b/spacy/lang/gu/__init__.py
index 67228ac40..e6fbc9d18 100644
--- a/spacy/lang/gu/__init__.py
+++ b/spacy/lang/gu/__init__.py
@@ -1,8 +1,8 @@
from .stop_words import STOP_WORDS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class GujaratiDefaults(Language.Defaults):
+class GujaratiDefaults(BaseDefaults):
stop_words = STOP_WORDS
diff --git a/spacy/lang/he/__init__.py b/spacy/lang/he/__init__.py
index e0adc3293..dd2ee478d 100644
--- a/spacy/lang/he/__init__.py
+++ b/spacy/lang/he/__init__.py
@@ -1,9 +1,9 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class HebrewDefaults(Language.Defaults):
+class HebrewDefaults(BaseDefaults):
stop_words = STOP_WORDS
lex_attr_getters = LEX_ATTRS
writing_system = {"direction": "rtl", "has_case": False, "has_letters": True}
diff --git a/spacy/lang/hi/__init__.py b/spacy/lang/hi/__init__.py
index 384f040c8..4c8ae446d 100644
--- a/spacy/lang/hi/__init__.py
+++ b/spacy/lang/hi/__init__.py
@@ -1,9 +1,9 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class HindiDefaults(Language.Defaults):
+class HindiDefaults(BaseDefaults):
stop_words = STOP_WORDS
lex_attr_getters = LEX_ATTRS
diff --git a/spacy/lang/hr/__init__.py b/spacy/lang/hr/__init__.py
index 118e0946a..30870b522 100644
--- a/spacy/lang/hr/__init__.py
+++ b/spacy/lang/hr/__init__.py
@@ -1,8 +1,8 @@
from .stop_words import STOP_WORDS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class CroatianDefaults(Language.Defaults):
+class CroatianDefaults(BaseDefaults):
stop_words = STOP_WORDS
diff --git a/spacy/lang/hu/__init__.py b/spacy/lang/hu/__init__.py
index 8962603a6..9426bacea 100644
--- a/spacy/lang/hu/__init__.py
+++ b/spacy/lang/hu/__init__.py
@@ -1,10 +1,10 @@
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS, TOKEN_MATCH
from .punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES, TOKENIZER_INFIXES
from .stop_words import STOP_WORDS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class HungarianDefaults(Language.Defaults):
+class HungarianDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
prefixes = TOKENIZER_PREFIXES
suffixes = TOKENIZER_SUFFIXES
diff --git a/spacy/lang/hy/__init__.py b/spacy/lang/hy/__init__.py
index 4577ab641..481eaae0a 100644
--- a/spacy/lang/hy/__init__.py
+++ b/spacy/lang/hy/__init__.py
@@ -1,9 +1,9 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class ArmenianDefaults(Language.Defaults):
+class ArmenianDefaults(BaseDefaults):
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/id/__init__.py b/spacy/lang/id/__init__.py
index 87373551c..0d72cfa9d 100644
--- a/spacy/lang/id/__init__.py
+++ b/spacy/lang/id/__init__.py
@@ -3,10 +3,10 @@ from .punctuation import TOKENIZER_SUFFIXES, TOKENIZER_PREFIXES, TOKENIZER_INFIX
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .lex_attrs import LEX_ATTRS
from .syntax_iterators import SYNTAX_ITERATORS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class IndonesianDefaults(Language.Defaults):
+class IndonesianDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
prefixes = TOKENIZER_PREFIXES
suffixes = TOKENIZER_SUFFIXES
diff --git a/spacy/lang/id/syntax_iterators.py b/spacy/lang/id/syntax_iterators.py
index 0f29bfe16..fa984d411 100644
--- a/spacy/lang/id/syntax_iterators.py
+++ b/spacy/lang/id/syntax_iterators.py
@@ -1,11 +1,11 @@
-from typing import Union, Iterator
+from typing import Union, Iterator, Tuple
from ...symbols import NOUN, PROPN, PRON
from ...errors import Errors
from ...tokens import Doc, Span
-def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Span]:
+def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
"""
Detect base noun phrases from a dependency parse. Works on both Doc and Span.
"""
diff --git a/spacy/lang/is/__init__.py b/spacy/lang/is/__init__.py
index be5de5981..318363beb 100644
--- a/spacy/lang/is/__init__.py
+++ b/spacy/lang/is/__init__.py
@@ -1,8 +1,8 @@
from .stop_words import STOP_WORDS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class IcelandicDefaults(Language.Defaults):
+class IcelandicDefaults(BaseDefaults):
stop_words = STOP_WORDS
diff --git a/spacy/lang/it/__init__.py b/spacy/lang/it/__init__.py
index fc74789a3..1edebc837 100644
--- a/spacy/lang/it/__init__.py
+++ b/spacy/lang/it/__init__.py
@@ -4,11 +4,11 @@ from thinc.api import Model
from .stop_words import STOP_WORDS
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .punctuation import TOKENIZER_PREFIXES, TOKENIZER_INFIXES
-from ...language import Language
+from ...language import Language, BaseDefaults
from .lemmatizer import ItalianLemmatizer
-class ItalianDefaults(Language.Defaults):
+class ItalianDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
stop_words = STOP_WORDS
prefixes = TOKENIZER_PREFIXES
diff --git a/spacy/lang/ja/__init__.py b/spacy/lang/ja/__init__.py
index e701ecfdf..33335a189 100644
--- a/spacy/lang/ja/__init__.py
+++ b/spacy/lang/ja/__init__.py
@@ -11,7 +11,7 @@ from .tag_map import TAG_MAP
from .tag_orth_map import TAG_ORTH_MAP
from .tag_bigram_map import TAG_BIGRAM_MAP
from ...errors import Errors
-from ...language import Language
+from ...language import Language, BaseDefaults
from ...pipeline import Morphologizer
from ...pipeline.morphologizer import DEFAULT_MORPH_MODEL
from ...scorer import Scorer
@@ -172,7 +172,7 @@ class JapaneseTokenizer(DummyTokenizer):
def to_disk(self, path: Union[str, Path], **kwargs) -> None:
path = util.ensure_path(path)
serializers = {"cfg": lambda p: srsly.write_json(p, self._get_config())}
- return util.to_disk(path, serializers, [])
+ util.to_disk(path, serializers, [])
def from_disk(self, path: Union[str, Path], **kwargs) -> "JapaneseTokenizer":
path = util.ensure_path(path)
@@ -182,7 +182,7 @@ class JapaneseTokenizer(DummyTokenizer):
return self
-class JapaneseDefaults(Language.Defaults):
+class JapaneseDefaults(BaseDefaults):
config = load_config_from_str(DEFAULT_CONFIG)
stop_words = STOP_WORDS
syntax_iterators = SYNTAX_ITERATORS
diff --git a/spacy/lang/ja/syntax_iterators.py b/spacy/lang/ja/syntax_iterators.py
index cca4902ab..588a9ba03 100644
--- a/spacy/lang/ja/syntax_iterators.py
+++ b/spacy/lang/ja/syntax_iterators.py
@@ -1,4 +1,4 @@
-from typing import Union, Iterator
+from typing import Union, Iterator, Tuple, Set
from ...symbols import NOUN, PROPN, PRON, VERB
from ...tokens import Doc, Span
@@ -10,13 +10,13 @@ labels = ["nsubj", "nmod", "ddoclike", "nsubjpass", "pcomp", "pdoclike", "doclik
# fmt: on
-def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Span]:
+def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
"""Detect base noun phrases from a dependency parse. Works on Doc and Span."""
doc = doclike.doc # Ensure works on both Doc and Span.
np_deps = [doc.vocab.strings.add(label) for label in labels]
doc.vocab.strings.add("conj")
np_label = doc.vocab.strings.add("NP")
- seen = set()
+ seen: Set[int] = set()
for i, word in enumerate(doclike):
if word.pos not in (NOUN, PROPN, PRON):
continue
diff --git a/spacy/lang/kn/__init__.py b/spacy/lang/kn/__init__.py
index 8e53989e6..ccd46a394 100644
--- a/spacy/lang/kn/__init__.py
+++ b/spacy/lang/kn/__init__.py
@@ -1,8 +1,8 @@
from .stop_words import STOP_WORDS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class KannadaDefaults(Language.Defaults):
+class KannadaDefaults(BaseDefaults):
stop_words = STOP_WORDS
diff --git a/spacy/lang/ko/__init__.py b/spacy/lang/ko/__init__.py
index daa445e09..05fc67e79 100644
--- a/spacy/lang/ko/__init__.py
+++ b/spacy/lang/ko/__init__.py
@@ -1,9 +1,9 @@
-from typing import Any, Dict
+from typing import Iterator, Any, Dict
from .stop_words import STOP_WORDS
from .tag_map import TAG_MAP
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
from ...tokens import Doc
from ...scorer import Scorer
from ...symbols import POS
@@ -31,7 +31,7 @@ def create_tokenizer():
class KoreanTokenizer(DummyTokenizer):
def __init__(self, vocab: Vocab):
self.vocab = vocab
- MeCab = try_mecab_import()
+ MeCab = try_mecab_import() # type: ignore[func-returns-value]
self.mecab_tokenizer = MeCab("-F%f[0],%f[7]")
def __reduce__(self):
@@ -52,7 +52,7 @@ class KoreanTokenizer(DummyTokenizer):
doc.user_data["full_tags"] = [dt["tag"] for dt in dtokens]
return doc
- def detailed_tokens(self, text: str) -> Dict[str, Any]:
+ def detailed_tokens(self, text: str) -> Iterator[Dict[str, Any]]:
# 품사 태그(POS)[0], 의미 부류(semantic class)[1], 종성 유무(jongseong)[2], 읽기(reading)[3],
# 타입(type)[4], 첫번째 품사(start pos)[5], 마지막 품사(end pos)[6], 표현(expression)[7], *
for node in self.mecab_tokenizer.parse(text, as_nodes=True):
@@ -71,7 +71,7 @@ class KoreanTokenizer(DummyTokenizer):
return Scorer.score_tokenization(examples)
-class KoreanDefaults(Language.Defaults):
+class KoreanDefaults(BaseDefaults):
config = load_config_from_str(DEFAULT_CONFIG)
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/ky/__init__.py b/spacy/lang/ky/__init__.py
index a333db035..ccca384bd 100644
--- a/spacy/lang/ky/__init__.py
+++ b/spacy/lang/ky/__init__.py
@@ -2,10 +2,10 @@ from .lex_attrs import LEX_ATTRS
from .punctuation import TOKENIZER_INFIXES
from .stop_words import STOP_WORDS
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class KyrgyzDefaults(Language.Defaults):
+class KyrgyzDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
infixes = TOKENIZER_INFIXES
lex_attr_getters = LEX_ATTRS
diff --git a/spacy/lang/lb/__init__.py b/spacy/lang/lb/__init__.py
index da6fe55d7..7827e7762 100644
--- a/spacy/lang/lb/__init__.py
+++ b/spacy/lang/lb/__init__.py
@@ -2,10 +2,10 @@ from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .punctuation import TOKENIZER_INFIXES
from .lex_attrs import LEX_ATTRS
from .stop_words import STOP_WORDS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class LuxembourgishDefaults(Language.Defaults):
+class LuxembourgishDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
infixes = TOKENIZER_INFIXES
lex_attr_getters = LEX_ATTRS
diff --git a/spacy/lang/lij/__init__.py b/spacy/lang/lij/__init__.py
index 5ae280324..b7e11f77e 100644
--- a/spacy/lang/lij/__init__.py
+++ b/spacy/lang/lij/__init__.py
@@ -1,10 +1,10 @@
from .stop_words import STOP_WORDS
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .punctuation import TOKENIZER_INFIXES
-from ...language import Language
+from ...language import Language, BaseDefaults
-class LigurianDefaults(Language.Defaults):
+class LigurianDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
infixes = TOKENIZER_INFIXES
stop_words = STOP_WORDS
diff --git a/spacy/lang/lt/__init__.py b/spacy/lang/lt/__init__.py
index e395a8f62..3ae000e5f 100644
--- a/spacy/lang/lt/__init__.py
+++ b/spacy/lang/lt/__init__.py
@@ -2,10 +2,10 @@ from .punctuation import TOKENIZER_INFIXES, TOKENIZER_SUFFIXES
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class LithuanianDefaults(Language.Defaults):
+class LithuanianDefaults(BaseDefaults):
infixes = TOKENIZER_INFIXES
suffixes = TOKENIZER_SUFFIXES
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
diff --git a/spacy/lang/lv/__init__.py b/spacy/lang/lv/__init__.py
index 142bc706e..a05e5b939 100644
--- a/spacy/lang/lv/__init__.py
+++ b/spacy/lang/lv/__init__.py
@@ -1,8 +1,8 @@
from .stop_words import STOP_WORDS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class LatvianDefaults(Language.Defaults):
+class LatvianDefaults(BaseDefaults):
stop_words = STOP_WORDS
diff --git a/spacy/lang/mk/__init__.py b/spacy/lang/mk/__init__.py
index a8464f3b7..fa07cfef9 100644
--- a/spacy/lang/mk/__init__.py
+++ b/spacy/lang/mk/__init__.py
@@ -6,13 +6,13 @@ from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .lex_attrs import LEX_ATTRS
from ..tokenizer_exceptions import BASE_EXCEPTIONS
-from ...language import Language
+from ...language import Language, BaseDefaults
from ...attrs import LANG
from ...util import update_exc
from ...lookups import Lookups
-class MacedonianDefaults(Language.Defaults):
+class MacedonianDefaults(BaseDefaults):
lex_attr_getters = dict(Language.Defaults.lex_attr_getters)
lex_attr_getters[LANG] = lambda text: "mk"
diff --git a/spacy/lang/ml/__init__.py b/spacy/lang/ml/__init__.py
index cfad52261..9f90605f0 100644
--- a/spacy/lang/ml/__init__.py
+++ b/spacy/lang/ml/__init__.py
@@ -1,9 +1,9 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class MalayalamDefaults(Language.Defaults):
+class MalayalamDefaults(BaseDefaults):
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/mr/__init__.py b/spacy/lang/mr/__init__.py
index af0c49878..3e172fa60 100644
--- a/spacy/lang/mr/__init__.py
+++ b/spacy/lang/mr/__init__.py
@@ -1,8 +1,8 @@
from .stop_words import STOP_WORDS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class MarathiDefaults(Language.Defaults):
+class MarathiDefaults(BaseDefaults):
stop_words = STOP_WORDS
diff --git a/spacy/lang/nb/__init__.py b/spacy/lang/nb/__init__.py
index d08f8f768..e079236fd 100644
--- a/spacy/lang/nb/__init__.py
+++ b/spacy/lang/nb/__init__.py
@@ -5,11 +5,11 @@ from .punctuation import TOKENIZER_PREFIXES, TOKENIZER_INFIXES
from .punctuation import TOKENIZER_SUFFIXES
from .stop_words import STOP_WORDS
from .syntax_iterators import SYNTAX_ITERATORS
-from ...language import Language
+from ...language import Language, BaseDefaults
from ...pipeline import Lemmatizer
-class NorwegianDefaults(Language.Defaults):
+class NorwegianDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
prefixes = TOKENIZER_PREFIXES
infixes = TOKENIZER_INFIXES
diff --git a/spacy/lang/nb/syntax_iterators.py b/spacy/lang/nb/syntax_iterators.py
index 68117a54d..d86662693 100644
--- a/spacy/lang/nb/syntax_iterators.py
+++ b/spacy/lang/nb/syntax_iterators.py
@@ -1,11 +1,11 @@
-from typing import Union, Iterator
+from typing import Union, Iterator, Tuple
from ...symbols import NOUN, PROPN, PRON
from ...errors import Errors
from ...tokens import Doc, Span
-def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Span]:
+def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
"""Detect base noun phrases from a dependency parse. Works on Doc and Span."""
# fmt: off
labels = ["nsubj", "nsubj:pass", "obj", "iobj", "ROOT", "appos", "nmod", "nmod:poss"]
diff --git a/spacy/lang/ne/__init__.py b/spacy/lang/ne/__init__.py
index 68632e9ad..0028d1b0b 100644
--- a/spacy/lang/ne/__init__.py
+++ b/spacy/lang/ne/__init__.py
@@ -1,9 +1,9 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class NepaliDefaults(Language.Defaults):
+class NepaliDefaults(BaseDefaults):
stop_words = STOP_WORDS
lex_attr_getters = LEX_ATTRS
diff --git a/spacy/lang/nl/__init__.py b/spacy/lang/nl/__init__.py
index 0a6480a1d..ad2205a0b 100644
--- a/spacy/lang/nl/__init__.py
+++ b/spacy/lang/nl/__init__.py
@@ -9,10 +9,10 @@ from .punctuation import TOKENIZER_SUFFIXES
from .stop_words import STOP_WORDS
from .syntax_iterators import SYNTAX_ITERATORS
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class DutchDefaults(Language.Defaults):
+class DutchDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
prefixes = TOKENIZER_PREFIXES
infixes = TOKENIZER_INFIXES
diff --git a/spacy/lang/nl/syntax_iterators.py b/spacy/lang/nl/syntax_iterators.py
index 1959ba6e2..1ab5e7cff 100644
--- a/spacy/lang/nl/syntax_iterators.py
+++ b/spacy/lang/nl/syntax_iterators.py
@@ -1,11 +1,11 @@
-from typing import Union, Iterator
+from typing import Union, Iterator, Tuple
from ...symbols import NOUN, PRON
from ...errors import Errors
from ...tokens import Doc, Span
-def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Span]:
+def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
"""
Detect base noun phrases from a dependency parse. Works on Doc and Span.
The definition is inspired by https://www.nltk.org/book/ch07.html
diff --git a/spacy/lang/pl/__init__.py b/spacy/lang/pl/__init__.py
index 1d71244a2..02c96799b 100644
--- a/spacy/lang/pl/__init__.py
+++ b/spacy/lang/pl/__init__.py
@@ -8,7 +8,7 @@ from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
from .lemmatizer import PolishLemmatizer
from ..tokenizer_exceptions import BASE_EXCEPTIONS
-from ...language import Language
+from ...language import Language, BaseDefaults
TOKENIZER_EXCEPTIONS = {
@@ -16,7 +16,7 @@ TOKENIZER_EXCEPTIONS = {
}
-class PolishDefaults(Language.Defaults):
+class PolishDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
prefixes = TOKENIZER_PREFIXES
infixes = TOKENIZER_INFIXES
diff --git a/spacy/lang/pt/__init__.py b/spacy/lang/pt/__init__.py
index 0447099f0..9ae6501fb 100644
--- a/spacy/lang/pt/__init__.py
+++ b/spacy/lang/pt/__init__.py
@@ -2,10 +2,10 @@ from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
from .punctuation import TOKENIZER_INFIXES, TOKENIZER_PREFIXES
-from ...language import Language
+from ...language import Language, BaseDefaults
-class PortugueseDefaults(Language.Defaults):
+class PortugueseDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
infixes = TOKENIZER_INFIXES
prefixes = TOKENIZER_PREFIXES
diff --git a/spacy/lang/ro/__init__.py b/spacy/lang/ro/__init__.py
index f0d8d8d31..50027ffd2 100644
--- a/spacy/lang/ro/__init__.py
+++ b/spacy/lang/ro/__init__.py
@@ -3,14 +3,14 @@ from .stop_words import STOP_WORDS
from .punctuation import TOKENIZER_PREFIXES, TOKENIZER_INFIXES
from .punctuation import TOKENIZER_SUFFIXES
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
# Lemma data note:
# Original pairs downloaded from http://www.lexiconista.com/datasets/lemmatization/
# Replaced characters using cedillas with the correct ones (ș and ț)
-class RomanianDefaults(Language.Defaults):
+class RomanianDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
prefixes = TOKENIZER_PREFIXES
suffixes = TOKENIZER_SUFFIXES
diff --git a/spacy/lang/ru/__init__.py b/spacy/lang/ru/__init__.py
index 0f645ddb1..5d31d8ea2 100644
--- a/spacy/lang/ru/__init__.py
+++ b/spacy/lang/ru/__init__.py
@@ -5,10 +5,10 @@ from .stop_words import STOP_WORDS
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .lex_attrs import LEX_ATTRS
from .lemmatizer import RussianLemmatizer
-from ...language import Language
+from ...language import Language, BaseDefaults
-class RussianDefaults(Language.Defaults):
+class RussianDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/ru/lemmatizer.py b/spacy/lang/ru/lemmatizer.py
index a56938641..2fc3a471b 100644
--- a/spacy/lang/ru/lemmatizer.py
+++ b/spacy/lang/ru/lemmatizer.py
@@ -58,7 +58,9 @@ class RussianLemmatizer(Lemmatizer):
if not len(filtered_analyses):
return [string.lower()]
if morphology is None or (len(morphology) == 1 and POS in morphology):
- return list(dict.fromkeys([analysis.normal_form for analysis in filtered_analyses]))
+ return list(
+ dict.fromkeys([analysis.normal_form for analysis in filtered_analyses])
+ )
if univ_pos in ("ADJ", "DET", "NOUN", "PROPN"):
features_to_compare = ["Case", "Number", "Gender"]
elif univ_pos == "NUM":
@@ -89,7 +91,9 @@ class RussianLemmatizer(Lemmatizer):
filtered_analyses.append(analysis)
if not len(filtered_analyses):
return [string.lower()]
- return list(dict.fromkeys([analysis.normal_form for analysis in filtered_analyses]))
+ return list(
+ dict.fromkeys([analysis.normal_form for analysis in filtered_analyses])
+ )
def pymorphy2_lookup_lemmatize(self, token: Token) -> List[str]:
string = token.text
diff --git a/spacy/lang/sa/__init__.py b/spacy/lang/sa/__init__.py
index 345137817..61398af6c 100644
--- a/spacy/lang/sa/__init__.py
+++ b/spacy/lang/sa/__init__.py
@@ -1,9 +1,9 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class SanskritDefaults(Language.Defaults):
+class SanskritDefaults(BaseDefaults):
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/si/__init__.py b/spacy/lang/si/__init__.py
index d77e3bb8b..971cee3c6 100644
--- a/spacy/lang/si/__init__.py
+++ b/spacy/lang/si/__init__.py
@@ -1,9 +1,9 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class SinhalaDefaults(Language.Defaults):
+class SinhalaDefaults(BaseDefaults):
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/sk/__init__.py b/spacy/lang/sk/__init__.py
index 4003c7340..da6e3048e 100644
--- a/spacy/lang/sk/__init__.py
+++ b/spacy/lang/sk/__init__.py
@@ -1,9 +1,9 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class SlovakDefaults(Language.Defaults):
+class SlovakDefaults(BaseDefaults):
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/sl/__init__.py b/spacy/lang/sl/__init__.py
index 0330cc4d0..9ddd676bf 100644
--- a/spacy/lang/sl/__init__.py
+++ b/spacy/lang/sl/__init__.py
@@ -1,8 +1,8 @@
from .stop_words import STOP_WORDS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class SlovenianDefaults(Language.Defaults):
+class SlovenianDefaults(BaseDefaults):
stop_words = STOP_WORDS
diff --git a/spacy/lang/sq/__init__.py b/spacy/lang/sq/__init__.py
index a4bacfa49..5e32a0cbe 100644
--- a/spacy/lang/sq/__init__.py
+++ b/spacy/lang/sq/__init__.py
@@ -1,8 +1,8 @@
from .stop_words import STOP_WORDS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class AlbanianDefaults(Language.Defaults):
+class AlbanianDefaults(BaseDefaults):
stop_words = STOP_WORDS
diff --git a/spacy/lang/sr/__init__.py b/spacy/lang/sr/__init__.py
index 165e54975..fd0c8c832 100644
--- a/spacy/lang/sr/__init__.py
+++ b/spacy/lang/sr/__init__.py
@@ -1,10 +1,10 @@
from .stop_words import STOP_WORDS
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class SerbianDefaults(Language.Defaults):
+class SerbianDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/sv/__init__.py b/spacy/lang/sv/__init__.py
index aa8d3f110..6963e8b79 100644
--- a/spacy/lang/sv/__init__.py
+++ b/spacy/lang/sv/__init__.py
@@ -4,7 +4,7 @@ from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
from .syntax_iterators import SYNTAX_ITERATORS
-from ...language import Language
+from ...language import Language, BaseDefaults
from ...pipeline import Lemmatizer
@@ -12,7 +12,7 @@ from ...pipeline import Lemmatizer
from ..da.punctuation import TOKENIZER_INFIXES, TOKENIZER_SUFFIXES
-class SwedishDefaults(Language.Defaults):
+class SwedishDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
infixes = TOKENIZER_INFIXES
suffixes = TOKENIZER_SUFFIXES
diff --git a/spacy/lang/sv/syntax_iterators.py b/spacy/lang/sv/syntax_iterators.py
index d5ae47853..06ad016ac 100644
--- a/spacy/lang/sv/syntax_iterators.py
+++ b/spacy/lang/sv/syntax_iterators.py
@@ -1,11 +1,11 @@
-from typing import Union, Iterator
+from typing import Union, Iterator, Tuple
from ...symbols import NOUN, PROPN, PRON
from ...errors import Errors
from ...tokens import Doc, Span
-def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Span]:
+def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
"""Detect base noun phrases from a dependency parse. Works on Doc and Span."""
# fmt: off
labels = ["nsubj", "nsubj:pass", "dobj", "obj", "iobj", "ROOT", "appos", "nmod", "nmod:poss"]
diff --git a/spacy/lang/ta/__init__.py b/spacy/lang/ta/__init__.py
index ac5fc7124..4929a4b97 100644
--- a/spacy/lang/ta/__init__.py
+++ b/spacy/lang/ta/__init__.py
@@ -1,9 +1,9 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class TamilDefaults(Language.Defaults):
+class TamilDefaults(BaseDefaults):
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/te/__init__.py b/spacy/lang/te/__init__.py
index e6dc80e28..77cc2fe9b 100644
--- a/spacy/lang/te/__init__.py
+++ b/spacy/lang/te/__init__.py
@@ -1,9 +1,9 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class TeluguDefaults(Language.Defaults):
+class TeluguDefaults(BaseDefaults):
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/th/__init__.py b/spacy/lang/th/__init__.py
index a89d4dc77..12b1527e0 100644
--- a/spacy/lang/th/__init__.py
+++ b/spacy/lang/th/__init__.py
@@ -1,6 +1,6 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
from ...tokens import Doc
from ...util import DummyTokenizer, registry, load_config_from_str
from ...vocab import Vocab
@@ -40,7 +40,7 @@ class ThaiTokenizer(DummyTokenizer):
return Doc(self.vocab, words=words, spaces=spaces)
-class ThaiDefaults(Language.Defaults):
+class ThaiDefaults(BaseDefaults):
config = load_config_from_str(DEFAULT_CONFIG)
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/ti/__init__.py b/spacy/lang/ti/__init__.py
index 709fb21cb..c74c081b5 100644
--- a/spacy/lang/ti/__init__.py
+++ b/spacy/lang/ti/__init__.py
@@ -4,12 +4,12 @@ from .punctuation import TOKENIZER_SUFFIXES
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from ..tokenizer_exceptions import BASE_EXCEPTIONS
-from ...language import Language
+from ...language import Language, BaseDefaults
from ...attrs import LANG
from ...util import update_exc
-class TigrinyaDefaults(Language.Defaults):
+class TigrinyaDefaults(BaseDefaults):
lex_attr_getters = dict(Language.Defaults.lex_attr_getters)
lex_attr_getters.update(LEX_ATTRS)
lex_attr_getters[LANG] = lambda text: "ti"
diff --git a/spacy/lang/tl/__init__.py b/spacy/lang/tl/__init__.py
index 61530dc30..30838890a 100644
--- a/spacy/lang/tl/__init__.py
+++ b/spacy/lang/tl/__init__.py
@@ -1,10 +1,10 @@
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class TagalogDefaults(Language.Defaults):
+class TagalogDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/tn/__init__.py b/spacy/lang/tn/__init__.py
index 99907c28a..28e887eea 100644
--- a/spacy/lang/tn/__init__.py
+++ b/spacy/lang/tn/__init__.py
@@ -1,10 +1,10 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
from .punctuation import TOKENIZER_INFIXES
-from ...language import Language
+from ...language import Language, BaseDefaults
-class SetswanaDefaults(Language.Defaults):
+class SetswanaDefaults(BaseDefaults):
infixes = TOKENIZER_INFIXES
stop_words = STOP_WORDS
lex_attr_getters = LEX_ATTRS
diff --git a/spacy/lang/tr/__init__.py b/spacy/lang/tr/__init__.py
index 679411acf..02b5c7bf4 100644
--- a/spacy/lang/tr/__init__.py
+++ b/spacy/lang/tr/__init__.py
@@ -2,10 +2,10 @@ from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS, TOKEN_MATCH
from .stop_words import STOP_WORDS
from .syntax_iterators import SYNTAX_ITERATORS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class TurkishDefaults(Language.Defaults):
+class TurkishDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/tr/syntax_iterators.py b/spacy/lang/tr/syntax_iterators.py
index 3fd726fb5..769af1223 100644
--- a/spacy/lang/tr/syntax_iterators.py
+++ b/spacy/lang/tr/syntax_iterators.py
@@ -1,8 +1,10 @@
+from typing import Union, Iterator, Tuple
+from ...tokens import Doc, Span
from ...symbols import NOUN, PROPN, PRON
from ...errors import Errors
-def noun_chunks(doclike):
+def noun_chunks(doclike: Union[Doc, Span]) -> Iterator[Tuple[int, int, int]]:
"""
Detect base noun phrases from a dependency parse. Works on both Doc and Span.
"""
diff --git a/spacy/lang/tt/__init__.py b/spacy/lang/tt/__init__.py
index c8e293f29..d5e1e87ef 100644
--- a/spacy/lang/tt/__init__.py
+++ b/spacy/lang/tt/__init__.py
@@ -2,10 +2,10 @@ from .lex_attrs import LEX_ATTRS
from .punctuation import TOKENIZER_INFIXES
from .stop_words import STOP_WORDS
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class TatarDefaults(Language.Defaults):
+class TatarDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
infixes = TOKENIZER_INFIXES
lex_attr_getters = LEX_ATTRS
diff --git a/spacy/lang/uk/__init__.py b/spacy/lang/uk/__init__.py
index 2eef110b2..21f9649f2 100644
--- a/spacy/lang/uk/__init__.py
+++ b/spacy/lang/uk/__init__.py
@@ -6,10 +6,10 @@ from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
from .lemmatizer import UkrainianLemmatizer
-from ...language import Language
+from ...language import Language, BaseDefaults
-class UkrainianDefaults(Language.Defaults):
+class UkrainianDefaults(BaseDefaults):
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/ur/__init__.py b/spacy/lang/ur/__init__.py
index e3dee5805..266c5a73d 100644
--- a/spacy/lang/ur/__init__.py
+++ b/spacy/lang/ur/__init__.py
@@ -1,10 +1,10 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
from .punctuation import TOKENIZER_SUFFIXES
-from ...language import Language
+from ...language import Language, BaseDefaults
-class UrduDefaults(Language.Defaults):
+class UrduDefaults(BaseDefaults):
suffixes = TOKENIZER_SUFFIXES
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/vi/__init__.py b/spacy/lang/vi/__init__.py
index afc715ff3..822dc348c 100644
--- a/spacy/lang/vi/__init__.py
+++ b/spacy/lang/vi/__init__.py
@@ -6,7 +6,7 @@ import string
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
from ...tokens import Doc
from ...util import DummyTokenizer, registry, load_config_from_str
from ...vocab import Vocab
@@ -145,7 +145,7 @@ class VietnameseTokenizer(DummyTokenizer):
def to_disk(self, path: Union[str, Path], **kwargs) -> None:
path = util.ensure_path(path)
serializers = {"cfg": lambda p: srsly.write_json(p, self._get_config())}
- return util.to_disk(path, serializers, [])
+ util.to_disk(path, serializers, [])
def from_disk(self, path: Union[str, Path], **kwargs) -> "VietnameseTokenizer":
path = util.ensure_path(path)
@@ -154,7 +154,7 @@ class VietnameseTokenizer(DummyTokenizer):
return self
-class VietnameseDefaults(Language.Defaults):
+class VietnameseDefaults(BaseDefaults):
config = load_config_from_str(DEFAULT_CONFIG)
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/yo/__init__.py b/spacy/lang/yo/__init__.py
index df6bb7d4a..6c38ec8af 100644
--- a/spacy/lang/yo/__init__.py
+++ b/spacy/lang/yo/__init__.py
@@ -1,9 +1,9 @@
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
-from ...language import Language
+from ...language import Language, BaseDefaults
-class YorubaDefaults(Language.Defaults):
+class YorubaDefaults(BaseDefaults):
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
diff --git a/spacy/lang/zh/__init__.py b/spacy/lang/zh/__init__.py
index c6dd7bb85..fdf6776e2 100644
--- a/spacy/lang/zh/__init__.py
+++ b/spacy/lang/zh/__init__.py
@@ -6,7 +6,7 @@ import warnings
from pathlib import Path
from ...errors import Warnings, Errors
-from ...language import Language
+from ...language import Language, BaseDefaults
from ...scorer import Scorer
from ...tokens import Doc
from ...training import validate_examples, Example
@@ -57,21 +57,21 @@ def create_chinese_tokenizer(segmenter: Segmenter = Segmenter.char):
class ChineseTokenizer(DummyTokenizer):
def __init__(self, vocab: Vocab, segmenter: Segmenter = Segmenter.char):
self.vocab = vocab
- if isinstance(segmenter, Segmenter):
- segmenter = segmenter.value
- self.segmenter = segmenter
+ self.segmenter = (
+ segmenter.value if isinstance(segmenter, Segmenter) else segmenter
+ )
self.pkuseg_seg = None
self.jieba_seg = None
- if segmenter not in Segmenter.values():
+ if self.segmenter not in Segmenter.values():
warn_msg = Warnings.W103.format(
lang="Chinese",
- segmenter=segmenter,
+ segmenter=self.segmenter,
supported=", ".join(Segmenter.values()),
default="'char' (character segmentation)",
)
warnings.warn(warn_msg)
self.segmenter = Segmenter.char
- if segmenter == Segmenter.jieba:
+ if self.segmenter == Segmenter.jieba:
self.jieba_seg = try_jieba_import()
def initialize(
@@ -91,7 +91,7 @@ class ChineseTokenizer(DummyTokenizer):
def __call__(self, text: str) -> Doc:
if self.segmenter == Segmenter.jieba:
- words = list([x for x in self.jieba_seg.cut(text, cut_all=False) if x])
+ words = list([x for x in self.jieba_seg.cut(text, cut_all=False) if x]) # type: ignore[union-attr]
(words, spaces) = util.get_words_and_spaces(words, text)
return Doc(self.vocab, words=words, spaces=spaces)
elif self.segmenter == Segmenter.pkuseg:
@@ -122,7 +122,7 @@ class ChineseTokenizer(DummyTokenizer):
try:
import spacy_pkuseg
- self.pkuseg_seg.preprocesser = spacy_pkuseg.Preprocesser(None)
+ self.pkuseg_seg.preprocesser = spacy_pkuseg.Preprocesser(None) # type: ignore[attr-defined]
except ImportError:
msg = (
"spacy_pkuseg not installed: unable to reset pkuseg "
@@ -130,7 +130,7 @@ class ChineseTokenizer(DummyTokenizer):
)
raise ImportError(msg) from None
for word in words:
- self.pkuseg_seg.preprocesser.insert(word.strip(), "")
+ self.pkuseg_seg.preprocesser.insert(word.strip(), "") # type: ignore[attr-defined]
else:
warn_msg = Warnings.W104.format(target="pkuseg", current=self.segmenter)
warnings.warn(warn_msg)
@@ -283,7 +283,7 @@ class ChineseTokenizer(DummyTokenizer):
util.from_disk(path, serializers, [])
-class ChineseDefaults(Language.Defaults):
+class ChineseDefaults(BaseDefaults):
config = load_config_from_str(DEFAULT_CONFIG)
lex_attr_getters = LEX_ATTRS
stop_words = STOP_WORDS
@@ -295,7 +295,7 @@ class Chinese(Language):
Defaults = ChineseDefaults
-def try_jieba_import() -> None:
+def try_jieba_import():
try:
import jieba
@@ -311,7 +311,7 @@ def try_jieba_import() -> None:
raise ImportError(msg) from None
-def try_pkuseg_import(pkuseg_model: str, pkuseg_user_dict: str) -> None:
+def try_pkuseg_import(pkuseg_model: Optional[str], pkuseg_user_dict: Optional[str]):
try:
import spacy_pkuseg
@@ -319,9 +319,9 @@ def try_pkuseg_import(pkuseg_model: str, pkuseg_user_dict: str) -> None:
msg = "spacy-pkuseg not installed. To use pkuseg, " + _PKUSEG_INSTALL_MSG
raise ImportError(msg) from None
try:
- return spacy_pkuseg.pkuseg(pkuseg_model, pkuseg_user_dict)
+ return spacy_pkuseg.pkuseg(pkuseg_model, user_dict=pkuseg_user_dict)
except FileNotFoundError:
- msg = "Unable to load pkuseg model from: " + pkuseg_model
+ msg = "Unable to load pkuseg model from: " + str(pkuseg_model or "")
raise FileNotFoundError(msg) from None
diff --git a/spacy/language.py b/spacy/language.py
index fd3773f82..83de83702 100644
--- a/spacy/language.py
+++ b/spacy/language.py
@@ -1,6 +1,7 @@
-from typing import Iterator, Optional, Any, Dict, Callable, Iterable, TypeVar
-from typing import Union, List, Pattern, overload
-from typing import Tuple
+from typing import Iterator, Optional, Any, Dict, Callable, Iterable
+from typing import Union, Tuple, List, Set, Pattern, Sequence
+from typing import NoReturn, TYPE_CHECKING, TypeVar, cast, overload
+
from dataclasses import dataclass
import random
import itertools
@@ -9,13 +10,14 @@ from contextlib import contextmanager
from copy import deepcopy
from pathlib import Path
import warnings
-from thinc.api import get_current_ops, Config, Optimizer
+from thinc.api import get_current_ops, Config, CupyOps, Optimizer
import srsly
import multiprocessing as mp
from itertools import chain, cycle
from timeit import default_timer as timer
import traceback
+from . import ty
from .tokens.underscore import Underscore
from .vocab import Vocab, create_vocab
from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis
@@ -37,6 +39,11 @@ from .git_info import GIT_VERSION
from . import util
from . import about
from .lookups import load_lookups
+from .compat import Literal
+
+
+if TYPE_CHECKING:
+ from .pipeline import Pipe # noqa: F401
# This is the base config will all settings (training etc.)
@@ -46,6 +53,9 @@ DEFAULT_CONFIG = util.load_config(DEFAULT_CONFIG_PATH)
# in the main config and only added via the 'init fill-config' command
DEFAULT_CONFIG_PRETRAIN_PATH = Path(__file__).parent / "default_config_pretraining.cfg"
+# Type variable for contexts piped with documents
+_AnyContext = TypeVar("_AnyContext")
+
class BaseDefaults:
"""Language data defaults, available via Language.Defaults. Can be
@@ -55,14 +65,14 @@ class BaseDefaults:
config: Config = Config(section_order=CONFIG_SECTION_ORDER)
tokenizer_exceptions: Dict[str, List[dict]] = BASE_EXCEPTIONS
- prefixes: Optional[List[Union[str, Pattern]]] = TOKENIZER_PREFIXES
- suffixes: Optional[List[Union[str, Pattern]]] = TOKENIZER_SUFFIXES
- infixes: Optional[List[Union[str, Pattern]]] = TOKENIZER_INFIXES
- token_match: Optional[Pattern] = None
- url_match: Optional[Pattern] = URL_MATCH
+ prefixes: Optional[Sequence[Union[str, Pattern]]] = TOKENIZER_PREFIXES
+ suffixes: Optional[Sequence[Union[str, Pattern]]] = TOKENIZER_SUFFIXES
+ infixes: Optional[Sequence[Union[str, Pattern]]] = TOKENIZER_INFIXES
+ token_match: Optional[Callable] = None
+ url_match: Optional[Callable] = URL_MATCH
syntax_iterators: Dict[str, Callable] = {}
lex_attr_getters: Dict[int, Callable[[str], Any]] = {}
- stop_words = set()
+ stop_words: Set[str] = set()
writing_system = {"direction": "ltr", "has_case": True, "has_letters": True}
@@ -111,7 +121,7 @@ class Language:
"""
Defaults = BaseDefaults
- lang: str = None
+ lang: Optional[str] = None
default_config = DEFAULT_CONFIG
factories = SimpleFrozenDict(error=Errors.E957)
@@ -154,7 +164,7 @@ class Language:
self._config = DEFAULT_CONFIG.merge(self.default_config)
self._meta = dict(meta)
self._path = None
- self._optimizer = None
+ self._optimizer: Optional[Optimizer] = None
# Component meta and configs are only needed on the instance
self._pipe_meta: Dict[str, "FactoryMeta"] = {} # meta by component
self._pipe_configs: Dict[str, Config] = {} # config by component
@@ -170,8 +180,8 @@ class Language:
self.vocab: Vocab = vocab
if self.lang is None:
self.lang = self.vocab.lang
- self._components = []
- self._disabled = set()
+ self._components: List[Tuple[str, "Pipe"]] = []
+ self._disabled: Set[str] = set()
self.max_length = max_length
# Create the default tokenizer from the default config
if not create_tokenizer:
@@ -291,7 +301,7 @@ class Language:
return SimpleFrozenList(names)
@property
- def components(self) -> List[Tuple[str, Callable[[Doc], Doc]]]:
+ def components(self) -> List[Tuple[str, "Pipe"]]:
"""Get all (name, component) tuples in the pipeline, including the
currently disabled components.
"""
@@ -310,12 +320,12 @@ class Language:
return SimpleFrozenList(names, error=Errors.E926.format(attr="component_names"))
@property
- def pipeline(self) -> List[Tuple[str, Callable[[Doc], Doc]]]:
+ def pipeline(self) -> List[Tuple[str, "Pipe"]]:
"""The processing pipeline consisting of (name, component) tuples. The
components are called on the Doc in order as it passes through the
pipeline.
- RETURNS (List[Tuple[str, Callable[[Doc], Doc]]]): The pipeline.
+ RETURNS (List[Tuple[str, Pipe]]): The pipeline.
"""
pipes = [(n, p) for n, p in self._components if n not in self._disabled]
return SimpleFrozenList(pipes, error=Errors.E926.format(attr="pipeline"))
@@ -423,7 +433,7 @@ class Language:
assigns: Iterable[str] = SimpleFrozenList(),
requires: Iterable[str] = SimpleFrozenList(),
retokenizes: bool = False,
- default_score_weights: Dict[str, float] = SimpleFrozenDict(),
+ default_score_weights: Dict[str, Optional[float]] = SimpleFrozenDict(),
func: Optional[Callable] = None,
) -> Callable:
"""Register a new pipeline component factory. Can be used as a decorator
@@ -440,7 +450,7 @@ class Language:
e.g. "token.ent_id". Used for pipeline analysis.
retokenizes (bool): Whether the component changes the tokenization.
Used for pipeline analysis.
- default_score_weights (Dict[str, float]): The scores to report during
+ default_score_weights (Dict[str, Optional[float]]): The scores to report during
training, and their default weight towards the final score used to
select the best model. Weights should sum to 1.0 per component and
will be combined and normalized for the whole pipeline. If None,
@@ -505,12 +515,12 @@ class Language:
@classmethod
def component(
cls,
- name: Optional[str] = None,
+ name: str,
*,
assigns: Iterable[str] = SimpleFrozenList(),
requires: Iterable[str] = SimpleFrozenList(),
retokenizes: bool = False,
- func: Optional[Callable[[Doc], Doc]] = None,
+ func: Optional["Pipe"] = None,
) -> Callable:
"""Register a new pipeline component. Can be used for stateless function
components that don't require a separate factory. Can be used as a
@@ -533,11 +543,11 @@ class Language:
raise ValueError(Errors.E963.format(decorator="component"))
component_name = name if name is not None else util.get_object_name(func)
- def add_component(component_func: Callable[[Doc], Doc]) -> Callable:
+ def add_component(component_func: "Pipe") -> Callable:
if isinstance(func, type): # function is a class
raise ValueError(Errors.E965.format(name=component_name))
- def factory_func(nlp: cls, name: str) -> Callable[[Doc], Doc]:
+ def factory_func(nlp, name: str) -> "Pipe":
return component_func
internal_name = cls.get_factory_name(name)
@@ -587,7 +597,7 @@ class Language:
print_pipe_analysis(analysis, keys=keys)
return analysis
- def get_pipe(self, name: str) -> Callable[[Doc], Doc]:
+ def get_pipe(self, name: str) -> "Pipe":
"""Get a pipeline component for a given component name.
name (str): Name of pipeline component to get.
@@ -608,7 +618,7 @@ class Language:
config: Dict[str, Any] = SimpleFrozenDict(),
raw_config: Optional[Config] = None,
validate: bool = True,
- ) -> Callable[[Doc], Doc]:
+ ) -> "Pipe":
"""Create a pipeline component. Mostly used internally. To create and
add a component to the pipeline, you can use nlp.add_pipe.
@@ -620,7 +630,7 @@ class Language:
raw_config (Optional[Config]): Internals: the non-interpolated config.
validate (bool): Whether to validate the component config against the
arguments and types expected by the factory.
- RETURNS (Callable[[Doc], Doc]): The pipeline component.
+ RETURNS (Pipe): The pipeline component.
DOCS: https://spacy.io/api/language#create_pipe
"""
@@ -675,7 +685,7 @@ class Language:
def create_pipe_from_source(
self, source_name: str, source: "Language", *, name: str
- ) -> Tuple[Callable[[Doc], Doc], str]:
+ ) -> Tuple["Pipe", str]:
"""Create a pipeline component by copying it from an existing model.
source_name (str): Name of the component in the source pipeline.
@@ -725,7 +735,7 @@ class Language:
config: Dict[str, Any] = SimpleFrozenDict(),
raw_config: Optional[Config] = None,
validate: bool = True,
- ) -> Callable[[Doc], Doc]:
+ ) -> "Pipe":
"""Add a component to the processing pipeline. Valid components are
callables that take a `Doc` object, modify it and return it. Only one
of before/after/first/last can be set. Default behaviour is "last".
@@ -748,7 +758,7 @@ class Language:
raw_config (Optional[Config]): Internals: the non-interpolated config.
validate (bool): Whether to validate the component config against the
arguments and types expected by the factory.
- RETURNS (Callable[[Doc], Doc]): The pipeline component.
+ RETURNS (Pipe): The pipeline component.
DOCS: https://spacy.io/api/language#add_pipe
"""
@@ -859,7 +869,7 @@ class Language:
*,
config: Dict[str, Any] = SimpleFrozenDict(),
validate: bool = True,
- ) -> Callable[[Doc], Doc]:
+ ) -> "Pipe":
"""Replace a component in the pipeline.
name (str): Name of the component to replace.
@@ -868,7 +878,7 @@ class Language:
component. Will be merged with default config, if available.
validate (bool): Whether to validate the component config against the
arguments and types expected by the factory.
- RETURNS (Callable[[Doc], Doc]): The new pipeline component.
+ RETURNS (Pipe): The new pipeline component.
DOCS: https://spacy.io/api/language#replace_pipe
"""
@@ -920,7 +930,7 @@ class Language:
init_cfg = self._config["initialize"]["components"].pop(old_name)
self._config["initialize"]["components"][new_name] = init_cfg
- def remove_pipe(self, name: str) -> Tuple[str, Callable[[Doc], Doc]]:
+ def remove_pipe(self, name: str) -> Tuple[str, "Pipe"]:
"""Remove a component from the pipeline.
name (str): Name of the component to remove.
@@ -980,7 +990,7 @@ class Language:
text (Union[str, Doc]): If `str`, the text to be processed. If `Doc`,
the doc will be passed directly to the pipeline, skipping
`Language.make_doc`.
- disable (list): Names of the pipeline components to disable.
+ disable (List[str]): Names of the pipeline components to disable.
component_cfg (Dict[str, dict]): An optional dictionary with extra
keyword arguments for specific components.
RETURNS (Doc): A container for accessing the annotations.
@@ -999,7 +1009,7 @@ class Language:
if hasattr(proc, "get_error_handler"):
error_handler = proc.get_error_handler()
try:
- doc = proc(doc, **component_cfg.get(name, {}))
+ doc = proc(doc, **component_cfg.get(name, {})) # type: ignore[call-arg]
except KeyError as e:
# This typically happens if a component is not initialized
raise ValueError(Errors.E109.format(name=name)) from e
@@ -1019,7 +1029,7 @@ class Language:
"""
warnings.warn(Warnings.W096, DeprecationWarning)
if len(names) == 1 and isinstance(names[0], (list, tuple)):
- names = names[0] # support list of names instead of spread
+ names = names[0] # type: ignore[assignment] # support list of names instead of spread
return self.select_pipes(disable=names)
def select_pipes(
@@ -1054,6 +1064,7 @@ class Language:
)
)
disable = to_disable
+ assert disable is not None
# DisabledPipes will restore the pipes in 'disable' when it's done, so we need to exclude
# those pipes that were already disabled.
disable = [d for d in disable if d not in self._disabled]
@@ -1112,7 +1123,7 @@ class Language:
raise ValueError(Errors.E989)
if losses is None:
losses = {}
- if len(examples) == 0:
+ if isinstance(examples, list) and len(examples) == 0:
return losses
validate_examples(examples, "Language.update")
examples = _copy_examples(examples)
@@ -1129,12 +1140,13 @@ class Language:
component_cfg[name].setdefault("drop", drop)
pipe_kwargs[name].setdefault("batch_size", self.batch_size)
for name, proc in self.pipeline:
+ # ignore statements are used here because mypy ignores hasattr
if name not in exclude and hasattr(proc, "update"):
- proc.update(examples, sgd=None, losses=losses, **component_cfg[name])
+ proc.update(examples, sgd=None, losses=losses, **component_cfg[name]) # type: ignore
if sgd not in (None, False):
if (
name not in exclude
- and hasattr(proc, "is_trainable")
+ and isinstance(proc, ty.TrainableComponent)
and proc.is_trainable
and proc.model not in (True, False, None)
):
@@ -1184,8 +1196,10 @@ class Language:
DOCS: https://spacy.io/api/language#rehearse
"""
- if len(examples) == 0:
- return
+ if losses is None:
+ losses = {}
+ if isinstance(examples, list) and len(examples) == 0:
+ return losses
validate_examples(examples, "Language.rehearse")
if sgd is None:
if self._optimizer is None:
@@ -1200,18 +1214,18 @@ class Language:
def get_grads(W, dW, key=None):
grads[key] = (W, dW)
- get_grads.learn_rate = sgd.learn_rate
- get_grads.b1 = sgd.b1
- get_grads.b2 = sgd.b2
+ get_grads.learn_rate = sgd.learn_rate # type: ignore[attr-defined, union-attr]
+ get_grads.b1 = sgd.b1 # type: ignore[attr-defined, union-attr]
+ get_grads.b2 = sgd.b2 # type: ignore[attr-defined, union-attr]
for name, proc in pipes:
if name in exclude or not hasattr(proc, "rehearse"):
continue
grads = {}
- proc.rehearse(
+ proc.rehearse( # type: ignore[attr-defined]
examples, sgd=get_grads, losses=losses, **component_cfg.get(name, {})
)
for key, (W, dW) in grads.items():
- sgd(W, dW, key=key)
+ sgd(W, dW, key=key) # type: ignore[call-arg, misc]
return losses
def begin_training(
@@ -1268,14 +1282,14 @@ class Language:
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
if hasattr(self.tokenizer, "initialize"):
tok_settings = validate_init_settings(
- self.tokenizer.initialize,
+ self.tokenizer.initialize, # type: ignore[union-attr]
I["tokenizer"],
section="tokenizer",
name="tokenizer",
)
- self.tokenizer.initialize(get_examples, nlp=self, **tok_settings)
+ self.tokenizer.initialize(get_examples, nlp=self, **tok_settings) # type: ignore[union-attr]
for name, proc in self.pipeline:
- if hasattr(proc, "initialize"):
+ if isinstance(proc, ty.InitializableComponent):
p_settings = I["components"].get(name, {})
p_settings = validate_init_settings(
proc.initialize, p_settings, section="components", name=name
@@ -1314,7 +1328,7 @@ class Language:
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
for name, proc in self.pipeline:
if hasattr(proc, "_rehearsal_model"):
- proc._rehearsal_model = deepcopy(proc.model)
+ proc._rehearsal_model = deepcopy(proc.model) # type: ignore[attr-defined]
if sgd is not None:
self._optimizer = sgd
elif self._optimizer is None:
@@ -1323,14 +1337,12 @@ class Language:
def set_error_handler(
self,
- error_handler: Callable[
- [str, Callable[[Doc], Doc], List[Doc], Exception], None
- ],
+ error_handler: Callable[[str, "Pipe", List[Doc], Exception], NoReturn],
):
"""Set an error handler object for all the components in the pipeline that implement
a set_error_handler function.
- error_handler (Callable[[str, Callable[[Doc], Doc], List[Doc], Exception], None]):
+ error_handler (Callable[[str, Pipe, List[Doc], Exception], NoReturn]):
Function that deals with a failing batch of documents. This callable function should take in
the component's name, the component itself, the offending batch of documents, and the exception
that was thrown.
@@ -1349,7 +1361,7 @@ class Language:
scorer: Optional[Scorer] = None,
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
scorer_cfg: Optional[Dict[str, Any]] = None,
- ) -> Dict[str, Union[float, dict]]:
+ ) -> Dict[str, Any]:
"""Evaluate a model's pipeline components.
examples (Iterable[Example]): `Example` objects.
@@ -1427,7 +1439,7 @@ class Language:
yield
else:
contexts = [
- pipe.use_params(params)
+ pipe.use_params(params) # type: ignore[attr-defined]
for name, pipe in self.pipeline
if hasattr(pipe, "use_params") and hasattr(pipe, "model")
]
@@ -1445,14 +1457,25 @@ class Language:
except StopIteration:
pass
- _AnyContext = TypeVar("_AnyContext")
-
@overload
def pipe(
self,
- texts: Iterable[Tuple[Union[str, Doc], _AnyContext]],
+ texts: Iterable[Union[str, Doc]],
*,
- as_tuples: bool = ...,
+ as_tuples: Literal[False] = ...,
+ batch_size: Optional[int] = ...,
+ disable: Iterable[str] = ...,
+ component_cfg: Optional[Dict[str, Dict[str, Any]]] = ...,
+ n_process: int = ...,
+ ) -> Iterator[Doc]:
+ ...
+
+ @overload
+ def pipe( # noqa: F811
+ self,
+ texts: Iterable[Tuple[str, _AnyContext]],
+ *,
+ as_tuples: Literal[True] = ...,
batch_size: Optional[int] = ...,
disable: Iterable[str] = ...,
component_cfg: Optional[Dict[str, Dict[str, Any]]] = ...,
@@ -1462,14 +1485,14 @@ class Language:
def pipe( # noqa: F811
self,
- texts: Iterable[Union[str, Doc]],
+ texts: Union[Iterable[Union[str, Doc]], Iterable[Tuple[str, _AnyContext]]],
*,
as_tuples: bool = False,
batch_size: Optional[int] = None,
disable: Iterable[str] = SimpleFrozenList(),
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
n_process: int = 1,
- ) -> Iterator[Doc]:
+ ) -> Union[Iterator[Doc], Iterator[Tuple[Doc, _AnyContext]]]:
"""Process texts as a stream, and yield `Doc` objects in order.
texts (Iterable[Union[str, Doc]]): A sequence of texts or docs to
@@ -1486,9 +1509,9 @@ class Language:
DOCS: https://spacy.io/api/language#pipe
"""
- if n_process == -1:
- n_process = mp.cpu_count()
+ # Handle texts with context as tuples
if as_tuples:
+ texts = cast(Iterable[Tuple[str, _AnyContext]], texts)
text_context1, text_context2 = itertools.tee(texts)
texts = (tc[0] for tc in text_context1)
contexts = (tc[1] for tc in text_context2)
@@ -1502,6 +1525,13 @@ class Language:
for doc, context in zip(docs, contexts):
yield (doc, context)
return
+
+ # At this point, we know that we're dealing with an iterable of plain texts
+ texts = cast(Iterable[str], texts)
+
+ # Set argument defaults
+ if n_process == -1:
+ n_process = mp.cpu_count()
if component_cfg is None:
component_cfg = {}
if batch_size is None:
@@ -1526,6 +1556,9 @@ class Language:
pipes.append(f)
if n_process != 1:
+ if self._has_gpu_model(disable):
+ warnings.warn(Warnings.W114)
+
docs = self._multiprocessing_pipe(texts, pipes, n_process, batch_size)
else:
# if n_process == 1, no processes are forked.
@@ -1535,17 +1568,28 @@ class Language:
for doc in docs:
yield doc
+ def _has_gpu_model(self, disable: Iterable[str]):
+ for name, proc in self.pipeline:
+ is_trainable = hasattr(proc, "is_trainable") and proc.is_trainable # type: ignore
+ if name in disable or not is_trainable:
+ continue
+
+ if hasattr(proc, "model") and hasattr(proc.model, "ops") and isinstance(proc.model.ops, CupyOps): # type: ignore
+ return True
+
+ return False
+
def _multiprocessing_pipe(
self,
texts: Iterable[str],
- pipes: Iterable[Callable[[Doc], Doc]],
+ pipes: Iterable[Callable[..., Iterator[Doc]]],
n_process: int,
batch_size: int,
- ) -> None:
+ ) -> Iterator[Doc]:
# raw_texts is used later to stop iteration.
texts, raw_texts = itertools.tee(texts)
# for sending texts to worker
- texts_q = [mp.Queue() for _ in range(n_process)]
+ texts_q: List[mp.Queue] = [mp.Queue() for _ in range(n_process)]
# for receiving byte-encoded docs from worker
bytedocs_recv_ch, bytedocs_send_ch = zip(
*[mp.Pipe(False) for _ in range(n_process)]
@@ -1604,7 +1648,7 @@ class Language:
# components don't receive the pipeline then. So this does have to be
# here :(
for i, (name1, proc1) in enumerate(self.pipeline):
- if hasattr(proc1, "find_listeners"):
+ if isinstance(proc1, ty.ListenedToComponent):
for name2, proc2 in self.pipeline[i + 1 :]:
proc1.find_listeners(proc2)
@@ -1792,25 +1836,25 @@ class Language:
)
# Detect components with listeners that are not frozen consistently
for name, proc in nlp.pipeline:
- # Remove listeners not in the pipeline
- listener_names = getattr(proc, "listening_components", [])
- unused_listener_names = [
- ll for ll in listener_names if ll not in nlp.pipe_names
- ]
- for listener_name in unused_listener_names:
- for listener in proc.listener_map.get(listener_name, []):
- proc.remove_listener(listener, listener_name)
+ if isinstance(proc, ty.ListenedToComponent):
+ # Remove listeners not in the pipeline
+ listener_names = proc.listening_components
+ unused_listener_names = [
+ ll for ll in listener_names if ll not in nlp.pipe_names
+ ]
+ for listener_name in unused_listener_names:
+ for listener in proc.listener_map.get(listener_name, []):
+ proc.remove_listener(listener, listener_name)
- for listener in getattr(
- proc, "listening_components", []
- ): # e.g. tok2vec/transformer
- # If it's a component sourced from another pipeline, we check if
- # the tok2vec listeners should be replaced with standalone tok2vec
- # models (e.g. so component can be frozen without its performance
- # degrading when other components/tok2vec are updated)
- paths = sourced.get(listener, {}).get("replace_listeners", [])
- if paths:
- nlp.replace_listeners(name, listener, paths)
+ for listener_name in proc.listening_components:
+ # e.g. tok2vec/transformer
+ # If it's a component sourced from another pipeline, we check if
+ # the tok2vec listeners should be replaced with standalone tok2vec
+ # models (e.g. so component can be frozen without its performance
+ # degrading when other components/tok2vec are updated)
+ paths = sourced.get(listener_name, {}).get("replace_listeners", [])
+ if paths:
+ nlp.replace_listeners(name, listener_name, paths)
return nlp
def replace_listeners(
@@ -1860,20 +1904,15 @@ class Language:
raise ValueError(err)
tok2vec = self.get_pipe(tok2vec_name)
tok2vec_cfg = self.get_pipe_config(tok2vec_name)
- tok2vec_model = tok2vec.model
- if (
- not hasattr(tok2vec, "model")
- or not hasattr(tok2vec, "listener_map")
- or not hasattr(tok2vec, "remove_listener")
- or "model" not in tok2vec_cfg
- ):
+ if not isinstance(tok2vec, ty.ListenedToComponent):
raise ValueError(Errors.E888.format(name=tok2vec_name, pipe=type(tok2vec)))
+ tok2vec_model = tok2vec.model
pipe_listeners = tok2vec.listener_map.get(pipe_name, [])
pipe = self.get_pipe(pipe_name)
pipe_cfg = self._pipe_configs[pipe_name]
if listeners:
util.logger.debug(f"Replacing listeners of component '{pipe_name}'")
- if len(listeners) != len(pipe_listeners):
+ if len(list(listeners)) != len(pipe_listeners):
# The number of listeners defined in the component model doesn't
# match the listeners to replace, so we won't be able to update
# the nodes and generate a matching config
@@ -1907,7 +1946,7 @@ class Language:
new_model = tok2vec_model.copy()
if "replace_listener" in tok2vec_model.attrs:
new_model = tok2vec_model.attrs["replace_listener"](new_model)
- util.replace_model_node(pipe.model, listener, new_model)
+ util.replace_model_node(pipe.model, listener, new_model) # type: ignore[attr-defined]
tok2vec.remove_listener(listener, pipe_name)
def to_disk(
@@ -1918,13 +1957,13 @@ class Language:
path (str / Path): Path to a directory, which will be created if
it doesn't exist.
- exclude (list): Names of components or serialization fields to exclude.
+ exclude (Iterable[str]): Names of components or serialization fields to exclude.
DOCS: https://spacy.io/api/language#to_disk
"""
path = util.ensure_path(path)
serializers = {}
- serializers["tokenizer"] = lambda p: self.tokenizer.to_disk(
+ serializers["tokenizer"] = lambda p: self.tokenizer.to_disk( # type: ignore[union-attr]
p, exclude=["vocab"]
)
serializers["meta.json"] = lambda p: srsly.write_json(p, self.meta)
@@ -1934,7 +1973,7 @@ class Language:
continue
if not hasattr(proc, "to_disk"):
continue
- serializers[name] = lambda p, proc=proc: proc.to_disk(p, exclude=["vocab"])
+ serializers[name] = lambda p, proc=proc: proc.to_disk(p, exclude=["vocab"]) # type: ignore[misc]
serializers["vocab"] = lambda p: self.vocab.to_disk(p, exclude=exclude)
util.to_disk(path, serializers, exclude)
@@ -1950,7 +1989,7 @@ class Language:
model will be loaded.
path (str / Path): A path to a directory.
- exclude (list): Names of components or serialization fields to exclude.
+ exclude (Iterable[str]): Names of components or serialization fields to exclude.
RETURNS (Language): The modified `Language` object.
DOCS: https://spacy.io/api/language#from_disk
@@ -1970,13 +2009,13 @@ class Language:
path = util.ensure_path(path)
deserializers = {}
- if Path(path / "config.cfg").exists():
+ if Path(path / "config.cfg").exists(): # type: ignore[operator]
deserializers["config.cfg"] = lambda p: self.config.from_disk(
p, interpolate=False, overrides=overrides
)
- deserializers["meta.json"] = deserialize_meta
- deserializers["vocab"] = deserialize_vocab
- deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk(
+ deserializers["meta.json"] = deserialize_meta # type: ignore[assignment]
+ deserializers["vocab"] = deserialize_vocab # type: ignore[assignment]
+ deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk( # type: ignore[union-attr]
p, exclude=["vocab"]
)
for name, proc in self._components:
@@ -1984,28 +2023,28 @@ class Language:
continue
if not hasattr(proc, "from_disk"):
continue
- deserializers[name] = lambda p, proc=proc: proc.from_disk(
+ deserializers[name] = lambda p, proc=proc: proc.from_disk( # type: ignore[misc]
p, exclude=["vocab"]
)
- if not (path / "vocab").exists() and "vocab" not in exclude:
+ if not (path / "vocab").exists() and "vocab" not in exclude: # type: ignore[operator]
# Convert to list here in case exclude is (default) tuple
exclude = list(exclude) + ["vocab"]
- util.from_disk(path, deserializers, exclude)
- self._path = path
+ util.from_disk(path, deserializers, exclude) # type: ignore[arg-type]
+ self._path = path # type: ignore[assignment]
self._link_components()
return self
def to_bytes(self, *, exclude: Iterable[str] = SimpleFrozenList()) -> bytes:
"""Serialize the current state to a binary string.
- exclude (list): Names of components or serialization fields to exclude.
+ exclude (Iterable[str]): Names of components or serialization fields to exclude.
RETURNS (bytes): The serialized form of the `Language` object.
DOCS: https://spacy.io/api/language#to_bytes
"""
- serializers = {}
+ serializers: Dict[str, Callable[[], bytes]] = {}
serializers["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude)
- serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"])
+ serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"]) # type: ignore[union-attr]
serializers["meta.json"] = lambda: srsly.json_dumps(self.meta)
serializers["config.cfg"] = lambda: self.config.to_bytes()
for name, proc in self._components:
@@ -2013,7 +2052,7 @@ class Language:
continue
if not hasattr(proc, "to_bytes"):
continue
- serializers[name] = lambda proc=proc: proc.to_bytes(exclude=["vocab"])
+ serializers[name] = lambda proc=proc: proc.to_bytes(exclude=["vocab"]) # type: ignore[misc]
return util.to_bytes(serializers, exclude)
def from_bytes(
@@ -2022,7 +2061,7 @@ class Language:
"""Load state from a binary string.
bytes_data (bytes): The data to load from.
- exclude (list): Names of components or serialization fields to exclude.
+ exclude (Iterable[str]): Names of components or serialization fields to exclude.
RETURNS (Language): The `Language` object.
DOCS: https://spacy.io/api/language#from_bytes
@@ -2035,13 +2074,13 @@ class Language:
# from self.vocab.vectors, so set the name directly
self.vocab.vectors.name = data.get("vectors", {}).get("name")
- deserializers = {}
+ deserializers: Dict[str, Callable[[bytes], Any]] = {}
deserializers["config.cfg"] = lambda b: self.config.from_bytes(
b, interpolate=False
)
deserializers["meta.json"] = deserialize_meta
deserializers["vocab"] = lambda b: self.vocab.from_bytes(b, exclude=exclude)
- deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes(
+ deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes( # type: ignore[union-attr]
b, exclude=["vocab"]
)
for name, proc in self._components:
@@ -2049,7 +2088,7 @@ class Language:
continue
if not hasattr(proc, "from_bytes"):
continue
- deserializers[name] = lambda b, proc=proc: proc.from_bytes(
+ deserializers[name] = lambda b, proc=proc: proc.from_bytes( # type: ignore[misc]
b, exclude=["vocab"]
)
util.from_bytes(bytes_data, deserializers, exclude)
@@ -2071,7 +2110,7 @@ class FactoryMeta:
requires: Iterable[str] = tuple()
retokenizes: bool = False
scores: Iterable[str] = tuple()
- default_score_weights: Optional[Dict[str, float]] = None # noqa: E704
+ default_score_weights: Optional[Dict[str, Optional[float]]] = None # noqa: E704
class DisabledPipes(list):
@@ -2111,7 +2150,7 @@ def _copy_examples(examples: Iterable[Example]) -> List[Example]:
def _apply_pipes(
ensure_doc: Callable[[Union[str, Doc]], Doc],
- pipes: Iterable[Callable[[Doc], Doc]],
+ pipes: Iterable[Callable[..., Iterator[Doc]]],
receiver,
sender,
underscore_state: Tuple[dict, dict, dict],
@@ -2120,7 +2159,7 @@ def _apply_pipes(
ensure_doc (Callable[[Union[str, Doc]], Doc]): Function to create Doc from text
or raise an error if the input is neither a Doc nor a string.
- pipes (Iterable[Callable[[Doc], Doc]]): The components to apply.
+ pipes (Iterable[Pipe]): The components to apply.
receiver (multiprocessing.Connection): Pipe to receive text. Usually
created by `multiprocessing.Pipe()`
sender (multiprocessing.Connection): Pipe to send doc. Usually created by
@@ -2134,11 +2173,11 @@ def _apply_pipes(
texts = receiver.get()
docs = (ensure_doc(text) for text in texts)
for pipe in pipes:
- docs = pipe(docs)
+ docs = pipe(docs) # type: ignore[arg-type, assignment]
# Connection does not accept unpickable objects, so send list.
byte_docs = [(doc.to_bytes(), None) for doc in docs]
padding = [(None, None)] * (len(texts) - len(byte_docs))
- sender.send(byte_docs + padding)
+ sender.send(byte_docs + padding) # type: ignore[operator]
except Exception:
error_msg = [(None, srsly.msgpack_dumps(traceback.format_exc()))]
padding = [(None, None)] * (len(texts) - 1)
diff --git a/spacy/lookups.py b/spacy/lookups.py
index 025afa04b..b2f3dc15e 100644
--- a/spacy/lookups.py
+++ b/spacy/lookups.py
@@ -1,4 +1,4 @@
-from typing import Any, List, Union, Optional
+from typing import Any, List, Union, Optional, Dict
from pathlib import Path
import srsly
from preshed.bloom import BloomFilter
@@ -34,9 +34,9 @@ def load_lookups(lang: str, tables: List[str], strict: bool = True) -> "Lookups"
if table not in data:
if strict:
raise ValueError(Errors.E955.format(table=table, lang=lang))
- language_data = {}
+ language_data = {} # type: ignore[var-annotated]
else:
- language_data = load_language_data(data[table])
+ language_data = load_language_data(data[table]) # type: ignore[assignment]
lookups.add_table(table, language_data)
return lookups
@@ -116,7 +116,7 @@ class Table(OrderedDict):
key = get_string_id(key)
return OrderedDict.get(self, key, default)
- def __contains__(self, key: Union[str, int]) -> bool:
+ def __contains__(self, key: Union[str, int]) -> bool: # type: ignore[override]
"""Check whether a key is in the table. String keys will be hashed.
key (str / int): The key to check.
@@ -172,7 +172,7 @@ class Lookups:
DOCS: https://spacy.io/api/lookups#init
"""
- self._tables = {}
+ self._tables: Dict[str, Table] = {}
def __contains__(self, name: str) -> bool:
"""Check if the lookups contain a table of a given name. Delegates to
diff --git a/spacy/matcher/matcher.pyi b/spacy/matcher/matcher.pyi
index 3be065bcd..ec4a88eaf 100644
--- a/spacy/matcher/matcher.pyi
+++ b/spacy/matcher/matcher.pyi
@@ -9,7 +9,7 @@ class Matcher:
def __contains__(self, key: str) -> bool: ...
def add(
self,
- key: str,
+ key: Union[str, int],
patterns: List[List[Dict[str, Any]]],
*,
on_match: Optional[
@@ -39,3 +39,4 @@ class Matcher:
allow_missing: bool = ...,
with_alignments: bool = ...
) -> Union[List[Tuple[int, int, int]], List[Span]]: ...
+ def _normalize_key(self, key: Any) -> Any: ...
diff --git a/spacy/matcher/matcher.pyx b/spacy/matcher/matcher.pyx
index e080ce7fd..745d7cf43 100644
--- a/spacy/matcher/matcher.pyx
+++ b/spacy/matcher/matcher.pyx
@@ -96,12 +96,10 @@ cdef class Matcher:
by returning a non-overlapping set per key, either taking preference to
the first greedy match ("FIRST"), or the longest ("LONGEST").
- As of spaCy v2.2.2, Matcher.add supports the future API, which makes
- the patterns the second argument and a list (instead of a variable
- number of arguments). The on_match callback becomes an optional keyword
- argument.
+ Since spaCy v2.2.2, Matcher.add takes a list of patterns as the second
+ argument, and the on_match callback is an optional keyword argument.
- key (str): The match ID.
+ key (Union[str, int]): The match ID.
patterns (list): The patterns to add for the given key.
on_match (callable): Optional callback executed on match.
greedy (str): Optional filter: "FIRST" or "LONGEST".
diff --git a/spacy/matcher/phrasematcher.pyi b/spacy/matcher/phrasematcher.pyi
new file mode 100644
index 000000000..d73633ec0
--- /dev/null
+++ b/spacy/matcher/phrasematcher.pyi
@@ -0,0 +1,25 @@
+from typing import List, Tuple, Union, Optional, Callable, Any, Dict
+
+from . import Matcher
+from ..vocab import Vocab
+from ..tokens import Doc, Span
+
+class PhraseMatcher:
+ def __init__(
+ self, vocab: Vocab, attr: Optional[Union[int, str]], validate: bool = ...
+ ) -> None: ...
+ def __call__(
+ self,
+ doclike: Union[Doc, Span],
+ *,
+ as_spans: bool = ...,
+ ) -> Union[List[Tuple[int, int, int]], List[Span]]: ...
+ def add(
+ self,
+ key: str,
+ docs: List[List[Dict[str, Any]]],
+ *,
+ on_match: Optional[
+ Callable[[Matcher, Doc, int, List[Tuple[Any, ...]]], Any]
+ ] = ...,
+ ) -> None: ...
diff --git a/spacy/matcher/phrasematcher.pyx b/spacy/matcher/phrasematcher.pyx
index d8486b84b..2ff5105ad 100644
--- a/spacy/matcher/phrasematcher.pyx
+++ b/spacy/matcher/phrasematcher.pyx
@@ -157,9 +157,8 @@ cdef class PhraseMatcher:
"""Add a match-rule to the phrase-matcher. A match-rule consists of: an ID
key, an on_match callback, and one or more patterns.
- As of spaCy v2.2.2, PhraseMatcher.add supports the future API, which
- makes the patterns the second argument and a list (instead of a variable
- number of arguments). The on_match callback becomes an optional keyword
+ Since spaCy v2.2.2, PhraseMatcher.add takes a list of patterns as the
+ second argument, with the on_match callback as an optional keyword
argument.
key (str): The match ID.
diff --git a/spacy/ml/__init__.py b/spacy/ml/__init__.py
index c382d915b..fce8ae5af 100644
--- a/spacy/ml/__init__.py
+++ b/spacy/ml/__init__.py
@@ -1 +1,2 @@
+from .callbacks import create_models_with_nvtx_range # noqa: F401
from .models import * # noqa: F401, F403
diff --git a/spacy/ml/_character_embed.py b/spacy/ml/_character_embed.py
index 0ed28b859..e46735102 100644
--- a/spacy/ml/_character_embed.py
+++ b/spacy/ml/_character_embed.py
@@ -44,7 +44,7 @@ def forward(model: Model, docs: List[Doc], is_train: bool):
# Let's say I have a 2d array of indices, and a 3d table of data. What numpy
# incantation do I chant to get
# output[i, j, k] == data[j, ids[i, j], k]?
- doc_vectors[:, nCv] = E[nCv, doc_ids[:, nCv]]
+ doc_vectors[:, nCv] = E[nCv, doc_ids[:, nCv]] # type: ignore[call-overload, index]
output.append(doc_vectors.reshape((len(doc), nO)))
ids.append(doc_ids)
diff --git a/spacy/ml/callbacks.py b/spacy/ml/callbacks.py
new file mode 100644
index 000000000..b0d088182
--- /dev/null
+++ b/spacy/ml/callbacks.py
@@ -0,0 +1,39 @@
+from functools import partial
+from typing import Type, Callable, TYPE_CHECKING
+
+from thinc.layers import with_nvtx_range
+from thinc.model import Model, wrap_model_recursive
+
+from ..util import registry
+
+if TYPE_CHECKING:
+ # This lets us add type hints for mypy etc. without causing circular imports
+ from ..language import Language # noqa: F401
+
+
+@registry.callbacks("spacy.models_with_nvtx_range.v1")
+def create_models_with_nvtx_range(
+ forward_color: int = -1, backprop_color: int = -1
+) -> Callable[["Language"], "Language"]:
+ def models_with_nvtx_range(nlp):
+ pipes = [
+ pipe
+ for _, pipe in nlp.components
+ if hasattr(pipe, "is_trainable") and pipe.is_trainable
+ ]
+
+ # We need process all models jointly to avoid wrapping callbacks twice.
+ models = Model(
+ "wrap_with_nvtx_range",
+ forward=lambda model, X, is_train: ...,
+ layers=[pipe.model for pipe in pipes],
+ )
+
+ for node in models.walk():
+ with_nvtx_range(
+ node, forward_color=forward_color, backprop_color=backprop_color
+ )
+
+ return nlp
+
+ return models_with_nvtx_range
diff --git a/spacy/ml/extract_ngrams.py b/spacy/ml/extract_ngrams.py
index c1c2929fd..c9c82f369 100644
--- a/spacy/ml/extract_ngrams.py
+++ b/spacy/ml/extract_ngrams.py
@@ -6,7 +6,7 @@ from ..attrs import LOWER
@registry.layers("spacy.extract_ngrams.v1")
def extract_ngrams(ngram_size: int, attr: int = LOWER) -> Model:
- model = Model("extract_ngrams", forward)
+ model: Model = Model("extract_ngrams", forward)
model.attrs["ngram_size"] = ngram_size
model.attrs["attr"] = attr
return model
@@ -19,7 +19,7 @@ def forward(model: Model, docs, is_train: bool):
unigrams = model.ops.asarray(doc.to_array([model.attrs["attr"]]))
ngrams = [unigrams]
for n in range(2, model.attrs["ngram_size"] + 1):
- ngrams.append(model.ops.ngrams(n, unigrams))
+ ngrams.append(model.ops.ngrams(n, unigrams)) # type: ignore[arg-type]
keys = model.ops.xp.concatenate(ngrams)
keys, vals = model.ops.xp.unique(keys, return_counts=True)
batch_keys.append(keys)
diff --git a/spacy/ml/extract_spans.py b/spacy/ml/extract_spans.py
index 8afd1a3cc..9bc972032 100644
--- a/spacy/ml/extract_spans.py
+++ b/spacy/ml/extract_spans.py
@@ -28,13 +28,13 @@ def forward(
X, spans = source_spans
assert spans.dataXd.ndim == 2
indices = _get_span_indices(ops, spans, X.lengths)
- Y = Ragged(X.dataXd[indices], spans.dataXd[:, 1] - spans.dataXd[:, 0])
+ Y = Ragged(X.dataXd[indices], spans.dataXd[:, 1] - spans.dataXd[:, 0]) # type: ignore[arg-type, index]
x_shape = X.dataXd.shape
x_lengths = X.lengths
def backprop_windows(dY: Ragged) -> Tuple[Ragged, Ragged]:
dX = Ragged(ops.alloc2f(*x_shape), x_lengths)
- ops.scatter_add(dX.dataXd, indices, dY.dataXd)
+ ops.scatter_add(dX.dataXd, indices, dY.dataXd) # type: ignore[arg-type]
return (dX, spans)
return Y, backprop_windows
@@ -51,7 +51,7 @@ def _get_span_indices(ops, spans: Ragged, lengths: Ints1d) -> Ints1d:
for i, length in enumerate(lengths):
spans_i = spans[i].dataXd + offset
for j in range(spans_i.shape[0]):
- indices.append(ops.xp.arange(spans_i[j, 0], spans_i[j, 1]))
+ indices.append(ops.xp.arange(spans_i[j, 0], spans_i[j, 1])) # type: ignore[call-overload, index]
offset += length
return ops.flatten(indices)
diff --git a/spacy/ml/models/entity_linker.py b/spacy/ml/models/entity_linker.py
index 645b67c62..831fee90f 100644
--- a/spacy/ml/models/entity_linker.py
+++ b/spacy/ml/models/entity_linker.py
@@ -1,16 +1,19 @@
from pathlib import Path
-from typing import Optional, Callable, Iterable
+from typing import Optional, Callable, Iterable, List
+from thinc.types import Floats2d
from thinc.api import chain, clone, list2ragged, reduce_mean, residual
from thinc.api import Model, Maxout, Linear
from ...util import registry
from ...kb import KnowledgeBase, Candidate, get_candidates
from ...vocab import Vocab
-from ...tokens import Span
+from ...tokens import Span, Doc
@registry.architectures("spacy.EntityLinker.v1")
-def build_nel_encoder(tok2vec: Model, nO: Optional[int] = None) -> Model:
+def build_nel_encoder(
+ tok2vec: Model, nO: Optional[int] = None
+) -> Model[List[Doc], Floats2d]:
with Model.define_operators({">>": chain, "**": clone}):
token_width = tok2vec.maybe_get_dim("nO")
output_layer = Linear(nO=nO, nI=token_width)
@@ -18,7 +21,7 @@ def build_nel_encoder(tok2vec: Model, nO: Optional[int] = None) -> Model:
tok2vec
>> list2ragged()
>> reduce_mean()
- >> residual(Maxout(nO=token_width, nI=token_width, nP=2, dropout=0.0))
+ >> residual(Maxout(nO=token_width, nI=token_width, nP=2, dropout=0.0)) # type: ignore[arg-type]
>> output_layer
)
model.set_ref("output_layer", output_layer)
diff --git a/spacy/ml/models/multi_task.py b/spacy/ml/models/multi_task.py
index 97bef2d0e..37473b7f4 100644
--- a/spacy/ml/models/multi_task.py
+++ b/spacy/ml/models/multi_task.py
@@ -1,7 +1,9 @@
-from typing import Optional, Iterable, Tuple, List, Callable, TYPE_CHECKING
+from typing import Any, Optional, Iterable, Tuple, List, Callable, TYPE_CHECKING, cast
+from thinc.types import Floats2d
from thinc.api import chain, Maxout, LayerNorm, Softmax, Linear, zero_init, Model
from thinc.api import MultiSoftmax, list2array
from thinc.api import to_categorical, CosineDistance, L2Distance
+from thinc.loss import Loss
from ...util import registry, OOV_RANK
from ...errors import Errors
@@ -30,6 +32,7 @@ def create_pretrain_vectors(
return model
def create_vectors_loss() -> Callable:
+ distance: Loss
if loss == "cosine":
distance = CosineDistance(normalize=True, ignore_zeros=True)
return partial(get_vectors_loss, distance=distance)
@@ -115,7 +118,7 @@ def build_cloze_multi_task_model(
) -> Model:
nO = vocab.vectors.data.shape[1]
output_layer = chain(
- list2array(),
+ cast(Model[List["Floats2d"], Floats2d], list2array()),
Maxout(
nO=hidden_size,
nI=tok2vec.get_dim("nO"),
@@ -136,10 +139,10 @@ def build_cloze_characters_multi_task_model(
vocab: "Vocab", tok2vec: Model, maxout_pieces: int, hidden_size: int, nr_char: int
) -> Model:
output_layer = chain(
- list2array(),
+ cast(Model[List["Floats2d"], Floats2d], list2array()),
Maxout(nO=hidden_size, nP=maxout_pieces),
LayerNorm(nI=hidden_size),
- MultiSoftmax([256] * nr_char, nI=hidden_size),
+ MultiSoftmax([256] * nr_char, nI=hidden_size), # type: ignore[arg-type]
)
model = build_masked_language_model(vocab, chain(tok2vec, output_layer))
model.set_ref("tok2vec", tok2vec)
@@ -171,7 +174,7 @@ def build_masked_language_model(
if wrapped.has_dim(dim):
model.set_dim(dim, wrapped.get_dim(dim))
- mlm_model = Model(
+ mlm_model: Model = Model(
"masked-language-model",
mlm_forward,
layers=[wrapped_model],
@@ -185,13 +188,19 @@ def build_masked_language_model(
class _RandomWords:
def __init__(self, vocab: "Vocab") -> None:
+ # Extract lexeme representations
self.words = [lex.text for lex in vocab if lex.prob != 0.0]
- self.probs = [lex.prob for lex in vocab if lex.prob != 0.0]
self.words = self.words[:10000]
- self.probs = self.probs[:10000]
- self.probs = numpy.exp(numpy.array(self.probs, dtype="f"))
- self.probs /= self.probs.sum()
- self._cache = []
+
+ # Compute normalized lexeme probabilities
+ probs = [lex.prob for lex in vocab if lex.prob != 0.0]
+ probs = probs[:10000]
+ probs: numpy.ndarray = numpy.exp(numpy.array(probs, dtype="f"))
+ probs /= probs.sum()
+ self.probs = probs
+
+ # Initialize cache
+ self._cache: List[int] = []
def next(self) -> str:
if not self._cache:
diff --git a/spacy/ml/models/parser.py b/spacy/ml/models/parser.py
index 97137313d..63284e766 100644
--- a/spacy/ml/models/parser.py
+++ b/spacy/ml/models/parser.py
@@ -1,4 +1,4 @@
-from typing import Optional, List
+from typing import Optional, List, cast
from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops
from thinc.types import Floats2d
@@ -70,7 +70,11 @@ def build_tb_parser_model(
else:
raise ValueError(Errors.E917.format(value=state_type))
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
- tok2vec = chain(tok2vec, list2array(), Linear(hidden_width, t2v_width))
+ tok2vec = chain(
+ tok2vec,
+ cast(Model[List["Floats2d"], Floats2d], list2array()),
+ Linear(hidden_width, t2v_width),
+ )
tok2vec.set_dim("nO", hidden_width)
lower = _define_lower(
nO=hidden_width if use_upper else nO,
@@ -80,7 +84,7 @@ def build_tb_parser_model(
)
upper = None
if use_upper:
- with use_ops("numpy"):
+ with use_ops("cpu"):
# Initialize weights at zero, as it's a classification layer.
upper = _define_upper(nO=nO, nI=None)
return TransitionModel(tok2vec, lower, upper, resize_output)
@@ -110,7 +114,7 @@ def _resize_upper(model, new_nO):
smaller = upper
nI = smaller.maybe_get_dim("nI")
- with use_ops("numpy"):
+ with use_ops("cpu"):
larger = _define_upper(nO=new_nO, nI=nI)
# it could be that the model is not initialized yet, then skip this bit
if smaller.has_param("W"):
diff --git a/spacy/ml/models/spancat.py b/spacy/ml/models/spancat.py
index 5c49fef40..893db2e6d 100644
--- a/spacy/ml/models/spancat.py
+++ b/spacy/ml/models/spancat.py
@@ -1,4 +1,4 @@
-from typing import List, Tuple
+from typing import List, Tuple, cast
from thinc.api import Model, with_getitem, chain, list2ragged, Logistic
from thinc.api import Maxout, Linear, concatenate, glorot_uniform_init
from thinc.api import reduce_mean, reduce_max, reduce_first, reduce_last
@@ -9,7 +9,7 @@ from ...tokens import Doc
from ..extract_spans import extract_spans
-@registry.layers.register("spacy.LinearLogistic.v1")
+@registry.layers("spacy.LinearLogistic.v1")
def build_linear_logistic(nO=None, nI=None) -> Model[Floats2d, Floats2d]:
"""An output layer for multi-label classification. It uses a linear layer
followed by a logistic activation.
@@ -17,18 +17,23 @@ def build_linear_logistic(nO=None, nI=None) -> Model[Floats2d, Floats2d]:
return chain(Linear(nO=nO, nI=nI, init_W=glorot_uniform_init), Logistic())
-@registry.layers.register("spacy.mean_max_reducer.v1")
+@registry.layers("spacy.mean_max_reducer.v1")
def build_mean_max_reducer(hidden_size: int) -> Model[Ragged, Floats2d]:
"""Reduce sequences by concatenating their mean and max pooled vectors,
and then combine the concatenated vectors with a hidden layer.
"""
return chain(
- concatenate(reduce_last(), reduce_first(), reduce_mean(), reduce_max()),
+ concatenate(
+ cast(Model[Ragged, Floats2d], reduce_last()),
+ cast(Model[Ragged, Floats2d], reduce_first()),
+ reduce_mean(),
+ reduce_max(),
+ ),
Maxout(nO=hidden_size, normalize=True, dropout=0.0),
)
-@registry.architectures.register("spacy.SpanCategorizer.v1")
+@registry.architectures("spacy.SpanCategorizer.v1")
def build_spancat_model(
tok2vec: Model[List[Doc], List[Floats2d]],
reducer: Model[Ragged, Floats2d],
@@ -43,7 +48,12 @@ def build_spancat_model(
scorer (Model[Floats2d, Floats2d]): The scorer model.
"""
model = chain(
- with_getitem(0, chain(tok2vec, list2ragged())),
+ cast(
+ Model[Tuple[List[Doc], Ragged], Tuple[Ragged, Ragged]],
+ with_getitem(
+ 0, chain(tok2vec, cast(Model[List[Floats2d], Ragged], list2ragged()))
+ ),
+ ),
extract_spans(),
reducer,
scorer,
diff --git a/spacy/ml/models/tagger.py b/spacy/ml/models/tagger.py
index 87944e305..9c7fe042d 100644
--- a/spacy/ml/models/tagger.py
+++ b/spacy/ml/models/tagger.py
@@ -20,7 +20,7 @@ def build_tagger_model(
# TODO: glorot_uniform_init seems to work a bit better than zero_init here?!
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
output_layer = Softmax(nO, t2v_width, init_W=zero_init)
- softmax = with_array(output_layer)
+ softmax = with_array(output_layer) # type: ignore
model = chain(tok2vec, softmax)
model.set_ref("tok2vec", tok2vec)
model.set_ref("softmax", output_layer)
diff --git a/spacy/ml/models/textcat.py b/spacy/ml/models/textcat.py
index e3f6e944a..c8c146f02 100644
--- a/spacy/ml/models/textcat.py
+++ b/spacy/ml/models/textcat.py
@@ -37,7 +37,7 @@ def build_simple_cnn_text_classifier(
if exclusive_classes:
output_layer = Softmax(nO=nO, nI=nI)
fill_defaults["b"] = NEG_VALUE
- resizable_layer = resizable(
+ resizable_layer: Model = resizable(
output_layer,
resize_layer=partial(
resize_linear_weighted, fill_defaults=fill_defaults
@@ -59,7 +59,7 @@ def build_simple_cnn_text_classifier(
resizable_layer=resizable_layer,
)
model.set_ref("tok2vec", tok2vec)
- model.set_dim("nO", nO)
+ model.set_dim("nO", nO) # type: ignore # TODO: remove type ignore once Thinc has been updated
model.attrs["multi_label"] = not exclusive_classes
return model
@@ -85,7 +85,7 @@ def build_bow_text_classifier(
if not no_output_layer:
fill_defaults["b"] = NEG_VALUE
output_layer = softmax_activation() if exclusive_classes else Logistic()
- resizable_layer = resizable(
+ resizable_layer = resizable( # type: ignore[var-annotated]
sparse_linear,
resize_layer=partial(resize_linear_weighted, fill_defaults=fill_defaults),
)
@@ -93,7 +93,7 @@ def build_bow_text_classifier(
model = with_cpu(model, model.ops)
if output_layer:
model = model >> with_cpu(output_layer, output_layer.ops)
- model.set_dim("nO", nO)
+ model.set_dim("nO", nO) # type: ignore[arg-type]
model.set_ref("output_layer", sparse_linear)
model.attrs["multi_label"] = not exclusive_classes
model.attrs["resize_output"] = partial(
@@ -130,14 +130,14 @@ def build_text_classifier_v2(
model = (linear_model | cnn_model) >> output_layer
model.set_ref("tok2vec", tok2vec)
if model.has_dim("nO") is not False:
- model.set_dim("nO", nO)
+ model.set_dim("nO", nO) # type: ignore[arg-type]
model.set_ref("output_layer", linear_model.get_ref("output_layer"))
model.set_ref("attention_layer", attention_layer)
model.set_ref("maxout_layer", maxout_layer)
model.set_ref("norm_layer", norm_layer)
model.attrs["multi_label"] = not exclusive_classes
- model.init = init_ensemble_textcat
+ model.init = init_ensemble_textcat # type: ignore[assignment]
return model
@@ -164,7 +164,7 @@ def build_text_classifier_lowdata(
>> list2ragged()
>> ParametricAttention(width)
>> reduce_sum()
- >> residual(Relu(width, width)) ** 2
+ >> residual(Relu(width, width)) ** 2 # type: ignore[arg-type]
>> Linear(nO, width)
)
if dropout:
diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py
index 76ec87054..8d78e418f 100644
--- a/spacy/ml/models/tok2vec.py
+++ b/spacy/ml/models/tok2vec.py
@@ -1,5 +1,5 @@
-from typing import Optional, List, Union
-from thinc.types import Floats2d
+from typing import Optional, List, Union, cast
+from thinc.types import Floats2d, Ints2d, Ragged
from thinc.api import chain, clone, concatenate, with_array, with_padded
from thinc.api import Model, noop, list2ragged, ragged2list, HashEmbed
from thinc.api import expand_window, residual, Maxout, Mish, PyTorchLSTM
@@ -158,26 +158,30 @@ def MultiHashEmbed(
embeddings = [make_hash_embed(i) for i in range(len(attrs))]
concat_size = width * (len(embeddings) + include_static_vectors)
+ max_out: Model[Ragged, Ragged] = with_array(
+ Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True) # type: ignore
+ )
if include_static_vectors:
+ feature_extractor: Model[List[Doc], Ragged] = chain(
+ FeatureExtractor(attrs),
+ cast(Model[List[Ints2d], Ragged], list2ragged()),
+ with_array(concatenate(*embeddings)),
+ )
model = chain(
concatenate(
- chain(
- FeatureExtractor(attrs),
- list2ragged(),
- with_array(concatenate(*embeddings)),
- ),
+ feature_extractor,
StaticVectors(width, dropout=0.0),
),
- with_array(Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True)),
- ragged2list(),
+ max_out,
+ cast(Model[Ragged, List[Floats2d]], ragged2list()),
)
else:
model = chain(
FeatureExtractor(list(attrs)),
- list2ragged(),
+ cast(Model[List[Ints2d], Ragged], list2ragged()),
with_array(concatenate(*embeddings)),
- with_array(Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True)),
- ragged2list(),
+ max_out,
+ cast(Model[Ragged, List[Floats2d]], ragged2list()),
)
return model
@@ -220,37 +224,41 @@ def CharacterEmbed(
"""
feature = intify_attr(feature)
if feature is None:
- raise ValueError(Errors.E911(feat=feature))
+ raise ValueError(Errors.E911.format(feat=feature))
+ char_embed = chain(
+ _character_embed.CharacterEmbed(nM=nM, nC=nC),
+ cast(Model[List[Floats2d], Ragged], list2ragged()),
+ )
+ feature_extractor: Model[List[Doc], Ragged] = chain(
+ FeatureExtractor([feature]),
+ cast(Model[List[Ints2d], Ragged], list2ragged()),
+ with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)), # type: ignore
+ )
+ max_out: Model[Ragged, Ragged]
if include_static_vectors:
+ max_out = with_array(
+ Maxout(width, nM * nC + (2 * width), nP=3, normalize=True, dropout=0.0) # type: ignore
+ )
model = chain(
concatenate(
- chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()),
- chain(
- FeatureExtractor([feature]),
- list2ragged(),
- with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)),
- ),
+ char_embed,
+ feature_extractor,
StaticVectors(width, dropout=0.0),
),
- with_array(
- Maxout(width, nM * nC + (2 * width), nP=3, normalize=True, dropout=0.0)
- ),
- ragged2list(),
+ max_out,
+ cast(Model[Ragged, List[Floats2d]], ragged2list()),
)
else:
+ max_out = with_array(
+ Maxout(width, nM * nC + width, nP=3, normalize=True, dropout=0.0) # type: ignore
+ )
model = chain(
concatenate(
- chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()),
- chain(
- FeatureExtractor([feature]),
- list2ragged(),
- with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)),
- ),
+ char_embed,
+ feature_extractor,
),
- with_array(
- Maxout(width, nM * nC + width, nP=3, normalize=True, dropout=0.0)
- ),
- ragged2list(),
+ max_out,
+ cast(Model[Ragged, List[Floats2d]], ragged2list()),
)
return model
@@ -281,10 +289,10 @@ def MaxoutWindowEncoder(
normalize=True,
),
)
- model = clone(residual(cnn), depth)
+ model = clone(residual(cnn), depth) # type: ignore[arg-type]
model.set_dim("nO", width)
receptive_field = window_size * depth
- return with_array(model, pad=receptive_field)
+ return with_array(model, pad=receptive_field) # type: ignore[arg-type]
@registry.architectures("spacy.MishWindowEncoder.v2")
@@ -305,9 +313,9 @@ def MishWindowEncoder(
expand_window(window_size=window_size),
Mish(nO=width, nI=width * ((window_size * 2) + 1), dropout=0.0, normalize=True),
)
- model = clone(residual(cnn), depth)
+ model = clone(residual(cnn), depth) # type: ignore[arg-type]
model.set_dim("nO", width)
- return with_array(model)
+ return with_array(model) # type: ignore[arg-type]
@registry.architectures("spacy.TorchBiLSTMEncoder.v1")
diff --git a/spacy/ml/staticvectors.py b/spacy/ml/staticvectors.py
index 4e7262e7d..53ef01906 100644
--- a/spacy/ml/staticvectors.py
+++ b/spacy/ml/staticvectors.py
@@ -49,7 +49,7 @@ def forward(
# Convert negative indices to 0-vectors (TODO: more options for UNK tokens)
vectors_data[rows < 0] = 0
output = Ragged(
- vectors_data, model.ops.asarray([len(doc) for doc in docs], dtype="i")
+ vectors_data, model.ops.asarray([len(doc) for doc in docs], dtype="i") # type: ignore
)
mask = None
if is_train:
@@ -62,7 +62,9 @@ def forward(
d_output.data *= mask
model.inc_grad(
"W",
- model.ops.gemm(d_output.data, model.ops.as_contig(V[rows]), trans1=True),
+ model.ops.gemm(
+ cast(Floats2d, d_output.data), model.ops.as_contig(V[rows]), trans1=True
+ ),
)
return []
@@ -97,4 +99,7 @@ def _handle_empty(ops: Ops, nO: int):
def _get_drop_mask(ops: Ops, nO: int, rate: Optional[float]) -> Optional[Floats1d]:
- return ops.get_dropout_mask((nO,), rate) if rate is not None else None
+ if rate is not None:
+ mask = ops.get_dropout_mask((nO,), rate)
+ return mask # type: ignore
+ return None
diff --git a/spacy/pipe_analysis.py b/spacy/pipe_analysis.py
index d0362e7e1..245747061 100644
--- a/spacy/pipe_analysis.py
+++ b/spacy/pipe_analysis.py
@@ -1,4 +1,4 @@
-from typing import List, Dict, Iterable, Optional, Union, TYPE_CHECKING
+from typing import List, Set, Dict, Iterable, ItemsView, Union, TYPE_CHECKING
from wasabi import msg
from .tokens import Doc, Token, Span
@@ -67,7 +67,7 @@ def get_attr_info(nlp: "Language", attr: str) -> Dict[str, List[str]]:
RETURNS (Dict[str, List[str]]): A dict keyed by "assigns" and "requires",
mapped to a list of component names.
"""
- result = {"assigns": [], "requires": []}
+ result: Dict[str, List[str]] = {"assigns": [], "requires": []}
for pipe_name in nlp.pipe_names:
meta = nlp.get_pipe_meta(pipe_name)
if attr in meta.assigns:
@@ -79,7 +79,7 @@ def get_attr_info(nlp: "Language", attr: str) -> Dict[str, List[str]]:
def analyze_pipes(
nlp: "Language", *, keys: List[str] = DEFAULT_KEYS
-) -> Dict[str, Union[List[str], Dict[str, List[str]]]]:
+) -> Dict[str, Dict[str, Union[List[str], Dict]]]:
"""Print a formatted summary for the current nlp object's pipeline. Shows
a table with the pipeline components and why they assign and require, as
well as any problems if available.
@@ -88,8 +88,11 @@ def analyze_pipes(
keys (List[str]): The meta keys to show in the table.
RETURNS (dict): A dict with "summary" and "problems".
"""
- result = {"summary": {}, "problems": {}}
- all_attrs = set()
+ result: Dict[str, Dict[str, Union[List[str], Dict]]] = {
+ "summary": {},
+ "problems": {},
+ }
+ all_attrs: Set[str] = set()
for i, name in enumerate(nlp.pipe_names):
meta = nlp.get_pipe_meta(name)
all_attrs.update(meta.assigns)
@@ -102,19 +105,18 @@ def analyze_pipes(
prev_meta = nlp.get_pipe_meta(prev_name)
for annot in prev_meta.assigns:
requires[annot] = True
- result["problems"][name] = []
- for annot, fulfilled in requires.items():
- if not fulfilled:
- result["problems"][name].append(annot)
+ result["problems"][name] = [
+ annot for annot, fulfilled in requires.items() if not fulfilled
+ ]
result["attrs"] = {attr: get_attr_info(nlp, attr) for attr in all_attrs}
return result
def print_pipe_analysis(
- analysis: Dict[str, Union[List[str], Dict[str, List[str]]]],
+ analysis: Dict[str, Dict[str, Union[List[str], Dict]]],
*,
keys: List[str] = DEFAULT_KEYS,
-) -> Optional[Dict[str, Union[List[str], Dict[str, List[str]]]]]:
+) -> None:
"""Print a formatted version of the pipe analysis produced by analyze_pipes.
analysis (Dict[str, Union[List[str], Dict[str, List[str]]]]): The analysis.
@@ -122,7 +124,7 @@ def print_pipe_analysis(
"""
msg.divider("Pipeline Overview")
header = ["#", "Component", *[key.capitalize() for key in keys]]
- summary = analysis["summary"].items()
+ summary: ItemsView = analysis["summary"].items()
body = [[i, n, *[v for v in m.values()]] for i, (n, m) in enumerate(summary)]
msg.table(body, header=header, divider=True, multiline=True)
n_problems = sum(len(p) for p in analysis["problems"].values())
diff --git a/spacy/pipeline/attributeruler.py b/spacy/pipeline/attributeruler.py
index b1a2f3e9c..0d9494865 100644
--- a/spacy/pipeline/attributeruler.py
+++ b/spacy/pipeline/attributeruler.py
@@ -95,9 +95,9 @@ class AttributeRuler(Pipe):
self.vocab = vocab
self.matcher = Matcher(self.vocab, validate=validate)
self.validate = validate
- self.attrs = []
- self._attrs_unnormed = [] # store for reference
- self.indices = []
+ self.attrs: List[Dict] = []
+ self._attrs_unnormed: List[Dict] = [] # store for reference
+ self.indices: List[int] = []
self.scorer = scorer
def clear(self) -> None:
@@ -144,13 +144,13 @@ class AttributeRuler(Pipe):
self.set_annotations(doc, matches)
return doc
except Exception as e:
- error_handler(self.name, self, [doc], e)
+ return error_handler(self.name, self, [doc], e)
def match(self, doc: Doc):
- matches = self.matcher(doc, allow_missing=True)
+ matches = self.matcher(doc, allow_missing=True, as_spans=False)
# Sort by the attribute ID, so that later rules have precedence
matches = [
- (int(self.vocab.strings[m_id]), m_id, s, e) for m_id, s, e in matches
+ (int(self.vocab.strings[m_id]), m_id, s, e) for m_id, s, e in matches # type: ignore
]
matches.sort()
return matches
@@ -196,7 +196,7 @@ class AttributeRuler(Pipe):
else:
morph = self.vocab.morphology.add(attrs["MORPH"])
attrs["MORPH"] = self.vocab.strings[morph]
- self.add([pattern], attrs)
+ self.add([pattern], attrs) # type: ignore[list-item]
def load_from_morph_rules(
self, morph_rules: Dict[str, Dict[str, Dict[Union[int, str], Union[int, str]]]]
@@ -220,7 +220,7 @@ class AttributeRuler(Pipe):
elif morph_attrs:
morph = self.vocab.morphology.add(morph_attrs)
attrs["MORPH"] = self.vocab.strings[morph]
- self.add([pattern], attrs)
+ self.add([pattern], attrs) # type: ignore[list-item]
def add(
self, patterns: Iterable[MatcherPatternType], attrs: Dict, index: int = 0
@@ -240,7 +240,7 @@ class AttributeRuler(Pipe):
# We need to make a string here, because otherwise the ID we pass back
# will be interpreted as the hash of a string, rather than an ordinal.
key = str(len(self.attrs))
- self.matcher.add(self.vocab.strings.add(key), patterns)
+ self.matcher.add(self.vocab.strings.add(key), patterns) # type: ignore[arg-type]
self._attrs_unnormed.append(attrs)
attrs = normalize_token_attrs(self.vocab, attrs)
self.attrs.append(attrs)
@@ -256,7 +256,7 @@ class AttributeRuler(Pipe):
DOCS: https://spacy.io/api/attributeruler#add_patterns
"""
for p in patterns:
- self.add(**p)
+ self.add(**p) # type: ignore[arg-type]
@property
def patterns(self) -> List[AttributeRulerPatternType]:
@@ -265,10 +265,10 @@ class AttributeRuler(Pipe):
for i in range(len(self.attrs)):
p = {}
p["patterns"] = self.matcher.get(str(i))[1]
- p["attrs"] = self._attrs_unnormed[i]
- p["index"] = self.indices[i]
+ p["attrs"] = self._attrs_unnormed[i] # type: ignore
+ p["index"] = self.indices[i] # type: ignore
all_patterns.append(p)
- return all_patterns
+ return all_patterns # type: ignore[return-value]
def to_bytes(self, exclude: Iterable[str] = SimpleFrozenList()) -> bytes:
"""Serialize the AttributeRuler to a bytestring.
diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py
index 80e135a30..1169e898d 100644
--- a/spacy/pipeline/entity_linker.py
+++ b/spacy/pipeline/entity_linker.py
@@ -1,4 +1,5 @@
-from typing import Optional, Iterable, Callable, Dict, Union, List
+from typing import Optional, Iterable, Callable, Dict, Union, List, Any
+from thinc.types import Floats2d
from pathlib import Path
from itertools import islice
import srsly
@@ -162,7 +163,7 @@ class EntityLinker(TrainablePipe):
self.incl_prior = incl_prior
self.incl_context = incl_context
self.get_candidates = get_candidates
- self.cfg = {"overwrite": overwrite}
+ self.cfg: Dict[str, Any] = {"overwrite": overwrite}
self.distance = CosineDistance(normalize=False)
# how many neighbour sentences to take into account
# create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'.
@@ -189,7 +190,7 @@ class EntityLinker(TrainablePipe):
get_examples: Callable[[], Iterable[Example]],
*,
nlp: Optional[Language] = None,
- kb_loader: Callable[[Vocab], KnowledgeBase] = None,
+ kb_loader: Optional[Callable[[Vocab], KnowledgeBase]] = None,
):
"""Initialize the pipe for training, using a representative set
of data examples.
@@ -284,7 +285,7 @@ class EntityLinker(TrainablePipe):
losses[self.name] += loss
return losses
- def get_loss(self, examples: Iterable[Example], sentence_encodings):
+ def get_loss(self, examples: Iterable[Example], sentence_encodings: Floats2d):
validate_examples(examples, "EntityLinker.get_loss")
entity_encodings = []
for eg in examples:
@@ -300,8 +301,9 @@ class EntityLinker(TrainablePipe):
method="get_loss", msg="gold entities do not match up"
)
raise RuntimeError(err)
- gradients = self.distance.get_grad(sentence_encodings, entity_encodings)
- loss = self.distance.get_loss(sentence_encodings, entity_encodings)
+ # TODO: fix typing issue here
+ gradients = self.distance.get_grad(sentence_encodings, entity_encodings) # type: ignore
+ loss = self.distance.get_loss(sentence_encodings, entity_encodings) # type: ignore
loss = loss / len(entity_encodings)
return float(loss), gradients
@@ -311,13 +313,13 @@ class EntityLinker(TrainablePipe):
no prediction.
docs (Iterable[Doc]): The documents to predict.
- RETURNS (List[int]): The models prediction for each document.
+ RETURNS (List[str]): The models prediction for each document.
DOCS: https://spacy.io/api/entitylinker#predict
"""
self.validate_kb()
entity_count = 0
- final_kb_ids = []
+ final_kb_ids: List[str] = []
if not docs:
return final_kb_ids
if isinstance(docs, Doc):
@@ -347,7 +349,7 @@ class EntityLinker(TrainablePipe):
# ignoring this entity - setting to NIL
final_kb_ids.append(self.NIL)
else:
- candidates = self.get_candidates(self.kb, ent)
+ candidates = list(self.get_candidates(self.kb, ent))
if not candidates:
# no prediction possible for this entity - setting to NIL
final_kb_ids.append(self.NIL)
@@ -492,7 +494,7 @@ class EntityLinker(TrainablePipe):
except AttributeError:
raise ValueError(Errors.E149) from None
- deserialize = {}
+ deserialize: Dict[str, Callable[[Any], Any]] = {}
deserialize["cfg"] = lambda p: self.cfg.update(deserialize_config(p))
deserialize["vocab"] = lambda p: self.vocab.from_disk(p, exclude=exclude)
deserialize["kb"] = lambda p: self.kb.from_disk(p)
diff --git a/spacy/pipeline/entityruler.py b/spacy/pipeline/entityruler.py
index ad67a7a1f..2c3db2575 100644
--- a/spacy/pipeline/entityruler.py
+++ b/spacy/pipeline/entityruler.py
@@ -1,5 +1,6 @@
import warnings
from typing import Optional, Union, List, Dict, Tuple, Iterable, Any, Callable, Sequence
+from typing import cast
from collections import defaultdict
from pathlib import Path
import srsly
@@ -114,8 +115,8 @@ class EntityRuler(Pipe):
self.nlp = nlp
self.name = name
self.overwrite = overwrite_ents
- self.token_patterns = defaultdict(list)
- self.phrase_patterns = defaultdict(list)
+ self.token_patterns = defaultdict(list) # type: ignore
+ self.phrase_patterns = defaultdict(list) # type: ignore
self._validate = validate
self.matcher = Matcher(nlp.vocab, validate=validate)
self.phrase_matcher_attr = phrase_matcher_attr
@@ -123,7 +124,7 @@ class EntityRuler(Pipe):
nlp.vocab, attr=self.phrase_matcher_attr, validate=validate
)
self.ent_id_sep = ent_id_sep
- self._ent_ids = defaultdict(dict)
+ self._ent_ids = defaultdict(tuple) # type: ignore
if patterns is not None:
self.add_patterns(patterns)
self.scorer = scorer
@@ -152,19 +153,22 @@ class EntityRuler(Pipe):
self.set_annotations(doc, matches)
return doc
except Exception as e:
- error_handler(self.name, self, [doc], e)
+ return error_handler(self.name, self, [doc], e)
def match(self, doc: Doc):
self._require_patterns()
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="\\[W036")
- matches = list(self.matcher(doc)) + list(self.phrase_matcher(doc))
- matches = set(
+ matches = cast(
+ List[Tuple[int, int, int]],
+ list(self.matcher(doc)) + list(self.phrase_matcher(doc)),
+ )
+ final_matches = set(
[(m_id, start, end) for m_id, start, end in matches if start != end]
)
get_sort_key = lambda m: (m[2] - m[1], -m[1])
- matches = sorted(matches, key=get_sort_key, reverse=True)
- return matches
+ final_matches = sorted(final_matches, key=get_sort_key, reverse=True)
+ return final_matches
def set_annotations(self, doc, matches):
"""Modify the document in place"""
@@ -229,10 +233,10 @@ class EntityRuler(Pipe):
"""
self.clear()
if patterns:
- self.add_patterns(patterns)
+ self.add_patterns(patterns) # type: ignore[arg-type]
@property
- def ent_ids(self) -> Tuple[str, ...]:
+ def ent_ids(self) -> Tuple[Optional[str], ...]:
"""All entity ids present in the match patterns `id` properties
RETURNS (set): The string entity ids.
@@ -317,17 +321,17 @@ class EntityRuler(Pipe):
if ent_id:
phrase_pattern["id"] = ent_id
phrase_patterns.append(phrase_pattern)
- for entry in token_patterns + phrase_patterns:
+ for entry in token_patterns + phrase_patterns: # type: ignore[operator]
label = entry["label"]
if "id" in entry:
ent_label = label
label = self._create_label(label, entry["id"])
key = self.matcher._normalize_key(label)
self._ent_ids[key] = (ent_label, entry["id"])
- pattern = entry["pattern"]
+ pattern = entry["pattern"] # type: ignore
if isinstance(pattern, Doc):
self.phrase_patterns[label].append(pattern)
- self.phrase_matcher.add(label, [pattern])
+ self.phrase_matcher.add(label, [pattern]) # type: ignore
elif isinstance(pattern, list):
self.token_patterns[label].append(pattern)
self.matcher.add(label, [pattern])
@@ -338,7 +342,7 @@ class EntityRuler(Pipe):
"""Reset all patterns."""
self.token_patterns = defaultdict(list)
self.phrase_patterns = defaultdict(list)
- self._ent_ids = defaultdict(dict)
+ self._ent_ids = defaultdict(tuple)
self.matcher = Matcher(self.nlp.vocab, validate=self._validate)
self.phrase_matcher = PhraseMatcher(
self.nlp.vocab, attr=self.phrase_matcher_attr, validate=self._validate
@@ -349,7 +353,7 @@ class EntityRuler(Pipe):
if len(self) == 0:
warnings.warn(Warnings.W036.format(name=self.name))
- def _split_label(self, label: str) -> Tuple[str, str]:
+ def _split_label(self, label: str) -> Tuple[str, Optional[str]]:
"""Split Entity label into ent_label and ent_id if it contains self.ent_id_sep
label (str): The value of label in a pattern entry
@@ -359,11 +363,12 @@ class EntityRuler(Pipe):
ent_label, ent_id = label.rsplit(self.ent_id_sep, 1)
else:
ent_label = label
- ent_id = None
+ ent_id = None # type: ignore
return ent_label, ent_id
- def _create_label(self, label: str, ent_id: str) -> str:
+ def _create_label(self, label: Any, ent_id: Any) -> str:
"""Join Entity label with ent_id if the pattern has an `id` attribute
+ If ent_id is not a string, the label is returned as is.
label (str): The label to set for ent.label_
ent_id (str): The label
diff --git a/spacy/pipeline/functions.py b/spacy/pipeline/functions.py
index 03c7db422..f0a75dc2c 100644
--- a/spacy/pipeline/functions.py
+++ b/spacy/pipeline/functions.py
@@ -25,7 +25,7 @@ def merge_noun_chunks(doc: Doc) -> Doc:
with doc.retokenize() as retokenizer:
for np in doc.noun_chunks:
attrs = {"tag": np.root.tag, "dep": np.root.dep}
- retokenizer.merge(np, attrs=attrs)
+ retokenizer.merge(np, attrs=attrs) # type: ignore[arg-type]
return doc
@@ -45,7 +45,7 @@ def merge_entities(doc: Doc):
with doc.retokenize() as retokenizer:
for ent in doc.ents:
attrs = {"tag": ent.root.tag, "dep": ent.root.dep, "ent_type": ent.label}
- retokenizer.merge(ent, attrs=attrs)
+ retokenizer.merge(ent, attrs=attrs) # type: ignore[arg-type]
return doc
@@ -63,7 +63,7 @@ def merge_subtokens(doc: Doc, label: str = "subtok") -> Doc:
merger = Matcher(doc.vocab)
merger.add("SUBTOK", [[{"DEP": label, "op": "+"}]])
matches = merger(doc)
- spans = util.filter_spans([doc[start : end + 1] for _, start, end in matches])
+ spans = util.filter_spans([doc[start : end + 1] for _, start, end in matches]) # type: ignore[misc, operator]
with doc.retokenize() as retokenizer:
for span in spans:
retokenizer.merge(span)
@@ -93,11 +93,11 @@ class TokenSplitter:
if len(t.text) >= self.min_length:
orths = []
heads = []
- attrs = {}
+ attrs = {} # type: ignore[var-annotated]
for i in range(0, len(t.text), self.split_length):
orths.append(t.text[i : i + self.split_length])
heads.append((t, i / self.split_length))
- retokenizer.split(t, orths, heads, attrs)
+ retokenizer.split(t, orths, heads, attrs) # type: ignore[arg-type]
return doc
def _get_config(self) -> Dict[str, Any]:
diff --git a/spacy/pipeline/lemmatizer.py b/spacy/pipeline/lemmatizer.py
index 5adae10d2..9c2fc2f09 100644
--- a/spacy/pipeline/lemmatizer.py
+++ b/spacy/pipeline/lemmatizer.py
@@ -111,7 +111,7 @@ class Lemmatizer(Pipe):
if not hasattr(self, mode_attr):
raise ValueError(Errors.E1003.format(mode=mode))
self.lemmatize = getattr(self, mode_attr)
- self.cache = {}
+ self.cache = {} # type: ignore[var-annotated]
self.scorer = scorer
@property
@@ -201,7 +201,7 @@ class Lemmatizer(Pipe):
DOCS: https://spacy.io/api/lemmatizer#rule_lemmatize
"""
- cache_key = (token.orth, token.pos, token.morph.key)
+ cache_key = (token.orth, token.pos, token.morph.key) # type: ignore[attr-defined]
if cache_key in self.cache:
return self.cache[cache_key]
string = token.text
@@ -297,7 +297,7 @@ class Lemmatizer(Pipe):
DOCS: https://spacy.io/api/lemmatizer#from_disk
"""
- deserialize = {}
+ deserialize: Dict[str, Callable[[Any], Any]] = {}
deserialize["vocab"] = lambda p: self.vocab.from_disk(p, exclude=exclude)
deserialize["lookups"] = lambda p: self.lookups.from_disk(p)
util.from_disk(path, deserialize, exclude)
@@ -328,7 +328,7 @@ class Lemmatizer(Pipe):
DOCS: https://spacy.io/api/lemmatizer#from_bytes
"""
- deserialize = {}
+ deserialize: Dict[str, Callable[[Any], Any]] = {}
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b, exclude=exclude)
deserialize["lookups"] = lambda b: self.lookups.from_bytes(b)
util.from_bytes(bytes_data, deserialize, exclude)
diff --git a/spacy/pipeline/pipe.pyi b/spacy/pipeline/pipe.pyi
new file mode 100644
index 000000000..c7c0568f9
--- /dev/null
+++ b/spacy/pipeline/pipe.pyi
@@ -0,0 +1,38 @@
+from pathlib import Path
+from typing import Any, Callable, Dict, Iterable, Iterator, List
+from typing import NoReturn, Optional, Tuple, Union
+
+from ..tokens.doc import Doc
+
+from ..training import Example
+from ..language import Language
+
+class Pipe:
+ def __call__(self, doc: Doc) -> Doc: ...
+ def pipe(
+ self, stream: Iterable[Doc], *, batch_size: int = ...
+ ) -> Iterator[Doc]: ...
+ def initialize(
+ self,
+ get_examples: Callable[[], Iterable[Example]],
+ *,
+ nlp: Language = ...,
+ ) -> None: ...
+ def score(
+ self, examples: Iterable[Example], **kwargs: Any
+ ) -> Dict[str, Union[float, Dict[str, float]]]: ...
+ @property
+ def is_trainable(self) -> bool: ...
+ @property
+ def labels(self) -> Tuple[str, ...]: ...
+ @property
+ def label_data(self) -> Any: ...
+ def _require_labels(self) -> None: ...
+ def set_error_handler(
+ self, error_handler: Callable[[str, "Pipe", List[Doc], Exception], NoReturn]
+ ) -> None: ...
+ def get_error_handler(
+ self,
+ ) -> Callable[[str, "Pipe", List[Doc], Exception], NoReturn]: ...
+
+def deserialize_config(path: Path) -> Any: ...
diff --git a/spacy/pipeline/pipe.pyx b/spacy/pipeline/pipe.pyx
index 14f9f08f8..9eddc1e3f 100644
--- a/spacy/pipeline/pipe.pyx
+++ b/spacy/pipeline/pipe.pyx
@@ -99,7 +99,7 @@ cdef class Pipe:
return False
@property
- def labels(self) -> Optional[Tuple[str]]:
+ def labels(self) -> Tuple[str, ...]:
return tuple()
@property
@@ -126,7 +126,7 @@ cdef class Pipe:
"""
self.error_handler = error_handler
- def get_error_handler(self) -> Optional[Callable]:
+ def get_error_handler(self) -> Callable:
"""Retrieve the error handler function.
RETURNS (Callable): The error handler, or if it's not set a default function that just reraises.
diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py
index ef1880372..5b84ce8fb 100644
--- a/spacy/pipeline/spancat.py
+++ b/spacy/pipeline/spancat.py
@@ -1,9 +1,10 @@
import numpy
-from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any
+from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
from thinc.api import Optimizer
-from thinc.types import Ragged, Ints2d, Floats2d
+from thinc.types import Ragged, Ints2d, Floats2d, Ints1d
+from ..compat import Protocol, runtime_checkable
from ..scorer import Scorer
from ..language import Language
from .trainable_pipe import TrainablePipe
@@ -44,13 +45,19 @@ depth = 4
DEFAULT_SPANCAT_MODEL = Config().from_str(spancat_default_config)["model"]
+@runtime_checkable
+class Suggester(Protocol):
+ def __call__(self, docs: Iterable[Doc], *, ops: Optional[Ops] = None) -> Ragged:
+ ...
+
+
@registry.misc("spacy.ngram_suggester.v1")
-def build_ngram_suggester(sizes: List[int]) -> Callable[[List[Doc]], Ragged]:
+def build_ngram_suggester(sizes: List[int]) -> Suggester:
"""Suggest all spans of the given lengths. Spans are returned as a ragged
array of integers. The array has two columns, indicating the start and end
position."""
- def ngram_suggester(docs: List[Doc], *, ops: Optional[Ops] = None) -> Ragged:
+ def ngram_suggester(docs: Iterable[Doc], *, ops: Optional[Ops] = None) -> Ragged:
if ops is None:
ops = get_current_ops()
spans = []
@@ -67,10 +74,11 @@ def build_ngram_suggester(sizes: List[int]) -> Callable[[List[Doc]], Ragged]:
if spans:
assert spans[-1].ndim == 2, spans[-1].shape
lengths.append(length)
+ lengths_array = cast(Ints1d, ops.asarray(lengths, dtype="i"))
if len(spans) > 0:
- output = Ragged(ops.xp.vstack(spans), ops.asarray(lengths, dtype="i"))
+ output = Ragged(ops.xp.vstack(spans), lengths_array)
else:
- output = Ragged(ops.xp.zeros((0, 0)), ops.asarray(lengths, dtype="i"))
+ output = Ragged(ops.xp.zeros((0, 0)), lengths_array)
assert output.dataXd.ndim == 2
return output
@@ -79,13 +87,11 @@ def build_ngram_suggester(sizes: List[int]) -> Callable[[List[Doc]], Ragged]:
@registry.misc("spacy.ngram_range_suggester.v1")
-def build_ngram_range_suggester(
- min_size: int, max_size: int
-) -> Callable[[List[Doc]], Ragged]:
+def build_ngram_range_suggester(min_size: int, max_size: int) -> Suggester:
"""Suggest all spans of the given lengths between a given min and max value - both inclusive.
Spans are returned as a ragged array of integers. The array has two columns,
indicating the start and end position."""
- sizes = range(min_size, max_size + 1)
+ sizes = list(range(min_size, max_size + 1))
return build_ngram_suggester(sizes)
@@ -105,7 +111,7 @@ def build_ngram_range_suggester(
def make_spancat(
nlp: Language,
name: str,
- suggester: Callable[[List[Doc]], Ragged],
+ suggester: Suggester,
model: Model[Tuple[List[Doc], Ragged], Floats2d],
spans_key: str,
scorer: Optional[Callable],
@@ -116,7 +122,7 @@ def make_spancat(
parts: a suggester function that proposes candidate spans, and a labeller
model that predicts one or more labels for each span.
- suggester (Callable[List[Doc], Ragged]): A function that suggests spans.
+ suggester (Callable[[Iterable[Doc], Optional[Ops]], Ragged]): A function that suggests spans.
Spans are returned as a ragged array with two integer columns, for the
start and end positions.
model (Model[Tuple[List[Doc], Ragged], Floats2d]): A model instance that
@@ -172,7 +178,7 @@ class SpanCategorizer(TrainablePipe):
self,
vocab: Vocab,
model: Model[Tuple[List[Doc], Ragged], Floats2d],
- suggester: Callable[[List[Doc]], Ragged],
+ suggester: Suggester,
name: str = "spancat",
*,
spans_key: str = "spans",
@@ -218,7 +224,7 @@ class SpanCategorizer(TrainablePipe):
initialization and training, the component will look for spans on the
reference document under the same key.
"""
- return self.cfg["spans_key"]
+ return str(self.cfg["spans_key"])
def add_label(self, label: str) -> int:
"""Add a new label to the pipe.
@@ -233,7 +239,7 @@ class SpanCategorizer(TrainablePipe):
if label in self.labels:
return 0
self._allow_extra_label()
- self.cfg["labels"].append(label)
+ self.cfg["labels"].append(label) # type: ignore
self.vocab.strings.add(label)
return 1
@@ -243,7 +249,7 @@ class SpanCategorizer(TrainablePipe):
DOCS: https://spacy.io/api/spancategorizer#labels
"""
- return tuple(self.cfg["labels"])
+ return tuple(self.cfg["labels"]) # type: ignore
@property
def label_data(self) -> List[str]:
@@ -262,8 +268,8 @@ class SpanCategorizer(TrainablePipe):
DOCS: https://spacy.io/api/spancategorizer#predict
"""
indices = self.suggester(docs, ops=self.model.ops)
- scores = self.model.predict((docs, indices))
- return (indices, scores)
+ scores = self.model.predict((docs, indices)) # type: ignore
+ return indices, scores
def set_annotations(self, docs: Iterable[Doc], indices_scores) -> None:
"""Modify a batch of Doc objects, using pre-computed scores.
@@ -279,7 +285,7 @@ class SpanCategorizer(TrainablePipe):
for i, doc in enumerate(docs):
indices_i = indices[i].dataXd
doc.spans[self.key] = self._make_span_group(
- doc, indices_i, scores[offset : offset + indices.lengths[i]], labels
+ doc, indices_i, scores[offset : offset + indices.lengths[i]], labels # type: ignore[arg-type]
)
offset += indices.lengths[i]
@@ -318,14 +324,14 @@ class SpanCategorizer(TrainablePipe):
set_dropout_rate(self.model, drop)
scores, backprop_scores = self.model.begin_update((docs, spans))
loss, d_scores = self.get_loss(examples, (spans, scores))
- backprop_scores(d_scores)
+ backprop_scores(d_scores) # type: ignore
if sgd is not None:
self.finish_update(sgd)
losses[self.name] += loss
return losses
def get_loss(
- self, examples: Iterable[Example], spans_scores: Tuple[Ragged, Ragged]
+ self, examples: Iterable[Example], spans_scores: Tuple[Ragged, Floats2d]
) -> Tuple[float, float]:
"""Find the loss and gradient of loss for the batch of documents and
their predicted scores.
@@ -350,8 +356,8 @@ class SpanCategorizer(TrainablePipe):
spans_index = {}
spans_i = spans[i].dataXd
for j in range(spans.lengths[i]):
- start = int(spans_i[j, 0])
- end = int(spans_i[j, 1])
+ start = int(spans_i[j, 0]) # type: ignore
+ end = int(spans_i[j, 1]) # type: ignore
spans_index[(start, end)] = offset + j
for gold_span in self._get_aligned_spans(eg):
key = (gold_span.start, gold_span.end)
@@ -362,7 +368,7 @@ class SpanCategorizer(TrainablePipe):
# The target is a flat array for all docs. Track the position
# we're at within the flat array.
offset += spans.lengths[i]
- target = self.model.ops.asarray(target, dtype="f")
+ target = self.model.ops.asarray(target, dtype="f") # type: ignore
# The target will have the values 0 (for untrue predictions) or 1
# (for true predictions).
# The scores should be in the range [0, 1].
@@ -378,7 +384,7 @@ class SpanCategorizer(TrainablePipe):
self,
get_examples: Callable[[], Iterable[Example]],
*,
- nlp: Language = None,
+ nlp: Optional[Language] = None,
labels: Optional[List[str]] = None,
) -> None:
"""Initialize the pipe for training, using a representative set
@@ -386,14 +392,14 @@ class SpanCategorizer(TrainablePipe):
get_examples (Callable[[], Iterable[Example]]): Function that
returns a representative sample of gold-standard Example objects.
- nlp (Language): The current nlp object the component is part of.
- labels: The labels to add to the component, typically generated by the
+ nlp (Optional[Language]): The current nlp object the component is part of.
+ labels (Optional[List[str]]): The labels to add to the component, typically generated by the
`init labels` command. If no labels are provided, the get_examples
callback is used to extract the labels from the data.
DOCS: https://spacy.io/api/spancategorizer#initialize
"""
- subbatch = []
+ subbatch: List[Example] = []
if labels is not None:
for label in labels:
self.add_label(label)
@@ -412,7 +418,7 @@ class SpanCategorizer(TrainablePipe):
else:
self.model.initialize()
- def _validate_categories(self, examples):
+ def _validate_categories(self, examples: Iterable[Example]):
# TODO
pass
@@ -429,10 +435,11 @@ class SpanCategorizer(TrainablePipe):
threshold = self.cfg["threshold"]
keeps = scores >= threshold
- ranked = (scores * -1).argsort()
+ ranked = (scores * -1).argsort() # type: ignore
if max_positive is not None:
- filter = ranked[:, max_positive:]
- for i, row in enumerate(filter):
+ assert isinstance(max_positive, int)
+ span_filter = ranked[:, max_positive:]
+ for i, row in enumerate(span_filter):
keeps[i, row] = False
spans.attrs["scores"] = scores[keeps].flatten()
diff --git a/spacy/pipeline/textcat.py b/spacy/pipeline/textcat.py
index 6956a919d..30a65ec52 100644
--- a/spacy/pipeline/textcat.py
+++ b/spacy/pipeline/textcat.py
@@ -164,7 +164,7 @@ class TextCategorizer(TrainablePipe):
DOCS: https://spacy.io/api/textcategorizer#labels
"""
- return tuple(self.cfg["labels"])
+ return tuple(self.cfg["labels"]) # type: ignore[arg-type, return-value]
@property
def label_data(self) -> List[str]:
@@ -172,7 +172,7 @@ class TextCategorizer(TrainablePipe):
DOCS: https://spacy.io/api/textcategorizer#label_data
"""
- return self.labels
+ return self.labels # type: ignore[return-value]
def predict(self, docs: Iterable[Doc]):
"""Apply the pipeline's model to a batch of docs, without modifying them.
@@ -186,7 +186,7 @@ class TextCategorizer(TrainablePipe):
# Handle cases where there are no tokens in any docs.
tensors = [doc.tensor for doc in docs]
xp = get_array_module(tensors)
- scores = xp.zeros((len(docs), len(self.labels)))
+ scores = xp.zeros((len(list(docs)), len(self.labels)))
return scores
scores = self.model.predict(docs)
scores = self.model.ops.asarray(scores)
@@ -263,8 +263,9 @@ class TextCategorizer(TrainablePipe):
DOCS: https://spacy.io/api/textcategorizer#rehearse
"""
- if losses is not None:
- losses.setdefault(self.name, 0.0)
+ if losses is None:
+ losses = {}
+ losses.setdefault(self.name, 0.0)
if self._rehearsal_model is None:
return losses
validate_examples(examples, "TextCategorizer.rehearse")
@@ -280,23 +281,23 @@ class TextCategorizer(TrainablePipe):
bp_scores(gradient)
if sgd is not None:
self.finish_update(sgd)
- if losses is not None:
- losses[self.name] += (gradient ** 2).sum()
+ losses[self.name] += (gradient ** 2).sum()
return losses
def _examples_to_truth(
- self, examples: List[Example]
+ self, examples: Iterable[Example]
) -> Tuple[numpy.ndarray, numpy.ndarray]:
- truths = numpy.zeros((len(examples), len(self.labels)), dtype="f")
- not_missing = numpy.ones((len(examples), len(self.labels)), dtype="f")
+ nr_examples = len(list(examples))
+ truths = numpy.zeros((nr_examples, len(self.labels)), dtype="f")
+ not_missing = numpy.ones((nr_examples, len(self.labels)), dtype="f")
for i, eg in enumerate(examples):
for j, label in enumerate(self.labels):
if label in eg.reference.cats:
truths[i, j] = eg.reference.cats[label]
else:
not_missing[i, j] = 0.0
- truths = self.model.ops.asarray(truths)
- return truths, not_missing
+ truths = self.model.ops.asarray(truths) # type: ignore
+ return truths, not_missing # type: ignore
def get_loss(self, examples: Iterable[Example], scores) -> Tuple[float, float]:
"""Find the loss and gradient of loss for the batch of documents and
@@ -311,7 +312,7 @@ class TextCategorizer(TrainablePipe):
validate_examples(examples, "TextCategorizer.get_loss")
self._validate_categories(examples)
truths, not_missing = self._examples_to_truth(examples)
- not_missing = self.model.ops.asarray(not_missing)
+ not_missing = self.model.ops.asarray(not_missing) # type: ignore
d_scores = (scores - truths) / scores.shape[0]
d_scores *= not_missing
mean_square_error = (d_scores ** 2).sum(axis=1).mean()
@@ -330,11 +331,9 @@ class TextCategorizer(TrainablePipe):
if label in self.labels:
return 0
self._allow_extra_label()
- self.cfg["labels"].append(label)
+ self.cfg["labels"].append(label) # type: ignore[attr-defined]
if self.model and "resize_output" in self.model.attrs:
- self.model = self.model.attrs["resize_output"](
- self.model, len(self.cfg["labels"])
- )
+ self.model = self.model.attrs["resize_output"](self.model, len(self.labels))
self.vocab.strings.add(label)
return 1
@@ -387,7 +386,7 @@ class TextCategorizer(TrainablePipe):
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
self.model.initialize(X=doc_sample, Y=label_sample)
- def _validate_categories(self, examples: List[Example]):
+ def _validate_categories(self, examples: Iterable[Example]):
"""Check whether the provided examples all have single-label cats annotations."""
for ex in examples:
if list(ex.reference.cats.values()).count(1.0) > 1:
diff --git a/spacy/pipeline/textcat_multilabel.py b/spacy/pipeline/textcat_multilabel.py
index efa7d28b5..a7bfacca7 100644
--- a/spacy/pipeline/textcat_multilabel.py
+++ b/spacy/pipeline/textcat_multilabel.py
@@ -158,7 +158,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
self.cfg = dict(cfg)
self.scorer = scorer
- def initialize(
+ def initialize( # type: ignore[override]
self,
get_examples: Callable[[], Iterable[Example]],
*,
@@ -193,7 +193,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
self.model.initialize(X=doc_sample, Y=label_sample)
- def _validate_categories(self, examples: List[Example]):
+ def _validate_categories(self, examples: Iterable[Example]):
"""This component allows any type of single- or multi-label annotations.
This method overwrites the more strict one from 'textcat'."""
pass
diff --git a/spacy/pipeline/tok2vec.py b/spacy/pipeline/tok2vec.py
index 00d9548a4..cb601e5dc 100644
--- a/spacy/pipeline/tok2vec.py
+++ b/spacy/pipeline/tok2vec.py
@@ -1,4 +1,4 @@
-from typing import Sequence, Iterable, Optional, Dict, Callable, List
+from typing import Sequence, Iterable, Optional, Dict, Callable, List, Any
from thinc.api import Model, set_dropout_rate, Optimizer, Config
from itertools import islice
@@ -60,8 +60,8 @@ class Tok2Vec(TrainablePipe):
self.vocab = vocab
self.model = model
self.name = name
- self.listener_map = {}
- self.cfg = {}
+ self.listener_map: Dict[str, List["Tok2VecListener"]] = {}
+ self.cfg: Dict[str, Any] = {}
@property
def listeners(self) -> List["Tok2VecListener"]:
@@ -245,12 +245,12 @@ class Tok2VecListener(Model):
"""
Model.__init__(self, name=self.name, forward=forward, dims={"nO": width})
self.upstream_name = upstream_name
- self._batch_id = None
+ self._batch_id: Optional[int] = None
self._outputs = None
self._backprop = None
@classmethod
- def get_batch_id(cls, inputs: List[Doc]) -> int:
+ def get_batch_id(cls, inputs: Iterable[Doc]) -> int:
"""Calculate a content-sensitive hash of the batch of documents, to check
whether the next batch of documents is unexpected.
"""
diff --git a/spacy/schemas.py b/spacy/schemas.py
index bd3f0ecf0..b3ea11d8b 100644
--- a/spacy/schemas.py
+++ b/spacy/schemas.py
@@ -44,7 +44,7 @@ def validate(schema: Type[BaseModel], obj: Dict[str, Any]) -> List[str]:
for error in errors:
err_loc = " -> ".join([str(p) for p in error.get("loc", [])])
data[err_loc].append(error.get("msg"))
- return [f"[{loc}] {', '.join(msg)}" for loc, msg in data.items()]
+ return [f"[{loc}] {', '.join(msg)}" for loc, msg in data.items()] # type: ignore[arg-type]
# Initialization
@@ -82,7 +82,7 @@ def get_arg_model(
except ValueError:
# Typically happens if the method is part of a Cython module without
# binding=True. Here we just use an empty model that allows everything.
- return create_model(name, __config__=ArgSchemaConfigExtra)
+ return create_model(name, __config__=ArgSchemaConfigExtra) # type: ignore[arg-type, return-value]
has_variable = False
for param in sig.parameters.values():
if param.name in exclude:
@@ -102,8 +102,8 @@ def get_arg_model(
default = param.default if param.default != param.empty else default_empty
sig_args[param.name] = (annotation, default)
is_strict = strict and not has_variable
- sig_args["__config__"] = ArgSchemaConfig if is_strict else ArgSchemaConfigExtra
- return create_model(name, **sig_args)
+ sig_args["__config__"] = ArgSchemaConfig if is_strict else ArgSchemaConfigExtra # type: ignore[assignment]
+ return create_model(name, **sig_args) # type: ignore[arg-type, return-value]
def validate_init_settings(
@@ -198,10 +198,10 @@ class TokenPatternNumber(BaseModel):
class TokenPatternOperator(str, Enum):
- plus: StrictStr = "+"
- start: StrictStr = "*"
- question: StrictStr = "?"
- exclamation: StrictStr = "!"
+ plus: StrictStr = StrictStr("+")
+ start: StrictStr = StrictStr("*")
+ question: StrictStr = StrictStr("?")
+ exclamation: StrictStr = StrictStr("!")
StringValue = Union[TokenPatternString, StrictStr]
@@ -386,7 +386,7 @@ class ConfigSchemaInit(BaseModel):
class ConfigSchema(BaseModel):
training: ConfigSchemaTraining
nlp: ConfigSchemaNlp
- pretraining: Union[ConfigSchemaPretrain, ConfigSchemaPretrainEmpty] = {}
+ pretraining: Union[ConfigSchemaPretrain, ConfigSchemaPretrainEmpty] = {} # type: ignore[assignment]
components: Dict[str, Dict[str, Any]]
corpora: Dict[str, Reader]
initialize: ConfigSchemaInit
diff --git a/spacy/scorer.py b/spacy/scorer.py
index bd305c123..49d51a4b3 100644
--- a/spacy/scorer.py
+++ b/spacy/scorer.py
@@ -1,4 +1,5 @@
-from typing import Optional, Iterable, Dict, Set, Any, Callable, TYPE_CHECKING
+from typing import Optional, Iterable, Dict, Set, List, Any, Callable, Tuple
+from typing import TYPE_CHECKING
import numpy as np
from collections import defaultdict
@@ -74,8 +75,8 @@ class ROCAUCScore:
may throw an error."""
def __init__(self) -> None:
- self.golds = []
- self.cands = []
+ self.golds: List[Any] = []
+ self.cands: List[Any] = []
self.saved_score = 0.0
self.saved_score_at_len = 0
@@ -111,9 +112,10 @@ class Scorer:
DOCS: https://spacy.io/api/scorer#init
"""
- self.nlp = nlp
self.cfg = cfg
- if not nlp:
+ if nlp:
+ self.nlp = nlp
+ else:
nlp = get_lang_class(default_lang)()
for pipe in default_pipeline:
nlp.add_pipe(pipe)
@@ -129,7 +131,7 @@ class Scorer:
"""
scores = {}
if hasattr(self.nlp.tokenizer, "score"):
- scores.update(self.nlp.tokenizer.score(examples, **self.cfg))
+ scores.update(self.nlp.tokenizer.score(examples, **self.cfg)) # type: ignore
for name, component in self.nlp.pipeline:
if hasattr(component, "score"):
scores.update(component.score(examples, **self.cfg))
@@ -191,7 +193,7 @@ class Scorer:
attr: str,
*,
getter: Callable[[Token, str], Any] = getattr,
- missing_values: Set[Any] = MISSING_VALUES,
+ missing_values: Set[Any] = MISSING_VALUES, # type: ignore[assignment]
**cfg,
) -> Dict[str, Any]:
"""Returns an accuracy score for a token-level attribute.
@@ -201,6 +203,8 @@ class Scorer:
getter (Callable[[Token, str], Any]): Defaults to getattr. If provided,
getter(token, attr) should return the value of the attribute for an
individual token.
+ missing_values (Set[Any]): Attribute values to treat as missing annotation
+ in the reference annotation.
RETURNS (Dict[str, Any]): A dictionary containing the accuracy score
under the key attr_acc.
@@ -240,7 +244,7 @@ class Scorer:
attr: str,
*,
getter: Callable[[Token, str], Any] = getattr,
- missing_values: Set[Any] = MISSING_VALUES,
+ missing_values: Set[Any] = MISSING_VALUES, # type: ignore[assignment]
**cfg,
) -> Dict[str, Any]:
"""Return PRF scores per feat for a token attribute in UFEATS format.
@@ -250,6 +254,8 @@ class Scorer:
getter (Callable[[Token, str], Any]): Defaults to getattr. If provided,
getter(token, attr) should return the value of the attribute for an
individual token.
+ missing_values (Set[Any]): Attribute values to treat as missing annotation
+ in the reference annotation.
RETURNS (dict): A dictionary containing the per-feat PRF scores under
the key attr_per_feat.
"""
@@ -258,7 +264,7 @@ class Scorer:
pred_doc = example.predicted
gold_doc = example.reference
align = example.alignment
- gold_per_feat = {}
+ gold_per_feat: Dict[str, Set] = {}
missing_indices = set()
for gold_i, token in enumerate(gold_doc):
value = getter(token, attr)
@@ -273,7 +279,7 @@ class Scorer:
gold_per_feat[field].add((gold_i, feat))
else:
missing_indices.add(gold_i)
- pred_per_feat = {}
+ pred_per_feat: Dict[str, Set] = {}
for token in pred_doc:
if token.orth_.isspace():
continue
@@ -350,7 +356,7 @@ class Scorer:
+ [k.label_ for k in getter(pred_doc, attr)]
)
# Set up all labels for per type scoring and prepare gold per type
- gold_per_type = {label: set() for label in labels}
+ gold_per_type: Dict[str, Set] = {label: set() for label in labels}
for label in labels:
if label not in score_per_type:
score_per_type[label] = PRFScore()
@@ -358,16 +364,18 @@ class Scorer:
gold_spans = set()
pred_spans = set()
for span in getter(gold_doc, attr):
+ gold_span: Tuple
if labeled:
gold_span = (span.label_, span.start, span.end - 1)
else:
gold_span = (span.start, span.end - 1)
gold_spans.add(gold_span)
gold_per_type[span.label_].add(gold_span)
- pred_per_type = {label: set() for label in labels}
+ pred_per_type: Dict[str, Set] = {label: set() for label in labels}
for span in example.get_aligned_spans_x2y(
getter(pred_doc, attr), allow_overlap
):
+ pred_span: Tuple
if labeled:
pred_span = (span.label_, span.start, span.end - 1)
else:
@@ -382,7 +390,7 @@ class Scorer:
# Score for all labels
score.score_set(pred_spans, gold_spans)
# Assemble final result
- final_scores = {
+ final_scores: Dict[str, Any] = {
f"{attr}_p": None,
f"{attr}_r": None,
f"{attr}_f": None,
@@ -508,7 +516,7 @@ class Scorer:
sum(auc.score if auc.is_binary() else 0.0 for auc in auc_per_type.values())
/ n_cats
)
- results = {
+ results: Dict[str, Any] = {
f"{attr}_score": None,
f"{attr}_score_desc": None,
f"{attr}_micro_p": micro_prf.precision,
@@ -613,7 +621,7 @@ class Scorer:
head_attr: str = "head",
head_getter: Callable[[Token, str], Token] = getattr,
ignore_labels: Iterable[str] = SimpleFrozenList(),
- missing_values: Set[Any] = MISSING_VALUES,
+ missing_values: Set[Any] = MISSING_VALUES, # type: ignore[assignment]
**cfg,
) -> Dict[str, Any]:
"""Returns the UAS, LAS, and LAS per type scores for dependency
@@ -630,6 +638,8 @@ class Scorer:
head_getter(token, attr) should return the value of the head for an
individual token.
ignore_labels (Tuple): Labels to ignore while scoring (e.g., punct).
+ missing_values (Set[Any]): Attribute values to treat as missing annotation
+ in the reference annotation.
RETURNS (Dict[str, Any]): A dictionary containing the scores:
attr_uas, attr_las, and attr_las_per_type.
@@ -644,7 +654,7 @@ class Scorer:
pred_doc = example.predicted
align = example.alignment
gold_deps = set()
- gold_deps_per_dep = {}
+ gold_deps_per_dep: Dict[str, Set] = {}
for gold_i, token in enumerate(gold_doc):
dep = getter(token, attr)
head = head_getter(token, head_attr)
@@ -659,12 +669,12 @@ class Scorer:
else:
missing_indices.add(gold_i)
pred_deps = set()
- pred_deps_per_dep = {}
+ pred_deps_per_dep: Dict[str, Set] = {}
for token in pred_doc:
if token.orth_.isspace():
continue
if align.x2y.lengths[token.i] != 1:
- gold_i = None
+ gold_i = None # type: ignore
else:
gold_i = align.x2y[token.i].dataXd[0, 0]
if gold_i not in missing_indices:
diff --git a/spacy/strings.pyi b/spacy/strings.pyi
index 57bf71b93..5b4147e12 100644
--- a/spacy/strings.pyi
+++ b/spacy/strings.pyi
@@ -1,7 +1,7 @@
from typing import Optional, Iterable, Iterator, Union, Any
from pathlib import Path
-def get_string_id(key: str) -> int: ...
+def get_string_id(key: Union[str, int]) -> int: ...
class StringStore:
def __init__(
diff --git a/spacy/tests/conftest.py b/spacy/tests/conftest.py
index 10982bac1..b88d11f0e 100644
--- a/spacy/tests/conftest.py
+++ b/spacy/tests/conftest.py
@@ -3,8 +3,13 @@ from spacy.util import get_lang_class
def pytest_addoption(parser):
- parser.addoption("--slow", action="store_true", help="include slow tests")
- parser.addoption("--issue", action="store", help="test specific issues")
+ try:
+ parser.addoption("--slow", action="store_true", help="include slow tests")
+ parser.addoption("--issue", action="store", help="test specific issues")
+ # Options are already added, e.g. if conftest is copied in a build pipeline
+ # and runs twice
+ except ValueError:
+ pass
def pytest_runtest_setup(item):
diff --git a/spacy/tests/lang/fr/test_prefix_suffix_infix.py b/spacy/tests/lang/fr/test_prefix_suffix_infix.py
index 2ead34069..7770f807b 100644
--- a/spacy/tests/lang/fr/test_prefix_suffix_infix.py
+++ b/spacy/tests/lang/fr/test_prefix_suffix_infix.py
@@ -1,5 +1,5 @@
import pytest
-from spacy.language import Language
+from spacy.language import Language, BaseDefaults
from spacy.lang.punctuation import TOKENIZER_INFIXES
from spacy.lang.char_classes import ALPHA
@@ -12,7 +12,7 @@ def test_issue768(text, expected_tokens):
SPLIT_INFIX = r"(?<=[{a}]\')(?=[{a}])".format(a=ALPHA)
class FrenchTest(Language):
- class Defaults(Language.Defaults):
+ class Defaults(BaseDefaults):
infixes = TOKENIZER_INFIXES + [SPLIT_INFIX]
fr_tokenizer_w_infix = FrenchTest().tokenizer
diff --git a/spacy/tests/lang/hu/test_tokenizer.py b/spacy/tests/lang/hu/test_tokenizer.py
index fd3acd0a0..0488474ae 100644
--- a/spacy/tests/lang/hu/test_tokenizer.py
+++ b/spacy/tests/lang/hu/test_tokenizer.py
@@ -294,7 +294,7 @@ WIKI_TESTS = [
]
EXTRA_TESTS = (
- DOT_TESTS + QUOTE_TESTS + NUMBER_TESTS + HYPHEN_TESTS + WIKI_TESTS + TYPO_TESTS
+ DOT_TESTS + QUOTE_TESTS + NUMBER_TESTS + HYPHEN_TESTS + WIKI_TESTS + TYPO_TESTS # type: ignore[operator]
)
# normal: default tests + 10% of extra tests
diff --git a/spacy/tests/package/test_requirements.py b/spacy/tests/package/test_requirements.py
index 8e042c9cf..1d51bd609 100644
--- a/spacy/tests/package/test_requirements.py
+++ b/spacy/tests/package/test_requirements.py
@@ -12,6 +12,10 @@ def test_build_dependencies():
"flake8",
"hypothesis",
"pre-commit",
+ "mypy",
+ "types-dataclasses",
+ "types-mock",
+ "types-requests",
]
# ignore language-specific packages that shouldn't be installed by all
libs_ignore_setup = [
diff --git a/spacy/tests/parser/test_ner.py b/spacy/tests/parser/test_ner.py
index a30001b27..21094bcb1 100644
--- a/spacy/tests/parser/test_ner.py
+++ b/spacy/tests/parser/test_ner.py
@@ -2,14 +2,14 @@ import pytest
from numpy.testing import assert_equal
from spacy.attrs import ENT_IOB
-from spacy import util
+from spacy import util, registry
from spacy.lang.en import English
from spacy.language import Language
from spacy.lookups import Lookups
from spacy.pipeline._parser_internals.ner import BiluoPushDown
from spacy.training import Example
from spacy.tokens import Doc, Span
-from spacy.vocab import Vocab, registry
+from spacy.vocab import Vocab
import logging
from ..util import make_tempdir
diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py
index b97795344..a98d01964 100644
--- a/spacy/tests/pipeline/test_entity_linker.py
+++ b/spacy/tests/pipeline/test_entity_linker.py
@@ -154,6 +154,40 @@ def test_kb_serialize(nlp):
mykb.from_disk(d / "unknown" / "kb")
+@pytest.mark.issue(9137)
+def test_kb_serialize_2(nlp):
+ v = [5, 6, 7, 8]
+ kb1 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=4)
+ kb1.set_entities(["E1"], [1], [v])
+ assert kb1.get_vector("E1") == v
+ with make_tempdir() as d:
+ kb1.to_disk(d / "kb")
+ kb2 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=4)
+ kb2.from_disk(d / "kb")
+ assert kb2.get_vector("E1") == v
+
+
+def test_kb_set_entities(nlp):
+ """Test that set_entities entirely overwrites the previous set of entities"""
+ v = [5, 6, 7, 8]
+ v1 = [1, 1, 1, 0]
+ v2 = [2, 2, 2, 3]
+ kb1 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=4)
+ kb1.set_entities(["E0"], [1], [v])
+ assert kb1.get_entity_strings() == ["E0"]
+ kb1.set_entities(["E1", "E2"], [1, 9], [v1, v2])
+ assert set(kb1.get_entity_strings()) == {"E1", "E2"}
+ assert kb1.get_vector("E1") == v1
+ assert kb1.get_vector("E2") == v2
+ with make_tempdir() as d:
+ kb1.to_disk(d / "kb")
+ kb2 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=4)
+ kb2.from_disk(d / "kb")
+ assert set(kb2.get_entity_strings()) == {"E1", "E2"}
+ assert kb2.get_vector("E1") == v1
+ assert kb2.get_vector("E2") == v2
+
+
def test_kb_serialize_vocab(nlp):
"""Test serialization of the KB and custom strings"""
entity = "MyFunnyID"
diff --git a/spacy/tests/pipeline/test_pipe_factories.py b/spacy/tests/pipeline/test_pipe_factories.py
index f1f0c8a6e..0c2554727 100644
--- a/spacy/tests/pipeline/test_pipe_factories.py
+++ b/spacy/tests/pipeline/test_pipe_factories.py
@@ -135,8 +135,8 @@ def test_pipe_class_component_defaults():
self,
nlp: Language,
name: str,
- value1: StrictInt = 10,
- value2: StrictStr = "hello",
+ value1: StrictInt = StrictInt(10),
+ value2: StrictStr = StrictStr("hello"),
):
self.nlp = nlp
self.value1 = value1
@@ -196,7 +196,11 @@ def test_pipe_class_component_model_custom():
@Language.factory(name, default_config=default_config)
class Component:
def __init__(
- self, nlp: Language, model: Model, name: str, value1: StrictInt = 10
+ self,
+ nlp: Language,
+ model: Model,
+ name: str,
+ value1: StrictInt = StrictInt(10),
):
self.nlp = nlp
self.model = model
diff --git a/spacy/tests/pipeline/test_spancat.py b/spacy/tests/pipeline/test_spancat.py
index d4d0617d7..5c3a9d27d 100644
--- a/spacy/tests/pipeline/test_spancat.py
+++ b/spacy/tests/pipeline/test_spancat.py
@@ -6,8 +6,8 @@ from thinc.api import get_current_ops
from spacy import util
from spacy.lang.en import English
from spacy.language import Language
-from spacy.tokens.doc import SpanGroups
from spacy.tokens import SpanGroup
+from spacy.tokens._dict_proxies import SpanGroups
from spacy.training import Example
from spacy.util import fix_random_seed, registry, make_tempdir
diff --git a/spacy/tests/serialize/test_serialize_doc.py b/spacy/tests/serialize/test_serialize_doc.py
index e51c7f45b..23afaf26c 100644
--- a/spacy/tests/serialize/test_serialize_doc.py
+++ b/spacy/tests/serialize/test_serialize_doc.py
@@ -1,5 +1,5 @@
import pytest
-from spacy.tokens.doc import Underscore
+from spacy.tokens.underscore import Underscore
import spacy
from spacy.lang.en import English
diff --git a/spacy/tests/test_language.py b/spacy/tests/test_language.py
index 7a9021af0..444b1c83e 100644
--- a/spacy/tests/test_language.py
+++ b/spacy/tests/test_language.py
@@ -10,11 +10,21 @@ from spacy.lang.en import English
from spacy.lang.de import German
from spacy.util import registry, ignore_error, raise_error, find_matching_language
import spacy
-from thinc.api import NumpyOps, get_current_ops
+from thinc.api import CupyOps, NumpyOps, get_current_ops
from .util import add_vecs_to_vocab, assert_docs_equal
+try:
+ import torch
+
+ # Ensure that we don't deadlock in multiprocessing tests.
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+except ImportError:
+ pass
+
+
def evil_component(doc):
if "2" in doc.text:
raise ValueError("no dice")
@@ -603,3 +613,17 @@ def test_invalid_arg_to_pipeline(nlp):
list(nlp.pipe(int_list)) # type: ignore
with pytest.raises(ValueError):
nlp(int_list) # type: ignore
+
+
+@pytest.mark.skipif(
+ not isinstance(get_current_ops(), CupyOps), reason="test requires GPU"
+)
+def test_multiprocessing_gpu_warning(nlp2, texts):
+ texts = texts * 10
+ docs = nlp2.pipe(texts, n_process=2, batch_size=2)
+
+ with pytest.warns(UserWarning, match="multiprocessing with GPU models"):
+ with pytest.raises(ValueError):
+ # Trigger multi-processing.
+ for _ in docs:
+ pass
diff --git a/spacy/tests/test_ty.py b/spacy/tests/test_ty.py
new file mode 100644
index 000000000..2037520df
--- /dev/null
+++ b/spacy/tests/test_ty.py
@@ -0,0 +1,18 @@
+import spacy
+from spacy import ty
+
+
+def test_component_types():
+ nlp = spacy.blank("en")
+ tok2vec = nlp.create_pipe("tok2vec")
+ tagger = nlp.create_pipe("tagger")
+ entity_ruler = nlp.create_pipe("entity_ruler")
+ assert isinstance(tok2vec, ty.TrainableComponent)
+ assert isinstance(tagger, ty.TrainableComponent)
+ assert not isinstance(entity_ruler, ty.TrainableComponent)
+ assert isinstance(tok2vec, ty.InitializableComponent)
+ assert isinstance(tagger, ty.InitializableComponent)
+ assert isinstance(entity_ruler, ty.InitializableComponent)
+ assert isinstance(tok2vec, ty.ListenedToComponent)
+ assert not isinstance(tagger, ty.ListenedToComponent)
+ assert not isinstance(entity_ruler, ty.ListenedToComponent)
diff --git a/spacy/tests/training/test_readers.py b/spacy/tests/training/test_readers.py
index 1f262c011..c0c51b287 100644
--- a/spacy/tests/training/test_readers.py
+++ b/spacy/tests/training/test_readers.py
@@ -28,7 +28,7 @@ def test_readers():
"""
@registry.readers("myreader.v1")
- def myreader() -> Dict[str, Callable[[Language, str], Iterable[Example]]]:
+ def myreader() -> Dict[str, Callable[[Language], Iterable[Example]]]:
annots = {"cats": {"POS": 1.0, "NEG": 0.0}}
def reader(nlp: Language):
diff --git a/spacy/tests/vocab_vectors/test_vectors.py b/spacy/tests/vocab_vectors/test_vectors.py
index 8a7dd22c3..23597455f 100644
--- a/spacy/tests/vocab_vectors/test_vectors.py
+++ b/spacy/tests/vocab_vectors/test_vectors.py
@@ -5,7 +5,7 @@ from thinc.api import get_current_ops
from spacy.vocab import Vocab
from spacy.vectors import Vectors
from spacy.tokenizer import Tokenizer
-from spacy.strings import hash_string
+from spacy.strings import hash_string # type: ignore
from spacy.tokens import Doc
from ..util import add_vecs_to_vocab, get_cosine, make_tempdir
diff --git a/spacy/tokens/_dict_proxies.py b/spacy/tokens/_dict_proxies.py
index 9ee1ad02f..470d3430f 100644
--- a/spacy/tokens/_dict_proxies.py
+++ b/spacy/tokens/_dict_proxies.py
@@ -1,9 +1,10 @@
-from typing import Iterable, Tuple, Union, TYPE_CHECKING
+from typing import Iterable, Tuple, Union, Optional, TYPE_CHECKING
import weakref
from collections import UserDict
import srsly
from .span_group import SpanGroup
+from ..errors import Errors
if TYPE_CHECKING:
# This lets us add type hints for mypy etc. without causing circular imports
@@ -13,7 +14,7 @@ if TYPE_CHECKING:
# Why inherit from UserDict instead of dict here?
# Well, the 'dict' class doesn't necessarily delegate everything nicely,
-# for performance reasons. The UserDict is slower by better behaved.
+# for performance reasons. The UserDict is slower but better behaved.
# See https://treyhunner.com/2019/04/why-you-shouldnt-inherit-from-list-and-dict-in-python/0ww
class SpanGroups(UserDict):
"""A dict-like proxy held by the Doc, to control access to span groups."""
@@ -22,7 +23,7 @@ class SpanGroups(UserDict):
self, doc: "Doc", items: Iterable[Tuple[str, SpanGroup]] = tuple()
) -> None:
self.doc_ref = weakref.ref(doc)
- UserDict.__init__(self, items)
+ UserDict.__init__(self, items) # type: ignore[arg-type]
def __setitem__(self, key: str, value: Union[SpanGroup, Iterable["Span"]]) -> None:
if not isinstance(value, SpanGroup):
@@ -31,11 +32,12 @@ class SpanGroups(UserDict):
UserDict.__setitem__(self, key, value)
def _make_span_group(self, name: str, spans: Iterable["Span"]) -> SpanGroup:
- return SpanGroup(self.doc_ref(), name=name, spans=spans)
+ doc = self._ensure_doc()
+ return SpanGroup(doc, name=name, spans=spans)
- def copy(self, doc: "Doc" = None) -> "SpanGroups":
+ def copy(self, doc: Optional["Doc"] = None) -> "SpanGroups":
if doc is None:
- doc = self.doc_ref()
+ doc = self._ensure_doc()
return SpanGroups(doc).from_bytes(self.to_bytes())
def to_bytes(self) -> bytes:
@@ -47,8 +49,14 @@ class SpanGroups(UserDict):
def from_bytes(self, bytes_data: bytes) -> "SpanGroups":
msg = srsly.msgpack_loads(bytes_data)
self.clear()
- doc = self.doc_ref()
+ doc = self._ensure_doc()
for value_bytes in msg:
group = SpanGroup(doc).from_bytes(value_bytes)
self[group.name] = group
return self
+
+ def _ensure_doc(self) -> "Doc":
+ doc = self.doc_ref()
+ if doc is None:
+ raise ValueError(Errors.E866)
+ return doc
diff --git a/spacy/tokens/_retokenize.pyi b/spacy/tokens/_retokenize.pyi
index b829b71a3..8834d38c0 100644
--- a/spacy/tokens/_retokenize.pyi
+++ b/spacy/tokens/_retokenize.pyi
@@ -2,6 +2,7 @@ from typing import Dict, Any, Union, List, Tuple
from .doc import Doc
from .span import Span
from .token import Token
+from .. import Vocab
class Retokenizer:
def __init__(self, doc: Doc) -> None: ...
@@ -15,3 +16,6 @@ class Retokenizer:
) -> None: ...
def __enter__(self) -> Retokenizer: ...
def __exit__(self, *args: Any) -> None: ...
+
+def normalize_token_attrs(vocab: Vocab, attrs: Dict): ...
+def set_token_attrs(py_token: Token, attrs: Dict): ...
diff --git a/spacy/tokens/_serialize.py b/spacy/tokens/_serialize.py
index 1d26b968c..bd2bdb811 100644
--- a/spacy/tokens/_serialize.py
+++ b/spacy/tokens/_serialize.py
@@ -1,6 +1,7 @@
-from typing import Iterable, Iterator, Union
+from typing import List, Dict, Set, Iterable, Iterator, Union, Optional
from pathlib import Path
import numpy
+from numpy import ndarray
import zlib
import srsly
from thinc.api import NumpyOps
@@ -74,13 +75,13 @@ class DocBin:
self.version = "0.1"
self.attrs = [attr for attr in attrs if attr != ORTH and attr != SPACY]
self.attrs.insert(0, ORTH) # Ensure ORTH is always attrs[0]
- self.tokens = []
- self.spaces = []
- self.cats = []
- self.span_groups = []
- self.user_data = []
- self.flags = []
- self.strings = set()
+ self.tokens: List[ndarray] = []
+ self.spaces: List[ndarray] = []
+ self.cats: List[Dict] = []
+ self.span_groups: List[bytes] = []
+ self.user_data: List[Optional[bytes]] = []
+ self.flags: List[Dict] = []
+ self.strings: Set[str] = set()
self.store_user_data = store_user_data
for doc in docs:
self.add(doc)
@@ -139,11 +140,11 @@ class DocBin:
for i in range(len(self.tokens)):
flags = self.flags[i]
tokens = self.tokens[i]
- spaces = self.spaces[i]
+ spaces: Optional[ndarray] = self.spaces[i]
if flags.get("has_unknown_spaces"):
spaces = None
- doc = Doc(vocab, words=tokens[:, orth_col], spaces=spaces)
- doc = doc.from_array(self.attrs, tokens)
+ doc = Doc(vocab, words=tokens[:, orth_col], spaces=spaces) # type: ignore
+ doc = doc.from_array(self.attrs, tokens) # type: ignore
doc.cats = self.cats[i]
if self.span_groups[i]:
doc.spans.from_bytes(self.span_groups[i])
diff --git a/spacy/tokens/doc.pyi b/spacy/tokens/doc.pyi
index 8688fb91f..2b18cee7a 100644
--- a/spacy/tokens/doc.pyi
+++ b/spacy/tokens/doc.pyi
@@ -1,16 +1,5 @@
-from typing import (
- Callable,
- Protocol,
- Iterable,
- Iterator,
- Optional,
- Union,
- Tuple,
- List,
- Dict,
- Any,
- overload,
-)
+from typing import Callable, Protocol, Iterable, Iterator, Optional
+from typing import Union, Tuple, List, Dict, Any, overload
from cymem.cymem import Pool
from thinc.types import Floats1d, Floats2d, Ints2d
from .span import Span
@@ -24,7 +13,7 @@ from pathlib import Path
import numpy
class DocMethod(Protocol):
- def __call__(self: Doc, *args: Any, **kwargs: Any) -> Any: ...
+ def __call__(self: Doc, *args: Any, **kwargs: Any) -> Any: ... # type: ignore[misc]
class Doc:
vocab: Vocab
@@ -150,6 +139,7 @@ class Doc:
self, attr_id: int, exclude: Optional[Any] = ..., counts: Optional[Any] = ...
) -> Dict[Any, int]: ...
def from_array(self, attrs: List[int], array: Ints2d) -> Doc: ...
+ def to_array(self, py_attr_ids: List[int]) -> numpy.ndarray: ...
@staticmethod
def from_docs(
docs: List[Doc],
diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx
index 8ea94558f..d65d18f48 100644
--- a/spacy/tokens/doc.pyx
+++ b/spacy/tokens/doc.pyx
@@ -914,7 +914,7 @@ cdef class Doc:
can specify attributes by integer ID (e.g. spacy.attrs.LEMMA) or
string name (e.g. 'LEMMA' or 'lemma').
- attr_ids (list[]): A list of attributes (int IDs or string names).
+ py_attr_ids (list[]): A list of attributes (int IDs or string names).
RETURNS (numpy.ndarray[long, ndim=2]): A feature matrix, with one row
per word, and one column per attribute indicated in the input
`attr_ids`.
diff --git a/spacy/tokens/morphanalysis.pyi b/spacy/tokens/morphanalysis.pyi
index c7e05e58f..b86203cc4 100644
--- a/spacy/tokens/morphanalysis.pyi
+++ b/spacy/tokens/morphanalysis.pyi
@@ -11,8 +11,8 @@ class MorphAnalysis:
def __iter__(self) -> Iterator[str]: ...
def __len__(self) -> int: ...
def __hash__(self) -> int: ...
- def __eq__(self, other: MorphAnalysis) -> bool: ...
- def __ne__(self, other: MorphAnalysis) -> bool: ...
+ def __eq__(self, other: MorphAnalysis) -> bool: ... # type: ignore[override]
+ def __ne__(self, other: MorphAnalysis) -> bool: ... # type: ignore[override]
def get(self, field: Any) -> List[str]: ...
def to_json(self) -> str: ...
def to_dict(self) -> Dict[str, str]: ...
diff --git a/spacy/tokens/span.pyi b/spacy/tokens/span.pyi
index 4f65abace..697051e81 100644
--- a/spacy/tokens/span.pyi
+++ b/spacy/tokens/span.pyi
@@ -7,7 +7,7 @@ from ..lexeme import Lexeme
from ..vocab import Vocab
class SpanMethod(Protocol):
- def __call__(self: Span, *args: Any, **kwargs: Any) -> Any: ...
+ def __call__(self: Span, *args: Any, **kwargs: Any) -> Any: ... # type: ignore[misc]
class Span:
@classmethod
@@ -45,7 +45,7 @@ class Span:
doc: Doc,
start: int,
end: int,
- label: int = ...,
+ label: Union[str, int] = ...,
vector: Optional[Floats1d] = ...,
vector_norm: Optional[float] = ...,
kb_id: Optional[int] = ...,
@@ -65,6 +65,8 @@ class Span:
def get_lca_matrix(self) -> Ints2d: ...
def similarity(self, other: Union[Doc, Span, Token, Lexeme]) -> float: ...
@property
+ def doc(self) -> Doc: ...
+ @property
def vocab(self) -> Vocab: ...
@property
def sent(self) -> Span: ...
diff --git a/spacy/tokens/span.pyx b/spacy/tokens/span.pyx
index 342b9ffab..96f843a33 100644
--- a/spacy/tokens/span.pyx
+++ b/spacy/tokens/span.pyx
@@ -86,7 +86,7 @@ cdef class Span:
doc (Doc): The parent document.
start (int): The index of the first token of the span.
end (int): The index of the first token after the span.
- label (uint64): A label to attach to the Span, e.g. for named entities.
+ label (int or str): A label to attach to the Span, e.g. for named entities.
vector (ndarray[ndim=1, dtype='float32']): A meaning representation
of the span.
vector_norm (float): The L2 norm of the span's vector representation.
diff --git a/spacy/tokens/span_group.pyi b/spacy/tokens/span_group.pyi
index 4bd6bec27..26efc3ba0 100644
--- a/spacy/tokens/span_group.pyi
+++ b/spacy/tokens/span_group.pyi
@@ -3,6 +3,8 @@ from .doc import Doc
from .span import Span
class SpanGroup:
+ name: str
+ attrs: Dict[str, Any]
def __init__(
self,
doc: Doc,
diff --git a/spacy/tokens/token.pyi b/spacy/tokens/token.pyi
index 23d028ffd..bd585d034 100644
--- a/spacy/tokens/token.pyi
+++ b/spacy/tokens/token.pyi
@@ -16,7 +16,7 @@ from ..vocab import Vocab
from .underscore import Underscore
class TokenMethod(Protocol):
- def __call__(self: Token, *args: Any, **kwargs: Any) -> Any: ...
+ def __call__(self: Token, *args: Any, **kwargs: Any) -> Any: ... # type: ignore[misc]
class Token:
i: int
diff --git a/spacy/tokens/token.pyx b/spacy/tokens/token.pyx
index 75d908601..aa97e2b07 100644
--- a/spacy/tokens/token.pyx
+++ b/spacy/tokens/token.pyx
@@ -600,7 +600,7 @@ cdef class Token:
yield from word.subtree
@property
- def left_edge(self):
+ def left_edge(self) -> int:
"""The leftmost token of this token's syntactic descendents.
RETURNS (Token): The first token such that `self.is_ancestor(token)`.
@@ -608,7 +608,7 @@ cdef class Token:
return self.doc[self.c.l_edge]
@property
- def right_edge(self):
+ def right_edge(self) -> int:
"""The rightmost token of this token's syntactic descendents.
RETURNS (Token): The last token such that `self.is_ancestor(token)`.
diff --git a/spacy/tokens/underscore.py b/spacy/tokens/underscore.py
index b7966fd6e..7fa7bf095 100644
--- a/spacy/tokens/underscore.py
+++ b/spacy/tokens/underscore.py
@@ -1,3 +1,4 @@
+from typing import Dict, Any
import functools
import copy
@@ -6,9 +7,9 @@ from ..errors import Errors
class Underscore:
mutable_types = (dict, list, set)
- doc_extensions = {}
- span_extensions = {}
- token_extensions = {}
+ doc_extensions: Dict[Any, Any] = {}
+ span_extensions: Dict[Any, Any] = {}
+ token_extensions: Dict[Any, Any] = {}
def __init__(self, extensions, obj, start=None, end=None):
object.__setattr__(self, "_extensions", extensions)
diff --git a/spacy/training/augment.py b/spacy/training/augment.py
index 0dae92143..63b54034c 100644
--- a/spacy/training/augment.py
+++ b/spacy/training/augment.py
@@ -22,8 +22,8 @@ class OrthVariantsPaired(BaseModel):
class OrthVariants(BaseModel):
- paired: List[OrthVariantsPaired] = {}
- single: List[OrthVariantsSingle] = {}
+ paired: List[OrthVariantsPaired] = []
+ single: List[OrthVariantsSingle] = []
@registry.augmenters("spacy.orth_variants.v1")
@@ -76,7 +76,7 @@ def lower_casing_augmenter(
def orth_variants_augmenter(
nlp: "Language",
example: Example,
- orth_variants: dict,
+ orth_variants: Dict,
*,
level: float = 0.0,
lower: float = 0.0,
diff --git a/spacy/training/batchers.py b/spacy/training/batchers.py
index e79ba79b0..f0b6c3123 100644
--- a/spacy/training/batchers.py
+++ b/spacy/training/batchers.py
@@ -1,4 +1,4 @@
-from typing import Union, Iterable, Sequence, TypeVar, List, Callable
+from typing import Union, Iterable, Sequence, TypeVar, List, Callable, Iterator
from typing import Optional, Any
from functools import partial
import itertools
@@ -6,7 +6,7 @@ import itertools
from ..util import registry, minibatch
-Sizing = Union[Iterable[int], int]
+Sizing = Union[Sequence[int], int]
ItemT = TypeVar("ItemT")
BatcherT = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]]
@@ -24,7 +24,7 @@ def configure_minibatch_by_padded_size(
The padded size is defined as the maximum length of sequences within the
batch multiplied by the number of sequences in the batch.
- size (int or Iterable[int]): The largest padded size to batch sequences into.
+ size (int or Sequence[int]): The largest padded size to batch sequences into.
Can be a single integer, or a sequence, allowing for variable batch sizes.
buffer (int): The number of sequences to accumulate before sorting by length.
A larger buffer will result in more even sizing, but if the buffer is
@@ -56,7 +56,7 @@ def configure_minibatch_by_words(
) -> BatcherT:
"""Create a batcher that uses the "minibatch by words" strategy.
- size (int or Iterable[int]): The target number of words per batch.
+ size (int or Sequence[int]): The target number of words per batch.
Can be a single integer, or a sequence, allowing for variable batch sizes.
tolerance (float): What percentage of the size to allow batches to exceed.
discard_oversize (bool): Whether to discard sequences that by themselves
@@ -80,7 +80,7 @@ def configure_minibatch(
) -> BatcherT:
"""Create a batcher that creates batches of the specified size.
- size (int or Iterable[int]): The target number of items per batch.
+ size (int or Sequence[int]): The target number of items per batch.
Can be a single integer, or a sequence, allowing for variable batch sizes.
"""
optionals = {"get_length": get_length} if get_length is not None else {}
@@ -100,7 +100,7 @@ def minibatch_by_padded_size(
The padded size is defined as the maximum length of sequences within the
batch multiplied by the number of sequences in the batch.
- size (int): The largest padded size to batch sequences into.
+ size (int or Sequence[int]): The largest padded size to batch sequences into.
buffer (int): The number of sequences to accumulate before sorting by length.
A larger buffer will result in more even sizing, but if the buffer is
very large, the iteration order will be less random, which can result
@@ -111,9 +111,9 @@ def minibatch_by_padded_size(
The `len` function is used by default.
"""
if isinstance(size, int):
- size_ = itertools.repeat(size)
+ size_ = itertools.repeat(size) # type: Iterator[int]
else:
- size_ = size
+ size_ = iter(size)
for outer_batch in minibatch(seqs, size=buffer):
outer_batch = list(outer_batch)
target_size = next(size_)
@@ -138,7 +138,7 @@ def minibatch_by_words(
themselves, or be discarded if discard_oversize=True.
seqs (Iterable[Sequence]): The sequences to minibatch.
- size (int or Iterable[int]): The target number of words per batch.
+ size (int or Sequence[int]): The target number of words per batch.
Can be a single integer, or a sequence, allowing for variable batch sizes.
tolerance (float): What percentage of the size to allow batches to exceed.
discard_oversize (bool): Whether to discard sequences that by themselves
@@ -147,11 +147,9 @@ def minibatch_by_words(
item. The `len` function is used by default.
"""
if isinstance(size, int):
- size_ = itertools.repeat(size)
- elif isinstance(size, List):
- size_ = iter(size)
+ size_ = itertools.repeat(size) # type: Iterator[int]
else:
- size_ = size
+ size_ = iter(size)
target_size = next(size_)
tol_size = target_size * tolerance
batch = []
@@ -216,7 +214,7 @@ def _batch_by_length(
lengths_indices = [(get_length(seq), i) for i, seq in enumerate(seqs)]
lengths_indices.sort()
batches = []
- batch = []
+ batch: List[int] = []
for length, i in lengths_indices:
if not batch:
batch.append(i)
diff --git a/spacy/training/callbacks.py b/spacy/training/callbacks.py
index 2a21be98c..426fddf90 100644
--- a/spacy/training/callbacks.py
+++ b/spacy/training/callbacks.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Callable, Optional
from ..errors import Errors
from ..language import Language
from ..util import load_model, registry, logger
@@ -8,7 +8,7 @@ from ..util import load_model, registry, logger
def create_copy_from_base_model(
tokenizer: Optional[str] = None,
vocab: Optional[str] = None,
-) -> Language:
+) -> Callable[[Language], Language]:
def copy_from_base_model(nlp):
if tokenizer:
logger.info(f"Copying tokenizer from: {tokenizer}")
diff --git a/spacy/training/corpus.py b/spacy/training/corpus.py
index 606dbfb4a..b30d918fd 100644
--- a/spacy/training/corpus.py
+++ b/spacy/training/corpus.py
@@ -41,8 +41,8 @@ def create_docbin_reader(
@util.registry.readers("spacy.JsonlCorpus.v1")
def create_jsonl_reader(
- path: Optional[Path], min_length: int = 0, max_length: int = 0, limit: int = 0
-) -> Callable[["Language"], Iterable[Doc]]:
+ path: Union[str, Path], min_length: int = 0, max_length: int = 0, limit: int = 0
+) -> Callable[["Language"], Iterable[Example]]:
return JsonlCorpus(path, min_length=min_length, max_length=max_length, limit=limit)
@@ -129,15 +129,15 @@ class Corpus:
"""
ref_docs = self.read_docbin(nlp.vocab, walk_corpus(self.path, FILE_TYPE))
if self.shuffle:
- ref_docs = list(ref_docs)
- random.shuffle(ref_docs)
+ ref_docs = list(ref_docs) # type: ignore
+ random.shuffle(ref_docs) # type: ignore
if self.gold_preproc:
examples = self.make_examples_gold_preproc(nlp, ref_docs)
else:
examples = self.make_examples(nlp, ref_docs)
for real_eg in examples:
- for augmented_eg in self.augmenter(nlp, real_eg):
+ for augmented_eg in self.augmenter(nlp, real_eg): # type: ignore[operator]
yield augmented_eg
def _make_example(
@@ -190,7 +190,7 @@ class Corpus:
i = 0
for loc in locs:
loc = util.ensure_path(loc)
- if loc.parts[-1].endswith(FILE_TYPE):
+ if loc.parts[-1].endswith(FILE_TYPE): # type: ignore[union-attr]
doc_bin = DocBin().from_disk(loc)
docs = doc_bin.get_docs(vocab)
for doc in docs:
@@ -202,7 +202,7 @@ class Corpus:
class JsonlCorpus:
- """Iterate Doc objects from a file or directory of jsonl
+ """Iterate Example objects from a file or directory of jsonl
formatted raw text files.
path (Path): The directory or filename to read from.
diff --git a/spacy/training/initialize.py b/spacy/training/initialize.py
index 4eb8ea276..96abcc7cd 100644
--- a/spacy/training/initialize.py
+++ b/spacy/training/initialize.py
@@ -106,7 +106,7 @@ def init_vocab(
data: Optional[Path] = None,
lookups: Optional[Lookups] = None,
vectors: Optional[str] = None,
-) -> "Language":
+) -> None:
if lookups:
nlp.vocab.lookups = lookups
logger.info(f"Added vocab lookups: {', '.join(lookups.tables)}")
@@ -164,7 +164,7 @@ def load_vectors_into_model(
logger.warning(Warnings.W112.format(name=name))
for lex in nlp.vocab:
- lex.rank = nlp.vocab.vectors.key2row.get(lex.orth, OOV_RANK)
+ lex.rank = nlp.vocab.vectors.key2row.get(lex.orth, OOV_RANK) # type: ignore[attr-defined]
def init_tok2vec(
@@ -203,7 +203,7 @@ def convert_vectors(
nlp.vocab.vectors = Vectors(data=numpy.load(vectors_loc.open("rb")))
for lex in nlp.vocab:
if lex.rank and lex.rank != OOV_RANK:
- nlp.vocab.vectors.add(lex.orth, row=lex.rank)
+ nlp.vocab.vectors.add(lex.orth, row=lex.rank) # type: ignore[attr-defined]
else:
if vectors_loc:
logger.info(f"Reading vectors from {vectors_loc}")
@@ -251,14 +251,14 @@ def open_file(loc: Union[str, Path]) -> IO:
"""Handle .gz, .tar.gz or unzipped files"""
loc = ensure_path(loc)
if tarfile.is_tarfile(str(loc)):
- return tarfile.open(str(loc), "r:gz")
+ return tarfile.open(str(loc), "r:gz") # type: ignore[return-value]
elif loc.parts[-1].endswith("gz"):
- return (line.decode("utf8") for line in gzip.open(str(loc), "r"))
+ return (line.decode("utf8") for line in gzip.open(str(loc), "r")) # type: ignore[return-value]
elif loc.parts[-1].endswith("zip"):
zip_file = zipfile.ZipFile(str(loc))
names = zip_file.namelist()
file_ = zip_file.open(names[0])
- return (line.decode("utf8") for line in file_)
+ return (line.decode("utf8") for line in file_) # type: ignore[return-value]
else:
return loc.open("r", encoding="utf8")
diff --git a/spacy/training/iob_utils.py b/spacy/training/iob_utils.py
index 42dae8fc4..64492c2bc 100644
--- a/spacy/training/iob_utils.py
+++ b/spacy/training/iob_utils.py
@@ -1,4 +1,4 @@
-from typing import List, Tuple, Iterable, Union, Iterator
+from typing import List, Dict, Tuple, Iterable, Union, Iterator
import warnings
from ..errors import Errors, Warnings
@@ -6,7 +6,7 @@ from ..tokens import Span, Doc
def iob_to_biluo(tags: Iterable[str]) -> List[str]:
- out = []
+ out: List[str] = []
tags = list(tags)
while tags:
out.extend(_consume_os(tags))
@@ -90,7 +90,7 @@ def offsets_to_biluo_tags(
>>> assert tags == ["O", "O", 'U-LOC', "O"]
"""
# Ensure no overlapping entity labels exist
- tokens_in_ents = {}
+ tokens_in_ents: Dict[int, Tuple[int, int, Union[str, int]]] = {}
starts = {token.idx: token.i for token in doc}
ends = {token.idx + len(token): token.i for token in doc}
biluo = ["-" for _ in doc]
@@ -199,14 +199,18 @@ def tags_to_entities(tags: Iterable[str]) -> List[Tuple[str, int, int]]:
pass
elif tag.startswith("I"):
if start is None:
- raise ValueError(Errors.E067.format(start="I", tags=tags[: i + 1]))
+ raise ValueError(
+ Errors.E067.format(start="I", tags=list(tags)[: i + 1])
+ )
elif tag.startswith("U"):
entities.append((tag[2:], i, i))
elif tag.startswith("B"):
start = i
elif tag.startswith("L"):
if start is None:
- raise ValueError(Errors.E067.format(start="L", tags=tags[: i + 1]))
+ raise ValueError(
+ Errors.E067.format(start="L", tags=list(tags)[: i + 1])
+ )
entities.append((tag[2:], start, i))
start = None
else:
diff --git a/spacy/training/loggers.py b/spacy/training/loggers.py
index d80c77b6a..edd0f1959 100644
--- a/spacy/training/loggers.py
+++ b/spacy/training/loggers.py
@@ -98,4 +98,3 @@ def console_logger(progress_bar: bool = False):
return log_step, finalize
return setup_printer
-
diff --git a/spacy/training/loop.py b/spacy/training/loop.py
index 09c54fc9f..06372cbb0 100644
--- a/spacy/training/loop.py
+++ b/spacy/training/loop.py
@@ -32,7 +32,7 @@ def train(
"""Train a pipeline.
nlp (Language): The initialized nlp object with the full config.
- output_path (Path): Optional output path to save trained model to.
+ output_path (Optional[Path]): Optional output path to save trained model to.
use_gpu (int): Whether to train on GPU. Make sure to call require_gpu
before calling this function.
stdout (file): A file-like object to write output messages. To disable
@@ -194,17 +194,17 @@ def train_while_improving(
else:
dropouts = dropout
results = []
- losses = {}
+ losses: Dict[str, float] = {}
words_seen = 0
start_time = timer()
for step, (epoch, batch) in enumerate(train_data):
- dropout = next(dropouts)
+ dropout = next(dropouts) # type: ignore
for subbatch in subdivide_batch(batch, accumulate_gradient):
nlp.update(
subbatch,
drop=dropout,
losses=losses,
- sgd=False,
+ sgd=False, # type: ignore[arg-type]
exclude=exclude,
annotates=annotating_components,
)
@@ -214,9 +214,9 @@ def train_while_improving(
name not in exclude
and hasattr(proc, "is_trainable")
and proc.is_trainable
- and proc.model not in (True, False, None)
+ and proc.model not in (True, False, None) # type: ignore[attr-defined]
):
- proc.finish_update(optimizer)
+ proc.finish_update(optimizer) # type: ignore[attr-defined]
optimizer.step_schedules()
if not (step % eval_frequency):
if optimizer.averages:
@@ -310,13 +310,13 @@ def create_train_batches(
):
epoch = 0
if max_epochs >= 0:
- examples = list(corpus(nlp))
+ examples = list(corpus(nlp)) # type: Iterable[Example]
if not examples:
# Raise error if no data
raise ValueError(Errors.E986)
while max_epochs < 1 or epoch != max_epochs:
if max_epochs >= 0:
- random.shuffle(examples)
+ random.shuffle(examples) # type: ignore
else:
examples = corpus(nlp)
for batch in batcher(examples):
@@ -353,7 +353,7 @@ def create_before_to_disk_callback(
return before_to_disk
-def clean_output_dir(path: Union[str, Path]) -> None:
+def clean_output_dir(path: Optional[Path]) -> None:
"""Remove an existing output directory. Typically used to ensure that that
a directory like model-best and its contents aren't just being overwritten
by nlp.to_disk, which could preserve existing subdirectories (e.g.
diff --git a/spacy/training/pretrain.py b/spacy/training/pretrain.py
index de4f80e5d..7830196bc 100644
--- a/spacy/training/pretrain.py
+++ b/spacy/training/pretrain.py
@@ -101,7 +101,7 @@ def ensure_docs(examples_or_docs: Iterable[Union[Doc, Example]]) -> List[Doc]:
def _resume_model(
- model: Model, resume_path: Path, epoch_resume: int, silent: bool = True
+ model: Model, resume_path: Path, epoch_resume: Optional[int], silent: bool = True
) -> int:
msg = Printer(no_print=silent)
msg.info(f"Resume training tok2vec from: {resume_path}")
diff --git a/spacy/ty.py b/spacy/ty.py
new file mode 100644
index 000000000..8f2903d78
--- /dev/null
+++ b/spacy/ty.py
@@ -0,0 +1,55 @@
+from typing import TYPE_CHECKING
+from typing import Optional, Any, Iterable, Dict, Callable, Sequence, List
+from .compat import Protocol, runtime_checkable
+
+from thinc.api import Optimizer, Model
+
+if TYPE_CHECKING:
+ from .training import Example
+
+
+@runtime_checkable
+class TrainableComponent(Protocol):
+ model: Any
+ is_trainable: bool
+
+ def update(
+ self,
+ examples: Iterable["Example"],
+ *,
+ drop: float = 0.0,
+ sgd: Optional[Optimizer] = None,
+ losses: Optional[Dict[str, float]] = None
+ ) -> Dict[str, float]:
+ ...
+
+ def finish_update(self, sgd: Optimizer) -> None:
+ ...
+
+
+@runtime_checkable
+class InitializableComponent(Protocol):
+ def initialize(
+ self,
+ get_examples: Callable[[], Iterable["Example"]],
+ nlp: Iterable["Example"],
+ **kwargs: Any
+ ):
+ ...
+
+
+@runtime_checkable
+class ListenedToComponent(Protocol):
+ model: Any
+ listeners: Sequence[Model]
+ listener_map: Dict[str, Sequence[Model]]
+ listening_components: List[str]
+
+ def add_listener(self, listener: Model, component_name: str) -> None:
+ ...
+
+ def remove_listener(self, listener: Model, component_name: str) -> bool:
+ ...
+
+ def find_listeners(self, component) -> None:
+ ...
diff --git a/spacy/util.py b/spacy/util.py
index b25be5361..e14f6030f 100644
--- a/spacy/util.py
+++ b/spacy/util.py
@@ -1,4 +1,5 @@
-from typing import List, Union, Dict, Any, Optional, Iterable, Callable, Tuple
+from typing import List, Mapping, NoReturn, Union, Dict, Any, Set
+from typing import Optional, Iterable, Callable, Tuple, Type
from typing import Iterator, Type, Pattern, Generator, TYPE_CHECKING
from types import ModuleType
import os
@@ -52,6 +53,7 @@ from . import about
if TYPE_CHECKING:
# This lets us add type hints for mypy etc. without causing circular imports
from .language import Language # noqa: F401
+ from .pipeline import Pipe # noqa: F401
from .tokens import Doc, Span # noqa: F401
from .vocab import Vocab # noqa: F401
@@ -291,7 +293,7 @@ def find_matching_language(lang: str) -> Optional[str]:
# Find out which language modules we have
possible_languages = []
- for modinfo in pkgutil.iter_modules(spacy.lang.__path__):
+ for modinfo in pkgutil.iter_modules(spacy.lang.__path__): # type: ignore
code = modinfo.name
if code == 'xx':
# Temporarily make 'xx' into a valid language code
@@ -314,7 +316,7 @@ def find_matching_language(lang: str) -> Optional[str]:
return match
-def get_lang_class(lang: str) -> "Language":
+def get_lang_class(lang: str) -> Type["Language"]:
"""Import and load a Language class.
lang (str): IETF language code, such as 'en'.
@@ -341,7 +343,7 @@ def get_lang_class(lang: str) -> "Language":
module = importlib.import_module(f".lang.{lang}", "spacy")
else:
raise ImportError(Errors.E048.format(lang=lang, err=err)) from err
- set_lang_class(lang, getattr(module, module.__all__[0]))
+ set_lang_class(lang, getattr(module, module.__all__[0])) # type: ignore[attr-defined]
return registry.languages.get(lang)
@@ -416,13 +418,13 @@ def load_model(
if name.startswith("blank:"): # shortcut for blank model
return get_lang_class(name.replace("blank:", ""))()
if is_package(name): # installed as package
- return load_model_from_package(name, **kwargs)
+ return load_model_from_package(name, **kwargs) # type: ignore[arg-type]
if Path(name).exists(): # path to model data directory
- return load_model_from_path(Path(name), **kwargs)
+ return load_model_from_path(Path(name), **kwargs) # type: ignore[arg-type]
elif hasattr(name, "exists"): # Path or Path-like to model data
- return load_model_from_path(name, **kwargs)
+ return load_model_from_path(name, **kwargs) # type: ignore[arg-type]
if name in OLD_MODEL_SHORTCUTS:
- raise IOError(Errors.E941.format(name=name, full=OLD_MODEL_SHORTCUTS[name]))
+ raise IOError(Errors.E941.format(name=name, full=OLD_MODEL_SHORTCUTS[name])) # type: ignore[index]
raise IOError(Errors.E050.format(name=name))
@@ -449,11 +451,11 @@ def load_model_from_package(
RETURNS (Language): The loaded nlp object.
"""
cls = importlib.import_module(name)
- return cls.load(vocab=vocab, disable=disable, exclude=exclude, config=config)
+ return cls.load(vocab=vocab, disable=disable, exclude=exclude, config=config) # type: ignore[attr-defined]
def load_model_from_path(
- model_path: Union[str, Path],
+ model_path: Path,
*,
meta: Optional[Dict[str, Any]] = None,
vocab: Union["Vocab", bool] = True,
@@ -464,7 +466,7 @@ def load_model_from_path(
"""Load a model from a data directory path. Creates Language class with
pipeline from config.cfg and then calls from_disk() with path.
- name (str): Package name or model path.
+ model_path (Path): Mmodel path.
meta (Dict[str, Any]): Optional model meta.
vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
a new Vocab object will be created.
@@ -546,7 +548,9 @@ def get_sourced_components(
}
-def resolve_dot_names(config: Config, dot_names: List[Optional[str]]) -> Tuple[Any]:
+def resolve_dot_names(
+ config: Config, dot_names: List[Optional[str]]
+) -> Tuple[Any, ...]:
"""Resolve one or more "dot notation" names, e.g. corpora.train.
The paths could point anywhere into the config, so we don't know which
top-level section we'll be looking within.
@@ -556,7 +560,7 @@ def resolve_dot_names(config: Config, dot_names: List[Optional[str]]) -> Tuple[A
"""
# TODO: include schema?
resolved = {}
- output = []
+ output: List[Any] = []
errors = []
for name in dot_names:
if name is None:
@@ -572,7 +576,7 @@ def resolve_dot_names(config: Config, dot_names: List[Optional[str]]) -> Tuple[A
result = registry.resolve(config[section])
resolved[section] = result
try:
- output.append(dot_to_object(resolved, name))
+ output.append(dot_to_object(resolved, name)) # type: ignore[arg-type]
except KeyError:
msg = f"not a valid section reference: {name}"
errors.append({"loc": name.split("."), "msg": msg})
@@ -676,8 +680,8 @@ def get_package_version(name: str) -> Optional[str]:
RETURNS (str / None): The version or None if package not installed.
"""
try:
- return importlib_metadata.version(name)
- except importlib_metadata.PackageNotFoundError:
+ return importlib_metadata.version(name) # type: ignore[attr-defined]
+ except importlib_metadata.PackageNotFoundError: # type: ignore[attr-defined]
return None
@@ -700,7 +704,7 @@ def is_compatible_version(
constraint = f"=={constraint}"
try:
spec = SpecifierSet(constraint)
- version = Version(version)
+ version = Version(version) # type: ignore[assignment]
except (InvalidSpecifier, InvalidVersion):
return None
spec.prereleases = prereleases
@@ -814,7 +818,7 @@ def load_meta(path: Union[str, Path]) -> Dict[str, Any]:
if "spacy_version" in meta:
if not is_compatible_version(about.__version__, meta["spacy_version"]):
lower_version = get_model_lower_version(meta["spacy_version"])
- lower_version = get_minor_version(lower_version)
+ lower_version = get_minor_version(lower_version) # type: ignore[arg-type]
if lower_version is not None:
lower_version = "v" + lower_version
elif "spacy_git_version" in meta:
@@ -856,7 +860,7 @@ def is_package(name: str) -> bool:
RETURNS (bool): True if installed package, False if not.
"""
try:
- importlib_metadata.distribution(name)
+ importlib_metadata.distribution(name) # type: ignore[attr-defined]
return True
except: # noqa: E722
return False
@@ -917,7 +921,7 @@ def run_command(
*,
stdin: Optional[Any] = None,
capture: bool = False,
-) -> Optional[subprocess.CompletedProcess]:
+) -> subprocess.CompletedProcess:
"""Run a command on the command line as a subprocess. If the subprocess
returns a non-zero exit code, a system exit is performed.
@@ -960,8 +964,8 @@ def run_command(
message += f"\n\nProcess log (stdout and stderr):\n\n"
message += ret.stdout
error = subprocess.SubprocessError(message)
- error.ret = ret
- error.command = cmd_str
+ error.ret = ret # type: ignore[attr-defined]
+ error.command = cmd_str # type: ignore[attr-defined]
raise error
elif ret.returncode != 0:
sys.exit(ret.returncode)
@@ -969,7 +973,7 @@ def run_command(
@contextmanager
-def working_dir(path: Union[str, Path]) -> None:
+def working_dir(path: Union[str, Path]) -> Iterator[Path]:
"""Change current working directory and returns to previous on exit.
path (str / Path): The directory to navigate to.
@@ -1017,7 +1021,7 @@ def is_in_jupyter() -> bool:
"""
# https://stackoverflow.com/a/39662359/6400719
try:
- shell = get_ipython().__class__.__name__
+ shell = get_ipython().__class__.__name__ # type: ignore[name-defined]
if shell == "ZMQInteractiveShell":
return True # Jupyter notebook or qtconsole
except NameError:
@@ -1099,7 +1103,7 @@ def compile_prefix_regex(entries: Iterable[Union[str, Pattern]]) -> Pattern:
spacy.lang.punctuation.TOKENIZER_PREFIXES.
RETURNS (Pattern): The regex object. to be used for Tokenizer.prefix_search.
"""
- expression = "|".join(["^" + piece for piece in entries if piece.strip()])
+ expression = "|".join(["^" + piece for piece in entries if piece.strip()]) # type: ignore[operator, union-attr]
return re.compile(expression)
@@ -1110,7 +1114,7 @@ def compile_suffix_regex(entries: Iterable[Union[str, Pattern]]) -> Pattern:
spacy.lang.punctuation.TOKENIZER_SUFFIXES.
RETURNS (Pattern): The regex object. to be used for Tokenizer.suffix_search.
"""
- expression = "|".join([piece + "$" for piece in entries if piece.strip()])
+ expression = "|".join([piece + "$" for piece in entries if piece.strip()]) # type: ignore[operator, union-attr]
return re.compile(expression)
@@ -1121,7 +1125,7 @@ def compile_infix_regex(entries: Iterable[Union[str, Pattern]]) -> Pattern:
spacy.lang.punctuation.TOKENIZER_INFIXES.
RETURNS (regex object): The regex object. to be used for Tokenizer.infix_finditer.
"""
- expression = "|".join([piece for piece in entries if piece.strip()])
+ expression = "|".join([piece for piece in entries if piece.strip()]) # type: ignore[misc, union-attr]
return re.compile(expression)
@@ -1143,7 +1147,7 @@ def _get_attr_unless_lookup(
) -> Any:
for lookup in lookups:
if string in lookup:
- return lookup[string]
+ return lookup[string] # type: ignore[index]
return default_func(string)
@@ -1225,7 +1229,7 @@ def filter_spans(spans: Iterable["Span"]) -> List["Span"]:
get_sort_key = lambda span: (span.end - span.start, -span.start)
sorted_spans = sorted(spans, key=get_sort_key, reverse=True)
result = []
- seen_tokens = set()
+ seen_tokens: Set[int] = set()
for span in sorted_spans:
# Check for end - 1 here because boundaries are inclusive
if span.start not in seen_tokens and span.end - 1 not in seen_tokens:
@@ -1244,7 +1248,7 @@ def from_bytes(
setters: Dict[str, Callable[[bytes], Any]],
exclude: Iterable[str],
) -> None:
- return from_dict(srsly.msgpack_loads(bytes_data), setters, exclude)
+ return from_dict(srsly.msgpack_loads(bytes_data), setters, exclude) # type: ignore[return-value]
def to_dict(
@@ -1306,8 +1310,8 @@ def import_file(name: str, loc: Union[str, Path]) -> ModuleType:
RETURNS: The loaded module.
"""
spec = importlib.util.spec_from_file_location(name, str(loc))
- module = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(module)
+ module = importlib.util.module_from_spec(spec) # type: ignore[arg-type]
+ spec.loader.exec_module(module) # type: ignore[union-attr]
return module
@@ -1397,7 +1401,7 @@ def dot_to_dict(values: Dict[str, Any]) -> Dict[str, dict]:
values (Dict[str, Any]): The key/value pairs to convert.
RETURNS (Dict[str, dict]): The converted values.
"""
- result = {}
+ result: Dict[str, dict] = {}
for key, value in values.items():
path = result
parts = key.lower().split(".")
@@ -1479,9 +1483,9 @@ def get_arg_names(func: Callable) -> List[str]:
def combine_score_weights(
- weights: List[Dict[str, float]],
- overrides: Dict[str, Optional[Union[float, int]]] = SimpleFrozenDict(),
-) -> Dict[str, float]:
+ weights: List[Dict[str, Optional[float]]],
+ overrides: Dict[str, Optional[float]] = SimpleFrozenDict(),
+) -> Dict[str, Optional[float]]:
"""Combine and normalize score weights defined by components, e.g.
{"ents_r": 0.2, "ents_p": 0.3, "ents_f": 0.5} and {"some_other_score": 1.0}.
@@ -1493,7 +1497,9 @@ def combine_score_weights(
# We divide each weight by the total weight sum.
# We first need to extract all None/null values for score weights that
# shouldn't be shown in the table *or* be weighted
- result = {key: value for w_dict in weights for (key, value) in w_dict.items()}
+ result: Dict[str, Optional[float]] = {
+ key: value for w_dict in weights for (key, value) in w_dict.items()
+ }
result.update(overrides)
weight_sum = sum([v if v else 0.0 for v in result.values()])
for key, value in result.items():
@@ -1515,13 +1521,13 @@ class DummyTokenizer:
def to_bytes(self, **kwargs):
return b""
- def from_bytes(self, _bytes_data, **kwargs):
+ def from_bytes(self, data: bytes, **kwargs) -> "DummyTokenizer":
return self
- def to_disk(self, _path, **kwargs):
+ def to_disk(self, path: Union[str, Path], **kwargs) -> None:
return None
- def from_disk(self, _path, **kwargs):
+ def from_disk(self, path: Union[str, Path], **kwargs) -> "DummyTokenizer":
return self
@@ -1583,7 +1589,13 @@ def check_bool_env_var(env_var: str) -> bool:
return bool(value)
-def _pipe(docs, proc, name, default_error_handler, kwargs):
+def _pipe(
+ docs: Iterable["Doc"],
+ proc: "Pipe",
+ name: str,
+ default_error_handler: Callable[[str, "Pipe", List["Doc"], Exception], NoReturn],
+ kwargs: Mapping[str, Any],
+) -> Iterator["Doc"]:
if hasattr(proc, "pipe"):
yield from proc.pipe(docs, **kwargs)
else:
@@ -1597,7 +1609,7 @@ def _pipe(docs, proc, name, default_error_handler, kwargs):
kwargs.pop(arg)
for doc in docs:
try:
- doc = proc(doc, **kwargs)
+ doc = proc(doc, **kwargs) # type: ignore[call-arg]
yield doc
except Exception as e:
error_handler(name, proc, [doc], e)
@@ -1661,7 +1673,7 @@ def packages_distributions() -> Dict[str, List[str]]:
it's not available in the builtin importlib.metadata.
"""
pkg_to_dist = defaultdict(list)
- for dist in importlib_metadata.distributions():
+ for dist in importlib_metadata.distributions(): # type: ignore[attr-defined]
for pkg in (dist.read_text("top_level.txt") or "").split():
pkg_to_dist[pkg].append(dist.metadata["Name"])
return dict(pkg_to_dist)
diff --git a/spacy/vocab.pyi b/spacy/vocab.pyi
index 7c0d0598e..713e85c01 100644
--- a/spacy/vocab.pyi
+++ b/spacy/vocab.pyi
@@ -1,26 +1,27 @@
-from typing import (
- Callable,
- Iterator,
- Optional,
- Union,
- Tuple,
- List,
- Dict,
- Any,
-)
+from typing import Callable, Iterator, Optional, Union, List, Dict
+from typing import Any, Iterable
from thinc.types import Floats1d, FloatsXd
from . import Language
from .strings import StringStore
from .lexeme import Lexeme
from .lookups import Lookups
+from .morphology import Morphology
from .tokens import Doc, Span
+from .vectors import Vectors
from pathlib import Path
def create_vocab(
- lang: Language, defaults: Any, vectors_name: Optional[str] = ...
+ lang: Optional[str], defaults: Any, vectors_name: Optional[str] = ...
) -> Vocab: ...
class Vocab:
+ cfg: Dict[str, Any]
+ get_noun_chunks: Optional[Callable[[Union[Doc, Span]], Iterator[Span]]]
+ lookups: Lookups
+ morphology: Morphology
+ strings: StringStore
+ vectors: Vectors
+ writing_system: Dict[str, Any]
def __init__(
self,
lex_attr_getters: Optional[Dict[str, Callable[[str], Any]]] = ...,
@@ -32,7 +33,7 @@ class Vocab:
get_noun_chunks: Optional[Callable[[Union[Doc, Span]], Iterator[Span]]] = ...,
) -> None: ...
@property
- def lang(self) -> Language: ...
+ def lang(self) -> str: ...
def __len__(self) -> int: ...
def add_flag(
self, flag_getter: Callable[[str], bool], flag_id: int = ...
@@ -54,16 +55,15 @@ class Vocab:
) -> FloatsXd: ...
def set_vector(self, orth: Union[int, str], vector: Floats1d) -> None: ...
def has_vector(self, orth: Union[int, str]) -> bool: ...
- lookups: Lookups
def to_disk(
- self, path: Union[str, Path], *, exclude: Union[List[str], Tuple[str]] = ...
+ self, path: Union[str, Path], *, exclude: Iterable[str] = ...
) -> None: ...
def from_disk(
- self, path: Union[str, Path], *, exclude: Union[List[str], Tuple[str]] = ...
+ self, path: Union[str, Path], *, exclude: Iterable[str] = ...
) -> Vocab: ...
- def to_bytes(self, *, exclude: Union[List[str], Tuple[str]] = ...) -> bytes: ...
+ def to_bytes(self, *, exclude: Iterable[str] = ...) -> bytes: ...
def from_bytes(
- self, bytes_data: bytes, *, exclude: Union[List[str], Tuple[str]] = ...
+ self, bytes_data: bytes, *, exclude: Iterable[str] = ...
) -> Vocab: ...
def pickle_vocab(vocab: Vocab) -> Any: ...
diff --git a/spacy/vocab.pyx b/spacy/vocab.pyx
index ef4435656..9840603f5 100644
--- a/spacy/vocab.pyx
+++ b/spacy/vocab.pyx
@@ -61,7 +61,7 @@ cdef class Vocab:
lookups (Lookups): Container for large lookup tables and dictionaries.
oov_prob (float): Default OOV probability.
vectors_name (str): Optional name to identify the vectors table.
- get_noun_chunks (Optional[Callable[[Union[Doc, Span], Iterator[Span]]]]):
+ get_noun_chunks (Optional[Callable[[Union[Doc, Span], Iterator[Tuple[int, int, int]]]]]):
A function that yields base noun phrases used for Doc.noun_chunks.
"""
lex_attr_getters = lex_attr_getters if lex_attr_getters is not None else {}
@@ -450,7 +450,7 @@ cdef class Vocab:
path (str or Path): A path to a directory, which will be created if
it doesn't exist.
- exclude (list): String names of serialization fields to exclude.
+ exclude (Iterable[str]): String names of serialization fields to exclude.
DOCS: https://spacy.io/api/vocab#to_disk
"""
@@ -470,7 +470,7 @@ cdef class Vocab:
returns it.
path (str or Path): A path to a directory.
- exclude (list): String names of serialization fields to exclude.
+ exclude (Iterable[str]): String names of serialization fields to exclude.
RETURNS (Vocab): The modified `Vocab` object.
DOCS: https://spacy.io/api/vocab#to_disk
@@ -495,7 +495,7 @@ cdef class Vocab:
def to_bytes(self, *, exclude=tuple()):
"""Serialize the current state to a binary string.
- exclude (list): String names of serialization fields to exclude.
+ exclude (Iterable[str]): String names of serialization fields to exclude.
RETURNS (bytes): The serialized form of the `Vocab` object.
DOCS: https://spacy.io/api/vocab#to_bytes
@@ -517,7 +517,7 @@ cdef class Vocab:
"""Load state from a binary string.
bytes_data (bytes): The data to load from.
- exclude (list): String names of serialization fields to exclude.
+ exclude (Iterable[str]): String names of serialization fields to exclude.
RETURNS (Vocab): The `Vocab` object.
DOCS: https://spacy.io/api/vocab#from_bytes
diff --git a/website/docs/api/architectures.md b/website/docs/api/architectures.md
index ceeb388ab..7044a7d02 100644
--- a/website/docs/api/architectures.md
+++ b/website/docs/api/architectures.md
@@ -332,15 +332,18 @@ for details and system requirements.
-### spacy-transformers.TransformerModel.v1 {#TransformerModel}
+### spacy-transformers.TransformerModel.v3 {#TransformerModel}
> #### Example Config
>
> ```ini
> [model]
-> @architectures = "spacy-transformers.TransformerModel.v1"
+> @architectures = "spacy-transformers.TransformerModel.v3"
> name = "roberta-base"
> tokenizer_config = {"use_fast": true}
+> transformer_config = {}
+> mixed_precision = true
+> grad_scaler_config = {"init_scale": 32768}
>
> [model.get_spans]
> @span_getters = "spacy-transformers.strided_spans.v1"
@@ -366,12 +369,31 @@ transformer weights across your pipeline. For a layer that's configured for use
in other components, see
[Tok2VecTransformer](/api/architectures#Tok2VecTransformer).
-| Name | Description |
-| ------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| `name` | Any model name that can be loaded by [`transformers.AutoModel`](https://huggingface.co/transformers/model_doc/auto.html#transformers.AutoModel). ~~str~~ |
-| `get_spans` | Function that takes a batch of [`Doc`](/api/doc) object and returns lists of [`Span`](/api) objects to process by the transformer. [See here](/api/transformer#span_getters) for built-in options and examples. ~~Callable[[List[Doc]], List[Span]]~~ |
-| `tokenizer_config` | Tokenizer settings passed to [`transformers.AutoTokenizer`](https://huggingface.co/transformers/model_doc/auto.html#transformers.AutoTokenizer). ~~Dict[str, Any]~~ |
-| **CREATES** | The model using the architecture. ~~Model[List[Doc], FullTransformerBatch]~~ |
+| Name | Description |
+| -------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `name` | Any model name that can be loaded by [`transformers.AutoModel`](https://huggingface.co/transformers/model_doc/auto.html#transformers.AutoModel). ~~str~~ |
+| `get_spans` | Function that takes a batch of [`Doc`](/api/doc) object and returns lists of [`Span`](/api) objects to process by the transformer. [See here](/api/transformer#span_getters) for built-in options and examples. ~~Callable[[List[Doc]], List[Span]]~~ |
+| `tokenizer_config` | Tokenizer settings passed to [`transformers.AutoTokenizer`](https://huggingface.co/transformers/model_doc/auto.html#transformers.AutoTokenizer). ~~Dict[str, Any]~~ |
+| `transformer_config` | Transformer settings passed to [`transformers.AutoConfig`](https://huggingface.co/transformers/model_doc/auto.html?highlight=autoconfig#transformers.AutoConfig) ~~Dict[str, Any]~~ |
+| `mixed_precision` | Replace whitelisted ops by half-precision counterparts. Speeds up training and prediction on GPUs with [Tensor Cores](https://developer.nvidia.com/tensor-cores) and reduces GPU memory use. ~~bool~~ |
+| `grad_scaler_config` | Configuration to pass to `thinc.api.PyTorchGradScaler` during training when `mixed_precision` is enabled. ~~Dict[str, Any]~~ |
+| **CREATES** | The model using the architecture. ~~Model[List[Doc], FullTransformerBatch]~~ |
+| | |
+
+
+Mixed-precision support is currently an experimental feature.
+
+
+
+
+- The `transformer_config` argument was added in
+ `spacy-transformers.TransformerModel.v2`.
+- The `mixed_precision` and `grad_scaler_config` arguments were added in
+ `spacy-transformers.TransformerModel.v3`.
+
+The other arguments are shared between all versions.
+
+
### spacy-transformers.TransformerListener.v1 {#TransformerListener}
@@ -403,16 +425,19 @@ a single token vector given zero or more wordpiece vectors.
| `upstream` | A string to identify the "upstream" `Transformer` component to communicate with. By default, the upstream name is the wildcard string `"*"`, but you could also specify the name of the `Transformer` component. You'll almost never have multiple upstream `Transformer` components, so the wildcard string will almost always be fine. ~~str~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Doc], List[Floats2d]]~~ |
-### spacy-transformers.Tok2VecTransformer.v1 {#Tok2VecTransformer}
+### spacy-transformers.Tok2VecTransformer.v3 {#Tok2VecTransformer}
> #### Example Config
>
> ```ini
> [model]
-> @architectures = "spacy-transformers.Tok2VecTransformer.v1"
+> @architectures = "spacy-transformers.Tok2VecTransformer.v3"
> name = "albert-base-v2"
> tokenizer_config = {"use_fast": false}
+> transformer_config = {}
> grad_factor = 1.0
+> mixed_precision = true
+> grad_scaler_config = {"init_scale": 32768}
> ```
Use a transformer as a [`Tok2Vec`](/api/tok2vec) layer directly. This does
@@ -421,13 +446,31 @@ Use a transformer as a [`Tok2Vec`](/api/tok2vec) layer directly. This does
object, but it's a **simpler solution** if you only need the transformer within
one component.
-| Name | Description |
-| ------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| `get_spans` | Function that takes a batch of [`Doc`](/api/doc) object and returns lists of [`Span`](/api) objects to process by the transformer. [See here](/api/transformer#span_getters) for built-in options and examples. ~~Callable[[List[Doc]], List[Span]]~~ |
-| `tokenizer_config` | Tokenizer settings passed to [`transformers.AutoTokenizer`](https://huggingface.co/transformers/model_doc/auto.html#transformers.AutoTokenizer). ~~Dict[str, Any]~~ |
-| `pooling` | A reduction layer used to calculate the token vectors based on zero or more wordpiece vectors. If in doubt, mean pooling (see [`reduce_mean`](https://thinc.ai/docs/api-layers#reduce_mean)) is usually a good choice. ~~Model[Ragged, Floats2d]~~ |
-| `grad_factor` | Reweight gradients from the component before passing them upstream. You can set this to `0` to "freeze" the transformer weights with respect to the component, or use it to make some components more significant than others. Leaving it at `1.0` is usually fine. ~~float~~ |
-| **CREATES** | The model using the architecture. ~~Model[List[Doc], List[Floats2d]]~~ |
+| Name | Description |
+| -------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `get_spans` | Function that takes a batch of [`Doc`](/api/doc) object and returns lists of [`Span`](/api) objects to process by the transformer. [See here](/api/transformer#span_getters) for built-in options and examples. ~~Callable[[List[Doc]], List[Span]]~~ |
+| `tokenizer_config` | Tokenizer settings passed to [`transformers.AutoTokenizer`](https://huggingface.co/transformers/model_doc/auto.html#transformers.AutoTokenizer). ~~Dict[str, Any]~~ |
+| `transformer_config` | Settings to pass to the transformers forward pass. ~~Dict[str, Any]~~ |
+| `pooling` | A reduction layer used to calculate the token vectors based on zero or more wordpiece vectors. If in doubt, mean pooling (see [`reduce_mean`](https://thinc.ai/docs/api-layers#reduce_mean)) is usually a good choice. ~~Model[Ragged, Floats2d]~~ |
+| `grad_factor` | Reweight gradients from the component before passing them upstream. You can set this to `0` to "freeze" the transformer weights with respect to the component, or use it to make some components more significant than others. Leaving it at `1.0` is usually fine. ~~float~~ |
+| `mixed_precision` | Replace whitelisted ops by half-precision counterparts. Speeds up training and prediction on GPUs with [Tensor Cores](https://developer.nvidia.com/tensor-cores) and reduces GPU memory use. ~~bool~~ |
+| `grad_scaler_config` | Configuration to pass to `thinc.api.PyTorchGradScaler` during training when `mixed_precision` is enabled. ~~Dict[str, Any]~~ |
+| **CREATES** | The model using the architecture. ~~Model[List[Doc], List[Floats2d]]~~ |
+
+
+Mixed-precision support is currently an experimental feature.
+
+
+
+
+- The `transformer_config` argument was added in
+ `spacy-transformers.Tok2VecTransformer.v2`.
+- The `mixed_precision` and `grad_scaler_config` arguments were added in
+ `spacy-transformers.Tok2VecTransformer.v3`.
+
+The other arguments are shared between all versions.
+
+
## Pretraining architectures {#pretrain source="spacy/ml/models/multi_task.py"}
diff --git a/website/docs/api/entityruler.md b/website/docs/api/entityruler.md
index ed4ebbd10..fb33642f8 100644
--- a/website/docs/api/entityruler.md
+++ b/website/docs/api/entityruler.md
@@ -289,7 +289,7 @@ All labels present in the match patterns.
| ----------- | -------------------------------------- |
| **RETURNS** | The string labels. ~~Tuple[str, ...]~~ |
-## EntityRuler.ent_ids {#labels tag="property" new="2.2.2"}
+## EntityRuler.ent_ids {#ent_ids tag="property" new="2.2.2"}
All entity IDs present in the `id` properties of the match patterns.
diff --git a/website/docs/api/language.md b/website/docs/api/language.md
index 4cf063fcc..f756f14b5 100644
--- a/website/docs/api/language.md
+++ b/website/docs/api/language.md
@@ -1077,9 +1077,9 @@ customize the default language data:
| --------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `stop_words` | List of stop words, used for `Token.is_stop`.
**Example:** [`stop_words.py`](%%GITHUB_SPACY/spacy/lang/en/stop_words.py) ~~Set[str]~~ |
| `tokenizer_exceptions` | Tokenizer exception rules, string mapped to list of token attributes.
**Example:** [`de/tokenizer_exceptions.py`](%%GITHUB_SPACY/spacy/lang/de/tokenizer_exceptions.py) ~~Dict[str, List[dict]]~~ |
-| `prefixes`, `suffixes`, `infixes` | Prefix, suffix and infix rules for the default tokenizer.
**Example:** [`puncutation.py`](%%GITHUB_SPACY/spacy/lang/punctuation.py) ~~Optional[List[Union[str, Pattern]]]~~ |
-| `token_match` | Optional regex for matching strings that should never be split, overriding the infix rules.
**Example:** [`fr/tokenizer_exceptions.py`](%%GITHUB_SPACY/spacy/lang/fr/tokenizer_exceptions.py) ~~Optional[Pattern]~~ |
-| `url_match` | Regular expression for matching URLs. Prefixes and suffixes are removed before applying the match.
**Example:** [`tokenizer_exceptions.py`](%%GITHUB_SPACY/spacy/lang/tokenizer_exceptions.py) ~~Optional[Pattern]~~ |
+| `prefixes`, `suffixes`, `infixes` | Prefix, suffix and infix rules for the default tokenizer.
**Example:** [`puncutation.py`](%%GITHUB_SPACY/spacy/lang/punctuation.py) ~~Optional[Sequence[Union[str, Pattern]]]~~ |
+| `token_match` | Optional regex for matching strings that should never be split, overriding the infix rules.
**Example:** [`fr/tokenizer_exceptions.py`](%%GITHUB_SPACY/spacy/lang/fr/tokenizer_exceptions.py) ~~Optional[Callable]~~ |
+| `url_match` | Regular expression for matching URLs. Prefixes and suffixes are removed before applying the match.
**Example:** [`tokenizer_exceptions.py`](%%GITHUB_SPACY/spacy/lang/tokenizer_exceptions.py) ~~Optional[Callable]~~ |
| `lex_attr_getters` | Custom functions for setting lexical attributes on tokens, e.g. `like_num`.
**Example:** [`lex_attrs.py`](%%GITHUB_SPACY/spacy/lang/en/lex_attrs.py) ~~Dict[int, Callable[[str], Any]]~~ |
| `syntax_iterators` | Functions that compute views of a `Doc` object based on its syntax. At the moment, only used for [noun chunks](/usage/linguistic-features#noun-chunks).
**Example:** [`syntax_iterators.py`](%%GITHUB_SPACY/spacy/lang/en/syntax_iterators.py). ~~Dict[str, Callable[[Union[Doc, Span]], Iterator[Span]]]~~ |
| `writing_system` | Information about the language's writing system, available via `Vocab.writing_system`. Defaults to: `{"direction": "ltr", "has_case": True, "has_letters": True}.`.
**Example:** [`zh/__init__.py`](%%GITHUB_SPACY/spacy/lang/zh/__init__.py) ~~Dict[str, Any]~~ |
diff --git a/website/docs/api/phrasematcher.md b/website/docs/api/phrasematcher.md
index 71ee4b7d1..2cef9ac2a 100644
--- a/website/docs/api/phrasematcher.md
+++ b/website/docs/api/phrasematcher.md
@@ -150,7 +150,7 @@ patterns = [nlp("health care reform"), nlp("healthcare reform")]
| Name | Description |
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| `match_id` | An ID for the thing you're matching. ~~str~~ | |
+| `key` | An ID for the thing you're matching. ~~str~~ |
| `docs` | `Doc` objects of the phrases to match. ~~List[Doc]~~ |
| _keyword-only_ | |
| `on_match` | Callback function to act on matches. Takes the arguments `matcher`, `doc`, `i` and `matches`. ~~Optional[Callable[[Matcher, Doc, int, List[tuple], Any]]~~ |
diff --git a/website/docs/api/spancategorizer.md b/website/docs/api/spancategorizer.md
index d5d44239e..26fcaefdf 100644
--- a/website/docs/api/spancategorizer.md
+++ b/website/docs/api/spancategorizer.md
@@ -54,7 +54,7 @@ architectures and their arguments and hyperparameters.
| Setting | Description |
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| `suggester` | A function that [suggests spans](#suggesters). Spans are returned as a ragged array with two integer columns, for the start and end positions. Defaults to [`ngram_suggester`](#ngram_suggester). ~~Callable[List[Doc], Ragged]~~ |
+| `suggester` | A function that [suggests spans](#suggesters). Spans are returned as a ragged array with two integer columns, for the start and end positions. Defaults to [`ngram_suggester`](#ngram_suggester). ~~Callable[[Iterable[Doc], Optional[Ops]], Ragged]~~ |
| `model` | A model instance that is given a a list of documents and `(start, end)` indices representing candidate span offsets. The model predicts a probability for each category for each span. Defaults to [SpanCategorizer](/api/architectures#SpanCategorizer). ~~Model[Tuple[List[Doc], Ragged], Floats2d]~~ |
| `spans_key` | Key of the [`Doc.spans`](/api/doc#spans) dict to save the spans under. During initialization and training, the component will look for spans on the reference document under the same key. Defaults to `"spans"`. ~~str~~ |
| `threshold` | Minimum probability to consider a prediction positive. Spans with a positive prediction will be saved on the Doc. Defaults to `0.5`. ~~float~~ |
@@ -90,7 +90,7 @@ shortcut for this and instantiate the component using its string name and
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `vocab` | The shared vocabulary. ~~Vocab~~ |
| `model` | A model instance that is given a a list of documents and `(start, end)` indices representing candidate span offsets. The model predicts a probability for each category for each span. ~~Model[Tuple[List[Doc], Ragged], Floats2d]~~ |
-| `suggester` | A function that [suggests spans](#suggesters). Spans are returned as a ragged array with two integer columns, for the start and end positions. ~~Callable[List[Doc], Ragged]~~ |
+| `suggester` | A function that [suggests spans](#suggesters). Spans are returned as a ragged array with two integer columns, for the start and end positions. ~~Callable[[Iterable[Doc], Optional[Ops]], Ragged]~~ |
| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
| _keyword-only_ | |
| `spans_key` | Key of the [`Doc.spans`](/api/doc#sans) dict to save the spans under. During initialization and training, the component will look for spans on the reference document under the same key. Defaults to `"spans"`. ~~str~~ |
@@ -252,11 +252,11 @@ predicted scores.
> loss, d_loss = spancat.get_loss(examples, scores)
> ```
-| Name | Description |
-| ----------- | --------------------------------------------------------------------------- |
-| `examples` | The batch of examples. ~~Iterable[Example]~~ |
-| `scores` | Scores representing the model's predictions. |
-| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
+| Name | Description |
+| -------------- | --------------------------------------------------------------------------- |
+| `examples` | The batch of examples. ~~Iterable[Example]~~ |
+| `spans_scores` | Scores representing the model's predictions. ~~Tuple[Ragged, Floats2d]~~ |
+| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
## SpanCategorizer.create_optimizer {#create_optimizer tag="method"}
@@ -451,7 +451,7 @@ integers. The array has two columns, indicating the start and end position.
| Name | Description |
| ----------- | -------------------------------------------------------------------------------------------------------------------- |
| `sizes` | The phrase lengths to suggest. For example, `[1, 2]` will suggest phrases consisting of 1 or 2 tokens. ~~List[int]~~ |
-| **CREATES** | The suggester function. ~~Callable[[List[Doc]], Ragged]~~ |
+| **CREATES** | The suggester function. ~~Callable[[Iterable[Doc], Optional[Ops]], Ragged]~~ |
### spacy.ngram_range_suggester.v1 {#ngram_range_suggester}
@@ -468,8 +468,8 @@ Suggest all spans of at least length `min_size` and at most length `max_size`
(both inclusive). Spans are returned as a ragged array of integers. The array
has two columns, indicating the start and end position.
-| Name | Description |
-| ----------- | ------------------------------------------------------------ |
-| `min_size` | The minimal phrase lengths to suggest (inclusive). ~~[int]~~ |
-| `max_size` | The maximal phrase lengths to suggest (exclusive). ~~[int]~~ |
-| **CREATES** | The suggester function. ~~Callable[[List[Doc]], Ragged]~~ |
+| Name | Description |
+| ----------- | ---------------------------------------------------------------------------- |
+| `min_size` | The minimal phrase lengths to suggest (inclusive). ~~[int]~~ |
+| `max_size` | The maximal phrase lengths to suggest (exclusive). ~~[int]~~ |
+| **CREATES** | The suggester function. ~~Callable[[Iterable[Doc], Optional[Ops]], Ragged]~~ |
diff --git a/website/docs/api/top-level.md b/website/docs/api/top-level.md
index b48cd47f3..6bd29cf5f 100644
--- a/website/docs/api/top-level.md
+++ b/website/docs/api/top-level.md
@@ -764,6 +764,26 @@ from the specified model. Intended for use in `[initialize.before_init]`.
| `vocab` | The pipeline to copy the vocab from. The vocab includes the lookups and vectors. Defaults to `None`. ~~Optional[str]~~ |
| **CREATES** | A function that takes the current `nlp` object and modifies its `tokenizer` and `vocab`. ~~Callable[[Language], None]~~ |
+### spacy.models_with_nvtx_range.v1 {#models_with_nvtx_range tag="registered function"}
+
+> #### Example config
+>
+> ```ini
+> [nlp]
+> after_pipeline_creation = {"@callbacks":"spacy.models_with_nvtx_range.v1"}
+> ```
+
+Recursively wrap the models in each pipe using [NVTX](https://nvidia.github.io/NVTX/)
+range markers. These markers aid in GPU profiling by attributing specific operations
+to a ~~Model~~'s forward or backprop passes.
+
+| Name | Description |
+|------------------|------------------------------------------------------------------------------------------------------------------------------|
+| `forward_color` | Color identifier for forward passes. Defaults to `-1`. ~~int~~ |
+| `backprop_color` | Color identifier for backpropagation passes. Defaults to `-1`. ~~int~~ |
+| **CREATES** | A function that takes the current `nlp` and wraps forward/backprop passes in NVTX ranges. ~~Callable[[Language], Language]~~ |
+
+
## Training data and alignment {#gold source="spacy/training"}
### training.offsets_to_biluo_tags {#offsets_to_biluo_tags tag="function"}
diff --git a/website/docs/api/transformer.md b/website/docs/api/transformer.md
index 6e68ac599..b1673cdbe 100644
--- a/website/docs/api/transformer.md
+++ b/website/docs/api/transformer.md
@@ -92,9 +92,12 @@ https://github.com/explosion/spacy-transformers/blob/master/spacy_transformers/p
> # Construction via add_pipe with custom config
> config = {
> "model": {
-> "@architectures": "spacy-transformers.TransformerModel.v1",
+> "@architectures": "spacy-transformers.TransformerModel.v3",
> "name": "bert-base-uncased",
-> "tokenizer_config": {"use_fast": True}
+> "tokenizer_config": {"use_fast": True},
+> "transformer_config": {"output_attentions": True},
+> "mixed_precision": True,
+> "grad_scaler_config": {"init_scale": 32768}
> }
> }
> trf = nlp.add_pipe("transformer", config=config)
@@ -394,12 +397,13 @@ are wrapped into the
by this class. Instances of this class are typically assigned to the
[`Doc._.trf_data`](/api/transformer#assigned-attributes) extension attribute.
-| Name | Description |
-| --------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| `tokens` | A slice of the tokens data produced by the tokenizer. This may have several fields, including the token IDs, the texts and the attention mask. See the [`transformers.BatchEncoding`](https://huggingface.co/transformers/main_classes/tokenizer.html#transformers.BatchEncoding) object for details. ~~dict~~ |
-| `tensors` | The activations for the `Doc` from the transformer. Usually the last tensor that is 3-dimensional will be the most important, as that will provide the final hidden state. Generally activations that are 2-dimensional will be attention weights. Details of this variable will differ depending on the underlying transformer model. ~~List[FloatsXd]~~ |
-| `align` | Alignment from the `Doc`'s tokenization to the wordpieces. This is a ragged array, where `align.lengths[i]` indicates the number of wordpiece tokens that token `i` aligns against. The actual indices are provided at `align[i].dataXd`. ~~Ragged~~ |
-| `width` | The width of the last hidden layer. ~~int~~ |
+| Name | Description |
+| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
+| `tokens` | A slice of the tokens data produced by the tokenizer. This may have several fields, including the token IDs, the texts and the attention mask. See the [`transformers.BatchEncoding`](https://huggingface.co/transformers/main_classes/tokenizer.html#transformers.BatchEncoding) object for details. ~~dict~~ |
+| `model_output` | The model output from the transformer model, determined by the model and transformer config. New in `spacy-transformers` v1.1.0. ~~transformers.file_utils.ModelOutput~~ |
+| `tensors` | The `model_output` in the earlier `transformers` tuple format converted using [`ModelOutput.to_tuple()`](https://huggingface.co/transformers/main_classes/output.html#transformers.file_utils.ModelOutput.to_tuple). Returns `Tuple` instead of `List` as of `spacy-transformers` v1.1.0. ~~Tuple[Union[FloatsXd, List[FloatsXd]]]~~ |
+| `align` | Alignment from the `Doc`'s tokenization to the wordpieces. This is a ragged array, where `align.lengths[i]` indicates the number of wordpiece tokens that token `i` aligns against. The actual indices are provided at `align[i].dataXd`. ~~Ragged~~ |
+| `width` | The width of the last hidden layer. ~~int~~ |
### TransformerData.empty {#transformerdata-emoty tag="classmethod"}
@@ -409,19 +413,32 @@ Create an empty `TransformerData` container.
| ----------- | ---------------------------------- |
| **RETURNS** | The container. ~~TransformerData~~ |
+
+
+In `spacy-transformers` v1.0, the model output is stored in
+`TransformerData.tensors` as `List[Union[FloatsXd]]` and only includes the
+activations for the `Doc` from the transformer. Usually the last tensor that is
+3-dimensional will be the most important, as that will provide the final hidden
+state. Generally activations that are 2-dimensional will be attention weights.
+Details of this variable will differ depending on the underlying transformer
+model.
+
+
+
## FullTransformerBatch {#fulltransformerbatch tag="dataclass"}
Holds a batch of input and output objects for a transformer model. The data can
then be split to a list of [`TransformerData`](/api/transformer#transformerdata)
objects to associate the outputs to each [`Doc`](/api/doc) in the batch.
-| Name | Description |
-| ---------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| `spans` | The batch of input spans. The outer list refers to the Doc objects in the batch, and the inner list are the spans for that `Doc`. Note that spans are allowed to overlap or exclude tokens, but each `Span` can only refer to one `Doc` (by definition). This means that within a `Doc`, the regions of the output tensors that correspond to each `Span` may overlap or have gaps, but for each `Doc`, there is a non-overlapping contiguous slice of the outputs. ~~List[List[Span]]~~ |
-| `tokens` | The output of the tokenizer. ~~transformers.BatchEncoding~~ |
-| `tensors` | The output of the transformer model. ~~List[torch.Tensor]~~ |
-| `align` | Alignment from the spaCy tokenization to the wordpieces. This is a ragged array, where `align.lengths[i]` indicates the number of wordpiece tokens that token `i` aligns against. The actual indices are provided at `align[i].dataXd`. ~~Ragged~~ |
-| `doc_data` | The outputs, split per `Doc` object. ~~List[TransformerData]~~ |
+| Name | Description |
+| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `spans` | The batch of input spans. The outer list refers to the Doc objects in the batch, and the inner list are the spans for that `Doc`. Note that spans are allowed to overlap or exclude tokens, but each `Span` can only refer to one `Doc` (by definition). This means that within a `Doc`, the regions of the output tensors that correspond to each `Span` may overlap or have gaps, but for each `Doc`, there is a non-overlapping contiguous slice of the outputs. ~~List[List[Span]]~~ |
+| `tokens` | The output of the tokenizer. ~~transformers.BatchEncoding~~ |
+| `model_output` | The model output from the transformer model, determined by the model and transformer config. New in `spacy-transformers` v1.1.0. ~~transformers.file_utils.ModelOutput~~ |
+| `tensors` | The `model_output` in the earlier `transformers` tuple format converted using [`ModelOutput.to_tuple()`](https://huggingface.co/transformers/main_classes/output.html#transformers.file_utils.ModelOutput.to_tuple). Returns `Tuple` instead of `List` as of `spacy-transformers` v1.1.0. ~~Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]~~ |
+| `align` | Alignment from the spaCy tokenization to the wordpieces. This is a ragged array, where `align.lengths[i]` indicates the number of wordpiece tokens that token `i` aligns against. The actual indices are provided at `align[i].dataXd`. ~~Ragged~~ |
+| `doc_data` | The outputs, split per `Doc` object. ~~List[TransformerData]~~ |
### FullTransformerBatch.unsplit_by_doc {#fulltransformerbatch-unsplit_by_doc tag="method"}
@@ -444,6 +461,13 @@ Split a `TransformerData` object that represents a batch into a list with one
| ----------- | ------------------------------------------ |
| **RETURNS** | The split batch. ~~List[TransformerData]~~ |
+
+
+In `spacy-transformers` v1.0, the model output is stored in
+`FullTransformerBatch.tensors` as `List[torch.Tensor]`.
+
+
+
## Span getters {#span_getters source="github.com/explosion/spacy-transformers/blob/master/spacy_transformers/span_getters.py"}
Span getters are functions that take a batch of [`Doc`](/api/doc) objects and
diff --git a/website/docs/api/vocab.md b/website/docs/api/vocab.md
index c37b27a0e..c0a269d95 100644
--- a/website/docs/api/vocab.md
+++ b/website/docs/api/vocab.md
@@ -21,15 +21,15 @@ Create the vocabulary.
> vocab = Vocab(strings=["hello", "world"])
> ```
-| Name | Description |
-| ------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| `lex_attr_getters` | A dictionary mapping attribute IDs to functions to compute them. Defaults to `None`. ~~Optional[Dict[str, Callable[[str], Any]]]~~ |
-| `strings` | A [`StringStore`](/api/stringstore) that maps strings to hash values, and vice versa, or a list of strings. ~~Union[List[str], StringStore]~~ |
-| `lookups` | A [`Lookups`](/api/lookups) that stores the `lexeme_norm` and other large lookup tables. Defaults to `None`. ~~Optional[Lookups]~~ |
-| `oov_prob` | The default OOV probability. Defaults to `-20.0`. ~~float~~ |
-| `vectors_name` 2.2 | A name to identify the vectors table. ~~str~~ |
-| `writing_system` | A dictionary describing the language's writing system. Typically provided by [`Language.Defaults`](/api/language#defaults). ~~Dict[str, Any]~~ |
-| `get_noun_chunks` | A function that yields base noun phrases used for [`Doc.noun_chunks`](/api/doc#noun_chunks). ~~Optional[Callable[[Union[Doc, Span], Iterator[Span]]]]~~ |
+| Name | Description |
+| ------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `lex_attr_getters` | A dictionary mapping attribute IDs to functions to compute them. Defaults to `None`. ~~Optional[Dict[str, Callable[[str], Any]]]~~ |
+| `strings` | A [`StringStore`](/api/stringstore) that maps strings to hash values, and vice versa, or a list of strings. ~~Union[List[str], StringStore]~~ |
+| `lookups` | A [`Lookups`](/api/lookups) that stores the `lexeme_norm` and other large lookup tables. Defaults to `None`. ~~Optional[Lookups]~~ |
+| `oov_prob` | The default OOV probability. Defaults to `-20.0`. ~~float~~ |
+| `vectors_name` 2.2 | A name to identify the vectors table. ~~str~~ |
+| `writing_system` | A dictionary describing the language's writing system. Typically provided by [`Language.Defaults`](/api/language#defaults). ~~Dict[str, Any]~~ |
+| `get_noun_chunks` | A function that yields base noun phrases used for [`Doc.noun_chunks`](/api/doc#noun_chunks). ~~Optional[Callable[[Union[Doc, Span], Iterator[Tuple[int, int, int]]]]]~~ |
## Vocab.\_\_len\_\_ {#len tag="method"}
@@ -300,14 +300,14 @@ Load state from a binary string.
> assert type(PERSON) == int
> ```
-| Name | Description |
-| ---------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ |
-| `strings` | A table managing the string-to-int mapping. ~~StringStore~~ |
-| `vectors` 2 | A table associating word IDs to word vectors. ~~Vectors~~ |
-| `vectors_length` | Number of dimensions for each word vector. ~~int~~ |
-| `lookups` | The available lookup tables in this vocab. ~~Lookups~~ |
-| `writing_system` 2.1 | A dict with information about the language's writing system. ~~Dict[str, Any]~~ |
-| `get_noun_chunks` 3.0 | A function that yields base noun phrases used for [`Doc.noun_chunks`](/ap/doc#noun_chunks). ~~Optional[Callable[[Union[Doc, Span], Iterator[Span]]]]~~ |
+| Name | Description |
+| ---------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `strings` | A table managing the string-to-int mapping. ~~StringStore~~ |
+| `vectors` 2 | A table associating word IDs to word vectors. ~~Vectors~~ |
+| `vectors_length` | Number of dimensions for each word vector. ~~int~~ |
+| `lookups` | The available lookup tables in this vocab. ~~Lookups~~ |
+| `writing_system` 2.1 | A dict with information about the language's writing system. ~~Dict[str, Any]~~ |
+| `get_noun_chunks` 3.0 | A function that yields base noun phrases used for [`Doc.noun_chunks`](/ap/doc#noun_chunks). ~~Optional[Callable[[Union[Doc, Span], Iterator[Tuple[int, int, int]]]]]~~ |
## Serialization fields {#serialization-fields}
diff --git a/website/docs/usage/embeddings-transformers.md b/website/docs/usage/embeddings-transformers.md
index 88fb39f61..febed6f2f 100644
--- a/website/docs/usage/embeddings-transformers.md
+++ b/website/docs/usage/embeddings-transformers.md
@@ -351,7 +351,7 @@ factory = "transformer"
max_batch_items = 4096
[components.transformer.model]
-@architectures = "spacy-transformers.TransformerModel.v1"
+@architectures = "spacy-transformers.TransformerModel.v3"
name = "bert-base-cased"
tokenizer_config = {"use_fast": true}
@@ -367,7 +367,7 @@ The `[components.transformer.model]` block describes the `model` argument passed
to the transformer component. It's a Thinc
[`Model`](https://thinc.ai/docs/api-model) object that will be passed into the
component. Here, it references the function
-[spacy-transformers.TransformerModel.v1](/api/architectures#TransformerModel)
+[spacy-transformers.TransformerModel.v3](/api/architectures#TransformerModel)
registered in the [`architectures` registry](/api/top-level#registry). If a key
in a block starts with `@`, it's **resolved to a function** and all other
settings are passed to the function as arguments. In this case, `name`,
@@ -379,6 +379,21 @@ of potentially overlapping `Span` objects to process by the transformer. Several
to process the whole document or individual sentences. When the config is
resolved, the function is created and passed into the model as an argument.
+The `name` value is the name of any [HuggingFace model](huggingface-models),
+which will be downloaded automatically the first time it's used. You can also
+use a local file path. For full details, see the
+[`TransformerModel` docs](/api/architectures#TransformerModel).
+
+[huggingface-models]:
+ https://huggingface.co/models?library=pytorch&sort=downloads
+
+A wide variety of PyTorch models are supported, but some might not work. If a
+model doesn't seem to work feel free to open an
+[issue](https://github.com/explosion/spacy/issues). Additionally note that
+Transformers loaded in spaCy can only be used for tensors, and pretrained
+task-specific heads or text generation features cannot be used as part of
+the `transformer` pipeline component.
+
Remember that the `config.cfg` used for training should contain **no missing
@@ -697,9 +712,11 @@ given you a 10% error reduction, pretraining with spaCy might give you another
The [`spacy pretrain`](/api/cli#pretrain) command will take a **specific
subnetwork** within one of your components, and add additional layers to build a
network for a temporary task that forces the model to learn something about
-sentence structure and word cooccurrence statistics. Pretraining produces a
-**binary weights file** that can be loaded back in at the start of training. The
-weights file specifies an initial set of weights. Training then proceeds as
+sentence structure and word cooccurrence statistics.
+
+Pretraining produces a **binary weights file** that can be loaded back in at the
+start of training, using the configuration option `initialize.init_tok2vec`.
+The weights file specifies an initial set of weights. Training then proceeds as
normal.
You can only pretrain one subnetwork from your pipeline at a time, and the
@@ -732,6 +749,40 @@ component = "textcat"
layer = "tok2vec"
```
+#### Connecting pretraining to training {#pretraining-training}
+
+To benefit from pretraining, your training step needs to know to initialize
+its `tok2vec` component with the weights learned from the pretraining step.
+You do this by setting `initialize.init_tok2vec` to the filename of the
+`.bin` file that you want to use from pretraining.
+
+A pretraining step that runs for 5 epochs with an output path of `pretrain/`,
+as an example, produces `pretrain/model0.bin` through `pretrain/model4.bin`.
+To make use of the final output, you could fill in this value in your config
+file:
+
+```ini
+### config.cfg
+
+[paths]
+init_tok2vec = "pretrain/model4.bin"
+
+[initialize]
+init_tok2vec = ${paths.init_tok2vec}
+```
+
+
+
+The outputs of `spacy pretrain` are not the same data format as the
+pre-packaged static word vectors that would go into
+[`initialize.vectors`](/api/data-formats#config-initialize).
+The pretraining output consists of the weights that the `tok2vec`
+component should start with in an existing pipeline, so it goes in
+`initialize.init_tok2vec`.
+
+
+
+
#### Pretraining objectives {#pretraining-objectives}
> ```ini
diff --git a/website/docs/usage/layers-architectures.md b/website/docs/usage/layers-architectures.md
index 17043d599..2e23b3684 100644
--- a/website/docs/usage/layers-architectures.md
+++ b/website/docs/usage/layers-architectures.md
@@ -833,7 +833,7 @@ retrieve and add to them.
self.cfg = {"labels": []}
@property
- def labels(self) -> Tuple[str]:
+ def labels(self) -> Tuple[str, ...]:
"""Returns the labels currently added to the component."""
return tuple(self.cfg["labels"])
diff --git a/website/meta/type-annotations.json b/website/meta/type-annotations.json
index 8136b3e96..0ffcbfb33 100644
--- a/website/meta/type-annotations.json
+++ b/website/meta/type-annotations.json
@@ -43,6 +43,7 @@
"cymem.Pool": "https://github.com/explosion/cymem",
"preshed.BloomFilter": "https://github.com/explosion/preshed",
"transformers.BatchEncoding": "https://huggingface.co/transformers/main_classes/tokenizer.html#transformers.BatchEncoding",
+ "transformers.file_utils.ModelOutput": "https://huggingface.co/transformers/main_classes/output.html#modeloutput",
"torch.Tensor": "https://pytorch.org/docs/stable/tensors.html",
"numpy.ndarray": "https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html",
"Match": "https://docs.python.org/3/library/re.html#match-objects",
diff --git a/website/meta/universe.json b/website/meta/universe.json
index 7438a8932..df8077419 100644
--- a/website/meta/universe.json
+++ b/website/meta/universe.json
@@ -516,7 +516,7 @@
"title": "NeuroNER",
"slogan": "Named-entity recognition using neural networks",
"github": "Franck-Dernoncourt/NeuroNER",
- "category": ["ner"],
+ "category": ["models"],
"pip": "pyneuroner[cpu]",
"code_example": [
"from neuroner import neuromodel",
@@ -3550,6 +3550,11 @@
"title": "Scientific",
"description": "Frameworks and utilities for scientific text processing"
},
+ {
+ "id": "biomedical",
+ "title": "Biomedical",
+ "description": "Frameworks and utilities for processing biomedical text"
+ },
{
"id": "visualizers",
"title": "Visualizers",