diff --git a/.gitignore b/.gitignore
index ac333f958..af75a4d47 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,16 +10,6 @@ spacy/tests/package/setup.cfg
spacy/tests/package/pyproject.toml
spacy/tests/package/requirements.txt
-# Website
-website/.cache/
-website/public/
-website/node_modules
-website/.npm
-website/logs
-*.log
-npm-debug.log*
-quickstart-training-generator.js
-
# Cython / C extensions
cythonize.json
spacy/*.html
diff --git a/README.md b/README.md
index 842a5d839..bf8083e0e 100644
--- a/README.md
+++ b/README.md
@@ -16,7 +16,7 @@ production-ready [**training system**](https://spacy.io/usage/training) and easy
model packaging, deployment and workflow management. spaCy is commercial
open-source software, released under the [MIT license](https://github.com/explosion/spaCy/blob/master/LICENSE).
-π« **Version 3.4 out now!**
+π« **Version 3.5 out now!**
[Check out the release notes here.](https://github.com/explosion/spaCy/releases)
[](https://dev.azure.com/explosion-ai/public/_build?definitionId=8)
diff --git a/azure-pipelines.yml b/azure-pipelines.yml
index 99f1b8aff..a6a575315 100644
--- a/azure-pipelines.yml
+++ b/azure-pipelines.yml
@@ -11,18 +11,28 @@ trigger:
exclude:
- "website/*"
- "*.md"
+ - "*.mdx"
- ".github/workflows/*"
pr:
paths:
exclude:
- "*.md"
+ - "*.mdx"
- "website/docs/*"
- "website/src/*"
+ - "website/meta/*.tsx"
+ - "website/meta/*.mjs"
+ - "website/meta/languages.json"
+ - "website/meta/site.json"
+ - "website/meta/sidebars.json"
+ - "website/meta/type-annotations.json"
+ - "website/pages/*"
- ".github/workflows/*"
jobs:
- # Perform basic checks for most important errors (syntax etc.) Uses the config
- # defined in .flake8 and overwrites the selected codes.
+ # Check formatting and linting. Perform basic checks for most important errors
+ # (syntax etc.) Uses the config defined in setup.cfg and overwrites the
+ # selected codes.
- job: "Validate"
pool:
vmImage: "ubuntu-latest"
@@ -30,6 +40,10 @@ jobs:
- task: UsePythonVersion@0
inputs:
versionSpec: "3.8"
+ - script: |
+ pip install black==22.3.0
+ python -m black spacy --check
+ displayName: "black"
- script: |
pip install flake8==5.0.4
python -m flake8 spacy --count --select=E901,E999,F821,F822,F823,W605 --show-source --statistics
diff --git a/spacy/cli/__init__.py b/spacy/cli/__init__.py
index aabd1cfef..868526b42 100644
--- a/spacy/cli/__init__.py
+++ b/spacy/cli/__init__.py
@@ -4,6 +4,7 @@ from ._util import app, setup_cli # noqa: F401
# These are the actual functions, NOT the wrapped CLI commands. The CLI commands
# are registered automatically and won't have to be imported here.
+from .benchmark_speed import benchmark_speed_cli # noqa: F401
from .download import download # noqa: F401
from .info import info # noqa: F401
from .package import package # noqa: F401
diff --git a/spacy/cli/_util.py b/spacy/cli/_util.py
index 6dd3eadfc..eb4869666 100644
--- a/spacy/cli/_util.py
+++ b/spacy/cli/_util.py
@@ -45,6 +45,7 @@ DEBUG_HELP = """Suite of helpful commands for debugging and profiling. Includes
commands to check and validate your config files, training and evaluation data,
and custom model implementations.
"""
+BENCHMARK_HELP = """Commands for benchmarking pipelines."""
INIT_HELP = """Commands for initializing configs and pipeline packages."""
# Wrappers for Typer's annotations. Initially created to set defaults and to
@@ -53,12 +54,14 @@ Arg = typer.Argument
Opt = typer.Option
app = typer.Typer(name=NAME, help=HELP)
+benchmark_cli = typer.Typer(name="benchmark", help=BENCHMARK_HELP, no_args_is_help=True)
project_cli = typer.Typer(name="project", help=PROJECT_HELP, no_args_is_help=True)
debug_cli = typer.Typer(name="debug", help=DEBUG_HELP, no_args_is_help=True)
init_cli = typer.Typer(name="init", help=INIT_HELP, no_args_is_help=True)
app.add_typer(project_cli)
app.add_typer(debug_cli)
+app.add_typer(benchmark_cli)
app.add_typer(init_cli)
diff --git a/spacy/cli/benchmark_speed.py b/spacy/cli/benchmark_speed.py
new file mode 100644
index 000000000..4eb20a5fa
--- /dev/null
+++ b/spacy/cli/benchmark_speed.py
@@ -0,0 +1,174 @@
+from typing import Iterable, List, Optional
+import random
+from itertools import islice
+import numpy
+from pathlib import Path
+import time
+from tqdm import tqdm
+import typer
+from wasabi import msg
+
+from .. import util
+from ..language import Language
+from ..tokens import Doc
+from ..training import Corpus
+from ._util import Arg, Opt, benchmark_cli, setup_gpu
+
+
+@benchmark_cli.command(
+ "speed",
+ context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
+)
+def benchmark_speed_cli(
+ # fmt: off
+ ctx: typer.Context,
+ model: str = Arg(..., help="Model name or path"),
+ data_path: Path = Arg(..., help="Location of binary evaluation data in .spacy format", exists=True),
+ batch_size: Optional[int] = Opt(None, "--batch-size", "-b", min=1, help="Override the pipeline batch size"),
+ no_shuffle: bool = Opt(False, "--no-shuffle", help="Do not shuffle benchmark data"),
+ use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU"),
+ n_batches: int = Opt(50, "--batches", help="Minimum number of batches to benchmark", min=30,),
+ warmup_epochs: int = Opt(3, "--warmup", "-w", min=0, help="Number of iterations over the data for warmup"),
+ # fmt: on
+):
+ """
+ Benchmark a pipeline. Expects a loadable spaCy pipeline and benchmark
+ data in the binary .spacy format.
+ """
+ setup_gpu(use_gpu=use_gpu, silent=False)
+
+ nlp = util.load_model(model)
+ batch_size = batch_size if batch_size is not None else nlp.batch_size
+ corpus = Corpus(data_path)
+ docs = [eg.predicted for eg in corpus(nlp)]
+
+ if len(docs) == 0:
+ msg.fail("Cannot benchmark speed using an empty corpus.", exits=1)
+
+ print(f"Warming up for {warmup_epochs} epochs...")
+ warmup(nlp, docs, warmup_epochs, batch_size)
+
+ print()
+ print(f"Benchmarking {n_batches} batches...")
+ wps = benchmark(nlp, docs, n_batches, batch_size, not no_shuffle)
+
+ print()
+ print_outliers(wps)
+ print_mean_with_ci(wps)
+
+
+# Lowercased, behaves as a context manager function.
+class time_context:
+ """Register the running time of a context."""
+
+ def __enter__(self):
+ self.start = time.perf_counter()
+ return self
+
+ def __exit__(self, type, value, traceback):
+ self.elapsed = time.perf_counter() - self.start
+
+
+class Quartiles:
+ """Calculate the q1, q2, q3 quartiles and the inter-quartile range (iqr)
+ of a sample."""
+
+ q1: float
+ q2: float
+ q3: float
+ iqr: float
+
+ def __init__(self, sample: numpy.ndarray) -> None:
+ self.q1 = numpy.quantile(sample, 0.25)
+ self.q2 = numpy.quantile(sample, 0.5)
+ self.q3 = numpy.quantile(sample, 0.75)
+ self.iqr = self.q3 - self.q1
+
+
+def annotate(
+ nlp: Language, docs: List[Doc], batch_size: Optional[int]
+) -> numpy.ndarray:
+ docs = nlp.pipe(tqdm(docs, unit="doc"), batch_size=batch_size)
+ wps = []
+ while True:
+ with time_context() as elapsed:
+ batch_docs = list(
+ islice(docs, batch_size if batch_size else nlp.batch_size)
+ )
+ if len(batch_docs) == 0:
+ break
+ n_tokens = count_tokens(batch_docs)
+ wps.append(n_tokens / elapsed.elapsed)
+
+ return numpy.array(wps)
+
+
+def benchmark(
+ nlp: Language,
+ docs: List[Doc],
+ n_batches: int,
+ batch_size: int,
+ shuffle: bool,
+) -> numpy.ndarray:
+ if shuffle:
+ bench_docs = [
+ nlp.make_doc(random.choice(docs).text)
+ for _ in range(n_batches * batch_size)
+ ]
+ else:
+ bench_docs = [
+ nlp.make_doc(docs[i % len(docs)].text)
+ for i in range(n_batches * batch_size)
+ ]
+
+ return annotate(nlp, bench_docs, batch_size)
+
+
+def bootstrap(x, statistic=numpy.mean, iterations=10000) -> numpy.ndarray:
+ """Apply a statistic to repeated random samples of an array."""
+ return numpy.fromiter(
+ (
+ statistic(numpy.random.choice(x, len(x), replace=True))
+ for _ in range(iterations)
+ ),
+ numpy.float64,
+ )
+
+
+def count_tokens(docs: Iterable[Doc]) -> int:
+ return sum(len(doc) for doc in docs)
+
+
+def print_mean_with_ci(sample: numpy.ndarray):
+ mean = numpy.mean(sample)
+ bootstrap_means = bootstrap(sample)
+ bootstrap_means.sort()
+
+ # 95% confidence interval
+ low = bootstrap_means[int(len(bootstrap_means) * 0.025)]
+ high = bootstrap_means[int(len(bootstrap_means) * 0.975)]
+
+ print(f"Mean: {mean:.1f} words/s (95% CI: {low-mean:.1f} +{high-mean:.1f})")
+
+
+def print_outliers(sample: numpy.ndarray):
+ quartiles = Quartiles(sample)
+
+ n_outliers = numpy.sum(
+ (sample < (quartiles.q1 - 1.5 * quartiles.iqr))
+ | (sample > (quartiles.q3 + 1.5 * quartiles.iqr))
+ )
+ n_extreme_outliers = numpy.sum(
+ (sample < (quartiles.q1 - 3.0 * quartiles.iqr))
+ | (sample > (quartiles.q3 + 3.0 * quartiles.iqr))
+ )
+ print(
+ f"Outliers: {(100 * n_outliers) / len(sample):.1f}%, extreme outliers: {(100 * n_extreme_outliers) / len(sample)}%"
+ )
+
+
+def warmup(
+ nlp: Language, docs: List[Doc], warmup_epochs: int, batch_size: Optional[int]
+) -> numpy.ndarray:
+ docs = warmup_epochs * docs
+ return annotate(nlp, docs, batch_size)
diff --git a/spacy/cli/debug_data.py b/spacy/cli/debug_data.py
index 0df1049e5..1c242cec8 100644
--- a/spacy/cli/debug_data.py
+++ b/spacy/cli/debug_data.py
@@ -17,6 +17,7 @@ from ..pipeline import TrainablePipe
from ..pipeline._parser_internals import nonproj
from ..pipeline._parser_internals.nonproj import DELIMITER
from ..pipeline import Morphologizer, SpanCategorizer
+from ..pipeline._edit_tree_internals.edit_trees import EditTrees
from ..morphology import Morphology
from ..language import Language
from ..util import registry, resolve_dot_names
@@ -670,6 +671,59 @@ def debug_data(
f"Found {gold_train_data['n_cycles']} projectivized train sentence(s) with cycles"
)
+ if "trainable_lemmatizer" in factory_names:
+ msg.divider("Trainable Lemmatizer")
+ trees_train: Set[str] = gold_train_data["lemmatizer_trees"]
+ trees_dev: Set[str] = gold_dev_data["lemmatizer_trees"]
+ # This is necessary context when someone is attempting to interpret whether the
+ # number of trees exclusively in the dev set is meaningful.
+ msg.info(f"{len(trees_train)} lemmatizer trees generated from training data")
+ msg.info(f"{len(trees_dev)} lemmatizer trees generated from dev data")
+ dev_not_train = trees_dev - trees_train
+
+ if len(dev_not_train) != 0:
+ pct = len(dev_not_train) / len(trees_dev)
+ msg.info(
+ f"{len(dev_not_train)} lemmatizer trees ({pct*100:.1f}% of dev trees)"
+ " were found exclusively in the dev data."
+ )
+ else:
+ # Would we ever expect this case? It seems like it would be pretty rare,
+ # and we might actually want a warning?
+ msg.info("All trees in dev data present in training data.")
+
+ if gold_train_data["n_low_cardinality_lemmas"] > 0:
+ n = gold_train_data["n_low_cardinality_lemmas"]
+ msg.warn(f"{n} training docs with 0 or 1 unique lemmas.")
+
+ if gold_dev_data["n_low_cardinality_lemmas"] > 0:
+ n = gold_dev_data["n_low_cardinality_lemmas"]
+ msg.warn(f"{n} dev docs with 0 or 1 unique lemmas.")
+
+ if gold_train_data["no_lemma_annotations"] > 0:
+ n = gold_train_data["no_lemma_annotations"]
+ msg.warn(f"{n} training docs with no lemma annotations.")
+ else:
+ msg.good("All training docs have lemma annotations.")
+
+ if gold_dev_data["no_lemma_annotations"] > 0:
+ n = gold_dev_data["no_lemma_annotations"]
+ msg.warn(f"{n} dev docs with no lemma annotations.")
+ else:
+ msg.good("All dev docs have lemma annotations.")
+
+ if gold_train_data["partial_lemma_annotations"] > 0:
+ n = gold_train_data["partial_lemma_annotations"]
+ msg.info(f"{n} training docs with partial lemma annotations.")
+ else:
+ msg.good("All training docs have complete lemma annotations.")
+
+ if gold_dev_data["partial_lemma_annotations"] > 0:
+ n = gold_dev_data["partial_lemma_annotations"]
+ msg.info(f"{n} dev docs with partial lemma annotations.")
+ else:
+ msg.good("All dev docs have complete lemma annotations.")
+
msg.divider("Summary")
good_counts = msg.counts[MESSAGES.GOOD]
warn_counts = msg.counts[MESSAGES.WARN]
@@ -731,7 +785,13 @@ def _compile_gold(
"n_cats_multilabel": 0,
"n_cats_bad_values": 0,
"texts": set(),
+ "lemmatizer_trees": set(),
+ "no_lemma_annotations": 0,
+ "partial_lemma_annotations": 0,
+ "n_low_cardinality_lemmas": 0,
}
+ if "trainable_lemmatizer" in factory_names:
+ trees = EditTrees(nlp.vocab.strings)
for eg in examples:
gold = eg.reference
doc = eg.predicted
@@ -861,6 +921,25 @@ def _compile_gold(
data["n_nonproj"] += 1
if nonproj.contains_cycle(aligned_heads):
data["n_cycles"] += 1
+ if "trainable_lemmatizer" in factory_names:
+ # from EditTreeLemmatizer._labels_from_data
+ if all(token.lemma == 0 for token in gold):
+ data["no_lemma_annotations"] += 1
+ continue
+ if any(token.lemma == 0 for token in gold):
+ data["partial_lemma_annotations"] += 1
+ lemma_set = set()
+ for token in gold:
+ if token.lemma != 0:
+ lemma_set.add(token.lemma)
+ tree_id = trees.add(token.text, token.lemma_)
+ tree_str = trees.tree_to_str(tree_id)
+ data["lemmatizer_trees"].add(tree_str)
+ # We want to identify cases where lemmas aren't assigned
+ # or are all assigned the same value, as this would indicate
+ # an issue since we're expecting a large set of lemmas
+ if len(lemma_set) < 2 and len(gold) > 1:
+ data["n_low_cardinality_lemmas"] += 1
return data
diff --git a/spacy/cli/evaluate.py b/spacy/cli/evaluate.py
index 0d08d2c5e..8f3d6b859 100644
--- a/spacy/cli/evaluate.py
+++ b/spacy/cli/evaluate.py
@@ -7,12 +7,15 @@ from thinc.api import fix_random_seed
from ..training import Corpus
from ..tokens import Doc
-from ._util import app, Arg, Opt, setup_gpu, import_code
+from ._util import app, Arg, Opt, setup_gpu, import_code, benchmark_cli
from ..scorer import Scorer
from .. import util
from .. import displacy
+@benchmark_cli.command(
+ "accuracy",
+)
@app.command("evaluate")
def evaluate_cli(
# fmt: off
@@ -36,7 +39,7 @@ def evaluate_cli(
dependency parses in a HTML file, set as output directory as the
displacy_path argument.
- DOCS: https://spacy.io/api/cli#evaluate
+ DOCS: https://spacy.io/api/cli#benchmark-accuracy
"""
import_code(code_path)
evaluate(
diff --git a/spacy/displacy/__init__.py b/spacy/displacy/__init__.py
index a3cfd96dd..ea6bba2c9 100644
--- a/spacy/displacy/__init__.py
+++ b/spacy/displacy/__init__.py
@@ -106,9 +106,7 @@ def serve(
if is_in_jupyter():
warnings.warn(Warnings.W011)
- render(
- docs, style=style, page=page, minify=minify, options=options, manual=manual
- )
+ render(docs, style=style, page=page, minify=minify, options=options, manual=manual)
httpd = simple_server.make_server(host, port, app)
print(f"\nUsing the '{style}' visualizer")
print(f"Serving on http://{host}:{port} ...\n")
diff --git a/spacy/errors.py b/spacy/errors.py
index 18811b725..5f480c16c 100644
--- a/spacy/errors.py
+++ b/spacy/errors.py
@@ -949,8 +949,8 @@ class Errors(metaclass=ErrorsWithCodes):
E1047 = ("`find_threshold()` only supports components with a `scorer` attribute.")
E1048 = ("Got '{unexpected}' as console progress bar type, but expected one of the following: {expected}")
E1049 = ("No available port found for displaCy on host {host}. Please specify an available port "
- "with `displacy.serve(doc, port)`")
- E1050 = ("Port {port} is already in use. Please specify an available port with `displacy.serve(doc, port)` "
+ "with `displacy.serve(doc, port=port)`")
+ E1050 = ("Port {port} is already in use. Please specify an available port with `displacy.serve(doc, port=port)` "
"or use `auto_switch_port=True` to pick an available port automatically.")
# v4 error strings
diff --git a/spacy/kb/kb_in_memory.pyx b/spacy/kb/kb_in_memory.pyx
index 485e52c2f..edba523cf 100644
--- a/spacy/kb/kb_in_memory.pyx
+++ b/spacy/kb/kb_in_memory.pyx
@@ -25,7 +25,7 @@ cdef class InMemoryLookupKB(KnowledgeBase):
"""An `InMemoryLookupKB` instance stores unique identifiers for entities and their textual aliases,
to support entity linking of named entities to real-world concepts.
- DOCS: https://spacy.io/api/kb_in_memory
+ DOCS: https://spacy.io/api/inmemorylookupkb
"""
def __init__(self, Vocab vocab, entity_vector_length):
diff --git a/spacy/matcher/levenshtein.pyx b/spacy/matcher/levenshtein.pyx
index 0e8cd26da..e823ce99d 100644
--- a/spacy/matcher/levenshtein.pyx
+++ b/spacy/matcher/levenshtein.pyx
@@ -22,7 +22,7 @@ cpdef bint levenshtein_compare(input_text: str, pattern_text: str, fuzzy: int =
max_edits = fuzzy
else:
# allow at least two edits (to allow at least one transposition) and up
- # to 20% of the pattern string length
+ # to 30% of the pattern string length
max_edits = max(2, round(0.3 * len(pattern_text)))
return levenshtein(input_text, pattern_text, max_edits) <= max_edits
diff --git a/spacy/matcher/matcher.pyi b/spacy/matcher/matcher.pyi
index ad214c120..9797463aa 100644
--- a/spacy/matcher/matcher.pyi
+++ b/spacy/matcher/matcher.pyi
@@ -4,8 +4,12 @@ from ..vocab import Vocab
from ..tokens import Doc, Span
class Matcher:
- def __init__(self, vocab: Vocab, validate: bool = ...,
- fuzzy_compare: Callable[[str, str, int], bool] = ...) -> None: ...
+ def __init__(
+ self,
+ vocab: Vocab,
+ validate: bool = ...,
+ fuzzy_compare: Callable[[str, str, int], bool] = ...,
+ ) -> None: ...
def __reduce__(self) -> Any: ...
def __len__(self) -> int: ...
def __contains__(self, key: str) -> bool: ...
diff --git a/spacy/pipeline/edit_tree_lemmatizer.py b/spacy/pipeline/edit_tree_lemmatizer.py
index 20f83fffc..3198b7509 100644
--- a/spacy/pipeline/edit_tree_lemmatizer.py
+++ b/spacy/pipeline/edit_tree_lemmatizer.py
@@ -5,7 +5,7 @@ from itertools import islice
import numpy as np
import srsly
-from thinc.api import Config, Model
+from thinc.api import Config, Model, SequenceCategoricalCrossentropy, NumpyOps
from thinc.types import ArrayXd, Floats2d, Ints1d
from thinc.legacy import LegacySequenceCategoricalCrossentropy
@@ -22,6 +22,8 @@ from .. import util
ActivationsT = Dict[str, Union[List[Floats2d], List[Ints1d]]]
+# The cutoff value of *top_k* above which an alternative method is used to process guesses.
+TOP_K_GUARDRAIL = 20
default_model_config = """
@@ -125,6 +127,7 @@ class EditTreeLemmatizer(TrainablePipe):
self.cfg: Dict[str, Any] = {"labels": []}
self.scorer = scorer
self.save_activations = save_activations
+ self.numpy_ops = NumpyOps()
def get_loss(
self, examples: Iterable[Example], scores: List[Floats2d]
@@ -140,7 +143,7 @@ class EditTreeLemmatizer(TrainablePipe):
for (predicted, gold_lemma) in zip(
eg.predicted, eg.get_aligned("LEMMA", as_string=True)
):
- if gold_lemma is None:
+ if gold_lemma is None or gold_lemma == "":
label = -1
else:
tree_id = self.trees.add(predicted.text, gold_lemma)
@@ -165,7 +168,7 @@ class EditTreeLemmatizer(TrainablePipe):
student_scores: Scores representing the student model's predictions.
RETURNS (Tuple[float, float]): The loss and the gradient.
-
+
DOCS: https://spacy.io/api/edittreelemmatizer#get_teacher_student_loss
"""
loss_func = LegacySequenceCategoricalCrossentropy(normalize=False)
@@ -175,6 +178,18 @@ class EditTreeLemmatizer(TrainablePipe):
return float(loss), d_scores
def predict(self, docs: Iterable[Doc]) -> ActivationsT:
+ if self.top_k == 1:
+ scores2guesses = self._scores2guesses_top_k_equals_1
+ elif self.top_k <= TOP_K_GUARDRAIL:
+ scores2guesses = self._scores2guesses_top_k_greater_1
+ else:
+ scores2guesses = self._scores2guesses_top_k_guardrail
+ # The behaviour of *_scores2guesses_top_k_greater_1()* is efficient for values
+ # of *top_k>1* that are likely to be useful when the edit tree lemmatizer is used
+ # for its principal purpose of lemmatizing tokens. However, the code could also
+ # be used for other purposes, and with very large values of *top_k* the method
+ # becomes inefficient. In such cases, *_scores2guesses_top_k_guardrail()* is used
+ # instead.
n_docs = len(list(docs))
if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs.
@@ -189,20 +204,52 @@ class EditTreeLemmatizer(TrainablePipe):
return {"probabilities": scores, "tree_ids": guesses}
scores = self.model.predict(docs)
assert len(scores) == n_docs
- guesses = self._scores2guesses(docs, scores)
+ guesses = scores2guesses(docs, scores)
assert len(guesses) == n_docs
return {"probabilities": scores, "tree_ids": guesses}
- def _scores2guesses(self, docs, scores):
+ def _scores2guesses_top_k_equals_1(self, docs, scores):
guesses = []
for doc, doc_scores in zip(docs, scores):
- if self.top_k == 1:
- doc_guesses = doc_scores.argmax(axis=1).reshape(-1, 1)
- else:
- doc_guesses = np.argsort(doc_scores)[..., : -self.top_k - 1 : -1]
+ doc_guesses = doc_scores.argmax(axis=1)
+ doc_guesses = self.numpy_ops.asarray(doc_guesses)
- if not isinstance(doc_guesses, np.ndarray):
- doc_guesses = doc_guesses.get()
+ doc_compat_guesses = []
+ for i, token in enumerate(doc):
+ tree_id = self.cfg["labels"][doc_guesses[i]]
+ if self.trees.apply(tree_id, token.text) is not None:
+ doc_compat_guesses.append(tree_id)
+ else:
+ doc_compat_guesses.append(-1)
+ guesses.append(np.array(doc_compat_guesses))
+
+ return guesses
+
+ def _scores2guesses_top_k_greater_1(self, docs, scores):
+ guesses = []
+ top_k = min(self.top_k, len(self.labels))
+ for doc, doc_scores in zip(docs, scores):
+ doc_scores = self.numpy_ops.asarray(doc_scores)
+ doc_compat_guesses = []
+ for i, token in enumerate(doc):
+ for _ in range(top_k):
+ candidate = int(doc_scores[i].argmax())
+ candidate_tree_id = self.cfg["labels"][candidate]
+ if self.trees.apply(candidate_tree_id, token.text) is not None:
+ doc_compat_guesses.append(candidate_tree_id)
+ break
+ doc_scores[i, candidate] = np.finfo(np.float32).min
+ else:
+ doc_compat_guesses.append(-1)
+ guesses.append(np.array(doc_compat_guesses))
+
+ return guesses
+
+ def _scores2guesses_top_k_guardrail(self, docs, scores):
+ guesses = []
+ for doc, doc_scores in zip(docs, scores):
+ doc_guesses = np.argsort(doc_scores)[..., : -self.top_k - 1 : -1]
+ doc_guesses = self.numpy_ops.asarray(doc_guesses)
doc_compat_guesses = []
for token, candidates in zip(doc, doc_guesses):
diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py
index 19c355238..fa4dea75a 100644
--- a/spacy/pipeline/entity_linker.py
+++ b/spacy/pipeline/entity_linker.py
@@ -453,7 +453,11 @@ class EntityLinker(TrainablePipe):
docs_ents: List[Ragged] = []
docs_scores: List[Ragged] = []
if not docs:
- return {KNOWLEDGE_BASE_IDS: final_kb_ids, "ents": docs_ents, "scores": docs_scores}
+ return {
+ KNOWLEDGE_BASE_IDS: final_kb_ids,
+ "ents": docs_ents,
+ "scores": docs_scores,
+ }
if isinstance(docs, Doc):
docs = [docs]
for doc in docs:
@@ -585,7 +589,11 @@ class EntityLinker(TrainablePipe):
method="predict", msg="result variables not of equal length"
)
raise RuntimeError(err)
- return {KNOWLEDGE_BASE_IDS: final_kb_ids, "ents": docs_ents, "scores": docs_scores}
+ return {
+ KNOWLEDGE_BASE_IDS: final_kb_ids,
+ "ents": docs_ents,
+ "scores": docs_scores,
+ }
def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT) -> None:
"""Modify a batch of documents, using pre-computed scores.
diff --git a/spacy/pipeline/ner.py b/spacy/pipeline/ner.py
index 651a0b3e3..7e44b2835 100644
--- a/spacy/pipeline/ner.py
+++ b/spacy/pipeline/ner.py
@@ -252,8 +252,11 @@ class EntityRecognizer(Parser):
def labels(self):
# Get the labels from the model by looking at the available moves, e.g.
# B-PERSON, I-PERSON, L-PERSON, U-PERSON
- labels = set(remove_bilu_prefix(move) for move in self.move_names
- if move[0] in ("B", "I", "L", "U"))
+ labels = set(
+ remove_bilu_prefix(move)
+ for move in self.move_names
+ if move[0] in ("B", "I", "L", "U")
+ )
return tuple(sorted(labels))
def scored_ents(self, beams):
diff --git a/spacy/schemas.py b/spacy/schemas.py
index aea3cc4f7..c8467fea8 100644
--- a/spacy/schemas.py
+++ b/spacy/schemas.py
@@ -162,15 +162,33 @@ class TokenPatternString(BaseModel):
IS_SUPERSET: Optional[List[StrictStr]] = Field(None, alias="is_superset")
INTERSECTS: Optional[List[StrictStr]] = Field(None, alias="intersects")
FUZZY: Optional[Union[StrictStr, "TokenPatternString"]] = Field(None, alias="fuzzy")
- FUZZY1: Optional[Union[StrictStr, "TokenPatternString"]] = Field(None, alias="fuzzy1")
- FUZZY2: Optional[Union[StrictStr, "TokenPatternString"]] = Field(None, alias="fuzzy2")
- FUZZY3: Optional[Union[StrictStr, "TokenPatternString"]] = Field(None, alias="fuzzy3")
- FUZZY4: Optional[Union[StrictStr, "TokenPatternString"]] = Field(None, alias="fuzzy4")
- FUZZY5: Optional[Union[StrictStr, "TokenPatternString"]] = Field(None, alias="fuzzy5")
- FUZZY6: Optional[Union[StrictStr, "TokenPatternString"]] = Field(None, alias="fuzzy6")
- FUZZY7: Optional[Union[StrictStr, "TokenPatternString"]] = Field(None, alias="fuzzy7")
- FUZZY8: Optional[Union[StrictStr, "TokenPatternString"]] = Field(None, alias="fuzzy8")
- FUZZY9: Optional[Union[StrictStr, "TokenPatternString"]] = Field(None, alias="fuzzy9")
+ FUZZY1: Optional[Union[StrictStr, "TokenPatternString"]] = Field(
+ None, alias="fuzzy1"
+ )
+ FUZZY2: Optional[Union[StrictStr, "TokenPatternString"]] = Field(
+ None, alias="fuzzy2"
+ )
+ FUZZY3: Optional[Union[StrictStr, "TokenPatternString"]] = Field(
+ None, alias="fuzzy3"
+ )
+ FUZZY4: Optional[Union[StrictStr, "TokenPatternString"]] = Field(
+ None, alias="fuzzy4"
+ )
+ FUZZY5: Optional[Union[StrictStr, "TokenPatternString"]] = Field(
+ None, alias="fuzzy5"
+ )
+ FUZZY6: Optional[Union[StrictStr, "TokenPatternString"]] = Field(
+ None, alias="fuzzy6"
+ )
+ FUZZY7: Optional[Union[StrictStr, "TokenPatternString"]] = Field(
+ None, alias="fuzzy7"
+ )
+ FUZZY8: Optional[Union[StrictStr, "TokenPatternString"]] = Field(
+ None, alias="fuzzy8"
+ )
+ FUZZY9: Optional[Union[StrictStr, "TokenPatternString"]] = Field(
+ None, alias="fuzzy9"
+ )
class Config:
extra = "forbid"
diff --git a/spacy/tests/pipeline/test_edit_tree_lemmatizer.py b/spacy/tests/pipeline/test_edit_tree_lemmatizer.py
index b855c7a26..c5c50c77f 100644
--- a/spacy/tests/pipeline/test_edit_tree_lemmatizer.py
+++ b/spacy/tests/pipeline/test_edit_tree_lemmatizer.py
@@ -103,14 +103,15 @@ def test_initialize_from_labels():
}
-def test_no_data():
+@pytest.mark.parametrize("top_k", (1, 5, 30))
+def test_no_data(top_k):
# Test that the lemmatizer provides a nice error when there's no tagging data / labels
TEXTCAT_DATA = [
("I'm so happy.", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}),
("I'm so angry", {"cats": {"POSITIVE": 0.0, "NEGATIVE": 1.0}}),
]
nlp = English()
- nlp.add_pipe("trainable_lemmatizer")
+ nlp.add_pipe("trainable_lemmatizer", config={"top_k": top_k})
nlp.add_pipe("textcat")
train_examples = []
@@ -121,10 +122,11 @@ def test_no_data():
nlp.initialize(get_examples=lambda: train_examples)
-def test_incomplete_data():
+@pytest.mark.parametrize("top_k", (1, 5, 30))
+def test_incomplete_data(top_k):
# Test that the lemmatizer works with incomplete information
nlp = English()
- lemmatizer = nlp.add_pipe("trainable_lemmatizer")
+ lemmatizer = nlp.add_pipe("trainable_lemmatizer", config={"top_k": top_k})
lemmatizer.min_tree_freq = 1
train_examples = []
for t in PARTIAL_DATA:
@@ -141,10 +143,25 @@ def test_incomplete_data():
assert doc[1].lemma_ == "like"
assert doc[2].lemma_ == "blue"
+ # Check that incomplete annotations are ignored.
+ scores, _ = lemmatizer.model([eg.predicted for eg in train_examples], is_train=True)
+ _, dX = lemmatizer.get_loss(train_examples, scores)
+ xp = lemmatizer.model.ops.xp
-def test_overfitting_IO():
+ # Missing annotations.
+ assert xp.count_nonzero(dX[0][0]) == 0
+ assert xp.count_nonzero(dX[0][3]) == 0
+ assert xp.count_nonzero(dX[1][0]) == 0
+ assert xp.count_nonzero(dX[1][3]) == 0
+
+ # Misaligned annotations.
+ assert xp.count_nonzero(dX[1][1]) == 0
+
+
+@pytest.mark.parametrize("top_k", (1, 5, 30))
+def test_overfitting_IO(top_k):
nlp = English()
- lemmatizer = nlp.add_pipe("trainable_lemmatizer")
+ lemmatizer = nlp.add_pipe("trainable_lemmatizer", config={"top_k": top_k})
lemmatizer.min_tree_freq = 1
train_examples = []
for t in TRAIN_DATA:
@@ -177,7 +194,7 @@ def test_overfitting_IO():
# Check model after a {to,from}_bytes roundtrip
nlp_bytes = nlp.to_bytes()
nlp3 = English()
- nlp3.add_pipe("trainable_lemmatizer")
+ nlp3.add_pipe("trainable_lemmatizer", config={"top_k": top_k})
nlp3.from_bytes(nlp_bytes)
doc3 = nlp3(test_text)
assert doc3[0].lemma_ == "she"
diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py
index c88e20de2..42ffae22d 100644
--- a/spacy/tests/test_cli.py
+++ b/spacy/tests/test_cli.py
@@ -618,7 +618,6 @@ def test_string_to_list_intify(value):
assert string_to_list(value, intify=True) == [1, 2, 3]
-@pytest.mark.skip(reason="Temporarily skip for dev version")
def test_download_compatibility():
spec = SpecifierSet("==" + about.__version__)
spec.prereleases = False
@@ -629,7 +628,6 @@ def test_download_compatibility():
assert get_minor_version(about.__version__) == get_minor_version(version)
-@pytest.mark.skip(reason="Temporarily skip for dev version")
def test_validate_compatibility_table():
spec = SpecifierSet("==" + about.__version__)
spec.prereleases = False
@@ -1076,7 +1074,7 @@ def test_cli_find_threshold(capsys):
)
with make_tempdir() as nlp_dir:
nlp.to_disk(nlp_dir)
- res = find_threshold(
+ best_threshold, best_score, res = find_threshold(
model=nlp_dir,
data_path=docs_dir / "docs.spacy",
pipe_name="tc_multi",
@@ -1084,10 +1082,10 @@ def test_cli_find_threshold(capsys):
scores_key="cats_macro_f",
silent=True,
)
- assert res[0] != thresholds[0]
- assert thresholds[0] < res[0] < thresholds[9]
- assert res[1] == 1.0
- assert res[2][1.0] == 0.0
+ assert best_threshold != thresholds[0]
+ assert thresholds[0] < best_threshold < thresholds[9]
+ assert best_score == max(res.values())
+ assert res[1.0] == 0.0
# Test with spancat.
nlp, _ = init_nlp((("spancat", {}),))
@@ -1209,3 +1207,69 @@ def test_walk_directory():
assert (len(walk_directory(d, suffix="iob"))) == 2
assert (len(walk_directory(d, suffix="conll"))) == 3
assert (len(walk_directory(d, suffix="pdf"))) == 0
+
+
+def test_debug_data_trainable_lemmatizer_basic():
+ examples = [
+ ("She likes green eggs", {"lemmas": ["she", "like", "green", "egg"]}),
+ ("Eat blue ham", {"lemmas": ["eat", "blue", "ham"]}),
+ ]
+ nlp = Language()
+ train_examples = []
+ for t in examples:
+ train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
+
+ data = _compile_gold(train_examples, ["trainable_lemmatizer"], nlp, True)
+ # ref test_edit_tree_lemmatizer::test_initialize_from_labels
+ # this results in 4 trees
+ assert len(data["lemmatizer_trees"]) == 4
+
+
+def test_debug_data_trainable_lemmatizer_partial():
+ partial_examples = [
+ # partial annotation
+ ("She likes green eggs", {"lemmas": ["", "like", "green", ""]}),
+ # misaligned partial annotation
+ (
+ "He hates green eggs",
+ {
+ "words": ["He", "hat", "es", "green", "eggs"],
+ "lemmas": ["", "hat", "e", "green", ""],
+ },
+ ),
+ ]
+ nlp = Language()
+ train_examples = []
+ for t in partial_examples:
+ train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
+
+ data = _compile_gold(train_examples, ["trainable_lemmatizer"], nlp, True)
+ assert data["partial_lemma_annotations"] == 2
+
+
+def test_debug_data_trainable_lemmatizer_low_cardinality():
+ low_cardinality_examples = [
+ ("She likes green eggs", {"lemmas": ["no", "no", "no", "no"]}),
+ ("Eat blue ham", {"lemmas": ["no", "no", "no"]}),
+ ]
+ nlp = Language()
+ train_examples = []
+ for t in low_cardinality_examples:
+ train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
+
+ data = _compile_gold(train_examples, ["trainable_lemmatizer"], nlp, True)
+ assert data["n_low_cardinality_lemmas"] == 2
+
+
+def test_debug_data_trainable_lemmatizer_not_annotated():
+ unannotated_examples = [
+ ("She likes green eggs", {}),
+ ("Eat blue ham", {}),
+ ]
+ nlp = Language()
+ train_examples = []
+ for t in unannotated_examples:
+ train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
+
+ data = _compile_gold(train_examples, ["trainable_lemmatizer"], nlp, True)
+ assert data["no_lemma_annotations"] == 2
diff --git a/spacy/tests/test_cli_app.py b/spacy/tests/test_cli_app.py
index 873a3ff66..80da5a447 100644
--- a/spacy/tests/test_cli_app.py
+++ b/spacy/tests/test_cli_app.py
@@ -1,6 +1,7 @@
import os
from pathlib import Path
from typer.testing import CliRunner
+from spacy.tokens import DocBin, Doc
from spacy.cli._util import app
from .util import make_tempdir
@@ -31,3 +32,60 @@ def test_convert_auto_conflict():
assert "All input files must be same type" in result.stdout
out_files = os.listdir(d_out)
assert len(out_files) == 0
+
+
+def test_benchmark_accuracy_alias():
+ # Verify that the `evaluate` alias works correctly.
+ result_benchmark = CliRunner().invoke(app, ["benchmark", "accuracy", "--help"])
+ result_evaluate = CliRunner().invoke(app, ["evaluate", "--help"])
+ assert result_benchmark.stdout == result_evaluate.stdout.replace(
+ "spacy evaluate", "spacy benchmark accuracy"
+ )
+
+
+def test_debug_data_trainable_lemmatizer_cli(en_vocab):
+ train_docs = [
+ Doc(en_vocab, words=["I", "like", "cats"], lemmas=["I", "like", "cat"]),
+ Doc(
+ en_vocab,
+ words=["Dogs", "are", "great", "too"],
+ lemmas=["dog", "be", "great", "too"],
+ ),
+ ]
+ dev_docs = [
+ Doc(en_vocab, words=["Cats", "are", "cute"], lemmas=["cat", "be", "cute"]),
+ Doc(en_vocab, words=["Pets", "are", "great"], lemmas=["pet", "be", "great"]),
+ ]
+ with make_tempdir() as d_in:
+ train_bin = DocBin(docs=train_docs)
+ train_bin.to_disk(d_in / "train.spacy")
+ dev_bin = DocBin(docs=dev_docs)
+ dev_bin.to_disk(d_in / "dev.spacy")
+ # `debug data` requires an input pipeline config
+ CliRunner().invoke(
+ app,
+ [
+ "init",
+ "config",
+ f"{d_in}/config.cfg",
+ "--lang",
+ "en",
+ "--pipeline",
+ "trainable_lemmatizer",
+ ],
+ )
+ result_debug_data = CliRunner().invoke(
+ app,
+ [
+ "debug",
+ "data",
+ f"{d_in}/config.cfg",
+ "--paths.train",
+ f"{d_in}/train.spacy",
+ "--paths.dev",
+ f"{d_in}/dev.spacy",
+ ],
+ )
+ # Instead of checking specific wording of the output, which may change,
+ # we'll check that this section of the debug output is present.
+ assert "= Trainable Lemmatizer =" in result_debug_data.stdout
diff --git a/spacy/tests/training/test_corpus.py b/spacy/tests/training/test_corpus.py
new file mode 100644
index 000000000..b4f9cc13a
--- /dev/null
+++ b/spacy/tests/training/test_corpus.py
@@ -0,0 +1,78 @@
+from typing import IO, Generator, Iterable, List, TextIO, Tuple
+from contextlib import contextmanager
+from pathlib import Path
+import pytest
+import tempfile
+
+from spacy.lang.en import English
+from spacy.training import Example, PlainTextCorpus
+from spacy.util import make_tempdir
+
+# Intentional newlines to check that they are skipped.
+PLAIN_TEXT_DOC = """
+
+This is a doc. It contains two sentences.
+This is another doc.
+
+A third doc.
+
+"""
+
+PLAIN_TEXT_DOC_TOKENIZED = [
+ [
+ "This",
+ "is",
+ "a",
+ "doc",
+ ".",
+ "It",
+ "contains",
+ "two",
+ "sentences",
+ ".",
+ ],
+ ["This", "is", "another", "doc", "."],
+ ["A", "third", "doc", "."],
+]
+
+
+@pytest.mark.parametrize("min_length", [0, 5])
+@pytest.mark.parametrize("max_length", [0, 5])
+def test_plain_text_reader(min_length, max_length):
+ nlp = English()
+ with _string_to_tmp_file(PLAIN_TEXT_DOC) as file_path:
+ corpus = PlainTextCorpus(
+ file_path, min_length=min_length, max_length=max_length
+ )
+
+ check = [
+ doc
+ for doc in PLAIN_TEXT_DOC_TOKENIZED
+ if len(doc) >= min_length and (max_length == 0 or len(doc) <= max_length)
+ ]
+ reference, predicted = _examples_to_tokens(corpus(nlp))
+
+ assert reference == check
+ assert predicted == check
+
+
+@contextmanager
+def _string_to_tmp_file(s: str) -> Generator[Path, None, None]:
+ with make_tempdir() as d:
+ file_path = Path(d) / "string.txt"
+ with open(file_path, "w", encoding="utf-8") as f:
+ f.write(s)
+ yield file_path
+
+
+def _examples_to_tokens(
+ examples: Iterable[Example],
+) -> Tuple[List[List[str]], List[List[str]]]:
+ reference = []
+ predicted = []
+
+ for eg in examples:
+ reference.append([t.text for t in eg.reference])
+ predicted.append([t.text for t in eg.predicted])
+
+ return reference, predicted
diff --git a/spacy/training/__init__.py b/spacy/training/__init__.py
index 454437104..f8e69b1c8 100644
--- a/spacy/training/__init__.py
+++ b/spacy/training/__init__.py
@@ -1,4 +1,4 @@
-from .corpus import Corpus, JsonlCorpus # noqa: F401
+from .corpus import Corpus, JsonlCorpus, PlainTextCorpus # noqa: F401
from .example import Example, validate_examples, validate_get_examples # noqa: F401
from .example import validate_distillation_examples # noqa: F401
from .alignment import Alignment # noqa: F401
diff --git a/spacy/training/corpus.py b/spacy/training/corpus.py
index b9f929fcd..d626ad0e0 100644
--- a/spacy/training/corpus.py
+++ b/spacy/training/corpus.py
@@ -58,6 +58,28 @@ def read_labels(path: Path, *, require: bool = False):
return srsly.read_json(path)
+@util.registry.readers("spacy.PlainTextCorpus.v1")
+def create_plain_text_reader(
+ path: Optional[Path],
+ min_length: int = 0,
+ max_length: int = 0,
+) -> Callable[["Language"], Iterable[Doc]]:
+ """Iterate Example objects from a file or directory of plain text
+ UTF-8 files with one line per doc.
+
+ path (Path): The directory or filename to read from.
+ min_length (int): Minimum document length (in tokens). Shorter documents
+ will be skipped. Defaults to 0, which indicates no limit.
+ max_length (int): Maximum document length (in tokens). Longer documents will
+ be skipped. Defaults to 0, which indicates no limit.
+
+ DOCS: https://spacy.io/api/corpus#plaintextcorpus
+ """
+ if path is None:
+ raise ValueError(Errors.E913)
+ return PlainTextCorpus(path, min_length=min_length, max_length=max_length)
+
+
def walk_corpus(path: Union[str, Path], file_type) -> List[Path]:
path = util.ensure_path(path)
if not path.is_dir() and path.parts[-1].endswith(file_type):
@@ -257,3 +279,52 @@ class JsonlCorpus:
# We don't *need* an example here, but it seems nice to
# make it match the Corpus signature.
yield Example(doc, Doc(nlp.vocab, words=words, spaces=spaces))
+
+
+class PlainTextCorpus:
+ """Iterate Example objects from a file or directory of plain text
+ UTF-8 files with one line per doc.
+
+ path (Path): The directory or filename to read from.
+ min_length (int): Minimum document length (in tokens). Shorter documents
+ will be skipped. Defaults to 0, which indicates no limit.
+ max_length (int): Maximum document length (in tokens). Longer documents will
+ be skipped. Defaults to 0, which indicates no limit.
+
+ DOCS: https://spacy.io/api/corpus#plaintextcorpus
+ """
+
+ file_type = "txt"
+
+ def __init__(
+ self,
+ path: Optional[Union[str, Path]],
+ *,
+ min_length: int = 0,
+ max_length: int = 0,
+ ) -> None:
+ self.path = util.ensure_path(path)
+ self.min_length = min_length
+ self.max_length = max_length
+
+ def __call__(self, nlp: "Language") -> Iterator[Example]:
+ """Yield examples from the data.
+
+ nlp (Language): The current nlp object.
+ YIELDS (Example): The example objects.
+
+ DOCS: https://spacy.io/api/corpus#plaintextcorpus-call
+ """
+ for loc in walk_corpus(self.path, ".txt"):
+ with open(loc, encoding="utf-8") as f:
+ for text in f:
+ text = text.rstrip("\r\n")
+ if len(text):
+ doc = nlp.make_doc(text)
+ if self.min_length >= 1 and len(doc) < self.min_length:
+ continue
+ elif self.max_length >= 1 and len(doc) > self.max_length:
+ continue
+ # We don't *need* an example here, but it seems nice to
+ # make it match the Corpus signature.
+ yield Example(doc, doc.copy())
diff --git a/website/.dockerignore b/website/.dockerignore
new file mode 100644
index 000000000..e4a88552e
--- /dev/null
+++ b/website/.dockerignore
@@ -0,0 +1,9 @@
+.cache/
+.next/
+public/
+node_modules
+.npm
+logs
+*.log
+npm-debug.log*
+quickstart-training-generator.js
diff --git a/website/.gitignore b/website/.gitignore
index 70ef99fa5..599c0953a 100644
--- a/website/.gitignore
+++ b/website/.gitignore
@@ -1,5 +1,7 @@
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
+quickstart-training-generator.js
+
# dependencies
/node_modules
/.pnp
@@ -41,4 +43,4 @@ next-env.d.ts
public/robots.txt
public/sitemap*
public/sw.js*
-public/workbox*
\ No newline at end of file
+public/workbox*
diff --git a/website/Dockerfile b/website/Dockerfile
index f71733e55..9b2f6cac4 100644
--- a/website/Dockerfile
+++ b/website/Dockerfile
@@ -1,16 +1,14 @@
-FROM node:11.15.0
+FROM node:18
-WORKDIR /spacy-io
-
-RUN npm install -g gatsby-cli@2.7.4
-
-COPY package.json .
-COPY package-lock.json .
-
-RUN npm install
+USER node
# This is so the installed node_modules will be up one directory
# from where a user mounts files, so that they don't accidentally mount
# their own node_modules from a different build
# https://nodejs.org/api/modules.html#modules_loading_from_node_modules_folders
-WORKDIR /spacy-io/website/
+WORKDIR /home/node
+COPY --chown=node package.json .
+COPY --chown=node package-lock.json .
+RUN npm install
+
+WORKDIR /home/node/website/
diff --git a/website/README.md b/website/README.md
index e9d7aec26..a434efe9a 100644
--- a/website/README.md
+++ b/website/README.md
@@ -41,33 +41,27 @@ If you'd like to do this, **be sure you do _not_ include your local
`node_modules` folder**, since there are some dependencies that need to be built
for the image system. Rename it before using.
-```bash
-docker run -it \
- -v $(pwd):/spacy-io/website \
- -p 8000:8000 \
- ghcr.io/explosion/spacy-io \
- gatsby develop -H 0.0.0.0
-```
-
-This will allow you to access the built website at http://0.0.0.0:8000/ in your
-browser, and still edit code in your editor while having the site reflect those
-changes.
-
-**Note**: If you're working on a Mac with an M1 processor, you might see
-segfault errors from `qemu` if you use the default image. To fix this use the
-`arm64` tagged image in the `docker run` command
-(ghcr.io/explosion/spacy-io:arm64).
-
-### Building the Docker image
-
-If you'd like to build the image locally, you can do so like this:
+First build the Docker image. This only needs to be done on the first run
+or when changes are made to `Dockerfile` or the website dependencies:
```bash
docker build -t spacy-io .
```
-This will take some time, so if you want to use the prebuilt image you'll save a
-bit of time.
+You can then build and run the website with:
+
+```bash
+docker run -it \
+ --rm \
+ -v $(pwd):/home/node/website \
+ -p 3000:3000 \
+ spacy-io \
+ npm run dev -- -H 0.0.0.0
+```
+
+This will allow you to access the built website at http://0.0.0.0:3000/ in your
+browser, and still edit code in your editor while having the site reflect those
+changes.
## Project structure
diff --git a/website/docs/api/cli.mdx b/website/docs/api/cli.mdx
index 80b1362bc..d96f8b743 100644
--- a/website/docs/api/cli.mdx
+++ b/website/docs/api/cli.mdx
@@ -12,6 +12,7 @@ menu:
- ['train', 'train']
- ['pretrain', 'pretrain']
- ['evaluate', 'evaluate']
+ - ['benchmark', 'benchmark']
- ['apply', 'apply']
- ['find-threshold', 'find-threshold']
- ['assemble', 'assemble']
@@ -269,10 +270,10 @@ $ python -m spacy convert [input_file] [output_dir] [--converter] [--file-type]
| `--file-type`, `-t` | Type of file to create. Either `spacy` (default) for binary [`DocBin`](/api/docbin) data or `json` for v2.x JSON format. ~~str (option)~~ |
| `--n-sents`, `-n` | Number of sentences per document. Supported for: `conll`, `conllu`, `iob`, `ner` ~~int (option)~~ |
| `--seg-sents`, `-s` | Segment sentences. Supported for: `conll`, `ner` ~~bool (flag)~~ |
-| `--base`, `-b`, `--model` | Trained spaCy pipeline for sentence segmentation to use as base (for `--seg-sents`). ~~Optional[str](option)~~ |
+| `--base`, `-b`, `--model` | Trained spaCy pipeline for sentence segmentation to use as base (for `--seg-sents`). ~~Optional[str] (option)~~ |
| `--morphology`, `-m` | Enable appending morphology to tags. Supported for: `conllu` ~~bool (flag)~~ |
| `--merge-subtokens`, `-T` | Merge CoNLL-U subtokens ~~bool (flag)~~ |
-| `--ner-map`, `-nm` | NER tag mapping (as JSON-encoded dict of entity types). Supported for: `conllu` ~~Optional[Path](option)~~ |
+| `--ner-map`, `-nm` | NER tag mapping (as JSON-encoded dict of entity types). Supported for: `conllu` ~~Optional[Path] (option)~~ |
| `--lang`, `-l` | Language code (if tokenizer required). ~~Optional[str] \(option)~~ |
| `--concatenate`, `-C` | Concatenate output to a single file ~~bool (flag)~~ |
| `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ |
@@ -1135,8 +1136,19 @@ $ python -m spacy pretrain [config_path] [output_dir] [--code] [--resume-path] [
## evaluate {id="evaluate",version="2",tag="command"}
-Evaluate a trained pipeline. Expects a loadable spaCy pipeline (package name or
-path) and evaluation data in the
+The `evaluate` subcommand is superseded by
+[`spacy benchmark accuracy`](#benchmark-accuracy). `evaluate` is provided as an
+alias to `benchmark accuracy` for compatibility.
+
+## benchmark {id="benchmark", version="3.5"}
+
+The `spacy benchmark` CLI includes commands for benchmarking the accuracy and
+speed of your spaCy pipelines.
+
+### accuracy {id="benchmark-accuracy", version="3.5", tag="command"}
+
+Evaluate the accuracy of a trained pipeline. Expects a loadable spaCy pipeline
+(package name or path) and evaluation data in the
[binary `.spacy` format](/api/data-formats#binary-training). The
`--gold-preproc` option sets up the evaluation examples with gold-standard
sentences and tokens for the predictions. Gold preprocessing helps the
@@ -1147,7 +1159,7 @@ skew. To render a sample of dependency parses in a HTML file using the
`--displacy-path` argument.
```bash
-$ python -m spacy evaluate [model] [data_path] [--output] [--code] [--gold-preproc] [--gpu-id] [--displacy-path] [--displacy-limit]
+$ python -m spacy benchmark accuracy [model] [data_path] [--output] [--code] [--gold-preproc] [--gpu-id] [--displacy-path] [--displacy-limit]
```
| Name | Description |
@@ -1163,6 +1175,29 @@ $ python -m spacy evaluate [model] [data_path] [--output] [--code] [--gold-prepr
| `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ |
| **CREATES** | Training results and optional metrics and visualizations. |
+### speed {id="benchmark-speed", version="3.5", tag="command"}
+
+Benchmark the speed of a trained pipeline with a 95% confidence interval.
+Expects a loadable spaCy pipeline (package name or path) and benchmark data in
+the [binary `.spacy` format](/api/data-formats#binary-training). The pipeline is
+warmed up before any measurements are taken.
+
+```cli
+$ python -m spacy benchmark speed [model] [data_path] [--batch_size] [--no-shuffle] [--gpu-id] [--batches] [--warmup]
+```
+
+| Name | Description |
+| -------------------- | -------------------------------------------------------------------------------------------------------- |
+| `model` | Pipeline to benchmark the speed of. Can be a package or a path to a data directory. ~~str (positional)~~ |
+| `data_path` | Location of benchmark data in spaCy's [binary format](/api/data-formats#training). ~~Path (positional)~~ |
+| `--batch-size`, `-b` | Set the batch size. If not set, the pipeline's batch size is used. ~~Optional[int] \(option)~~ |
+| `--no-shuffle` | Do not shuffle documents in the benchmark data. ~~bool (flag)~~ |
+| `--gpu-id`, `-g` | GPU to use, if any. Defaults to `-1` for CPU. ~~int (option)~~ |
+| `--batches` | Number of batches to benchmark on. Defaults to `50`. ~~Optional[int] \(option)~~ |
+| `--warmup`, `-w` | Iterations over the benchmark data for warmup. Defaults to `3` ~~Optional[int] \(option)~~ |
+| `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ |
+| **PRINTS** | Pipeline speed in words per second with a 95% confidence interval. |
+
## apply {id="apply", version="3.5", tag="command"}
Applies a trained pipeline to data and stores the resulting annotated documents
@@ -1176,24 +1211,23 @@ input formats are:
When a directory is provided it is traversed recursively to collect all files.
-```cli
+```bash
$ python -m spacy apply [model] [data-path] [output-file] [--code] [--text-key] [--force-overwrite] [--gpu-id] [--batch-size] [--n-process]
```
-| Name | Description |
-| ----------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
-| `model` | Pipeline to apply to the data. Can be a package or a path to a data directory. ~~str (positional)~~ |
-| `data_path` | Location of data to be evaluated in spaCy's [binary format](/api/data-formats#training), jsonl, or plain text. ~~Path (positional)~~ |
-| `output-file`, `-o` | Output `DocBin` path. ~~str (positional)~~ |
-| `--code`, `-c`
+
```markdown
[](https://spacy.io)
@@ -575,8 +578,9 @@ project is using spaCy, you can grab one of our **spaCy badges** here:
```markdown
-[](https://spacy.io)
+[](https://spacy.io)
```
diff --git a/website/docs/usage/v3-5.mdx b/website/docs/usage/v3-5.mdx
new file mode 100644
index 000000000..ac61338e3
--- /dev/null
+++ b/website/docs/usage/v3-5.mdx
@@ -0,0 +1,215 @@
+---
+title: What's New in v3.5
+teaser: New features and how to upgrade
+menu:
+ - ['New Features', 'features']
+ - ['Upgrading Notes', 'upgrading']
+---
+
+## New features {id="features",hidden="true"}
+
+spaCy v3.5 introduces three new CLI commands, `apply`, `benchmark` and
+`find-threshold`, adds fuzzy matching, provides improvements to our entity
+linking functionality, and includes a range of language updates and bug fixes.
+
+### New CLI commands {id="cli"}
+
+#### apply CLI
+
+The [`apply` CLI](/api/cli#apply) can be used to apply a pipeline to one or more
+`.txt`, `.jsonl` or `.spacy` input files, saving the annotated docs in a single
+`.spacy` file.
+
+```bash
+$ spacy apply en_core_web_sm my_texts/ output.spacy
+```
+
+#### benchmark CLI
+
+The [`benchmark` CLI](/api/cli#benchmark) has been added to extend the existing
+`evaluate` functionality with a wider range of profiling subcommands.
+
+The `benchmark accuracy` CLI is introduced as an alias for `evaluate`. The new
+`benchmark speed` CLI performs warmup rounds before measuring the speed in words
+per second on batches of randomly shuffled documents from the provided data.
+
+```bash
+$ spacy benchmark speed my_pipeline data.spacy
+```
+
+The output is the mean performance using batches (`nlp.pipe`) with a 95%
+confidence interval, e.g., profiling `en_core_web_sm` on CPU:
+
+```none
+Outliers: 2.0%, extreme outliers: 0.0%
+Mean: 18904.1 words/s (95% CI: -256.9 +244.1)
+```
+
+#### find-threshold CLI
+
+The [`find-threshold` CLI](/api/cli#find-threshold) runs a series of trials
+across threshold values from `0.0` to `1.0` and identifies the best threshold
+for the provided score metric.
+
+The following command runs 20 trials for the `spancat` component in
+`my_pipeline`, recording the `spans_sc_f` score for each value of the threshold
+`[components.spancat.threshold]` from `0.0` to `1.0`:
+
+```bash
+$ spacy find-threshold my_pipeline data.spacy spancat threshold spans_sc_f --n_trials 20
+```
+
+The `find-threshold` CLI can be used with `textcat_multilabel`, `spancat` and
+custom components with thresholds that are applied while predicting or scoring.
+
+### Fuzzy matching {id="fuzzy"}
+
+New `FUZZY` operators support [fuzzy matching](/usage/rule-based-matching#fuzzy)
+with the `Matcher`. By default, the `FUZZY` operator allows a Levenshtein edit
+distance of 2 and up to 30% of the pattern string length. `FUZZY1`..`FUZZY9` can
+be used to specify the exact number of allowed edits.
+
+```python
+# Match lowercase with fuzzy matching (allows up to 3 edits)
+pattern = [{"LOWER": {"FUZZY": "definitely"}}]
+
+# Match custom attribute values with fuzzy matching (allows up to 3 edits)
+pattern = [{"_": {"country": {"FUZZY": "Kyrgyzstan"}}}]
+
+# Match with exact Levenshtein edit distance limits (allows up to 4 edits)
+pattern = [{"_": {"country": {"FUZZY4": "Kyrgyzstan"}}}]
+```
+
+Note that `FUZZY` uses Levenshtein edit distance rather than Damerau-Levenshtein
+edit distance, so a transposition like `teh` for `the` counts as two edits, one
+insertion and one deletion.
+
+If you'd prefer an alternate fuzzy matching algorithm, you can provide your own
+custom method to the `Matcher` or as a config option for an entity ruler and
+span ruler.
+
+### FUZZY and REGEX with lists {id="fuzzy-regex-lists"}
+
+The `FUZZY` and `REGEX` operators are also now supported for lists with `IN` and
+`NOT_IN`:
+
+```python
+pattern = [{"TEXT": {"FUZZY": {"IN": ["awesome", "cool", "wonderful"]}}}]
+pattern = [{"TEXT": {"REGEX": {"NOT_IN": ["^awe(some)?$", "^wonder(ful)?"]}}}]
+```
+
+### Entity linking generalization {id="el"}
+
+The knowledge base used for entity linking is now easier to customize and has a
+new default implementation [`InMemoryLookupKB`](/api/inmemorylookupkb).
+
+### Additional features and improvements {id="additional-features-and-improvements"}
+
+- Language updates:
+ - Extended support for Slovenian
+ - Fixed lookup fallback for French and Catalan lemmatizers
+ - Switch Russian and Ukrainian lemmatizers to `pymorphy3`
+ - Support for editorial punctuation in Ancient Greek
+ - Update to Russian tokenizer exceptions
+ - Small fix for Dutch stop words
+- Allow up to `typer` v0.7.x, `mypy` 0.990 and `typing_extensions` v4.4.x.
+- New `spacy.ConsoleLogger.v3` with expanded progress
+ [tracking](/api/top-level#ConsoleLogger).
+- Improved scoring behavior for `textcat` with `spacy.textcat_scorer.v2` and
+ `spacy.textcat_multilabel_scorer.v2`.
+- Updates so that downstream components can train properly on a frozen `tok2vec`
+ or `transformer` layer.
+- Allow interpolation of variables in directory names in projects.
+- Support for local file system [remotes](/usage/projects#remote) for projects.
+- Improve UX around `displacy.serve` when the default port is in use.
+- Optional `before_update` callback that is invoked at the start of each
+ [training step](/api/data-formats#config-training).
+- Improve performance of `SpanGroup` and fix typing issues for `SpanGroup` and
+ `Span` objects.
+- Patch a
+ [security vulnerability](https://github.com/advisories/GHSA-gw9q-c7gh-j9vm) in
+ extracting tar files.
+- Add equality definition for `Vectors`.
+- Ensure `Vocab.to_disk` respects the exclude setting for `lookups` and
+ `vectors`.
+- Correctly handle missing annotations in the edit tree lemmatizer.
+
+### Trained pipeline updates {id="pipelines"}
+
+- The CNN pipelines add `IS_SPACE` as a `tok2vec` feature for `tagger` and
+ `morphologizer` components to improve tagging of non-whitespace vs. whitespace
+ tokens.
+- The transformer pipelines require `spacy-transformers` v1.2, which uses the
+ exact alignment from `tokenizers` for fast tokenizers instead of the heuristic
+ alignment from `spacy-alignments`. For all trained pipelines except
+ `ja_core_news_trf`, the alignments between spaCy tokens and transformer tokens
+ may be slightly different. More details about the `spacy-transformers` changes
+ in the
+ [v1.2.0 release notes](https://github.com/explosion/spacy-transformers/releases/tag/v1.2.0).
+
+## Notes about upgrading from v3.4 {id="upgrading"}
+
+### Validation of textcat values {id="textcat-validation"}
+
+An error is now raised when unsupported values are given as input to train a
+`textcat` or `textcat_multilabel` model - ensure that values are `0.0` or `1.0`
+as explained in the [docs](/api/textcategorizer#assigned-attributes).
+
+### Updated scorers for tokenization and textcat {id="scores"}
+
+We fixed a bug that inflated the `token_acc` scores in v3.0-v3.4. The reported
+`token_acc` will drop from v3.4 to v3.5, but if `token_p/r/f` stay the same,
+your tokenization performance has not changed from v3.4.
+
+For new `textcat` or `textcat_multilabel` configs, the new default `v2` scorers:
+
+- ignore `threshold` for `textcat`, so the reported `cats_p/r/f` may increase
+ slightly in v3.5 even though the underlying predictions are unchanged
+- report the performance of only the **final** `textcat` or `textcat_multilabel`
+ component in the pipeline by default
+- allow custom scorers to be used to score multiple `textcat` and
+ `textcat_multilabel` components with `Scorer.score_cats` by restricting the
+ evaluation to the component's provided labels
+
+### Pipeline package version compatibility {id="version-compat"}
+
+> #### Using legacy implementations
+>
+> In spaCy v3, you'll still be able to load and reference legacy implementations
+> via [`spacy-legacy`](https://github.com/explosion/spacy-legacy), even if the
+> components or architectures change and newer versions are available in the
+> core library.
+
+When you're loading a pipeline package trained with an earlier version of spaCy
+v3, you will see a warning telling you that the pipeline may be incompatible.
+This doesn't necessarily have to be true, but we recommend running your
+pipelines against your test suite or evaluation data to make sure there are no
+unexpected results.
+
+If you're using one of the [trained pipelines](/models) we provide, you should
+run [`spacy download`](/api/cli#download) to update to the latest version. To
+see an overview of all installed packages and their compatibility, you can run
+[`spacy validate`](/api/cli#validate).
+
+If you've trained your own custom pipeline and you've confirmed that it's still
+working as expected, you can update the spaCy version requirements in the
+[`meta.json`](/api/data-formats#meta):
+
+```diff
+- "spacy_version": ">=3.4.0,<3.5.0",
++ "spacy_version": ">=3.4.0,<3.6.0",
+```
+
+### Updating v3.4 configs
+
+To update a config from spaCy v3.4 with the new v3.5 settings, run
+[`init fill-config`](/api/cli#init-fill-config):
+
+```cli
+$ python -m spacy init fill-config config-v3.4.cfg config-v3.5.cfg
+```
+
+In many cases ([`spacy train`](/api/cli#train),
+[`spacy.load`](/api/top-level#spacy.load)), the new defaults will be filled in
+automatically, but you'll need to fill in the new settings to run
+[`debug config`](/api/cli#debug) and [`debug data`](/api/cli#debug-data).
diff --git a/website/docs/usage/visualizers.mdx b/website/docs/usage/visualizers.mdx
index f1ff6dd3d..1d3682af4 100644
--- a/website/docs/usage/visualizers.mdx
+++ b/website/docs/usage/visualizers.mdx
@@ -437,6 +437,6 @@ Alternatively, if you're using [Streamlit](https://streamlit.io), check out the
helps you integrate spaCy visualizations into your apps. It includes a full
embedded visualizer, as well as individual components.
-
+
diff --git a/website/meta/sidebars.json b/website/meta/sidebars.json
index 339e4085b..b5c555da6 100644
--- a/website/meta/sidebars.json
+++ b/website/meta/sidebars.json
@@ -13,7 +13,8 @@
{ "text": "New in v3.1", "url": "/usage/v3-1" },
{ "text": "New in v3.2", "url": "/usage/v3-2" },
{ "text": "New in v3.3", "url": "/usage/v3-3" },
- { "text": "New in v3.4", "url": "/usage/v3-4" }
+ { "text": "New in v3.4", "url": "/usage/v3-4" },
+ { "text": "New in v3.5", "url": "/usage/v3-5" }
]
},
{
@@ -129,6 +130,7 @@
"items": [
{ "text": "Attributes", "url": "/api/attributes" },
{ "text": "Corpus", "url": "/api/corpus" },
+ { "text": "InMemoryLookupKB", "url": "/api/inmemorylookupkb" },
{ "text": "KnowledgeBase", "url": "/api/kb" },
{ "text": "Lookups", "url": "/api/lookups" },
{ "text": "MorphAnalysis", "url": "/api/morphology#morphanalysis" },
diff --git a/website/meta/site.json b/website/meta/site.json
index 5dcb89443..3d4f2d5ee 100644
--- a/website/meta/site.json
+++ b/website/meta/site.json
@@ -27,7 +27,7 @@
"indexName": "spacy"
},
"binderUrl": "explosion/spacy-io-binder",
- "binderVersion": "3.4",
+ "binderVersion": "3.5",
"sections": [
{ "id": "usage", "title": "Usage Documentation", "theme": "blue" },
{ "id": "models", "title": "Models Documentation", "theme": "blue" },
diff --git a/website/meta/universe.json b/website/meta/universe.json
index f15d461e8..e35a4f045 100644
--- a/website/meta/universe.json
+++ b/website/meta/universe.json
@@ -2381,7 +2381,7 @@
"author": "Nikita Kitaev",
"author_links": {
"github": "nikitakit",
- "website": " http://kitaev.io"
+ "website": "http://kitaev.io"
},
"category": ["research", "pipeline"]
},
diff --git a/website/pages/_app.tsx b/website/pages/_app.tsx
index 8db80a672..a837d9ce8 100644
--- a/website/pages/_app.tsx
+++ b/website/pages/_app.tsx
@@ -17,7 +17,7 @@ export default function App({ Component, pageProps }: AppProps) {
diff --git a/website/pages/index.tsx b/website/pages/index.tsx
index 170bca137..fc0dba378 100644
--- a/website/pages/index.tsx
+++ b/website/pages/index.tsx
@@ -13,7 +13,7 @@ import {
LandingBanner,
} from '../src/components/landing'
import { H2 } from '../src/components/typography'
-import { InlineCode } from '../src/components/code'
+import { InlineCode } from '../src/components/inlineCode'
import { Ul, Li } from '../src/components/list'
import Button from '../src/components/button'
import Link from '../src/components/link'
@@ -89,8 +89,8 @@ const Landing = () => {
-
+
diff --git a/website/src/components/accordion.js b/website/src/components/accordion.js
index 504f415a5..9ff145bd2 100644
--- a/website/src/components/accordion.js
+++ b/website/src/components/accordion.js
@@ -33,7 +33,7 @@ export default function Accordion({ title, id, expanded = false, spaced = false,
event.stopPropagation()}
>
¶
diff --git a/website/src/components/card.js b/website/src/components/card.js
index 9eb597b7b..ef43eb866 100644
--- a/website/src/components/card.js
+++ b/website/src/components/card.js
@@ -1,6 +1,7 @@
import React from 'react'
import PropTypes from 'prop-types'
import classNames from 'classnames'
+import ImageNext from 'next/image'
import Link from './link'
import { H5 } from './typography'
@@ -10,7 +11,7 @@ export default function Card({ title, to, image, header, small, onClick, childre
return (
{image && (
)}
-
+
{children}
+
-
-)
-
-export default CodeBlock
-
-export const Pre = (props) => {
- return
-
{props.children}
-}
-
-export const InlineCode = ({ wrap = false, className, children, ...props }) => {
- const codeClassNames = classNames(classes['inline-code'], className, {
- [classes['wrap']]: wrap || (isString(children) && children.length >= WRAP_THRESHOLD),
- })
- return (
-
- {children}
-
- )
-}
-
-InlineCode.propTypes = {
- wrap: PropTypes.bool,
- className: PropTypes.string,
- children: PropTypes.node,
-}
-
-function linkType(el, showLink = true) {
- if (!isString(el) || !el.length) return el
- const elStr = el.trim()
- if (!elStr) return el
- const typeUrl = CUSTOM_TYPES[elStr]
- const url = typeUrl == true ? DEFAULT_TYPE_URL : typeUrl
- const ws = el[0] == ' '
- return url && showLink ? (
- {props.children}
+}
+
+const CodeBlock = (props) => (
+
+
+)
+export default CodeBlock
diff --git a/website/src/components/codeDynamic.js b/website/src/components/codeDynamic.js
new file mode 100644
index 000000000..8c9483567
--- /dev/null
+++ b/website/src/components/codeDynamic.js
@@ -0,0 +1,5 @@
+import dynamic from 'next/dynamic'
+
+export default dynamic(() => import('./code'), {
+ loading: () =>
+
{data.spacy_version &&
- ))}
+
+