diff --git a/spacy/cli/evaluate.py b/spacy/cli/evaluate.py
index de2e01818..66b22b131 100644
--- a/spacy/cli/evaluate.py
+++ b/spacy/cli/evaluate.py
@@ -68,41 +68,43 @@ def evaluate(
nlp = util.load_model(model)
dev_dataset = list(corpus.dev_dataset(nlp, gold_preproc=gold_preproc))
begin = timer()
- scorer = nlp.evaluate(dev_dataset, verbose=False)
+ scores = nlp.evaluate(dev_dataset, verbose=False)
end = timer()
nwords = sum(len(ex.predicted) for ex in dev_dataset)
- results = {
- "Time": f"{end - begin:.2f} s",
- "Words": nwords,
- "Words/s": f"{nwords / (end - begin):.0f}",
- "TOK": f"{scorer.token_acc:.2f}",
- "TAG": f"{scorer.tags_acc:.2f}",
- "POS": f"{scorer.pos_acc:.2f}",
- "MORPH": f"{scorer.morphs_acc:.2f}",
- "UAS": f"{scorer.uas:.2f}",
- "LAS": f"{scorer.las:.2f}",
- "NER P": f"{scorer.ents_p:.2f}",
- "NER R": f"{scorer.ents_r:.2f}",
- "NER F": f"{scorer.ents_f:.2f}",
- "Textcat AUC": f"{scorer.textcat_auc:.2f}",
- "Textcat F": f"{scorer.textcat_f:.2f}",
- "Sent P": f"{scorer.sent_p:.2f}",
- "Sent R": f"{scorer.sent_r:.2f}",
- "Sent F": f"{scorer.sent_f:.2f}",
+ metrics = {
+ "TOK": "token_acc",
+ "TAG": "tag_acc",
+ "POS": "pos_acc",
+ "MORPH": "morph_acc",
+ "LEMMA": "lemma_acc",
+ "UAS": "dep_uas",
+ "LAS": "dep_las",
+ "NER P": "ents_p",
+ "NER R": "ents_r",
+ "NER F": "ents_f",
+ "Textcat AUC": 'textcat_macro_auc',
+ "Textcat F": 'textcat_macro_f',
+ "Sent P": 'sents_p',
+ "Sent R": 'sents_r',
+ "Sent F": 'sents_f',
}
+ results = {}
+ for metric, key in metrics.items():
+ if key in scores:
+ results[metric] = f"{scores[key]*100:.2f}"
data = {re.sub(r"[\s/]", "_", k.lower()): v for k, v in results.items()}
msg.table(results, title="Results")
- if scorer.ents_per_type:
- data["ents_per_type"] = scorer.ents_per_type
- print_ents_per_type(msg, scorer.ents_per_type)
- if scorer.textcats_f_per_cat:
- data["textcats_f_per_cat"] = scorer.textcats_f_per_cat
- print_textcats_f_per_cat(msg, scorer.textcats_f_per_cat)
- if scorer.textcats_auc_per_cat:
- data["textcats_auc_per_cat"] = scorer.textcats_auc_per_cat
- print_textcats_auc_per_cat(msg, scorer.textcats_auc_per_cat)
+ if "ents_per_type" in scores:
+ if scores["ents_per_type"]:
+ print_ents_per_type(msg, scores["ents_per_type"])
+ if "textcat_f_per_cat" in scores:
+ if scores["textcat_f_per_cat"]:
+ print_textcats_f_per_cat(msg, scores["textcat_f_per_cat"])
+ if "textcat_auc_per_cat" in scores:
+ if scores["textcat_auc_per_cat"]:
+ print_textcats_auc_per_cat(msg, scores["textcat_auc_per_cat"])
if displacy_path:
factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names]
@@ -148,7 +150,7 @@ def render_parses(
def print_ents_per_type(msg: Printer, scores: Dict[str, Dict[str, float]]) -> None:
data = [
- (k, f"{v['p']:.2f}", f"{v['r']:.2f}", f"{v['f']:.2f}")
+ (k, f"{v['p']*100:.2f}", f"{v['r']*100:.2f}", f"{v['f']*100:.2f}")
for k, v in scores.items()
]
msg.table(
@@ -161,7 +163,7 @@ def print_ents_per_type(msg: Printer, scores: Dict[str, Dict[str, float]]) -> No
def print_textcats_f_per_cat(msg: Printer, scores: Dict[str, Dict[str, float]]) -> None:
data = [
- (k, f"{v['p']:.2f}", f"{v['r']:.2f}", f"{v['f']:.2f}")
+ (k, f"{v['p']*100:.2f}", f"{v['r']*100:.2f}", f"{v['f']*100:.2f}")
for k, v in scores.items()
]
msg.table(
@@ -176,7 +178,7 @@ def print_textcats_auc_per_cat(
msg: Printer, scores: Dict[str, Dict[str, float]]
) -> None:
msg.table(
- [(k, f"{v['roc_auc_score']:.2f}") for k, v in scores.items()],
+ [(k, f"{v:.2f}") for k, v in scores.items()],
header=("", "ROC AUC"),
aligns=("l", "r"),
title="Textcat ROC AUC (per label)",
diff --git a/spacy/cli/train.py b/spacy/cli/train.py
index c51aac974..f3580ea10 100644
--- a/spacy/cli/train.py
+++ b/spacy/cli/train.py
@@ -179,6 +179,7 @@ def train(
progress = tqdm.tqdm(total=training["eval_frequency"], leave=False)
except Exception as e:
if output_path is not None:
+ raise e
msg.warn(
f"Aborting and saving the final best model. "
f"Encountered exception: {str(e)}",
@@ -259,12 +260,11 @@ def create_evaluation_callback(
start_time = timer()
if optimizer.averages:
with nlp.use_params(optimizer.averages):
- scorer = nlp.evaluate(dev_examples, batch_size=batch_size)
+ scores = nlp.evaluate(dev_examples, batch_size=batch_size)
else:
- scorer = nlp.evaluate(dev_examples, batch_size=batch_size)
+ scores = nlp.evaluate(dev_examples, batch_size=batch_size)
end_time = timer()
wps = n_words / (end_time - start_time)
- scores = scorer.scores
# Calculate a weighted sum based on score_weights for the main score
weights = cfg["score_weights"]
try:
diff --git a/spacy/default_config.cfg b/spacy/default_config.cfg
index 7e6c7a6ec..db8f03f2a 100644
--- a/spacy/default_config.cfg
+++ b/spacy/default_config.cfg
@@ -40,8 +40,8 @@ seed = 0
accumulate_gradient = 1
use_pytorch_for_gpu_memory = false
# Control how scores are printed and checkpoints are evaluated.
-scores = ["speed", "tags_acc", "uas", "las", "ents_f"]
-score_weights = {"tags_acc": 0.2, "las": 0.4, "ents_f": 0.4}
+scores = ["speed", "tag_acc", "dep_uas", "dep_las", "ents_f"]
+score_weights = {"tag_acc": 0.2, "dep_las": 0.4, "ents_f": 0.4}
# These settings are invalid for the transformer models.
init_tok2vec = null
discard_oversize = false
diff --git a/spacy/gold/example.pyx b/spacy/gold/example.pyx
index 355578de3..9101cefce 100644
--- a/spacy/gold/example.pyx
+++ b/spacy/gold/example.pyx
@@ -45,6 +45,7 @@ cdef class Example:
def __set__(self, doc):
self.x = doc
+ self._alignment = None
property reference:
def __get__(self):
@@ -52,6 +53,7 @@ cdef class Example:
def __set__(self, doc):
self.y = doc
+ self._alignment = None
def copy(self):
return Example(
diff --git a/spacy/language.py b/spacy/language.py
index 09429a04c..bd816e948 100644
--- a/spacy/language.py
+++ b/spacy/language.py
@@ -1011,10 +1011,13 @@ class Language:
name="language", method="evaluate", types=wrong_types
)
raise TypeError(err)
- if scorer is None:
- scorer = Scorer(pipeline=self.pipeline)
if component_cfg is None:
component_cfg = {}
+ if scorer is None:
+ kwargs = component_cfg.get("scorer", {})
+ kwargs.setdefault("verbose", verbose)
+ kwargs.setdefault("nlp", self)
+ scorer = Scorer(**kwargs)
docs = list(eg.predicted for eg in examples)
for name, pipe in self.pipeline:
kwargs = component_cfg.get(name, {})
@@ -1027,10 +1030,7 @@ class Language:
if verbose:
print(doc)
eg.predicted = doc
- kwargs = component_cfg.get("scorer", {})
- kwargs.setdefault("verbose", verbose)
- scorer.score(eg, **kwargs)
- return scorer
+ return scorer.score(examples)
@contextmanager
def use_params(self, params: dict, **cfg):
diff --git a/spacy/pipeline/dep_parser.pyx b/spacy/pipeline/dep_parser.pyx
index 1651119f8..78926a984 100644
--- a/spacy/pipeline/dep_parser.pyx
+++ b/spacy/pipeline/dep_parser.pyx
@@ -8,6 +8,7 @@ from ..syntax.arc_eager cimport ArcEager
from .functions import merge_subtokens
from ..language import Language
from ..syntax import nonproj
+from ..scorer import Scorer
default_model_config = """
@@ -102,3 +103,14 @@ cdef class DependencyParser(Parser):
label = label.split("||")[1]
labels.add(label)
return tuple(sorted(labels))
+
+ def score(self, examples, **kwargs):
+ def dep_getter(token, attr):
+ dep = getattr(token, attr)
+ dep = token.vocab.strings.as_string(dep).lower()
+ return dep
+ results = {}
+ results.update(Scorer.score_spans(examples, "sents", **kwargs))
+ results.update(Scorer.score_deps(examples, "dep", getter=dep_getter,
+ ignore_labels=("p", "punct"), **kwargs))
+ return results
diff --git a/spacy/pipeline/morphologizer.pyx b/spacy/pipeline/morphologizer.pyx
index a5a54f139..fb80a9d86 100644
--- a/spacy/pipeline/morphologizer.pyx
+++ b/spacy/pipeline/morphologizer.pyx
@@ -14,6 +14,7 @@ from ..errors import Errors
from .pipe import deserialize_config
from .tagger import Tagger
from .. import util
+from ..scorer import Scorer
default_model_config = """
@@ -162,6 +163,14 @@ class Morphologizer(Tagger):
raise ValueError("nan value when computing loss")
return float(loss), d_scores
+ def score(self, examples, **kwargs):
+ results = {}
+ results.update(Scorer.score_token_attr(examples, "pos", **kwargs))
+ results.update(Scorer.score_token_attr(examples, "morph", **kwargs))
+ results.update(Scorer.score_token_attr_per_feat(examples,
+ "morph", **kwargs))
+ return results
+
def to_bytes(self, exclude=tuple()):
serialize = {}
serialize["model"] = self.model.to_bytes
diff --git a/spacy/pipeline/ner.pyx b/spacy/pipeline/ner.pyx
index ea904f69e..cb2ca89d8 100644
--- a/spacy/pipeline/ner.pyx
+++ b/spacy/pipeline/ner.pyx
@@ -6,6 +6,7 @@ from ..syntax.nn_parser cimport Parser
from ..syntax.ner cimport BiluoPushDown
from ..language import Language
+from ..scorer import Scorer
default_model_config = """
@@ -88,3 +89,6 @@ cdef class EntityRecognizer(Parser):
labels = set(move.split("-")[1] for move in self.move_names
if move[0] in ("B", "I", "L", "U"))
return tuple(sorted(labels))
+
+ def score(self, examples, **kwargs):
+ return Scorer.score_spans(examples, "ents", **kwargs)
diff --git a/spacy/pipeline/pipe.pyx b/spacy/pipeline/pipe.pyx
index 5fa7d82db..e7702aa59 100644
--- a/spacy/pipeline/pipe.pyx
+++ b/spacy/pipeline/pipe.pyx
@@ -117,6 +117,9 @@ class Pipe:
with self.model.use_params(params):
yield
+ def score(self, examples, **kwargs):
+ return {}
+
def to_bytes(self, exclude=tuple()):
"""Serialize the pipe to a bytestring.
diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx
new file mode 100644
index 000000000..5c71a33e7
--- /dev/null
+++ b/spacy/pipeline/pipes.pyx
@@ -0,0 +1,1519 @@
+# cython: infer_types=True, profile=True
+import numpy
+import srsly
+import random
+
+from thinc.api import CosineDistance, to_categorical, get_array_module
+from thinc.api import set_dropout_rate, SequenceCategoricalCrossentropy
+import warnings
+
+from ..tokens.doc cimport Doc
+from ..syntax.nn_parser cimport Parser
+from ..syntax.ner cimport BiluoPushDown
+from ..syntax.arc_eager cimport ArcEager
+from ..morphology cimport Morphology
+from ..vocab cimport Vocab
+
+from .defaults import default_tagger, default_parser, default_ner, default_textcat
+from .defaults import default_nel, default_senter
+from .functions import merge_subtokens
+from ..language import Language, component
+from ..syntax import nonproj
+from ..gold.example import Example
+from ..attrs import POS, ID
+from ..util import link_vectors_to_models, create_default_optimizer
+from ..parts_of_speech import X
+from ..kb import KnowledgeBase
+from ..errors import Errors, TempErrors, Warnings
+from .. import util
+from ..scorer import Scorer
+
+
+def _load_cfg(path):
+ if path.exists():
+ return srsly.read_json(path)
+ else:
+ return {}
+
+
+class Pipe:
+ """This class is not instantiated directly. Components inherit from it, and
+ it defines the interface that components should follow to function as
+ components in a spaCy analysis pipeline.
+ """
+
+ name = None
+
+ @classmethod
+ def from_nlp(cls, nlp, model, **cfg):
+ return cls(nlp.vocab, model, **cfg)
+
+ def __init__(self, vocab, model, **cfg):
+ """Create a new pipe instance."""
+ raise NotImplementedError
+
+ def __call__(self, Doc doc):
+ """Apply the pipe to one document. The document is
+ modified in-place, and returned.
+
+ Both __call__ and pipe should delegate to the `predict()`
+ and `set_annotations()` methods.
+ """
+ scores = self.predict([doc])
+ self.set_annotations([doc], scores)
+ return doc
+
+ def pipe(self, stream, batch_size=128):
+ """Apply the pipe to a stream of documents.
+
+ Both __call__ and pipe should delegate to the `predict()`
+ and `set_annotations()` methods.
+ """
+ for docs in util.minibatch(stream, size=batch_size):
+ scores = self.predict(docs)
+ self.set_annotations(docs, scores)
+ yield from docs
+
+ def predict(self, docs):
+ """Apply the pipeline's model to a batch of docs, without
+ modifying them.
+ """
+ raise NotImplementedError
+
+ def set_annotations(self, docs, scores):
+ """Modify a batch of documents, using pre-computed scores."""
+ raise NotImplementedError
+
+ def rehearse(self, examples, sgd=None, losses=None, **config):
+ pass
+
+ def get_loss(self, examples, scores):
+ """Find the loss and gradient of loss for the batch of
+ examples (with embedded docs) and their predicted scores."""
+ raise NotImplementedError
+
+ def add_label(self, label):
+ """Add an output label, to be predicted by the model.
+
+ It's possible to extend pretrained models with new labels,
+ but care should be taken to avoid the "catastrophic forgetting"
+ problem.
+ """
+ raise NotImplementedError
+
+ def create_optimizer(self):
+ return create_default_optimizer()
+
+ def begin_training(
+ self, get_examples=lambda: [], pipeline=None, sgd=None, **kwargs
+ ):
+ """Initialize the pipe for training, using data exampes if available.
+ If no model has been initialized yet, the model is added."""
+ self.model.initialize()
+ if hasattr(self, "vocab"):
+ link_vectors_to_models(self.vocab)
+ if sgd is None:
+ sgd = self.create_optimizer()
+ return sgd
+
+ def set_output(self, nO):
+ if self.model.has_dim("nO") is not False:
+ self.model.set_dim("nO", nO)
+ if self.model.has_ref("output_layer"):
+ self.model.get_ref("output_layer").set_dim("nO", nO)
+
+ def get_gradients(self):
+ """Get non-zero gradients of the model's parameters, as a dictionary
+ keyed by the parameter ID. The values are (weights, gradients) tuples.
+ """
+ gradients = {}
+ queue = [self.model]
+ seen = set()
+ for node in queue:
+ if node.id in seen:
+ continue
+ seen.add(node.id)
+ if hasattr(node, "_mem") and node._mem.gradient.any():
+ gradients[node.id] = [node._mem.weights, node._mem.gradient]
+ if hasattr(node, "_layers"):
+ queue.extend(node._layers)
+ return gradients
+
+ def use_params(self, params):
+ """Modify the pipe's model, to use the given parameter values."""
+ with self.model.use_params(params):
+ yield
+
+ def score(self, examples, **kwargs):
+ return {}
+
+ def to_bytes(self, exclude=tuple()):
+ """Serialize the pipe to a bytestring.
+
+ exclude (list): String names of serialization fields to exclude.
+ RETURNS (bytes): The serialized object.
+ """
+ serialize = {}
+ serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
+ serialize["model"] = self.model.to_bytes
+ if hasattr(self, "vocab"):
+ serialize["vocab"] = self.vocab.to_bytes
+ return util.to_bytes(serialize, exclude)
+
+ def from_bytes(self, bytes_data, exclude=tuple()):
+ """Load the pipe from a bytestring."""
+
+ def load_model(b):
+ try:
+ self.model.from_bytes(b)
+ except AttributeError:
+ raise ValueError(Errors.E149)
+
+ deserialize = {}
+ if hasattr(self, "vocab"):
+ deserialize["vocab"] = lambda b: self.vocab.from_bytes(b)
+ deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
+ deserialize["model"] = load_model
+ util.from_bytes(bytes_data, deserialize, exclude)
+ return self
+
+ def to_disk(self, path, exclude=tuple()):
+ """Serialize the pipe to disk."""
+ serialize = {}
+ serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
+ serialize["vocab"] = lambda p: self.vocab.to_disk(p)
+ serialize["model"] = lambda p: self.model.to_disk(p)
+ util.to_disk(path, serialize, exclude)
+
+ def from_disk(self, path, exclude=tuple()):
+ """Load the pipe from disk."""
+
+ def load_model(p):
+ try:
+ self.model.from_bytes(p.open("rb").read())
+ except AttributeError:
+ raise ValueError(Errors.E149)
+
+ deserialize = {}
+ deserialize["vocab"] = lambda p: self.vocab.from_disk(p)
+ deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p))
+ deserialize["model"] = load_model
+ util.from_disk(path, deserialize, exclude)
+ return self
+
+
+@component("tagger", assigns=["token.tag", "token.pos", "token.lemma"], default_model=default_tagger)
+class Tagger(Pipe):
+ """Pipeline component for part-of-speech tagging.
+
+ DOCS: https://spacy.io/api/tagger
+ """
+
+ def __init__(self, vocab, model, **cfg):
+ self.vocab = vocab
+ self.model = model
+ self._rehearsal_model = None
+ self.cfg = dict(sorted(cfg.items()))
+
+ @property
+ def labels(self):
+ return tuple(self.vocab.morphology.tag_names)
+
+ def __call__(self, doc):
+ tags = self.predict([doc])
+ self.set_annotations([doc], tags)
+ return doc
+
+ def pipe(self, stream, batch_size=128):
+ for docs in util.minibatch(stream, size=batch_size):
+ tag_ids = self.predict(docs)
+ self.set_annotations(docs, tag_ids)
+ yield from docs
+
+ def predict(self, docs):
+ if not any(len(doc) for doc in docs):
+ # Handle cases where there are no tokens in any docs.
+ n_labels = len(self.labels)
+ guesses = [self.model.ops.alloc((0, n_labels)) for doc in docs]
+ assert len(guesses) == len(docs)
+ return guesses
+ scores = self.model.predict(docs)
+ assert len(scores) == len(docs), (len(scores), len(docs))
+ guesses = self._scores2guesses(scores)
+ assert len(guesses) == len(docs)
+ return guesses
+
+ def _scores2guesses(self, scores):
+ guesses = []
+ for doc_scores in scores:
+ doc_guesses = doc_scores.argmax(axis=1)
+ if not isinstance(doc_guesses, numpy.ndarray):
+ doc_guesses = doc_guesses.get()
+ guesses.append(doc_guesses)
+ return guesses
+
+ def set_annotations(self, docs, batch_tag_ids):
+ if isinstance(docs, Doc):
+ docs = [docs]
+ cdef Doc doc
+ cdef int idx = 0
+ cdef Vocab vocab = self.vocab
+ assign_morphology = self.cfg.get("set_morphology", True)
+ for i, doc in enumerate(docs):
+ doc_tag_ids = batch_tag_ids[i]
+ if hasattr(doc_tag_ids, "get"):
+ doc_tag_ids = doc_tag_ids.get()
+ for j, tag_id in enumerate(doc_tag_ids):
+ # Don't clobber preset POS tags
+ if doc.c[j].tag == 0:
+ if doc.c[j].pos == 0 and assign_morphology:
+ # Don't clobber preset lemmas
+ lemma = doc.c[j].lemma
+ vocab.morphology.assign_tag_id(&doc.c[j], tag_id)
+ if lemma != 0 and lemma != doc.c[j].lex.orth:
+ doc.c[j].lemma = lemma
+ else:
+ doc.c[j].tag = self.vocab.strings[self.labels[tag_id]]
+ idx += 1
+ doc.is_tagged = True
+
+ def update(self, examples, *, drop=0., sgd=None, losses=None, set_annotations=False):
+ if losses is None:
+ losses = {}
+ losses.setdefault(self.name, 0.0)
+
+ try:
+ if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
+ # Handle cases where there are no tokens in any docs.
+ return
+ except AttributeError:
+ types = set([type(eg) for eg in examples])
+ raise TypeError(Errors.E978.format(name="Tagger", method="update", types=types))
+ set_dropout_rate(self.model, drop)
+ tag_scores, bp_tag_scores = self.model.begin_update(
+ [eg.predicted for eg in examples])
+ for sc in tag_scores:
+ if self.model.ops.xp.isnan(sc.sum()):
+ raise ValueError("nan value in scores")
+ loss, d_tag_scores = self.get_loss(examples, tag_scores)
+ bp_tag_scores(d_tag_scores)
+ if sgd not in (None, False):
+ self.model.finish_update(sgd)
+
+ losses[self.name] += loss
+ if set_annotations:
+ docs = [eg.predicted for eg in examples]
+ self.set_annotations(docs, self._scores2guesses(tag_scores))
+ return losses
+
+ def rehearse(self, examples, drop=0., sgd=None, losses=None):
+ """Perform a 'rehearsal' update, where we try to match the output of
+ an initial model.
+ """
+ try:
+ docs = [eg.predicted for eg in examples]
+ except AttributeError:
+ types = set([type(eg) for eg in examples])
+ raise TypeError(Errors.E978.format(name="Tagger", method="rehearse", types=types))
+ if self._rehearsal_model is None:
+ return
+ if not any(len(doc) for doc in docs):
+ # Handle cases where there are no tokens in any docs.
+ return
+ set_dropout_rate(self.model, drop)
+ guesses, backprop = self.model.begin_update(docs)
+ target = self._rehearsal_model(examples)
+ gradient = guesses - target
+ backprop(gradient)
+ self.model.finish_update(sgd)
+ if losses is not None:
+ losses.setdefault(self.name, 0.0)
+ losses[self.name] += (gradient**2).sum()
+
+ def get_loss(self, examples, scores):
+ loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False)
+ truths = [eg.get_aligned("tag", as_string=True) for eg in examples]
+ d_scores, loss = loss_func(scores, truths)
+ if self.model.ops.xp.isnan(loss):
+ raise ValueError("nan value when computing loss")
+ return float(loss), d_scores
+
+ def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None,
+ **kwargs):
+ lemma_tables = ["lemma_rules", "lemma_index", "lemma_exc", "lemma_lookup"]
+ if not any(table in self.vocab.lookups for table in lemma_tables):
+ warnings.warn(Warnings.W022)
+ if len(self.vocab.lookups.get_table("lexeme_norm", {})) == 0:
+ warnings.warn(Warnings.W033.format(model="part-of-speech tagger"))
+ orig_tag_map = dict(self.vocab.morphology.tag_map)
+ new_tag_map = {}
+ for example in get_examples():
+ try:
+ y = example.y
+ except AttributeError:
+ raise TypeError(Errors.E978.format(name="Tagger", method="begin_training", types=type(example)))
+ for token in y:
+ tag = token.tag_
+ if tag in orig_tag_map:
+ new_tag_map[tag] = orig_tag_map[tag]
+ else:
+ new_tag_map[tag] = {POS: X}
+
+ cdef Vocab vocab = self.vocab
+ if new_tag_map:
+ if "_SP" in orig_tag_map:
+ new_tag_map["_SP"] = orig_tag_map["_SP"]
+ vocab.morphology = Morphology(vocab.strings, new_tag_map,
+ vocab.morphology.lemmatizer,
+ exc=vocab.morphology.exc)
+ self.set_output(len(self.labels))
+ doc_sample = [Doc(self.vocab, words=["hello", "world"])]
+ if pipeline is not None:
+ for name, component in pipeline:
+ if component is self:
+ break
+ if hasattr(component, "pipe"):
+ doc_sample = list(component.pipe(doc_sample))
+ else:
+ doc_sample = [component(doc) for doc in doc_sample]
+ self.model.initialize(X=doc_sample)
+ # Get batch of example docs, example outputs to call begin_training().
+ # This lets the model infer shapes.
+ link_vectors_to_models(self.vocab)
+ if sgd is None:
+ sgd = self.create_optimizer()
+ return sgd
+
+ def add_label(self, label, values=None):
+ if not isinstance(label, str):
+ raise ValueError(Errors.E187)
+ if label in self.labels:
+ return 0
+ if self.model.has_dim("nO"):
+ # Here's how the model resizing will work, once the
+ # neuron-to-tag mapping is no longer controlled by
+ # the Morphology class, which sorts the tag names.
+ # The sorting makes adding labels difficult.
+ # smaller = self.model._layers[-1]
+ # larger = Softmax(len(self.labels)+1, smaller.nI)
+ # copy_array(larger.W[:smaller.nO], smaller.W)
+ # copy_array(larger.b[:smaller.nO], smaller.b)
+ # self.model._layers[-1] = larger
+ raise ValueError(TempErrors.T003)
+ tag_map = dict(self.vocab.morphology.tag_map)
+ if values is None:
+ values = {POS: "X"}
+ tag_map[label] = values
+ self.vocab.morphology = Morphology(
+ self.vocab.strings, tag_map=tag_map,
+ lemmatizer=self.vocab.morphology.lemmatizer,
+ exc=self.vocab.morphology.exc)
+ return 1
+
+ def use_params(self, params):
+ with self.model.use_params(params):
+ yield
+
+ def score(self, examples, **kwargs):
+ scores = {}
+ scores.update(Scorer.score_token_attr(examples, "tag", **kwargs))
+ scores.update(Scorer.score_token_attr(examples, "pos", **kwargs))
+ scores.update(Scorer.score_token_attr(examples, "lemma", **kwargs))
+ return scores
+
+ def to_bytes(self, exclude=tuple()):
+ serialize = {}
+ serialize["model"] = self.model.to_bytes
+ serialize["vocab"] = self.vocab.to_bytes
+ serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
+ tag_map = dict(sorted(self.vocab.morphology.tag_map.items()))
+ serialize["tag_map"] = lambda: srsly.msgpack_dumps(tag_map)
+ return util.to_bytes(serialize, exclude)
+
+ def from_bytes(self, bytes_data, exclude=tuple()):
+ def load_model(b):
+ try:
+ self.model.from_bytes(b)
+ except AttributeError:
+ raise ValueError(Errors.E149)
+
+ def load_tag_map(b):
+ tag_map = srsly.msgpack_loads(b)
+ self.vocab.morphology = Morphology(
+ self.vocab.strings, tag_map=tag_map,
+ lemmatizer=self.vocab.morphology.lemmatizer,
+ exc=self.vocab.morphology.exc)
+
+ deserialize = {
+ "vocab": lambda b: self.vocab.from_bytes(b),
+ "tag_map": load_tag_map,
+ "cfg": lambda b: self.cfg.update(srsly.json_loads(b)),
+ "model": lambda b: load_model(b),
+ }
+ util.from_bytes(bytes_data, deserialize, exclude)
+ return self
+
+ def to_disk(self, path, exclude=tuple()):
+ tag_map = dict(sorted(self.vocab.morphology.tag_map.items()))
+ serialize = {
+ "vocab": lambda p: self.vocab.to_disk(p),
+ "tag_map": lambda p: srsly.write_msgpack(p, tag_map),
+ "model": lambda p: self.model.to_disk(p),
+ "cfg": lambda p: srsly.write_json(p, self.cfg),
+ }
+ util.to_disk(path, serialize, exclude)
+
+ def from_disk(self, path, exclude=tuple()):
+ def load_model(p):
+ with p.open("rb") as file_:
+ try:
+ self.model.from_bytes(file_.read())
+ except AttributeError:
+ raise ValueError(Errors.E149)
+
+ def load_tag_map(p):
+ tag_map = srsly.read_msgpack(p)
+ self.vocab.morphology = Morphology(
+ self.vocab.strings, tag_map=tag_map,
+ lemmatizer=self.vocab.morphology.lemmatizer,
+ exc=self.vocab.morphology.exc)
+
+ deserialize = {
+ "vocab": lambda p: self.vocab.from_disk(p),
+ "cfg": lambda p: self.cfg.update(_load_cfg(p)),
+ "tag_map": load_tag_map,
+ "model": load_model,
+ }
+ util.from_disk(path, deserialize, exclude)
+ return self
+
+
+@component("senter", assigns=["token.is_sent_start"], default_model=default_senter)
+class SentenceRecognizer(Tagger):
+ """Pipeline component for sentence segmentation.
+
+ DOCS: https://spacy.io/api/sentencerecognizer
+ """
+
+ def __init__(self, vocab, model, **cfg):
+ self.vocab = vocab
+ self.model = model
+ self._rehearsal_model = None
+ self.cfg = dict(sorted(cfg.items()))
+
+ @property
+ def labels(self):
+ # labels are numbered by index internally, so this matches GoldParse
+ # and Example where the sentence-initial tag is 1 and other positions
+ # are 0
+ return tuple(["I", "S"])
+
+ def set_annotations(self, docs, batch_tag_ids):
+ if isinstance(docs, Doc):
+ docs = [docs]
+ cdef Doc doc
+ for i, doc in enumerate(docs):
+ doc_tag_ids = batch_tag_ids[i]
+ if hasattr(doc_tag_ids, "get"):
+ doc_tag_ids = doc_tag_ids.get()
+ for j, tag_id in enumerate(doc_tag_ids):
+ # Don't clobber existing sentence boundaries
+ if doc.c[j].sent_start == 0:
+ if tag_id == 1:
+ doc.c[j].sent_start = 1
+ else:
+ doc.c[j].sent_start = -1
+
+ def get_loss(self, examples, scores):
+ labels = self.labels
+ loss_func = SequenceCategoricalCrossentropy(names=labels, normalize=False)
+ truths = []
+ for eg in examples:
+ eg_truth = []
+ for x in eg.get_aligned("sent_start"):
+ if x == None:
+ eg_truth.append(None)
+ elif x == 1:
+ eg_truth.append(labels[1])
+ else:
+ # anything other than 1: 0, -1, -1 as uint64
+ eg_truth.append(labels[0])
+ truths.append(eg_truth)
+ d_scores, loss = loss_func(scores, truths)
+ if self.model.ops.xp.isnan(loss):
+ raise ValueError("nan value when computing loss")
+ return float(loss), d_scores
+
+ def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None,
+ **kwargs):
+ self.set_output(len(self.labels))
+ self.model.initialize()
+ link_vectors_to_models(self.vocab)
+ if sgd is None:
+ sgd = self.create_optimizer()
+ return sgd
+
+ def add_label(self, label, values=None):
+ raise NotImplementedError
+
+ def score(self, examples, **kwargs):
+ return Scorer.score_spans(examples, "sents", **kwargs)
+
+ def to_bytes(self, exclude=tuple()):
+ serialize = {}
+ serialize["model"] = self.model.to_bytes
+ serialize["vocab"] = self.vocab.to_bytes
+ serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
+ return util.to_bytes(serialize, exclude)
+
+ def from_bytes(self, bytes_data, exclude=tuple()):
+ def load_model(b):
+ try:
+ self.model.from_bytes(b)
+ except AttributeError:
+ raise ValueError(Errors.E149)
+
+ deserialize = {
+ "vocab": lambda b: self.vocab.from_bytes(b),
+ "cfg": lambda b: self.cfg.update(srsly.json_loads(b)),
+ "model": lambda b: load_model(b),
+ }
+ util.from_bytes(bytes_data, deserialize, exclude)
+ return self
+
+ def to_disk(self, path, exclude=tuple()):
+ serialize = {
+ "vocab": lambda p: self.vocab.to_disk(p),
+ "model": lambda p: p.open("wb").write(self.model.to_bytes()),
+ "cfg": lambda p: srsly.write_json(p, self.cfg),
+ }
+ util.to_disk(path, serialize, exclude)
+
+ def from_disk(self, path, exclude=tuple()):
+ def load_model(p):
+ with p.open("rb") as file_:
+ try:
+ self.model.from_bytes(file_.read())
+ except AttributeError:
+ raise ValueError(Errors.E149)
+
+ deserialize = {
+ "vocab": lambda p: self.vocab.from_disk(p),
+ "cfg": lambda p: self.cfg.update(_load_cfg(p)),
+ "model": load_model,
+ }
+ util.from_disk(path, deserialize, exclude)
+ return self
+
+
+@component("nn_labeller")
+class MultitaskObjective(Tagger):
+ """Experimental: Assist training of a parser or tagger, by training a
+ side-objective.
+ """
+
+ def __init__(self, vocab, model, **cfg):
+ self.vocab = vocab
+ self.model = model
+ target = cfg["target"] # default: 'dep_tag_offset'
+ if target == "dep":
+ self.make_label = self.make_dep
+ elif target == "tag":
+ self.make_label = self.make_tag
+ elif target == "ent":
+ self.make_label = self.make_ent
+ elif target == "dep_tag_offset":
+ self.make_label = self.make_dep_tag_offset
+ elif target == "ent_tag":
+ self.make_label = self.make_ent_tag
+ elif target == "sent_start":
+ self.make_label = self.make_sent_start
+ elif hasattr(target, "__call__"):
+ self.make_label = target
+ else:
+ raise ValueError(Errors.E016)
+ self.cfg = dict(cfg)
+
+ @property
+ def labels(self):
+ return self.cfg.setdefault("labels", {})
+
+ @labels.setter
+ def labels(self, value):
+ self.cfg["labels"] = value
+
+ def set_annotations(self, docs, dep_ids):
+ pass
+
+ def begin_training(self, get_examples=lambda: [], pipeline=None,
+ sgd=None, **kwargs):
+ gold_examples = nonproj.preprocess_training_data(get_examples())
+ # for raw_text, doc_annot in gold_tuples:
+ for example in gold_examples:
+ for token in example.y:
+ label = self.make_label(token)
+ if label is not None and label not in self.labels:
+ self.labels[label] = len(self.labels)
+ self.model.initialize()
+ link_vectors_to_models(self.vocab)
+ if sgd is None:
+ sgd = self.create_optimizer()
+ return sgd
+
+ def predict(self, docs):
+ tokvecs = self.model.get_ref("tok2vec")(docs)
+ scores = self.model.get_ref("softmax")(tokvecs)
+ return tokvecs, scores
+
+ def get_loss(self, examples, scores):
+ cdef int idx = 0
+ correct = numpy.zeros((scores.shape[0],), dtype="i")
+ guesses = scores.argmax(axis=1)
+ docs = [eg.predicted for eg in examples]
+ for i, eg in enumerate(examples):
+ # Handles alignment for tokenization differences
+ doc_annots = eg.get_aligned() # TODO
+ for j in range(len(eg.predicted)):
+ tok_annots = {key: values[j] for key, values in tok_annots.items()}
+ label = self.make_label(j, tok_annots)
+ if label is None or label not in self.labels:
+ correct[idx] = guesses[idx]
+ else:
+ correct[idx] = self.labels[label]
+ idx += 1
+ correct = self.model.ops.xp.array(correct, dtype="i")
+ d_scores = scores - to_categorical(correct, n_classes=scores.shape[1])
+ loss = (d_scores**2).sum()
+ return float(loss), d_scores
+
+ @staticmethod
+ def make_dep(token):
+ return token.dep_
+
+ @staticmethod
+ def make_tag(token):
+ return token.tag_
+
+ @staticmethod
+ def make_ent(token):
+ if token.ent_iob_ == "O":
+ return "O"
+ else:
+ return token.ent_iob_ + "-" + token.ent_type_
+
+ @staticmethod
+ def make_dep_tag_offset(token):
+ dep = token.dep_
+ tag = token.tag_
+ offset = token.head.i - token.i
+ offset = min(offset, 2)
+ offset = max(offset, -2)
+ return f"{dep}-{tag}:{offset}"
+
+ @staticmethod
+ def make_ent_tag(token):
+ if token.ent_iob_ == "O":
+ ent = "O"
+ else:
+ ent = token.ent_iob_ + "-" + token.ent_type_
+ tag = token.tag_
+ return f"{tag}-{ent}"
+
+ @staticmethod
+ def make_sent_start(token):
+ """A multi-task objective for representing sentence boundaries,
+ using BILU scheme. (O is impossible)
+ """
+ if token.is_sent_start and token.is_sent_end:
+ return "U-SENT"
+ elif token.is_sent_start:
+ return "B-SENT"
+ else:
+ return "I-SENT"
+
+
+class ClozeMultitask(Pipe):
+ def __init__(self, vocab, model, **cfg):
+ self.vocab = vocab
+ self.model = model
+ self.cfg = cfg
+ self.distance = CosineDistance(ignore_zeros=True, normalize=False) # TODO: in config
+
+ def set_annotations(self, docs, dep_ids):
+ pass
+
+ def begin_training(self, get_examples=lambda: [], pipeline=None,
+ sgd=None, **kwargs):
+ link_vectors_to_models(self.vocab)
+ self.model.initialize()
+ X = self.model.ops.alloc((5, self.model.get_ref("tok2vec").get_dim("nO")))
+ self.model.output_layer.begin_training(X)
+ if sgd is None:
+ sgd = self.create_optimizer()
+ return sgd
+
+ def predict(self, docs):
+ tokvecs = self.model.get_ref("tok2vec")(docs)
+ vectors = self.model.get_ref("output_layer")(tokvecs)
+ return tokvecs, vectors
+
+ def get_loss(self, examples, vectors, prediction):
+ # The simplest way to implement this would be to vstack the
+ # token.vector values, but that's a bit inefficient, especially on GPU.
+ # Instead we fetch the index into the vectors table for each of our tokens,
+ # and look them up all at once. This prevents data copying.
+ ids = self.model.ops.flatten([eg.predicted.to_array(ID).ravel() for eg in examples])
+ target = vectors[ids]
+ gradient = self.distance.get_grad(prediction, target)
+ loss = self.distance.get_loss(prediction, target)
+ return loss, gradient
+
+ def update(self, examples, *, drop=0., set_annotations=False, sgd=None, losses=None):
+ pass
+
+ def rehearse(self, examples, drop=0., sgd=None, losses=None):
+ if losses is not None and self.name not in losses:
+ losses[self.name] = 0.
+ set_dropout_rate(self.model, drop)
+ try:
+ predictions, bp_predictions = self.model.begin_update([eg.predicted for eg in examples])
+ except AttributeError:
+ types = set([type(eg) for eg in examples])
+ raise TypeError(Errors.E978.format(name="ClozeMultitask", method="rehearse", types=types))
+ loss, d_predictions = self.get_loss(examples, self.vocab.vectors.data, predictions)
+ bp_predictions(d_predictions)
+ if sgd is not None:
+ self.model.finish_update(sgd)
+
+ if losses is not None:
+ losses[self.name] += loss
+
+
+@component("textcat", assigns=["doc.cats"], default_model=default_textcat)
+class TextCategorizer(Pipe):
+ """Pipeline component for text classification.
+
+ DOCS: https://spacy.io/api/textcategorizer
+ """
+ def __init__(self, vocab, model, **cfg):
+ self.vocab = vocab
+ self.model = model
+ self._rehearsal_model = None
+ self.cfg = dict(cfg)
+
+ @property
+ def labels(self):
+ return tuple(self.cfg.setdefault("labels", []))
+
+ def require_labels(self):
+ """Raise an error if the component's model has no labels defined."""
+ if not self.labels:
+ raise ValueError(Errors.E143.format(name=self.name))
+
+ @labels.setter
+ def labels(self, value):
+ self.cfg["labels"] = tuple(value)
+
+ def pipe(self, stream, batch_size=128):
+ for docs in util.minibatch(stream, size=batch_size):
+ scores = self.predict(docs)
+ self.set_annotations(docs, scores)
+ yield from docs
+
+ def predict(self, docs):
+ tensors = [doc.tensor for doc in docs]
+
+ if not any(len(doc) for doc in docs):
+ # Handle cases where there are no tokens in any docs.
+ xp = get_array_module(tensors)
+ scores = xp.zeros((len(docs), len(self.labels)))
+ return scores
+
+ scores = self.model.predict(docs)
+ scores = self.model.ops.asarray(scores)
+ return scores
+
+ def set_annotations(self, docs, scores):
+ for i, doc in enumerate(docs):
+ for j, label in enumerate(self.labels):
+ doc.cats[label] = float(scores[i, j])
+
+ def update(self, examples, *, drop=0., set_annotations=False, sgd=None, losses=None):
+ if losses is None:
+ losses = {}
+ losses.setdefault(self.name, 0.0)
+ try:
+ if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
+ # Handle cases where there are no tokens in any docs.
+ return losses
+ except AttributeError:
+ types = set([type(eg) for eg in examples])
+ raise TypeError(Errors.E978.format(name="TextCategorizer", method="update", types=types))
+ set_dropout_rate(self.model, drop)
+ scores, bp_scores = self.model.begin_update(
+ [eg.predicted for eg in examples]
+ )
+ loss, d_scores = self.get_loss(examples, scores)
+ bp_scores(d_scores)
+ if sgd is not None:
+ self.model.finish_update(sgd)
+ losses[self.name] += loss
+ if set_annotations:
+ docs = [eg.predicted for eg in examples]
+ self.set_annotations(docs, scores=scores)
+ return losses
+
+ def rehearse(self, examples, drop=0., sgd=None, losses=None):
+ if self._rehearsal_model is None:
+ return
+ try:
+ docs = [eg.predicted for eg in examples]
+ except AttributeError:
+ types = set([type(eg) for eg in examples])
+ raise TypeError(Errors.E978.format(name="TextCategorizer", method="rehearse", types=types))
+ if not any(len(doc) for doc in docs):
+ # Handle cases where there are no tokens in any docs.
+ return
+ set_dropout_rate(self.model, drop)
+ scores, bp_scores = self.model.begin_update(docs)
+ target = self._rehearsal_model(examples)
+ gradient = scores - target
+ bp_scores(gradient)
+ if sgd is not None:
+ self.model.finish_update(sgd)
+ if losses is not None:
+ losses.setdefault(self.name, 0.0)
+ losses[self.name] += (gradient**2).sum()
+
+ def _examples_to_truth(self, examples):
+ truths = numpy.zeros((len(examples), len(self.labels)), dtype="f")
+ not_missing = numpy.ones((len(examples), len(self.labels)), dtype="f")
+ for i, eg in enumerate(examples):
+ for j, label in enumerate(self.labels):
+ if label in eg.reference.cats:
+ truths[i, j] = eg.reference.cats[label]
+ else:
+ not_missing[i, j] = 0.
+ truths = self.model.ops.asarray(truths)
+ return truths, not_missing
+
+ def get_loss(self, examples, scores):
+ truths, not_missing = self._examples_to_truth(examples)
+ not_missing = self.model.ops.asarray(not_missing)
+ d_scores = (scores-truths) / scores.shape[0]
+ d_scores *= not_missing
+ mean_square_error = (d_scores**2).sum(axis=1).mean()
+ return float(mean_square_error), d_scores
+
+ def add_label(self, label):
+ if not isinstance(label, str):
+ raise ValueError(Errors.E187)
+ if label in self.labels:
+ return 0
+ if self.model.has_dim("nO"):
+ # This functionality was available previously, but was broken.
+ # The problem is that we resize the last layer, but the last layer
+ # is actually just an ensemble. We're not resizing the child layers
+ # - a huge problem.
+ raise ValueError(Errors.E116)
+ # smaller = self.model._layers[-1]
+ # larger = Linear(len(self.labels)+1, smaller.nI)
+ # copy_array(larger.W[:smaller.nO], smaller.W)
+ # copy_array(larger.b[:smaller.nO], smaller.b)
+ # self.model._layers[-1] = larger
+ self.labels = tuple(list(self.labels) + [label])
+ return 1
+
+ def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None, **kwargs):
+ # TODO: begin_training is not guaranteed to see all data / labels ?
+ examples = list(get_examples())
+ for example in examples:
+ try:
+ y = example.y
+ except AttributeError:
+ raise TypeError(Errors.E978.format(name="TextCategorizer", method="update", types=type(example)))
+ for cat in y.cats:
+ self.add_label(cat)
+ self.require_labels()
+ docs = [Doc(Vocab(), words=["hello"])]
+ truths, _ = self._examples_to_truth(examples)
+ self.set_output(len(self.labels))
+ link_vectors_to_models(self.vocab)
+ self.model.initialize(X=docs, Y=truths)
+ if sgd is None:
+ sgd = self.create_optimizer()
+ return sgd
+
+ def score(self, examples, **kwargs):
+ return Scorer.score_cats(examples, "cats", labels=self.labels,
+ multi_label=self.model.attrs["multi_label"],
+ positive_label=self.cfg.get("positive_label", None),
+ **kwargs
+ )
+
+
+cdef class DependencyParser(Parser):
+ """Pipeline component for dependency parsing.
+
+ DOCS: https://spacy.io/api/dependencyparser
+ """
+ # cdef classes can't have decorators, so we're defining this here
+ name = "parser"
+ factory = "parser"
+ assigns = ["token.dep", "token.is_sent_start", "doc.sents"]
+ requires = []
+ TransitionSystem = ArcEager
+
+ @property
+ def postprocesses(self):
+ output = [nonproj.deprojectivize]
+ if self.cfg.get("learn_tokens") is True:
+ output.append(merge_subtokens)
+ return tuple(output)
+
+ def add_multitask_objective(self, mt_component):
+ self._multitasks.append(mt_component)
+
+ def init_multitask_objectives(self, get_examples, pipeline, sgd=None, **cfg):
+ # TODO: transfer self.model.get_ref("tok2vec") to the multitask's model ?
+ for labeller in self._multitasks:
+ labeller.model.set_dim("nO", len(self.labels))
+ if labeller.model.has_ref("output_layer"):
+ labeller.model.get_ref("output_layer").set_dim("nO", len(self.labels))
+ labeller.begin_training(get_examples, pipeline=pipeline, sgd=sgd)
+
+ def __reduce__(self):
+ return (DependencyParser, (self.vocab, self.model), (self.moves, self.cfg))
+
+ def __getstate__(self):
+ return (self.moves, self.cfg)
+
+ def __setstate__(self, state):
+ moves, config = state
+ self.moves = moves
+ self.cfg = config
+
+ @property
+ def labels(self):
+ labels = set()
+ # Get the labels from the model by looking at the available moves
+ for move in self.move_names:
+ if "-" in move:
+ label = move.split("-")[1]
+ if "||" in label:
+ label = label.split("||")[1]
+ labels.add(label)
+ return tuple(sorted(labels))
+
+ def score(self, examples, **kwargs):
+ def dep_getter(token, attr):
+ dep = getattr(token, attr)
+ dep = token.vocab.strings.as_string(dep).lower()
+ return dep
+ results = {}
+ results.update(Scorer.score_spans(examples, "sents", **kwargs))
+ results.update(Scorer.score_deps(examples, "dep", getter=dep_getter,
+ ignore_labels=("p", "punct"), **kwargs))
+ return results
+
+
+cdef class EntityRecognizer(Parser):
+ """Pipeline component for named entity recognition.
+
+ DOCS: https://spacy.io/api/entityrecognizer
+ """
+ name = "ner"
+ factory = "ner"
+ assigns = ["doc.ents", "token.ent_iob", "token.ent_type"]
+ requires = []
+ TransitionSystem = BiluoPushDown
+
+ def add_multitask_objective(self, mt_component):
+ self._multitasks.append(mt_component)
+
+ def init_multitask_objectives(self, get_examples, pipeline, sgd=None, **cfg):
+ # TODO: transfer self.model.get_ref("tok2vec") to the multitask's model ?
+ for labeller in self._multitasks:
+ labeller.model.set_dim("nO", len(self.labels))
+ if labeller.model.has_ref("output_layer"):
+ labeller.model.get_ref("output_layer").set_dim("nO", len(self.labels))
+ labeller.begin_training(get_examples, pipeline=pipeline)
+
+ def __reduce__(self):
+ return (EntityRecognizer, (self.vocab, self.model), (self.moves, self.cfg))
+
+ def __getstate__(self):
+ return self.moves, self.cfg
+
+ def __setstate__(self, state):
+ moves, config = state
+ self.moves = moves
+ self.cfg = config
+
+ @property
+ def labels(self):
+ # Get the labels from the model by looking at the available moves, e.g.
+ # B-PERSON, I-PERSON, L-PERSON, U-PERSON
+ labels = set(move.split("-")[1] for move in self.move_names
+ if move[0] in ("B", "I", "L", "U"))
+ return tuple(sorted(labels))
+
+ def score(self, examples, **kwargs):
+ return Scorer.score_spans(examples, "ents", **kwargs)
+
+@component(
+ "entity_linker",
+ requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"],
+ assigns=["token.ent_kb_id"],
+ default_model=default_nel,
+)
+class EntityLinker(Pipe):
+ """Pipeline component for named entity linking.
+
+ DOCS: https://spacy.io/api/entitylinker
+ """
+ NIL = "NIL" # string used to refer to a non-existing link
+
+ def __init__(self, vocab, model, **cfg):
+ self.vocab = vocab
+ self.model = model
+ self.kb = None
+ self.kb = cfg.get("kb", None)
+ if self.kb is None:
+ # create an empty KB that should be filled by calling from_disk
+ self.kb = KnowledgeBase(vocab=vocab)
+ else:
+ del cfg["kb"] # we don't want to duplicate its serialization
+ if not isinstance(self.kb, KnowledgeBase):
+ raise ValueError(Errors.E990.format(type=type(self.kb)))
+ self.cfg = dict(cfg)
+ self.distance = CosineDistance(normalize=False)
+ # how many neightbour sentences to take into account
+ self.n_sents = cfg.get("n_sents", 0)
+
+ def require_kb(self):
+ # Raise an error if the knowledge base is not initialized.
+ if len(self.kb) == 0:
+ raise ValueError(Errors.E139.format(name=self.name))
+
+ def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None, **kwargs):
+ self.require_kb()
+ nO = self.kb.entity_vector_length
+ self.set_output(nO)
+ self.model.initialize()
+ if sgd is None:
+ sgd = self.create_optimizer()
+ return sgd
+
+ def update(self, examples, *, set_annotations=False, drop=0.0, sgd=None, losses=None):
+ self.require_kb()
+ if losses is None:
+ losses = {}
+ losses.setdefault(self.name, 0.0)
+ if not examples:
+ return losses
+ sentence_docs = []
+ try:
+ docs = [eg.predicted for eg in examples]
+ except AttributeError:
+ types = set([type(eg) for eg in examples])
+ raise TypeError(Errors.E978.format(name="EntityLinker", method="update", types=types))
+ if set_annotations:
+ # This seems simpler than other ways to get that exact output -- but
+ # it does run the model twice :(
+ predictions = self.model.predict(docs)
+
+ for eg in examples:
+ sentences = [s for s in eg.predicted.sents]
+ kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
+ for ent in eg.predicted.ents:
+ kb_id = kb_ids[ent.start] # KB ID of the first token is the same as the whole span
+ if kb_id:
+ try:
+ # find the sentence in the list of sentences.
+ sent_index = sentences.index(ent.sent)
+ except AttributeError:
+ # Catch the exception when ent.sent is None and provide a user-friendly warning
+ raise RuntimeError(Errors.E030)
+ # get n previous sentences, if there are any
+ start_sentence = max(0, sent_index - self.n_sents)
+
+ # get n posterior sentences, or as many < n as there are
+ end_sentence = min(len(sentences) -1, sent_index + self.n_sents)
+
+ # get token positions
+ start_token = sentences[start_sentence].start
+ end_token = sentences[end_sentence].end
+
+ # append that span as a doc to training
+ sent_doc = eg.predicted[start_token:end_token].as_doc()
+ sentence_docs.append(sent_doc)
+ set_dropout_rate(self.model, drop)
+ if not sentence_docs:
+ warnings.warn(Warnings.W093.format(name="Entity Linker"))
+ return 0.0
+ sentence_encodings, bp_context = self.model.begin_update(sentence_docs)
+ loss, d_scores = self.get_similarity_loss(
+ sentence_encodings=sentence_encodings,
+ examples=examples
+ )
+ bp_context(d_scores)
+ if sgd is not None:
+ self.model.finish_update(sgd)
+
+ losses[self.name] += loss
+ if set_annotations:
+ self.set_annotations(docs, predictions)
+ return losses
+
+ def get_similarity_loss(self, examples, sentence_encodings):
+ entity_encodings = []
+ for eg in examples:
+ kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
+ for ent in eg.predicted.ents:
+ kb_id = kb_ids[ent.start]
+ if kb_id:
+ entity_encoding = self.kb.get_vector(kb_id)
+ entity_encodings.append(entity_encoding)
+
+ entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
+
+ if sentence_encodings.shape != entity_encodings.shape:
+ raise RuntimeError(Errors.E147.format(method="get_similarity_loss", msg="gold entities do not match up"))
+
+ gradients = self.distance.get_grad(sentence_encodings, entity_encodings)
+ loss = self.distance.get_loss(sentence_encodings, entity_encodings)
+ loss = loss / len(entity_encodings)
+ return loss, gradients
+
+ def __call__(self, doc):
+ kb_ids = self.predict([doc])
+ self.set_annotations([doc], kb_ids)
+ return doc
+
+ def pipe(self, stream, batch_size=128):
+ for docs in util.minibatch(stream, size=batch_size):
+ kb_ids = self.predict(docs)
+ self.set_annotations(docs, kb_ids)
+ yield from docs
+
+ def predict(self, docs):
+ """ Return the KB IDs for each entity in each doc, including NIL if there is no prediction """
+ self.require_kb()
+ entity_count = 0
+ final_kb_ids = []
+
+ if not docs:
+ return final_kb_ids
+
+ if isinstance(docs, Doc):
+ docs = [docs]
+
+ for i, doc in enumerate(docs):
+ sentences = [s for s in doc.sents]
+
+ if len(doc) > 0:
+ # Looping through each sentence and each entity
+ # This may go wrong if there are entities across sentences - which shouldn't happen normally.
+ for sent_index, sent in enumerate(sentences):
+ if sent.ents:
+ # get n_neightbour sentences, clipped to the length of the document
+ start_sentence = max(0, sent_index - self.n_sents)
+ end_sentence = min(len(sentences) -1, sent_index + self.n_sents)
+
+ start_token = sentences[start_sentence].start
+ end_token = sentences[end_sentence].end
+
+ sent_doc = doc[start_token:end_token].as_doc()
+ # currently, the context is the same for each entity in a sentence (should be refined)
+ sentence_encoding = self.model.predict([sent_doc])[0]
+ xp = get_array_module(sentence_encoding)
+ sentence_encoding_t = sentence_encoding.T
+ sentence_norm = xp.linalg.norm(sentence_encoding_t)
+
+ for ent in sent.ents:
+ entity_count += 1
+
+ to_discard = self.cfg.get("labels_discard", [])
+ if to_discard and ent.label_ in to_discard:
+ # ignoring this entity - setting to NIL
+ final_kb_ids.append(self.NIL)
+
+ else:
+ candidates = self.kb.get_candidates(ent.text)
+ if not candidates:
+ # no prediction possible for this entity - setting to NIL
+ final_kb_ids.append(self.NIL)
+
+ elif len(candidates) == 1:
+ # shortcut for efficiency reasons: take the 1 candidate
+
+ # TODO: thresholding
+ final_kb_ids.append(candidates[0].entity_)
+
+ else:
+ random.shuffle(candidates)
+
+ # this will set all prior probabilities to 0 if they should be excluded from the model
+ prior_probs = xp.asarray([c.prior_prob for c in candidates])
+ if not self.cfg.get("incl_prior", True):
+ prior_probs = xp.asarray([0.0 for c in candidates])
+ scores = prior_probs
+
+ # add in similarity from the context
+ if self.cfg.get("incl_context", True):
+ entity_encodings = xp.asarray([c.entity_vector for c in candidates])
+ entity_norm = xp.linalg.norm(entity_encodings, axis=1)
+
+ if len(entity_encodings) != len(prior_probs):
+ raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length"))
+
+ # cosine similarity
+ sims = xp.dot(entity_encodings, sentence_encoding_t) / (sentence_norm * entity_norm)
+ if sims.shape != prior_probs.shape:
+ raise ValueError(Errors.E161)
+ scores = prior_probs + sims - (prior_probs*sims)
+
+ # TODO: thresholding
+ best_index = scores.argmax().item()
+ best_candidate = candidates[best_index]
+ final_kb_ids.append(best_candidate.entity_)
+
+ if not (len(final_kb_ids) == entity_count):
+ raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length"))
+
+ return final_kb_ids
+
+ def set_annotations(self, docs, kb_ids):
+ count_ents = len([ent for doc in docs for ent in doc.ents])
+ if count_ents != len(kb_ids):
+ raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids)))
+
+ i=0
+ for doc in docs:
+ for ent in doc.ents:
+ kb_id = kb_ids[i]
+ i += 1
+ for token in ent:
+ token.ent_kb_id_ = kb_id
+
+ def to_disk(self, path, exclude=tuple()):
+ serialize = {}
+ self.cfg["entity_width"] = self.kb.entity_vector_length
+ serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
+ serialize["vocab"] = lambda p: self.vocab.to_disk(p)
+ serialize["kb"] = lambda p: self.kb.dump(p)
+ serialize["model"] = lambda p: self.model.to_disk(p)
+ util.to_disk(path, serialize, exclude)
+
+ def from_disk(self, path, exclude=tuple()):
+ def load_model(p):
+ try:
+ self.model.from_bytes(p.open("rb").read())
+ except AttributeError:
+ raise ValueError(Errors.E149)
+
+ def load_kb(p):
+ self.kb = KnowledgeBase(vocab=self.vocab, entity_vector_length=self.cfg["entity_width"])
+ self.kb.load_bulk(p)
+
+ deserialize = {}
+ deserialize["vocab"] = lambda p: self.vocab.from_disk(p)
+ deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p))
+ deserialize["kb"] = load_kb
+ deserialize["model"] = load_model
+ util.from_disk(path, deserialize, exclude)
+ return self
+
+ def rehearse(self, examples, sgd=None, losses=None, **config):
+ raise NotImplementedError
+
+ def add_label(self, label):
+ raise NotImplementedError
+
+
+@component("sentencizer", assigns=["token.is_sent_start", "doc.sents"])
+class Sentencizer(Pipe):
+ """Segment the Doc into sentences using a rule-based strategy.
+
+ DOCS: https://spacy.io/api/sentencizer
+ """
+
+ default_punct_chars = ['!', '.', '?', '։', '؟', '۔', '܀', '܁', '܂', '߹',
+ '।', '॥', '၊', '။', '።', '፧', '፨', '᙮', '᜵', '᜶', '᠃', '᠉', '᥄',
+ '᥅', '᪨', '᪩', '᪪', '᪫', '᭚', '᭛', '᭞', '᭟', '᰻', '᰼', '᱾', '᱿',
+ '‼', '‽', '⁇', '⁈', '⁉', '⸮', '⸼', '꓿', '꘎', '꘏', '꛳', '꛷', '꡶',
+ '꡷', '꣎', '꣏', '꤯', '꧈', '꧉', '꩝', '꩞', '꩟', '꫰', '꫱', '꯫', '﹒',
+ '﹖', '﹗', '!', '.', '?', '𐩖', '𐩗', '𑁇', '𑁈', '𑂾', '𑂿', '𑃀',
+ '𑃁', '𑅁', '𑅂', '𑅃', '𑇅', '𑇆', '𑇍', '𑇞', '𑇟', '𑈸', '𑈹', '𑈻', '𑈼',
+ '𑊩', '𑑋', '𑑌', '𑗂', '𑗃', '𑗉', '𑗊', '𑗋', '𑗌', '𑗍', '𑗎', '𑗏', '𑗐',
+ '𑗑', '𑗒', '𑗓', '𑗔', '𑗕', '𑗖', '𑗗', '𑙁', '𑙂', '𑜼', '𑜽', '𑜾', '𑩂',
+ '𑩃', '𑪛', '𑪜', '𑱁', '𑱂', '𖩮', '𖩯', '𖫵', '𖬷', '𖬸', '𖭄', '𛲟', '𝪈',
+ '。', '。']
+
+ def __init__(self, punct_chars=None, **kwargs):
+ """Initialize the sentencizer.
+
+ punct_chars (list): Punctuation characters to split on. Will be
+ serialized with the nlp object.
+ RETURNS (Sentencizer): The sentencizer component.
+
+ DOCS: https://spacy.io/api/sentencizer#init
+ """
+ if punct_chars:
+ self.punct_chars = set(punct_chars)
+ else:
+ self.punct_chars = set(self.default_punct_chars)
+
+ @classmethod
+ def from_nlp(cls, nlp, model=None, **cfg):
+ return cls(**cfg)
+
+ def begin_training(
+ self, get_examples=lambda: [], pipeline=None, sgd=None, **kwargs
+ ):
+ pass
+
+ def __call__(self, doc):
+ """Apply the sentencizer to a Doc and set Token.is_sent_start.
+
+ example (Doc or Example): The document to process.
+ RETURNS (Doc or Example): The processed Doc or Example.
+
+ DOCS: https://spacy.io/api/sentencizer#call
+ """
+ start = 0
+ seen_period = False
+ for i, token in enumerate(doc):
+ is_in_punct_chars = token.text in self.punct_chars
+ token.is_sent_start = i == 0
+ if seen_period and not token.is_punct and not is_in_punct_chars:
+ doc[start].is_sent_start = True
+ start = token.i
+ seen_period = False
+ elif is_in_punct_chars:
+ seen_period = True
+ if start < len(doc):
+ doc[start].is_sent_start = True
+ return doc
+
+ def pipe(self, stream, batch_size=128):
+ for docs in util.minibatch(stream, size=batch_size):
+ predictions = self.predict(docs)
+ self.set_annotations(docs, predictions)
+ yield from docs
+
+ def predict(self, docs):
+ """Apply the pipeline's model to a batch of docs, without
+ modifying them.
+ """
+ if not any(len(doc) for doc in docs):
+ # Handle cases where there are no tokens in any docs.
+ guesses = [[] for doc in docs]
+ return guesses
+ guesses = []
+ for doc in docs:
+ doc_guesses = [False] * len(doc)
+ if len(doc) > 0:
+ start = 0
+ seen_period = False
+ doc_guesses[0] = True
+ for i, token in enumerate(doc):
+ is_in_punct_chars = token.text in self.punct_chars
+ if seen_period and not token.is_punct and not is_in_punct_chars:
+ doc_guesses[start] = True
+ start = token.i
+ seen_period = False
+ elif is_in_punct_chars:
+ seen_period = True
+ if start < len(doc):
+ doc_guesses[start] = True
+ guesses.append(doc_guesses)
+ return guesses
+
+ def set_annotations(self, docs, batch_tag_ids):
+ if isinstance(docs, Doc):
+ docs = [docs]
+ cdef Doc doc
+ cdef int idx = 0
+ for i, doc in enumerate(docs):
+ doc_tag_ids = batch_tag_ids[i]
+ for j, tag_id in enumerate(doc_tag_ids):
+ # Don't clobber existing sentence boundaries
+ if doc.c[j].sent_start == 0:
+ if tag_id:
+ doc.c[j].sent_start = 1
+ else:
+ doc.c[j].sent_start = -1
+
+ def score(self, examples, **kwargs):
+ return Scorer.score_spans(examples, "sents", **kwargs)
+
+ def to_bytes(self, **kwargs):
+ """Serialize the sentencizer to a bytestring.
+
+ RETURNS (bytes): The serialized object.
+
+ DOCS: https://spacy.io/api/sentencizer#to_bytes
+ """
+ return srsly.msgpack_dumps({"punct_chars": list(self.punct_chars)})
+
+ def from_bytes(self, bytes_data, **kwargs):
+ """Load the sentencizer from a bytestring.
+
+ bytes_data (bytes): The data to load.
+ returns (Sentencizer): The loaded object.
+
+ DOCS: https://spacy.io/api/sentencizer#from_bytes
+ """
+ cfg = srsly.msgpack_loads(bytes_data)
+ self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars))
+ return self
+
+ def to_disk(self, path, exclude=tuple(), **kwargs):
+ """Serialize the sentencizer to disk.
+
+ DOCS: https://spacy.io/api/sentencizer#to_disk
+ """
+ path = util.ensure_path(path)
+ path = path.with_suffix(".json")
+ srsly.write_json(path, {"punct_chars": list(self.punct_chars)})
+
+
+ def from_disk(self, path, exclude=tuple(), **kwargs):
+ """Load the sentencizer from disk.
+
+ DOCS: https://spacy.io/api/sentencizer#from_disk
+ """
+ path = util.ensure_path(path)
+ path = path.with_suffix(".json")
+ cfg = srsly.read_json(path)
+ self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars))
+ return self
+
+
+# Cython classes can't be decorated, so we need to add the factories here
+Language.factories["parser"] = lambda nlp, model, **cfg: parser_factory(nlp, model, **cfg)
+Language.factories["ner"] = lambda nlp, model, **cfg: ner_factory(nlp, model, **cfg)
+
+def parser_factory(nlp, model, **cfg):
+ default_config = {"learn_tokens": False, "min_action_freq": 30, "beam_width": 1, "beam_update_prob": 1.0}
+ if model is None:
+ model = default_parser()
+ warnings.warn(Warnings.W098.format(name="parser"))
+ for key, value in default_config.items():
+ if key not in cfg:
+ cfg[key] = value
+ return DependencyParser.from_nlp(nlp, model, **cfg)
+
+def ner_factory(nlp, model, **cfg):
+ default_config = {"learn_tokens": False, "min_action_freq": 30, "beam_width": 1, "beam_update_prob": 1.0}
+ if model is None:
+ model = default_ner()
+ warnings.warn(Warnings.W098.format(name="ner"))
+ for key, value in default_config.items():
+ if key not in cfg:
+ cfg[key] = value
+ return EntityRecognizer.from_nlp(nlp, model, **cfg)
+
+__all__ = ["Tagger", "DependencyParser", "EntityRecognizer", "TextCategorizer", "EntityLinker", "Sentencizer", "SentenceRecognizer"]
diff --git a/spacy/pipeline/sentencizer.pyx b/spacy/pipeline/sentencizer.pyx
index c827ffc5c..70188f856 100644
--- a/spacy/pipeline/sentencizer.pyx
+++ b/spacy/pipeline/sentencizer.pyx
@@ -6,6 +6,7 @@ from ..tokens.doc cimport Doc
from .pipe import Pipe
from ..language import Language
+from ..scorer import Scorer
from .. import util
@@ -130,6 +131,9 @@ class Sentencizer(Pipe):
else:
doc.c[j].sent_start = -1
+ def score(self, examples, **kwargs):
+ return Scorer.score_spans(examples, "sents", **kwargs)
+
def to_bytes(self, exclude=tuple()):
"""Serialize the sentencizer to a bytestring.
diff --git a/spacy/pipeline/senter.pyx b/spacy/pipeline/senter.pyx
index 603b7965e..c065ae72f 100644
--- a/spacy/pipeline/senter.pyx
+++ b/spacy/pipeline/senter.pyx
@@ -8,6 +8,7 @@ from .pipe import deserialize_config
from .tagger import Tagger
from ..language import Language
from ..errors import Errors
+from ..scorer import Scorer
from .. import util
@@ -104,6 +105,9 @@ class SentenceRecognizer(Tagger):
def add_label(self, label, values=None):
raise NotImplementedError
+ def score(self, examples, **kwargs):
+ return Scorer.score_spans(examples, "sents", **kwargs)
+
def to_bytes(self, exclude=tuple()):
serialize = {}
serialize["model"] = self.model.to_bytes
diff --git a/spacy/pipeline/tagger.pyx b/spacy/pipeline/tagger.pyx
index e4250b932..1c4105921 100644
--- a/spacy/pipeline/tagger.pyx
+++ b/spacy/pipeline/tagger.pyx
@@ -14,6 +14,7 @@ from ..language import Language
from ..attrs import POS, ID
from ..parts_of_speech import X
from ..errors import Errors, TempErrors, Warnings
+from ..scorer import Scorer
from .. import util
@@ -250,6 +251,13 @@ class Tagger(Pipe):
with self.model.use_params(params):
yield
+ def score(self, examples, **kwargs):
+ scores = {}
+ scores.update(Scorer.score_token_attr(examples, "tag", **kwargs))
+ scores.update(Scorer.score_token_attr(examples, "pos", **kwargs))
+ scores.update(Scorer.score_token_attr(examples, "lemma", **kwargs))
+ return scores
+
def to_bytes(self, exclude=tuple()):
serialize = {}
serialize["model"] = self.model.to_bytes
diff --git a/spacy/pipeline/textcat.py b/spacy/pipeline/textcat.py
index ff79a600a..bc68bb806 100644
--- a/spacy/pipeline/textcat.py
+++ b/spacy/pipeline/textcat.py
@@ -6,6 +6,7 @@ from .pipe import Pipe
from ..language import Language
from ..gold import Example
from ..errors import Errors
+from ..scorer import Scorer
from .. import util
from ..tokens import Doc
from ..vocab import Vocab
@@ -250,3 +251,9 @@ class TextCategorizer(Pipe):
if sgd is None:
sgd = self.create_optimizer()
return sgd
+
+ def score(self, examples, positive_label=None, **kwargs):
+ return Scorer.score_cats(examples, "cats", labels=self.labels,
+ multi_label=self.model.attrs["multi_label"],
+ positive_label=positive_label, **kwargs
+ )
diff --git a/spacy/scorer.py b/spacy/scorer.py
index 512f27e07..a95fe70cf 100644
--- a/spacy/scorer.py
+++ b/spacy/scorer.py
@@ -1,6 +1,8 @@
import numpy as np
from .errors import Errors
+from .util import get_lang_class
+from .morphology import Morphology
class PRFScore:
@@ -32,6 +34,9 @@ class PRFScore:
r = self.recall
return 2 * ((p * r) / (p + r + 1e-100))
+ def to_dict(self):
+ return {"p": self.precision, "r": self.recall, "f": self.fscore}
+
class ROCAUCScore:
"""
@@ -65,391 +70,405 @@ class ROCAUCScore:
class Scorer:
"""Compute evaluation scores."""
- def __init__(self, eval_punct=False, pipeline=None):
+ def __init__(self, nlp=None, **cfg):
"""Initialize the Scorer.
-
- eval_punct (bool): Evaluate the dependency attachments to and from
- punctuation.
RETURNS (Scorer): The newly created object.
DOCS: https://spacy.io/api/scorer#init
"""
- self.tokens = PRFScore()
- self.sbd = PRFScore()
- self.unlabelled = PRFScore()
- self.labelled = PRFScore()
- self.labelled_per_dep = dict()
- self.tags = PRFScore()
- self.pos = PRFScore()
- self.morphs = PRFScore()
- self.morphs_per_feat = dict()
- self.sent_starts = PRFScore()
- self.ner = PRFScore()
- self.ner_per_ents = dict()
- self.eval_punct = eval_punct
- self.textcat = PRFScore()
- self.textcat_f_per_cat = dict()
- self.textcat_auc_per_cat = dict()
- self.textcat_positive_label = None
- self.textcat_multilabel = False
+ self.nlp = nlp
+ self.cfg = cfg
- if pipeline:
- for name, component in pipeline:
- if name == "textcat":
- self.textcat_multilabel = component.model.attrs["multi_label"]
- self.textcat_positive_label = component.cfg.get(
- "positive_label", None
- )
- for label in component.cfg.get("labels", []):
- self.textcat_auc_per_cat[label] = ROCAUCScore()
- self.textcat_f_per_cat[label] = PRFScore()
+ if not nlp:
+ # create a default pipeline
+ nlp = get_lang_class("xx")()
+ nlp.add_pipe("senter")
+ nlp.add_pipe("tagger")
+ nlp.add_pipe("morphologizer")
+ nlp.add_pipe("parser")
+ nlp.add_pipe("ner")
+ nlp.add_pipe("textcat")
+ self.nlp = nlp
- @property
- def tags_acc(self):
- """RETURNS (float): Part-of-speech tag accuracy (fine grained tags,
- i.e. `Token.tag`).
- """
- return self.tags.fscore * 100
-
- @property
- def pos_acc(self):
- """RETURNS (float): Part-of-speech tag accuracy (coarse grained pos,
- i.e. `Token.pos`).
- """
- return self.pos.fscore * 100
-
- @property
- def morphs_acc(self):
- """RETURNS (float): Morph tag accuracy (morphological features,
- i.e. `Token.morph`).
- """
- return self.morphs.fscore * 100
-
- @property
- def morphs_per_type(self):
- """RETURNS (dict): Scores per dependency label.
- """
- return {
- k: {"p": v.precision * 100, "r": v.recall * 100, "f": v.fscore * 100}
- for k, v in self.morphs_per_feat.items()
- }
-
- @property
- def sent_p(self):
- """RETURNS (float): F-score for identification of sentence starts.
- i.e. `Token.is_sent_start`).
- """
- return self.sent_starts.precision * 100
-
- @property
- def sent_r(self):
- """RETURNS (float): F-score for identification of sentence starts.
- i.e. `Token.is_sent_start`).
- """
- return self.sent_starts.recall * 100
-
- @property
- def sent_f(self):
- """RETURNS (float): F-score for identification of sentence starts.
- i.e. `Token.is_sent_start`).
- """
- return self.sent_starts.fscore * 100
-
- @property
- def token_acc(self):
- """RETURNS (float): Tokenization accuracy."""
- return self.tokens.precision * 100
-
- @property
- def uas(self):
- """RETURNS (float): Unlabelled dependency score."""
- return self.unlabelled.fscore * 100
-
- @property
- def las(self):
- """RETURNS (float): Labelled dependency score."""
- return self.labelled.fscore * 100
-
- @property
- def las_per_type(self):
- """RETURNS (dict): Scores per dependency label.
- """
- return {
- k: {"p": v.precision * 100, "r": v.recall * 100, "f": v.fscore * 100}
- for k, v in self.labelled_per_dep.items()
- }
-
- @property
- def ents_p(self):
- """RETURNS (float): Named entity accuracy (precision)."""
- return self.ner.precision * 100
-
- @property
- def ents_r(self):
- """RETURNS (float): Named entity accuracy (recall)."""
- return self.ner.recall * 100
-
- @property
- def ents_f(self):
- """RETURNS (float): Named entity accuracy (F-score)."""
- return self.ner.fscore * 100
-
- @property
- def ents_per_type(self):
- """RETURNS (dict): Scores per entity label.
- """
- return {
- k: {"p": v.precision * 100, "r": v.recall * 100, "f": v.fscore * 100}
- for k, v in self.ner_per_ents.items()
- }
-
- @property
- def textcat_f(self):
- """RETURNS (float): f-score on positive label for binary classification,
- macro-averaged f-score for multilabel classification
- """
- if not self.textcat_multilabel:
- if self.textcat_positive_label:
- # binary classification
- return self.textcat.fscore * 100
- # multi-class and/or multi-label
- return (
- sum([score.fscore for label, score in self.textcat_f_per_cat.items()])
- / (len(self.textcat_f_per_cat) + 1e-100)
- * 100
- )
-
- @property
- def textcat_auc(self):
- """RETURNS (float): macro-averaged AUC ROC score for multilabel classification (-1 if undefined)
- """
- return max(
- sum([score.score for label, score in self.textcat_auc_per_cat.items()])
- / (len(self.textcat_auc_per_cat) + 1e-100),
- -1,
- )
-
- @property
- def textcats_auc_per_cat(self):
- """RETURNS (dict): AUC ROC Scores per textcat label.
- """
- return {
- k: {"roc_auc_score": max(v.score, -1)}
- for k, v in self.textcat_auc_per_cat.items()
- }
-
- @property
- def textcats_f_per_cat(self):
- """RETURNS (dict): F-scores per textcat label.
- """
- return {
- k: {"p": v.precision * 100, "r": v.recall * 100, "f": v.fscore * 100}
- for k, v in self.textcat_f_per_cat.items()
- }
-
- @property
- def scores(self):
- """RETURNS (dict): All scores mapped by key.
- """
- return {
- "uas": self.uas,
- "las": self.las,
- "las_per_type": self.las_per_type,
- "ents_p": self.ents_p,
- "ents_r": self.ents_r,
- "ents_f": self.ents_f,
- "ents_per_type": self.ents_per_type,
- "tags_acc": self.tags_acc,
- "pos_acc": self.pos_acc,
- "morphs_acc": self.morphs_acc,
- "morphs_per_type": self.morphs_per_type,
- "sent_p": self.sent_p,
- "sent_r": self.sent_r,
- "sent_f": self.sent_f,
- "token_acc": self.token_acc,
- "textcat_f": self.textcat_f,
- "textcat_auc": self.textcat_auc,
- "textcats_f_per_cat": self.textcats_f_per_cat,
- "textcats_auc_per_cat": self.textcats_auc_per_cat,
- }
-
- def score(self, example, verbose=False, punct_labels=("p", "punct")):
- """Update the evaluation scores from a single Example.
-
- example (Example): The predicted annotations + correct annotations.
- verbose (bool): Print debugging information.
- punct_labels (tuple): Dependency labels for punctuation. Used to
- evaluate dependency attachments to punctuation if `eval_punct` is
- `True`.
+ def score(self, examples):
+ """Evaluate a list of Examples.
+ examples (Iterable[Example]): The predicted annotations + correct annotations.
+ RETURNS (Dict): A dictionary of scores.
DOCS: https://spacy.io/api/scorer#score
"""
- doc = example.predicted
- gold_doc = example.reference
- align = example.alignment
- gold_deps = set()
- gold_deps_per_dep = {}
- gold_tags = set()
- gold_pos = set()
- gold_morphs = set()
- gold_morphs_per_feat = {}
- gold_sent_starts = set()
- for gold_i, token in enumerate(gold_doc):
- gold_tags.add((gold_i, token.tag_))
- gold_pos.add((gold_i, token.pos_))
- gold_morphs.add((gold_i, token.morph_))
- if token.morph_:
- for feat in token.morph_.split("|"):
- field, values = feat.split("=")
- if field not in self.morphs_per_feat:
- self.morphs_per_feat[field] = PRFScore()
- if field not in gold_morphs_per_feat:
- gold_morphs_per_feat[field] = set()
- gold_morphs_per_feat[field].add((gold_i, feat))
- if token.sent_start:
- gold_sent_starts.add(gold_i)
- dep = token.dep_.lower()
- if dep not in punct_labels:
- gold_deps.add((gold_i, token.head.i, dep))
- if dep not in self.labelled_per_dep:
- self.labelled_per_dep[dep] = PRFScore()
- if dep not in gold_deps_per_dep:
- gold_deps_per_dep[dep] = set()
- gold_deps_per_dep[dep].add((gold_i, token.head.i, dep))
- cand_deps = set()
- cand_deps_per_dep = {}
- cand_tags = set()
- cand_pos = set()
- cand_morphs = set()
- cand_morphs_per_feat = {}
- cand_sent_starts = set()
- for token in doc:
- if token.orth_.isspace():
+ scores = {}
+
+ if hasattr(self.nlp.tokenizer, "score"):
+ scores.update(self.nlp.tokenizer.score(examples, **self.cfg))
+ for name, component in self.nlp.pipeline:
+ if hasattr(component, "score"):
+ scores.update(component.score(examples, **self.cfg))
+
+ return scores
+
+ @staticmethod
+ def score_tokenization(examples, **cfg):
+ """Returns accuracy and PRF scores for tokenization.
+
+ * token_acc: # correct tokens / # gold tokens
+ * token_p/r/f: PRF for token character spans
+
+ examples (Iterable[Example]): Examples to score
+ RETURNS (dict): A dictionary containing the scores token_acc/p/r/f.
+ """
+ acc_score = PRFScore()
+ prf_score = PRFScore()
+ for example in examples:
+ gold_doc = example.reference
+ pred_doc = example.predicted
+ align = example.alignment
+ gold_spans = set()
+ pred_spans = set()
+ for token in gold_doc:
+ if token.orth_.isspace():
+ continue
+ gold_spans.add((token.idx, token.idx + len(token)))
+ for token in pred_doc:
+ if token.orth_.isspace():
+ continue
+ pred_spans.add((token.idx, token.idx + len(token)))
+ if align.x2y.lengths[token.i] != 1:
+ acc_score.fp += 1
+ else:
+ acc_score.tp += 1
+ prf_score.score_set(pred_spans, gold_spans)
+ return {
+ "token_acc": acc_score.fscore,
+ "token_p": prf_score.precision,
+ "token_r": prf_score.recall,
+ "token_f": prf_score.fscore,
+ }
+
+ @staticmethod
+ def score_token_attr(examples, attr, getter=getattr, **cfg):
+ """Returns an accuracy score for a token-level attribute.
+
+ examples (Iterable[Example]): Examples to score
+ attr (str): The attribute to score.
+ getter (callable): Defaults to getattr. If provided,
+ getter(token, attr) should return the value of the attribute for an
+ individual token.
+ RETURNS (dict): A dictionary containing the accuracy score under the
+ key attr_acc.
+ """
+ tag_score = PRFScore()
+ for example in examples:
+ gold_doc = example.reference
+ pred_doc = example.predicted
+ align = example.alignment
+ gold_tags = set()
+ for gold_i, token in enumerate(gold_doc):
+ gold_tags.add((gold_i, getter(token, attr)))
+ pred_tags = set()
+ for token in pred_doc:
+ if token.orth_.isspace():
+ continue
+ if align.x2y.lengths[token.i] == 1:
+ gold_i = align.x2y[token.i].dataXd[0, 0]
+ pred_tags.add((gold_i, getter(token, attr)))
+ tag_score.score_set(pred_tags, gold_tags)
+ return {
+ attr + "_acc": tag_score.fscore,
+ }
+
+ @staticmethod
+ def score_token_attr_per_feat(examples, attr, getter=getattr, **cfg):
+ """Return PRF scores per feat for a token attribute in UFEATS format.
+
+ examples (Iterable[Example]): Examples to score
+ attr (str): The attribute to score.
+ getter (callable): Defaults to getattr. If provided,
+ getter(token, attr) should return the value of the attribute for an
+ individual token.
+ RETURNS (dict): A dictionary containing the per-feat PRF scores unders
+ the key attr_per_feat.
+ """
+ per_feat = {}
+ for example in examples:
+ pred_doc = example.predicted
+ gold_doc = example.reference
+ align = example.alignment
+ gold_per_feat = {}
+ for gold_i, token in enumerate(gold_doc):
+ morph = str(getter(token, attr))
+ if morph:
+ for feat in morph.split(Morphology.FEATURE_SEP):
+ field, values = feat.split(Morphology.FIELD_SEP)
+ if field not in per_feat:
+ per_feat[field] = PRFScore()
+ if field not in gold_per_feat:
+ gold_per_feat[field] = set()
+ gold_per_feat[field].add((gold_i, feat))
+ pred_per_feat = {}
+ for token in pred_doc:
+ if token.orth_.isspace():
+ continue
+ if align.x2y.lengths[token.i] == 1:
+ gold_i = align.x2y[token.i].dataXd[0, 0]
+ morph = str(getter(token, attr))
+ if morph:
+ for feat in morph.split("|"):
+ field, values = feat.split("=")
+ if field not in per_feat:
+ per_feat[field] = PRFScore()
+ if field not in pred_per_feat:
+ pred_per_feat[field] = set()
+ pred_per_feat[field].add((gold_i, feat))
+ for field in per_feat:
+ per_feat[field].score_set(
+ pred_per_feat.get(field, set()), gold_per_feat.get(field, set()),
+ )
+ return {
+ attr + "_per_feat": per_feat,
+ }
+
+ @staticmethod
+ def score_spans(examples, attr, getter=getattr, **cfg):
+ """Returns PRF scores for labeled spans.
+
+ examples (Iterable[Example]): Examples to score
+ attr (str): The attribute to score.
+ getter (callable): Defaults to getattr. If provided,
+ getter(doc, attr) should return the spans for the individual doc.
+ RETURNS (dict): A dictionary containing the PRF scores under the
+ keys attr_p/r/f and the per-type PRF scores under attr_per_type.
+ """
+ score = PRFScore()
+ score_per_type = dict()
+ for example in examples:
+ pred_doc = example.predicted
+ gold_doc = example.reference
+ # Find all labels in gold and doc
+ labels = set(
+ [k.label_ for k in getter(gold_doc, attr)]
+ + [k.label_ for k in getter(pred_doc, attr)]
+ )
+ # Set up all labels for per type scoring and prepare gold per type
+ gold_per_type = {label: set() for label in labels}
+ for label in labels:
+ if label not in score_per_type:
+ score_per_type[label] = PRFScore()
+ # Find all predidate labels, for all and per type
+ gold_spans = set()
+ pred_spans = set()
+
+ # Special case for ents:
+ # If we have missing values in the gold, we can't easily tell
+ # whether our NER predictions are true.
+ # It seems bad but it's what we've always done.
+ if attr == "ents" and not all(token.ent_iob != 0 for token in gold_doc):
continue
- if align.x2y.lengths[token.i] != 1:
- self.tokens.fp += 1
- gold_i = None
- else:
- gold_i = align.x2y[token.i].dataXd[0, 0]
- self.tokens.tp += 1
- cand_tags.add((gold_i, token.tag_))
- cand_pos.add((gold_i, token.pos_))
- cand_morphs.add((gold_i, token.morph_))
- if token.morph_:
- for feat in token.morph_.split("|"):
- field, values = feat.split("=")
- if field not in self.morphs_per_feat:
- self.morphs_per_feat[field] = PRFScore()
- if field not in cand_morphs_per_feat:
- cand_morphs_per_feat[field] = set()
- cand_morphs_per_feat[field].add((gold_i, feat))
- if token.is_sent_start:
- cand_sent_starts.add(gold_i)
- if token.dep_.lower() not in punct_labels and token.orth_.strip():
- if align.x2y.lengths[token.head.i] == 1:
- gold_head = align.x2y[token.head.i].dataXd[0, 0]
- else:
- gold_head = None
- # None is indistinct, so we can't just add it to the set
- # Multiple (None, None) deps are possible
- if gold_i is None or gold_head is None:
- self.unlabelled.fp += 1
- self.labelled.fp += 1
- else:
- cand_deps.add((gold_i, gold_head, token.dep_.lower()))
- if token.dep_.lower() not in self.labelled_per_dep:
- self.labelled_per_dep[token.dep_.lower()] = PRFScore()
- if token.dep_.lower() not in cand_deps_per_dep:
- cand_deps_per_dep[token.dep_.lower()] = set()
- cand_deps_per_dep[token.dep_.lower()].add(
- (gold_i, gold_head, token.dep_.lower())
+
+ for span in getter(gold_doc, attr):
+ gold_span = (span.label_, span.start, span.end - 1)
+ gold_spans.add(gold_span)
+ gold_per_type[span.label_].add((span.label_, span.start, span.end - 1))
+ pred_per_type = {label: set() for label in labels}
+ for span in example.get_aligned_spans_x2y(getter(pred_doc, attr)):
+ pred_spans.add((span.label_, span.start, span.end - 1))
+ pred_per_type[span.label_].add((span.label_, span.start, span.end - 1))
+ # Scores per label
+ for k, v in score_per_type.items():
+ if k in pred_per_type:
+ v.score_set(pred_per_type[k], gold_per_type[k])
+ # Score for all labels
+ score.score_set(pred_spans, gold_spans)
+ results = {
+ attr + "_p": score.precision,
+ attr + "_r": score.recall,
+ attr + "_f": score.fscore,
+ attr + "_per_type": {k: v.to_dict() for k, v in score_per_type.items()},
+ }
+ return results
+
+ @staticmethod
+ def score_cats(
+ examples,
+ attr,
+ getter=getattr,
+ labels=[],
+ multi_label=True,
+ positive_label=None,
+ **cfg
+ ):
+ """Returns PRF and ROC AUC scores for a doc-level attribute with a
+ dict with scores for each label like Doc.cats.
+
+ examples (Iterable[Example]): Examples to score
+ attr (str): The attribute to score.
+ getter (callable): Defaults to getattr. If provided,
+ getter(doc, attr) should return the values for the individual doc.
+ labels (Iterable[str]): The set of possible labels. Defaults to [].
+ multi_label (bool): Whether the attribute allows multiple labels.
+ Defaults to True.
+ positive_label (str): The positive label for a binary task with
+ exclusive classes. Defaults to None.
+ RETURNS (dict): A dictionary containing the scores:
+ for binary exclusive with positive label: attr_p/r/f,
+ for 3+ exclusive classes, macro-averaged fscore: attr_macro_f,
+ for multilabel, macro-averaged AUC: attr_macro_auc,
+ for all: attr_f_per_type, attr_auc_per_type
+ """
+ score = PRFScore()
+ f_per_type = dict()
+ auc_per_type = dict()
+ for label in labels:
+ f_per_type[label] = PRFScore()
+ auc_per_type[label] = ROCAUCScore()
+ for example in examples:
+ gold_doc = example.reference
+ pred_doc = example.predicted
+ gold_values = getter(gold_doc, attr)
+ pred_values = getter(pred_doc, attr)
+ if (
+ len(gold_values) > 0
+ and set(f_per_type) == set(auc_per_type) == set(gold_values)
+ and set(gold_values) == set(pred_values)
+ ):
+ gold_val = max(gold_values, key=gold_values.get)
+ pred_val = max(pred_values, key=pred_values.get)
+ if positive_label:
+ score.score_set(
+ set([positive_label]) & set([pred_val]),
+ set([positive_label]) & set([gold_val]),
+ )
+ for label in set(gold_values):
+ auc_per_type[label].score_set(
+ pred_values[label], gold_values[label]
+ )
+ f_per_type[label].score_set(
+ set([label]) & set([pred_val]), set([label]) & set([gold_val])
+ )
+ elif len(f_per_type) > 0:
+ model_labels = set(f_per_type)
+ eval_labels = set(gold_values)
+ raise ValueError(
+ Errors.E162.format(
+ model_labels=model_labels, eval_labels=eval_labels
)
- # Find all NER labels in gold and doc
- ent_labels = set(
- [k.label_ for k in gold_doc.ents] + [k.label_ for k in doc.ents]
- )
- # Set up all labels for per type scoring and prepare gold per type
- gold_per_ents = {ent_label: set() for ent_label in ent_labels}
- for ent_label in ent_labels:
- if ent_label not in self.ner_per_ents:
- self.ner_per_ents[ent_label] = PRFScore()
- # Find all candidate labels, for all and per type
- gold_ents = set()
- cand_ents = set()
- # If we have missing values in the gold, we can't easily tell whether
- # our NER predictions are true.
- # It seems bad but it's what we've always done.
- if all(token.ent_iob != 0 for token in gold_doc):
- for ent in gold_doc.ents:
- gold_ent = (ent.label_, ent.start, ent.end - 1)
- gold_ents.add(gold_ent)
- gold_per_ents[ent.label_].add((ent.label_, ent.start, ent.end - 1))
- cand_per_ents = {ent_label: set() for ent_label in ent_labels}
- for ent in example.get_aligned_spans_x2y(doc.ents):
- cand_ents.add((ent.label_, ent.start, ent.end - 1))
- cand_per_ents[ent.label_].add((ent.label_, ent.start, ent.end - 1))
- # Scores per ent
- for k, v in self.ner_per_ents.items():
- if k in cand_per_ents:
- v.score_set(cand_per_ents[k], gold_per_ents[k])
- # Score for all ents
- self.ner.score_set(cand_ents, gold_ents)
- self.tags.score_set(cand_tags, gold_tags)
- self.pos.score_set(cand_pos, gold_pos)
- self.morphs.score_set(cand_morphs, gold_morphs)
- for field in self.morphs_per_feat:
- self.morphs_per_feat[field].score_set(
- cand_morphs_per_feat.get(field, set()),
- gold_morphs_per_feat.get(field, set()),
- )
- self.sent_starts.score_set(cand_sent_starts, gold_sent_starts)
- self.labelled.score_set(cand_deps, gold_deps)
- for dep in self.labelled_per_dep:
- self.labelled_per_dep[dep].score_set(
- cand_deps_per_dep.get(dep, set()), gold_deps_per_dep.get(dep, set())
- )
- self.unlabelled.score_set(
- set(item[:2] for item in cand_deps), set(item[:2] for item in gold_deps)
- )
- if (
- len(gold_doc.cats) > 0
- and set(self.textcat_f_per_cat)
- == set(self.textcat_auc_per_cat)
- == set(gold_doc.cats)
- and set(gold_doc.cats) == set(doc.cats)
- ):
- goldcat = max(gold_doc.cats, key=gold_doc.cats.get)
- candcat = max(doc.cats, key=doc.cats.get)
- if self.textcat_positive_label:
- self.textcat.score_set(
- set([self.textcat_positive_label]) & set([candcat]),
- set([self.textcat_positive_label]) & set([goldcat]),
)
- for label in set(gold_doc.cats):
- self.textcat_auc_per_cat[label].score_set(
- doc.cats[label], gold_doc.cats[label]
+ elif len(auc_per_type) > 0:
+ model_labels = set(auc_per_type)
+ eval_labels = set(gold_values)
+ raise ValueError(
+ Errors.E162.format(
+ model_labels=model_labels, eval_labels=eval_labels
+ )
)
- self.textcat_f_per_cat[label].score_set(
- set([label]) & set([candcat]), set([label]) & set([goldcat])
+ results = {
+ attr + "_f_per_type": {k: v.to_dict() for k, v in f_per_type.items()},
+ attr + "_auc_per_type": {k: v.score for k, v in auc_per_type.items()},
+ }
+ if len(labels) == 2 and not multi_label and positive_label:
+ results[attr + "_p"] = score.precision
+ results[attr + "_r"] = score.recall
+ results[attr + "_f"] = score.fscore
+ elif not multi_label:
+ results[attr + "_macro_f"] = sum(
+ [score.fscore for label, score in f_per_type.items()]
+ ) / (len(f_per_type) + 1e-100)
+ else:
+ results[attr + "_macro_auc"] = max(
+ sum([score.score for label, score in auc_per_type.items()])
+ / (len(auc_per_type) + 1e-100),
+ -1,
+ )
+ return results
+
+ @staticmethod
+ def score_deps(
+ examples,
+ attr,
+ getter=getattr,
+ head_attr="head",
+ head_getter=getattr,
+ ignore_labels=tuple(),
+ **cfg
+ ):
+ """Returns the UAS, LAS, and LAS per type scores for dependency
+ parses.
+
+ examples (Iterable[Example]): Examples to score
+ attr (str): The attribute containing the dependency label.
+ getter (callable): Defaults to getattr. If provided,
+ getter(token, attr) should return the value of the attribute for an
+ individual token.
+ head_attr (str): The attribute containing the head token. Defaults to
+ 'head'.
+ head_getter (callable): Defaults to getattr. If provided,
+ head_getter(token, attr) should return the value of the head for an
+ individual token.
+ ignore_labels (Tuple): Labels to ignore while scoring (e.g., punct).
+ RETURNS (dict): A dictionary containing the scores:
+ attr_uas, attr_las, and attr_las_per_type.
+ """
+ unlabelled = PRFScore()
+ labelled = PRFScore()
+ labelled_per_dep = dict()
+ for example in examples:
+ gold_doc = example.reference
+ pred_doc = example.predicted
+ align = example.alignment
+ gold_deps = set()
+ gold_deps_per_dep = {}
+ for gold_i, token in enumerate(gold_doc):
+ dep = getter(token, attr)
+ head = head_getter(token, head_attr)
+ if dep not in ignore_labels:
+ gold_deps.add((gold_i, head.i, dep))
+ if dep not in labelled_per_dep:
+ labelled_per_dep[dep] = PRFScore()
+ if dep not in gold_deps_per_dep:
+ gold_deps_per_dep[dep] = set()
+ gold_deps_per_dep[dep].add((gold_i, head.i, dep))
+ pred_deps = set()
+ pred_deps_per_dep = {}
+ for token in pred_doc:
+ if token.orth_.isspace():
+ continue
+ if align.x2y.lengths[token.i] != 1:
+ gold_i = None
+ else:
+ gold_i = align.x2y[token.i].dataXd[0, 0]
+ dep = getter(token, attr)
+ head = head_getter(token, head_attr)
+ if dep not in ignore_labels and token.orth_.strip():
+ if align.x2y.lengths[head.i] == 1:
+ gold_head = align.x2y[head.i].dataXd[0, 0]
+ else:
+ gold_head = None
+ # None is indistinct, so we can't just add it to the set
+ # Multiple (None, None) deps are possible
+ if gold_i is None or gold_head is None:
+ unlabelled.fp += 1
+ labelled.fp += 1
+ else:
+ pred_deps.add((gold_i, gold_head, dep))
+ if dep not in labelled_per_dep:
+ labelled_per_dep[dep] = PRFScore()
+ if dep not in pred_deps_per_dep:
+ pred_deps_per_dep[dep] = set()
+ pred_deps_per_dep[dep].add((gold_i, gold_head, dep))
+ labelled.score_set(pred_deps, gold_deps)
+ for dep in labelled_per_dep:
+ labelled_per_dep[dep].score_set(
+ pred_deps_per_dep.get(dep, set()), gold_deps_per_dep.get(dep, set())
)
- elif len(self.textcat_f_per_cat) > 0:
- model_labels = set(self.textcat_f_per_cat)
- eval_labels = set(gold_doc.cats)
- raise ValueError(
- Errors.E162.format(model_labels=model_labels, eval_labels=eval_labels)
+ unlabelled.score_set(
+ set(item[:2] for item in pred_deps), set(item[:2] for item in gold_deps)
)
- elif len(self.textcat_auc_per_cat) > 0:
- model_labels = set(self.textcat_auc_per_cat)
- eval_labels = set(gold_doc.cats)
- raise ValueError(
- Errors.E162.format(model_labels=model_labels, eval_labels=eval_labels)
- )
- if verbose:
- gold_words = gold_doc.words
- for w_id, h_id, dep in cand_deps - gold_deps:
- print("F", gold_words[w_id], dep, gold_words[h_id])
- for w_id, h_id, dep in gold_deps - cand_deps:
- print("M", gold_words[w_id], dep, gold_words[h_id])
+ return {
+ attr + "_uas": unlabelled.fscore,
+ attr + "_las": labelled.fscore,
+ attr
+ + "_las_per_type": {k: v.to_dict() for k, v in labelled_per_dep.items()},
+ }
#############################################################################
diff --git a/spacy/tests/pipeline/test_textcat.py b/spacy/tests/pipeline/test_textcat.py
index 5eb09a007..ff284873d 100644
--- a/spacy/tests/pipeline/test_textcat.py
+++ b/spacy/tests/pipeline/test_textcat.py
@@ -85,6 +85,8 @@ def test_overfitting_IO():
fix_random_seed(0)
nlp = English()
textcat = nlp.add_pipe("textcat")
+ # Set exclusive labels
+ textcat.model.attrs["multi_label"] = False
train_examples = []
for text, annotations in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
@@ -114,6 +116,10 @@ def test_overfitting_IO():
assert cats2["POSITIVE"] > 0.9
assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.1)
+ # Test scoring
+ scores = nlp.evaluate(train_examples, component_cfg={"scorer": {"positive_label": "POSITIVE"}})
+ assert scores["cats_f"] == 1.0
+
# fmt: off
@pytest.mark.parametrize(
diff --git a/spacy/tests/test_scorer.py b/spacy/tests/test_scorer.py
index a6684b706..fea263df5 100644
--- a/spacy/tests/test_scorer.py
+++ b/spacy/tests/test_scorer.py
@@ -7,6 +7,7 @@ from spacy.scorer import Scorer, ROCAUCScore
from spacy.scorer import _roc_auc_score, _roc_curve
from .util import get_doc
from spacy.lang.en import English
+from spacy.tokens import Doc
test_las_apple = [
@@ -77,13 +78,61 @@ def tagged_doc():
doc[i].tag_ = tags[i]
doc[i].pos_ = pos[i]
doc[i].morph_ = morphs[i]
+ if i > 0:
+ doc[i].is_sent_start = False
doc.is_tagged = True
return doc
+@pytest.fixture
+def sented_doc():
+ text = "One sentence. Two sentences. Three sentences."
+ nlp = English()
+ doc = nlp(text)
+ for i in range(len(doc)):
+ if i % 3 == 0:
+ doc[i].is_sent_start = True
+ else:
+ doc[i].is_sent_start = False
+ return doc
+
+
+def test_tokenization(sented_doc):
+ scorer = Scorer()
+ gold = {"sent_starts": [t.sent_start for t in sented_doc]}
+ example = Example.from_dict(sented_doc, gold)
+ scores = scorer.score([example])
+ assert scores["token_acc"] == 1.0
+
+ nlp = English()
+ example.predicted = Doc(nlp.vocab, words=["One", "sentence.", "Two", "sentences.", "Three", "sentences."], spaces=[True, True, True, True, True, False])
+ example.predicted[1].is_sent_start = False
+ scores = scorer.score([example])
+ assert scores["token_acc"] == approx(0.66666666)
+ assert scores["token_p"] == 0.5
+ assert scores["token_r"] == approx(0.33333333)
+ assert scores["token_f"] == 0.4
+
+
+def test_sents(sented_doc):
+ scorer = Scorer()
+ gold = {"sent_starts": [t.sent_start for t in sented_doc]}
+ example = Example.from_dict(sented_doc, gold)
+ scores = scorer.score([example])
+ assert scores["sents_f"] == 1.0
+
+ # One sentence start is moved
+ gold["sent_starts"][3] = 0
+ gold["sent_starts"][4] = 1
+ example = Example.from_dict(sented_doc, gold)
+ scores = scorer.score([example])
+ assert scores["sents_f"] == approx(0.3333333)
+
+
def test_las_per_type(en_vocab):
# Gold and Doc are identical
scorer = Scorer()
+ examples = []
for input_, annot in test_las_apple:
doc = get_doc(
en_vocab,
@@ -93,20 +142,21 @@ def test_las_per_type(en_vocab):
)
gold = {"heads": annot["heads"], "deps": annot["deps"]}
example = Example.from_dict(doc, gold)
- scorer.score(example)
- results = scorer.scores
+ examples.append(example)
+ results = scorer.score(examples)
- assert results["uas"] == 100
- assert results["las"] == 100
- assert results["las_per_type"]["nsubj"]["p"] == 100
- assert results["las_per_type"]["nsubj"]["r"] == 100
- assert results["las_per_type"]["nsubj"]["f"] == 100
- assert results["las_per_type"]["compound"]["p"] == 100
- assert results["las_per_type"]["compound"]["r"] == 100
- assert results["las_per_type"]["compound"]["f"] == 100
+ assert results["dep_uas"] == 1.0
+ assert results["dep_las"] == 1.0
+ assert results["dep_las_per_type"]["nsubj"]["p"] == 1.0
+ assert results["dep_las_per_type"]["nsubj"]["r"] == 1.0
+ assert results["dep_las_per_type"]["nsubj"]["f"] == 1.0
+ assert results["dep_las_per_type"]["compound"]["p"] == 1.0
+ assert results["dep_las_per_type"]["compound"]["r"] == 1.0
+ assert results["dep_las_per_type"]["compound"]["f"] == 1.0
# One dep is incorrect in Doc
scorer = Scorer()
+ examples = []
for input_, annot in test_las_apple:
doc = get_doc(
en_vocab,
@@ -117,22 +167,23 @@ def test_las_per_type(en_vocab):
gold = {"heads": annot["heads"], "deps": annot["deps"]}
doc[0].dep_ = "compound"
example = Example.from_dict(doc, gold)
- scorer.score(example)
- results = scorer.scores
+ examples.append(example)
+ results = scorer.score(examples)
- assert results["uas"] == 100
- assert_almost_equal(results["las"], 90.9090909)
- assert results["las_per_type"]["nsubj"]["p"] == 0
- assert results["las_per_type"]["nsubj"]["r"] == 0
- assert results["las_per_type"]["nsubj"]["f"] == 0
- assert_almost_equal(results["las_per_type"]["compound"]["p"], 66.6666666)
- assert results["las_per_type"]["compound"]["r"] == 100
- assert results["las_per_type"]["compound"]["f"] == 80
+ assert results["dep_uas"] == 1.0
+ assert_almost_equal(results["dep_las"], 0.9090909)
+ assert results["dep_las_per_type"]["nsubj"]["p"] == 0
+ assert results["dep_las_per_type"]["nsubj"]["r"] == 0
+ assert results["dep_las_per_type"]["nsubj"]["f"] == 0
+ assert_almost_equal(results["dep_las_per_type"]["compound"]["p"], 0.666666666)
+ assert results["dep_las_per_type"]["compound"]["r"] == 1.0
+ assert results["dep_las_per_type"]["compound"]["f"] == 0.8
def test_ner_per_type(en_vocab):
# Gold and Doc are identical
scorer = Scorer()
+ examples = []
for input_, annot in test_ner_cardinal:
doc = get_doc(
en_vocab,
@@ -140,20 +191,24 @@ def test_ner_per_type(en_vocab):
ents=[[0, 1, "CARDINAL"], [2, 3, "CARDINAL"]],
)
entities = biluo_tags_from_offsets(doc, annot["entities"])
- ex = Example.from_dict(doc, {"entities": entities})
- scorer.score(ex)
- results = scorer.scores
+ example = Example.from_dict(doc, {"entities": entities})
+ # a hack for sentence boundaries
+ example.predicted[1].is_sent_start = False
+ example.reference[1].is_sent_start = False
+ examples.append(example)
+ results = scorer.score(examples)
- assert results["ents_p"] == 100
- assert results["ents_f"] == 100
- assert results["ents_r"] == 100
- assert results["ents_per_type"]["CARDINAL"]["p"] == 100
- assert results["ents_per_type"]["CARDINAL"]["f"] == 100
- assert results["ents_per_type"]["CARDINAL"]["r"] == 100
+ assert results["ents_p"] == 1.0
+ assert results["ents_r"] == 1.0
+ assert results["ents_f"] == 1.0
+ assert results["ents_per_type"]["CARDINAL"]["p"] == 1.0
+ assert results["ents_per_type"]["CARDINAL"]["r"] == 1.0
+ assert results["ents_per_type"]["CARDINAL"]["f"] == 1.0
# Doc has one missing and one extra entity
# Entity type MONEY is not present in Doc
scorer = Scorer()
+ examples = []
for input_, annot in test_ner_apple:
doc = get_doc(
en_vocab,
@@ -161,25 +216,28 @@ def test_ner_per_type(en_vocab):
ents=[[0, 1, "ORG"], [5, 6, "GPE"], [6, 7, "ORG"]],
)
entities = biluo_tags_from_offsets(doc, annot["entities"])
- ex = Example.from_dict(doc, {"entities": entities})
- scorer.score(ex)
- results = scorer.scores
+ example = Example.from_dict(doc, {"entities": entities})
+ # a hack for sentence boundaries
+ example.predicted[1].is_sent_start = False
+ example.reference[1].is_sent_start = False
+ examples.append(example)
+ results = scorer.score(examples)
- assert results["ents_p"] == approx(66.66666)
- assert results["ents_r"] == approx(66.66666)
- assert results["ents_f"] == approx(66.66666)
+ assert results["ents_p"] == approx(0.6666666)
+ assert results["ents_r"] == approx(0.6666666)
+ assert results["ents_f"] == approx(0.6666666)
assert "GPE" in results["ents_per_type"]
assert "MONEY" in results["ents_per_type"]
assert "ORG" in results["ents_per_type"]
- assert results["ents_per_type"]["GPE"]["p"] == 100
- assert results["ents_per_type"]["GPE"]["r"] == 100
- assert results["ents_per_type"]["GPE"]["f"] == 100
+ assert results["ents_per_type"]["GPE"]["p"] == 1.0
+ assert results["ents_per_type"]["GPE"]["r"] == 1.0
+ assert results["ents_per_type"]["GPE"]["f"] == 1.0
assert results["ents_per_type"]["MONEY"]["p"] == 0
assert results["ents_per_type"]["MONEY"]["r"] == 0
assert results["ents_per_type"]["MONEY"]["f"] == 0
- assert results["ents_per_type"]["ORG"]["p"] == 50
- assert results["ents_per_type"]["ORG"]["r"] == 100
- assert results["ents_per_type"]["ORG"]["f"] == approx(66.66666)
+ assert results["ents_per_type"]["ORG"]["p"] == 0.5
+ assert results["ents_per_type"]["ORG"]["r"] == 1.0
+ assert results["ents_per_type"]["ORG"]["f"] == approx(0.6666666)
def test_tag_score(tagged_doc):
@@ -189,17 +247,17 @@ def test_tag_score(tagged_doc):
"tags": [t.tag_ for t in tagged_doc],
"pos": [t.pos_ for t in tagged_doc],
"morphs": [t.morph_ for t in tagged_doc],
+ "sent_starts": [1 if t.is_sent_start else -1 for t in tagged_doc],
}
example = Example.from_dict(tagged_doc, gold)
- scorer.score(example)
- results = scorer.scores
+ results = scorer.score([example])
- assert results["tags_acc"] == 100
- assert results["pos_acc"] == 100
- assert results["morphs_acc"] == 100
- assert results["morphs_per_type"]["NounType"]["f"] == 100
+ assert results["tag_acc"] == 1.0
+ assert results["pos_acc"] == 1.0
+ assert results["morph_acc"] == 1.0
+ assert results["morph_per_feat"]["NounType"].fscore == 1.0
- # Gold and Doc are identical
+ # Gold annotation is modified
scorer = Scorer()
tags = [t.tag_ for t in tagged_doc]
tags[0] = "NN"
@@ -208,16 +266,21 @@ def test_tag_score(tagged_doc):
morphs = [t.morph_ for t in tagged_doc]
morphs[1] = "Number=sing"
morphs[2] = "Number=plur"
- gold = {"tags": tags, "pos": pos, "morphs": morphs}
+ gold = {
+ "tags": tags,
+ "pos": pos,
+ "morphs": morphs,
+ "sent_starts": gold["sent_starts"],
+ }
example = Example.from_dict(tagged_doc, gold)
- scorer.score(example)
- results = scorer.scores
+ results = scorer.score([example])
- assert results["tags_acc"] == 90
- assert results["pos_acc"] == 90
- assert results["morphs_acc"] == approx(80)
- assert results["morphs_per_type"]["Poss"]["f"] == 0.0
- assert results["morphs_per_type"]["Number"]["f"] == approx(72.727272)
+ assert results["tag_acc"] == 0.9
+ assert results["pos_acc"] == 0.9
+ assert results["morph_acc"] == approx(0.8)
+ assert results["morph_per_feat"]["NounType"].fscore == 1.0
+ assert results["morph_per_feat"]["Poss"].fscore == 0.0
+ assert results["morph_per_feat"]["Number"].fscore == approx(0.72727272)
def test_roc_auc_score():
diff --git a/spacy/tokenizer.pyx b/spacy/tokenizer.pyx
index 114d227c8..c84dd8627 100644
--- a/spacy/tokenizer.pyx
+++ b/spacy/tokenizer.pyx
@@ -24,6 +24,7 @@ from . import util
from .util import registry
from .attrs import intify_attrs
from .symbols import ORTH
+from .scorer import Scorer
@registry.tokenizers("spacy.Tokenizer.v1")
@@ -743,6 +744,9 @@ cdef class Tokenizer:
tokens.extend(reversed(suffixes))
return tokens
+ def score(self, examples, **kwargs):
+ return Scorer.score_tokenization(examples)
+
def to_disk(self, path, **kwargs):
"""Save the current state to a directory.
diff --git a/website/docs/api/language.md b/website/docs/api/language.md
index 3ba93b360..be402532c 100644
--- a/website/docs/api/language.md
+++ b/website/docs/api/language.md
@@ -108,8 +108,8 @@ Evaluate a model's pipeline components.
> #### Example
>
> ```python
-> scorer = nlp.evaluate(examples, verbose=True)
-> print(scorer.scores)
+> scores = nlp.evaluate(examples, verbose=True)
+> print(scores)
> ```
| Name | Type | Description |
@@ -119,7 +119,7 @@ Evaluate a model's pipeline components.
| `batch_size` | int | The batch size to use. |
| `scorer` | `Scorer` | Optional [`Scorer`](/api/scorer) to use. If not passed in, a new one will be created. |
| `component_cfg` 2.1 | `Dict[str, Dict]` | Config parameters for specific pipeline components, keyed by component name. |
-| **RETURNS** | Scorer | The scorer containing the evaluation scores. |
+| **RETURNS** | `Dict[str, Union[float, Dict]]` | A dictionary of evaluation scores. |
## Language.begin_training {#begin_training tag="method"}
diff --git a/website/docs/api/scorer.md b/website/docs/api/scorer.md
index cd720d26c..ef4396e1b 100644
--- a/website/docs/api/scorer.md
+++ b/website/docs/api/scorer.md
@@ -5,9 +5,12 @@ tag: class
source: spacy/scorer.py
---
-The `Scorer` computes and stores evaluation scores. It's typically created by
+The `Scorer` computes evaluation scores. It's typically created by
[`Language.evaluate`](/api/language#evaluate).
+In addition, the `Scorer` provides a number of evaluation methods for
+evaluating `Token` and `Doc` attributes.
+
## Scorer.\_\_init\_\_ {#init tag="method"}
Create a new `Scorer`.
@@ -17,46 +20,114 @@ Create a new `Scorer`.
> ```python
> from spacy.scorer import Scorer
>
+> # default scoring pipeline
> scorer = Scorer()
+>
+> # provided scoring pipeline
+> nlp = spacy.load("en_core_web_sm")
+> scorer = Scorer(nlp)
> ```
| Name | Type | Description |
| ------------ | -------- | ------------------------------------------------------------ |
-| `eval_punct` | bool | Evaluate the dependency attachments to and from punctuation. |
+| `nlp` | Language | The pipeline to use for scoring, where each pipeline component may provide a scoring method. If none is provided, then a default pipeline for the multi-language code `xx` is constructed containing: `senter`, `tagger`, `morphologizer`, `parser`, `ner`, `textcat`. |
| **RETURNS** | `Scorer` | The newly created object. |
## Scorer.score {#score tag="method"}
-Update the evaluation scores from a single [`Example`](/api/example) object.
+Calculate the scores for a list of [`Example`](/api/example) objects using the
+scoring methods provided by the components in the pipeline.
+The returned `Dict` contains the scores provided by the individual pipeline
+components. For the scoring methods provided by the `Scorer` and use by the
+core pipeline components, the individual score names start with the `Token` or
+`Doc` attribute being scored: `token_acc`, `token_p/r/f`, `sents_p/r/f`,
+`tag_acc`, `pos_acc`, `morph_acc`, `morph_per_feat`, `lemma_acc`, `dep_uas`,
+`dep_las`, `dep_las_per_type`, `ents_p/r/f`, `ents_per_type`,
+`textcat_macro_auc`, `textcat_macro_f`.
+
> #### Example
>
> ```python
> scorer = Scorer()
-> scorer.score(example)
+> scorer.score(examples)
> ```
-| Name | Type | Description |
-| -------------- | --------- | -------------------------------------------------------------------------------------------------------------------- |
-| `example` | `Example` | The `Example` object holding both the predictions and the correct gold-standard annotations. |
-| `verbose` | bool | Print debugging information. |
-| `punct_labels` | tuple | Dependency labels for punctuation. Used to evaluate dependency attachments to punctuation if `eval_punct` is `True`. |
+| Name | Type | Description |
+| ----------- | --------- | --------------------------------------------------------------------------------------------------------|
+| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
+| **RETURNS** | `Dict` | A dictionary of scores. |
+## Scorer.score_tokenization {#score_tokenization tag="staticmethod"}
-## Properties
+Scores the tokenization:
+
+* `token_acc`: # correct tokens / # gold tokens
+* `token_p/r/f`: PRF for token character spans
+
+| Name | Type | Description |
+| ----------- | --------- | --------------------------------------------------------------------------------------------------------|
+| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
+| **RETURNS** | `Dict` | A dictionary containing the scores `token_acc/p/r/f`. |
+
+## Scorer.score_token_attr {#score_token_attr tag="staticmethod"}
+
+Scores a single token attribute.
+
+| Name | Type | Description |
+| ----------- | --------- | --------------------------------------------------------------------------------------------------------|
+| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
+| `attr` | `str` | The attribute to score. |
+| `getter` | `callable` | Defaults to `getattr`. If provided, `getter(token, attr)` should return the value of the attribute for an individual `Token`. |
+| **RETURNS** | `Dict` | A dictionary containing the score `attr_acc`. |
+
+## Scorer.score_token_attr_per_feat {#score_token_attr_per_feat tag="staticmethod"}
+
+Scores a single token attribute per feature for a token attribute in UFEATS format.
+
+| Name | Type | Description |
+| ----------- | --------- | --------------------------------------------------------------------------------------------------------|
+| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
+| `attr` | `str` | The attribute to score. |
+| `getter` | `callable` | Defaults to `getattr`. If provided, `getter(token, attr)` should return the value of the attribute for an individual `Token`. |
+| **RETURNS** | `Dict` | A dictionary containing the per-feature PRF scores unders the key `attr_per_feat`. |
+
+## Scorer.score_spans {#score_spans tag="staticmethod"}
+
+Returns PRF scores for labeled or unlabeled spans.
+
+| Name | Type | Description |
+| ----------- | --------- | --------------------------------------------------------------------------------------------------------|
+| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
+| `attr` | `str` | The attribute to score. |
+| `getter` | `callable` | Defaults to `getattr`. If provided, `getter(doc, attr)` should return the `Span` objects for an individual `Doc`. |
+| **RETURNS** | `Dict` | A dictionary containing the PRF scores under the keys `attr_p/r/f` and the per-type PRF scores under `attr_per_type`. |
+
+## Scorer.score_deps {#score_deps tag="staticmethod"}
+
+Calculate the UAS, LAS, and LAS per type scores for dependency parses.
+
+| Name | Type | Description |
+| ----------- | --------- | --------------------------------------------------------------------------------------------------------|
+| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
+| `attr` | `str` | The attribute containing the dependency label. |
+| `getter` | `callable` | Defaults to `getattr`. If provided, `getter(token, attr)` should return the value of the attribute for an individual `Token`. |
+| `head_attr` | `str` | The attribute containing the head token. |
+| `head_getter` | `callable` | Defaults to `getattr`. If provided, `head_getter(token, attr)` should return the head for an individual `Token`. |
+| `ignore_labels` | `Tuple` | Labels to ignore while scoring (e.g., `punct`).
+| **RETURNS** | `Dict` | A dictionary containing the scores: `attr_uas`, `attr_las`, and `attr_las_per_type`. |
+
+## Scorer.score_cats {#score_cats tag="staticmethod"}
+
+Calculate PRF and ROC AUC scores for a doc-level attribute that is a dict
+containing scores for each label like `Doc.cats`.
+
+| Name | Type | Description |
+| ----------- | --------- | --------------------------------------------------------------------------------------------------------|
+| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
+| `attr` | `str` | The attribute to score. |
+| `getter` | `callable` | Defaults to `getattr`. If provided, `getter(doc, attr)` should return the cats for an individual `Doc`. |
+| labels | `Iterable[str]` | The set of possible labels. Defaults to `[]`. |
+| multi_label | `bool` | Whether the attribute allows multiple labels. Defaults to `True`. |
+| positive_label | `str` | The positive label for a binary task with exclusive classes. Defaults to `None`. |
+| **RETURNS** | `Dict` | A dictionary containing the scores: 1) for binary exclusive with positive label: `attr_p/r/f`; 2) for 3+ exclusive classes, macro-averaged fscore: `attr_macro_f`; 3) for multilabel, macro-averaged AUC: `attr_macro_auc`; 4) for all: `attr_f_per_type`, `attr_auc_per_type` |
-| Name | Type | Description |
-| --------------------------------------------------- | ----- | -------------------------------------------------------------------------------------- |
-| `token_acc` | float | Tokenization accuracy. |
-| `tags_acc` | float | Part-of-speech tag accuracy (fine grained tags, i.e. `Token.tag`). |
-| `uas` | float | Unlabelled dependency score. |
-| `las` | float | Labelled dependency score. |
-| `ents_p` | float | Named entity accuracy (precision). |
-| `ents_r` | float | Named entity accuracy (recall). |
-| `ents_f` | float | Named entity accuracy (F-score). |
-| `ents_per_type` 2.1.5 | dict | Scores per entity label. Keyed by label, mapped to a dict of `p`, `r` and `f` scores. |
-| `textcat_f` 3.0 | float | F-score on positive label for binary classification, macro-averaged F-score otherwise. |
-| `textcat_auc` 3.0 | float | Macro-averaged AUC ROC score for multilabel classification (`-1` if undefined). |
-| `textcats_f_per_cat` 3.0 | dict | F-scores per textcat label, keyed by label. |
-| `textcats_auc_per_cat` 3.0 | dict | ROC AUC scores per textcat label, keyed by label. |
-| `las_per_type` 2.2.3 | dict | Labelled dependency scores, keyed by label. |
-| `scores` | dict | All scores, keyed by type. |