Merge remote-tracking branch 'upstream/develop' into fix/cli-debug

# Conflicts:
#	pyproject.toml
#	requirements.txt
#	setup.cfg
This commit is contained in:
svlandeg 2020-08-01 18:38:59 +02:00
commit 6f4e46ee93
79 changed files with 1358 additions and 736 deletions

View File

@ -16,7 +16,7 @@ from bin.ud import conll17_ud_eval
from spacy.tokens import Token, Doc
from spacy.gold import Example
from spacy.util import compounding, minibatch, minibatch_by_words
from spacy.syntax.nonproj import projectivize
from spacy.pipeline._parser_internals.nonproj import projectivize
from spacy.matcher import Matcher
from spacy import displacy
from collections import defaultdict

View File

@ -13,7 +13,7 @@ import spacy
import spacy.util
from spacy.tokens import Token, Doc
from spacy.gold import Example
from spacy.syntax.nonproj import projectivize
from spacy.pipeline._parser_internals.nonproj import projectivize
from collections import defaultdict
from spacy.matcher import Matcher

View File

@ -31,6 +31,7 @@ MOD_NAMES = [
"spacy.vocab",
"spacy.attrs",
"spacy.kb",
"spacy.ml.parser_model",
"spacy.morphology",
"spacy.pipeline.dep_parser",
"spacy.pipeline.morphologizer",
@ -40,14 +41,14 @@ MOD_NAMES = [
"spacy.pipeline.sentencizer",
"spacy.pipeline.senter",
"spacy.pipeline.tagger",
"spacy.syntax.stateclass",
"spacy.syntax._state",
"spacy.pipeline.transition_parser",
"spacy.pipeline._parser_internals.arc_eager",
"spacy.pipeline._parser_internals.ner",
"spacy.pipeline._parser_internals.nonproj",
"spacy.pipeline._parser_internals._state",
"spacy.pipeline._parser_internals.stateclass",
"spacy.pipeline._parser_internals.transition_system",
"spacy.tokenizer",
"spacy.syntax.nn_parser",
"spacy.syntax._parser_model",
"spacy.syntax.nonproj",
"spacy.syntax.transition_system",
"spacy.syntax.arc_eager",
"spacy.gold.gold_io",
"spacy.tokens.doc",
"spacy.tokens.span",
@ -57,7 +58,6 @@ MOD_NAMES = [
"spacy.matcher.matcher",
"spacy.matcher.phrasematcher",
"spacy.matcher.dependencymatcher",
"spacy.syntax.ner",
"spacy.symbols",
"spacy.vectors",
]

View File

@ -10,7 +10,7 @@ from thinc.api import Config
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
from ._util import import_code, debug_cli
from ..gold import Corpus, Example
from ..syntax import nonproj
from ..pipeline._parser_internals import nonproj
from ..language import Language
from .. import util

View File

@ -63,8 +63,6 @@ class Warnings:
"have the spacy-lookups-data package installed.")
W024 = ("Entity '{entity}' - Alias '{alias}' combination already exists in "
"the Knowledge Base.")
W025 = ("'{name}' requires '{attr}' to be assigned, but none of the "
"previous components in the pipeline declare that they assign it.")
W026 = ("Unable to set all sentence boundaries from dependency parses.")
W027 = ("Found a large training file of {size} bytes. Note that it may "
"be more efficient to split your training data into multiple "

View File

@ -10,7 +10,7 @@ from .align import Alignment
from .iob_utils import biluo_to_iob, biluo_tags_from_offsets, biluo_tags_from_doc
from .iob_utils import spans_from_biluo_tags
from ..errors import Errors, Warnings
from ..syntax import nonproj
from ..pipeline._parser_internals import nonproj
cpdef Doc annotations2doc(vocab, tok_annot, doc_annot):

View File

@ -18,7 +18,7 @@ from timeit import default_timer as timer
from .tokens.underscore import Underscore
from .vocab import Vocab, create_vocab
from .pipe_analysis import analyze_pipes, analyze_all_pipes, validate_attrs
from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis
from .gold import Example
from .scorer import Scorer
from .util import create_default_optimizer, registry
@ -37,8 +37,6 @@ from . import util
from . import about
# TODO: integrate pipeline analyis
ENABLE_PIPELINE_ANALYSIS = False
# This is the base config will all settings (training etc.)
DEFAULT_CONFIG_PATH = Path(__file__).parent / "default_config.cfg"
DEFAULT_CONFIG = Config().from_disk(DEFAULT_CONFIG_PATH)
@ -522,6 +520,25 @@ class Language:
return add_component(func)
return add_component
def analyze_pipes(
self,
*,
keys: List[str] = ["assigns", "requires", "scores", "retokenizes"],
pretty: bool = False,
) -> Optional[Dict[str, Any]]:
"""Analyze the current pipeline components, print a summary of what
they assign or require and check that all requirements are met.
keys (List[str]): The meta values to display in the table. Corresponds
to values in FactoryMeta, defined by @Language.factory decorator.
pretty (bool): Pretty-print the results.
RETURNS (dict): The data.
"""
analysis = analyze_pipes(self, keys=keys)
if pretty:
print_pipe_analysis(analysis, keys=keys)
return analysis
def get_pipe(self, name: str) -> Callable[[Doc], Doc]:
"""Get a pipeline component for a given component name.
@ -666,8 +683,6 @@ class Language:
pipe_index = self._get_pipe_index(before, after, first, last)
self._pipe_meta[name] = self.get_factory_meta(factory_name)
self.pipeline.insert(pipe_index, (name, pipe_component))
if ENABLE_PIPELINE_ANALYSIS:
analyze_pipes(self, name, pipe_index)
return pipe_component
def _get_pipe_index(
@ -758,8 +773,6 @@ class Language:
self.add_pipe(factory_name, name=name)
else:
self.add_pipe(factory_name, name=name, before=pipe_index)
if ENABLE_PIPELINE_ANALYSIS:
analyze_all_pipes(self)
def rename_pipe(self, old_name: str, new_name: str) -> None:
"""Rename a pipeline component.
@ -793,8 +806,6 @@ class Language:
# because factory may be used for something else
self._pipe_meta.pop(name)
self._pipe_configs.pop(name)
if ENABLE_PIPELINE_ANALYSIS:
analyze_all_pipes(self)
return removed
def __call__(
@ -1099,6 +1110,7 @@ class Language:
batch_size: int = 256,
scorer: Optional[Scorer] = None,
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
scorer_cfg: Optional[Dict[str, Any]] = None,
) -> Dict[str, Union[float, dict]]:
"""Evaluate a model's pipeline components.
@ -1109,6 +1121,8 @@ class Language:
will be created.
component_cfg (dict): An optional dictionary with extra keyword
arguments for specific components.
scorer_cfg (dict): An optional dictionary with extra keyword arguments
for the scorer.
RETURNS (Scorer): The scorer containing the evaluation results.
DOCS: https://spacy.io/api/language#evaluate
@ -1126,8 +1140,10 @@ class Language:
raise TypeError(err)
if component_cfg is None:
component_cfg = {}
if scorer_cfg is None:
scorer_cfg = {}
if scorer is None:
kwargs = component_cfg.get("scorer", {})
kwargs = dict(scorer_cfg)
kwargs.setdefault("verbose", verbose)
kwargs.setdefault("nlp", self)
scorer = Scorer(**kwargs)
@ -1136,9 +1152,9 @@ class Language:
start_time = timer()
# tokenize the texts only for timing purposes
if not hasattr(self.tokenizer, "pipe"):
_ = [self.tokenizer(text) for text in texts]
_ = [self.tokenizer(text) for text in texts] # noqa: F841
else:
_ = list(self.tokenizer.pipe(texts))
_ = list(self.tokenizer.pipe(texts)) # noqa: F841
for name, pipe in self.pipeline:
kwargs = component_cfg.get(name, {})
kwargs.setdefault("batch_size", batch_size)

View File

@ -1,6 +1,7 @@
from typing import List
from thinc.api import Model
from thinc.types import Floats2d
from ..tokens import Doc
@ -15,14 +16,14 @@ def CharacterEmbed(nM: int, nC: int) -> Model[List[Doc], List[Floats2d]]:
)
def init(model, X=None, Y=None):
def init(model: Model, X=None, Y=None):
vectors_table = model.ops.alloc3f(
model.get_dim("nC"), model.get_dim("nV"), model.get_dim("nM")
)
model.set_param("E", vectors_table)
def forward(model, docs, is_train):
def forward(model: Model, docs: List[Doc], is_train: bool):
if docs is None:
return []
ids = []

View File

@ -14,7 +14,7 @@ def IOB() -> Model[Padded, Padded]:
)
def init(model, X: Optional[Padded] = None, Y: Optional[Padded] = None):
def init(model: Model, X: Optional[Padded] = None, Y: Optional[Padded] = None) -> None:
if X is not None and Y is not None:
if X.data.shape != Y.data.shape:
# TODO: Fix error

View File

@ -4,14 +4,14 @@ from thinc.api import Model
from ..attrs import LOWER
def extract_ngrams(ngram_size, attr=LOWER) -> Model:
def extract_ngrams(ngram_size: int, attr: int = LOWER) -> Model:
model = Model("extract_ngrams", forward)
model.attrs["ngram_size"] = ngram_size
model.attrs["attr"] = attr
return model
def forward(model, docs, is_train: bool):
def forward(model: Model, docs, is_train: bool):
batch_keys = []
batch_vals = []
for doc in docs:

View File

@ -1,5 +1,4 @@
from pathlib import Path
from typing import Optional
from thinc.api import chain, clone, list2ragged, reduce_mean, residual
from thinc.api import Model, Maxout, Linear
@ -9,7 +8,7 @@ from ...vocab import Vocab
@registry.architectures.register("spacy.EntityLinker.v1")
def build_nel_encoder(tok2vec, nO=None):
def build_nel_encoder(tok2vec: Model, nO: Optional[int] = None) -> Model:
with Model.define_operators({">>": chain, "**": clone}):
token_width = tok2vec.get_dim("nO")
output_layer = Linear(nO=nO, nI=token_width)
@ -26,7 +25,7 @@ def build_nel_encoder(tok2vec, nO=None):
@registry.assets.register("spacy.KBFromFile.v1")
def load_kb(vocab_path, kb_path) -> KnowledgeBase:
def load_kb(vocab_path: str, kb_path: str) -> KnowledgeBase:
vocab = Vocab().from_disk(vocab_path)
kb = KnowledgeBase(vocab=vocab)
kb.load_bulk(kb_path)

View File

@ -1,10 +1,20 @@
from typing import Optional, Iterable, Tuple, List, TYPE_CHECKING
import numpy
from thinc.api import chain, Maxout, LayerNorm, Softmax, Linear, zero_init, Model
from thinc.api import MultiSoftmax, list2array
if TYPE_CHECKING:
# This lets us add type hints for mypy etc. without causing circular imports
from ...vocab import Vocab # noqa: F401
from ...tokens import Doc # noqa: F401
def build_multi_task_model(tok2vec, maxout_pieces, token_vector_width, nO=None):
def build_multi_task_model(
tok2vec: Model,
maxout_pieces: int,
token_vector_width: int,
nO: Optional[int] = None,
) -> Model:
softmax = Softmax(nO=nO, nI=token_vector_width * 2)
model = chain(
tok2vec,
@ -22,7 +32,13 @@ def build_multi_task_model(tok2vec, maxout_pieces, token_vector_width, nO=None):
return model
def build_cloze_multi_task_model(vocab, tok2vec, maxout_pieces, hidden_size, nO=None):
def build_cloze_multi_task_model(
vocab: "Vocab",
tok2vec: Model,
maxout_pieces: int,
hidden_size: int,
nO: Optional[int] = None,
) -> Model:
# nO = vocab.vectors.data.shape[1]
output_layer = chain(
list2array(),
@ -43,24 +59,24 @@ def build_cloze_multi_task_model(vocab, tok2vec, maxout_pieces, hidden_size, nO=
def build_cloze_characters_multi_task_model(
vocab, tok2vec, maxout_pieces, hidden_size, nr_char
):
vocab: "Vocab", tok2vec: Model, maxout_pieces: int, hidden_size: int, nr_char: int
) -> Model:
output_layer = chain(
list2array(),
Maxout(hidden_size, nP=maxout_pieces),
LayerNorm(nI=hidden_size),
MultiSoftmax([256] * nr_char, nI=hidden_size),
)
model = build_masked_language_model(vocab, chain(tok2vec, output_layer))
model.set_ref("tok2vec", tok2vec)
model.set_ref("output_layer", output_layer)
return model
def build_masked_language_model(vocab, wrapped_model, mask_prob=0.15):
def build_masked_language_model(
vocab: "Vocab", wrapped_model: Model, mask_prob: float = 0.15
) -> Model:
"""Convert a model into a BERT-style masked language model"""
random_words = _RandomWords(vocab)
def mlm_forward(model, docs, is_train):
@ -74,7 +90,7 @@ def build_masked_language_model(vocab, wrapped_model, mask_prob=0.15):
return output, mlm_backward
def mlm_initialize(model, X=None, Y=None):
def mlm_initialize(model: Model, X=None, Y=None):
wrapped = model.layers[0]
wrapped.initialize(X=X, Y=Y)
for dim in wrapped.dim_names:
@ -90,12 +106,11 @@ def build_masked_language_model(vocab, wrapped_model, mask_prob=0.15):
dims={dim: None for dim in wrapped_model.dim_names},
)
mlm_model.set_ref("wrapped", wrapped_model)
return mlm_model
class _RandomWords:
def __init__(self, vocab):
def __init__(self, vocab: "Vocab") -> None:
self.words = [lex.text for lex in vocab if lex.prob != 0.0]
self.probs = [lex.prob for lex in vocab if lex.prob != 0.0]
self.words = self.words[:10000]
@ -104,7 +119,7 @@ class _RandomWords:
self.probs /= self.probs.sum()
self._cache = []
def next(self):
def next(self) -> str:
if not self._cache:
self._cache.extend(
numpy.random.choice(len(self.words), 10000, p=self.probs)
@ -113,9 +128,11 @@ class _RandomWords:
return self.words[index]
def _apply_mask(docs, random_words, mask_prob=0.15):
def _apply_mask(
docs: Iterable["Doc"], random_words: _RandomWords, mask_prob: float = 0.15
) -> Tuple[numpy.ndarray, List["Doc"]]:
# This needs to be here to avoid circular imports
from ...tokens import Doc
from ...tokens import Doc # noqa: F811
N = sum(len(doc) for doc in docs)
mask = numpy.random.uniform(0.0, 1.0, (N,))
@ -141,7 +158,7 @@ def _apply_mask(docs, random_words, mask_prob=0.15):
return mask, masked_docs
def _replace_word(word, random_words, mask="[MASK]"):
def _replace_word(word: str, random_words: _RandomWords, mask: str = "[MASK]") -> str:
roll = numpy.random.random()
if roll < 0.8:
return mask

View File

@ -1,6 +1,5 @@
from pydantic import StrictInt
from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops, with_array
from thinc.api import LayerNorm, Maxout, Mish
from typing import Optional
from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops
from ...util import registry
from .._precomputable_affine import PrecomputableAffine
@ -10,16 +9,15 @@ from ..tb_framework import TransitionModel
@registry.architectures.register("spacy.TransitionBasedParser.v1")
def build_tb_parser_model(
tok2vec: Model,
nr_feature_tokens: StrictInt,
hidden_width: StrictInt,
maxout_pieces: StrictInt,
use_upper=True,
nO=None,
):
nr_feature_tokens: int,
hidden_width: int,
maxout_pieces: int,
use_upper: bool = True,
nO: Optional[int] = None,
) -> Model:
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
tok2vec = chain(tok2vec, list2array(), Linear(hidden_width, t2v_width),)
tok2vec.set_dim("nO", hidden_width)
lower = PrecomputableAffine(
nO=hidden_width if use_upper else nO,
nF=nr_feature_tokens,

View File

@ -26,7 +26,6 @@ def BiluoTagger(
with_array(softmax_activation()),
padded2list(),
)
return Model(
"biluo-tagger",
forward,
@ -52,7 +51,6 @@ def IOBTagger(
with_array(softmax_activation()),
padded2list(),
)
return Model(
"iob-tagger",
forward,

View File

@ -1,10 +1,11 @@
from typing import Optional
from thinc.api import zero_init, with_array, Softmax, chain, Model
from ...util import registry
@registry.architectures.register("spacy.Tagger.v1")
def build_tagger_model(tok2vec, nO=None) -> Model:
def build_tagger_model(tok2vec: Model, nO: Optional[int] = None) -> Model:
# TODO: glorot_uniform_init seems to work a bit better than zero_init here?!
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
output_layer = Softmax(nO, t2v_width, init_W=zero_init)

View File

@ -2,10 +2,9 @@ from typing import Optional
from thinc.api import Model, reduce_mean, Linear, list2ragged, Logistic
from thinc.api import chain, concatenate, clone, Dropout, ParametricAttention
from thinc.api import SparseLinear, Softmax, softmax_activation, Maxout, reduce_sum
from thinc.api import HashEmbed, with_ragged, with_array, with_cpu, uniqued
from thinc.api import HashEmbed, with_array, with_cpu, uniqued
from thinc.api import Relu, residual, expand_window, FeatureExtractor
from ... import util
from ...attrs import ID, ORTH, PREFIX, SUFFIX, SHAPE, LOWER
from ...util import registry
from ..extract_ngrams import extract_ngrams
@ -40,7 +39,12 @@ def build_simple_cnn_text_classifier(
@registry.architectures.register("spacy.TextCatBOW.v1")
def build_bow_text_classifier(exclusive_classes, ngram_size, no_output_layer, nO=None):
def build_bow_text_classifier(
exclusive_classes: bool,
ngram_size: int,
no_output_layer: bool,
nO: Optional[int] = None,
) -> Model:
with Model.define_operators({">>": chain}):
sparse_linear = SparseLinear(nO)
model = extract_ngrams(ngram_size, attr=ORTH) >> sparse_linear
@ -55,16 +59,16 @@ def build_bow_text_classifier(exclusive_classes, ngram_size, no_output_layer, nO
@registry.architectures.register("spacy.TextCatEnsemble.v1")
def build_text_classifier(
width,
embed_size,
pretrained_vectors,
exclusive_classes,
ngram_size,
window_size,
conv_depth,
dropout,
nO=None,
):
width: int,
embed_size: int,
pretrained_vectors: Optional[bool],
exclusive_classes: bool,
ngram_size: int,
window_size: int,
conv_depth: int,
dropout: Optional[float],
nO: Optional[int] = None,
) -> Model:
cols = [ORTH, LOWER, PREFIX, SUFFIX, SHAPE, ID]
with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
lower = HashEmbed(
@ -91,7 +95,6 @@ def build_text_classifier(
dropout=dropout,
seed=13,
)
width_nI = sum(layer.get_dim("nO") for layer in [lower, prefix, suffix, shape])
trained_vectors = FeatureExtractor(cols) >> with_array(
uniqued(
@ -100,7 +103,6 @@ def build_text_classifier(
column=cols.index(ORTH),
)
)
if pretrained_vectors:
static_vectors = StaticVectors(width)
vector_layer = trained_vectors | static_vectors
@ -152,7 +154,12 @@ def build_text_classifier(
@registry.architectures.register("spacy.TextCatLowData.v1")
def build_text_classifier_lowdata(width, pretrained_vectors, dropout, nO=None):
def build_text_classifier_lowdata(
width: int,
pretrained_vectors: Optional[bool],
dropout: Optional[float],
nO: Optional[int] = None,
) -> Model:
# Note, before v.3, this was the default if setting "low_data" and "pretrained_dims"
with Model.define_operators({">>": chain, "**": clone}):
model = (

View File

@ -6,16 +6,15 @@ from thinc.api import expand_window, residual, Maxout, Mish, PyTorchLSTM
from thinc.types import Floats2d
from ...tokens import Doc
from ... import util
from ...util import registry
from ...ml import _character_embed
from ..staticvectors import StaticVectors
from ...pipeline.tok2vec import Tok2VecListener
from ...attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE
from ...attrs import ORTH, NORM, PREFIX, SUFFIX, SHAPE
@registry.architectures.register("spacy.Tok2VecListener.v1")
def tok2vec_listener_v1(width, upstream="*"):
def tok2vec_listener_v1(width: int, upstream: str = "*"):
tok2vec = Tok2VecListener(upstream_name=upstream, width=width)
return tok2vec
@ -45,10 +44,11 @@ def build_hash_embed_cnn_tok2vec(
width=width,
depth=depth,
window_size=window_size,
maxout_pieces=maxout_pieces
)
maxout_pieces=maxout_pieces,
),
)
@registry.architectures.register("spacy.Tok2Vec.v1")
def build_Tok2Vec_model(
embed: Model[List[Doc], List[Floats2d]],
@ -68,7 +68,6 @@ def MultiHashEmbed(
width: int, rows: int, also_embed_subwords: bool, also_use_static_vectors: bool
):
cols = [NORM, PREFIX, SUFFIX, SHAPE, ORTH]
seed = 7
def make_hash_embed(feature):
@ -124,11 +123,11 @@ def CharacterEmbed(width: int, rows: int, nM: int, nC: int):
chain(
FeatureExtractor([NORM]),
list2ragged(),
with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5))
)
with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)),
),
),
with_array(Maxout(width, nM * nC + width, nP=3, normalize=True, dropout=0.0)),
ragged2list()
ragged2list(),
)
return model
@ -155,12 +154,7 @@ def MaxoutWindowEncoder(width: int, window_size: int, maxout_pieces: int, depth:
def MishWindowEncoder(width, window_size, depth):
cnn = chain(
expand_window(window_size=window_size),
Mish(
nO=width,
nI=width * ((window_size * 2) + 1),
dropout=0.0,
normalize=True
),
Mish(nO=width, nI=width * ((window_size * 2) + 1), dropout=0.0, normalize=True),
)
model = clone(residual(cnn), depth)
model.set_dim("nO", width)

View File

@ -1,8 +1,6 @@
from libc.string cimport memset, memcpy
from libc.stdlib cimport calloc, free, realloc
from ..typedefs cimport weight_t, class_t, hash_t
from ._state cimport StateC
from ..typedefs cimport weight_t, hash_t
from ..pipeline._parser_internals._state cimport StateC
cdef struct SizesC:

View File

@ -1,29 +1,18 @@
# cython: infer_types=True, cdivision=True, boundscheck=False
cimport cython.parallel
cimport numpy as np
from libc.math cimport exp
from libcpp.vector cimport vector
from libc.string cimport memset, memcpy
from libc.stdlib cimport calloc, free, realloc
from cymem.cymem cimport Pool
from thinc.extra.search cimport Beam
from thinc.backends.linalg cimport Vec, VecVec
cimport blis.cy
import numpy
import numpy.random
from thinc.api import Linear, Model, CupyOps, NumpyOps, use_ops, noop
from thinc.api import Model, CupyOps, NumpyOps
from ..typedefs cimport weight_t, class_t, hash_t
from ..tokens.doc cimport Doc
from .stateclass cimport StateClass
from .transition_system cimport Transition
from ..compat import copy_array
from ..errors import Errors, TempErrors
from ..util import create_default_optimizer
from .. import util
from . import nonproj
from ..typedefs cimport weight_t, class_t, hash_t
from ..pipeline._parser_internals.stateclass cimport StateClass
cdef WeightsC get_c_weights(model) except *:

View File

@ -1,5 +1,5 @@
from thinc.api import Model, noop, use_ops, Linear
from ..syntax._parser_model import ParserStepModel
from .parser_model import ParserStepModel
def TransitionModel(tok2vec, lower, upper, dropout=0.2, unseen_classes=set()):

View File

@ -1,9 +1,8 @@
from typing import List, Dict, Iterable, Optional, Union, TYPE_CHECKING
from wasabi import Printer
import warnings
from wasabi import msg
from .tokens import Doc, Token, Span
from .errors import Errors, Warnings
from .errors import Errors
from .util import dot_to_dict
if TYPE_CHECKING:
@ -11,48 +10,7 @@ if TYPE_CHECKING:
from .language import Language # noqa: F401
def analyze_pipes(
nlp: "Language", name: str, index: int, warn: bool = True
) -> List[str]:
"""Analyze a pipeline component with respect to its position in the current
pipeline and the other components. Will check whether requirements are
fulfilled (e.g. if previous components assign the attributes).
nlp (Language): The current nlp object.
name (str): The name of the pipeline component to analyze.
index (int): The index of the component in the pipeline.
warn (bool): Show user warning if problem is found.
RETURNS (List[str]): The problems found for the given pipeline component.
"""
assert nlp.pipeline[index][0] == name
prev_pipes = nlp.pipeline[:index]
meta = nlp.get_pipe_meta(name)
requires = {annot: False for annot in meta.requires}
if requires:
for prev_name, prev_pipe in prev_pipes:
prev_meta = nlp.get_pipe_meta(prev_name)
for annot in prev_meta.assigns:
requires[annot] = True
problems = []
for annot, fulfilled in requires.items():
if not fulfilled:
problems.append(annot)
if warn:
warnings.warn(Warnings.W025.format(name=name, attr=annot))
return problems
def analyze_all_pipes(nlp: "Language", warn: bool = True) -> Dict[str, List[str]]:
"""Analyze all pipes in the pipeline in order.
nlp (Language): The current nlp object.
warn (bool): Show user warning if problem is found.
RETURNS (Dict[str, List[str]]): The problems found, keyed by component name.
"""
problems = {}
for i, name in enumerate(nlp.pipe_names):
problems[name] = analyze_pipes(nlp, name, i, warn=warn)
return problems
DEFAULT_KEYS = ["requires", "assigns", "scores", "retokenizes"]
def validate_attrs(values: Iterable[str]) -> Iterable[str]:
@ -101,89 +59,77 @@ def validate_attrs(values: Iterable[str]) -> Iterable[str]:
return values
def _get_feature_for_attr(nlp: "Language", attr: str, feature: str) -> List[str]:
assert feature in ["assigns", "requires"]
result = []
def get_attr_info(nlp: "Language", attr: str) -> Dict[str, List[str]]:
"""Check which components in the pipeline assign or require an attribute.
nlp (Language): The current nlp object.
attr (str): The attribute, e.g. "doc.tensor".
RETURNS (Dict[str, List[str]]): A dict keyed by "assigns" and "requires",
mapped to a list of component names.
"""
result = {"assigns": [], "requires": []}
for pipe_name in nlp.pipe_names:
meta = nlp.get_pipe_meta(pipe_name)
pipe_assigns = getattr(meta, feature, [])
if attr in pipe_assigns:
result.append(pipe_name)
if attr in meta.assigns:
result["assigns"].append(pipe_name)
if attr in meta.requires:
result["requires"].append(pipe_name)
return result
def get_assigns_for_attr(nlp: "Language", attr: str) -> List[str]:
"""Get all pipeline components that assign an attr, e.g. "doc.tensor".
pipeline (Language): The current nlp object.
attr (str): The attribute to check.
RETURNS (List[str]): Names of components that require the attr.
"""
return _get_feature_for_attr(nlp, attr, "assigns")
def get_requires_for_attr(nlp: "Language", attr: str) -> List[str]:
"""Get all pipeline components that require an attr, e.g. "doc.tensor".
pipeline (Language): The current nlp object.
attr (str): The attribute to check.
RETURNS (List[str]): Names of components that require the attr.
"""
return _get_feature_for_attr(nlp, attr, "requires")
def print_summary(
nlp: "Language", pretty: bool = True, no_print: bool = False
) -> Optional[Dict[str, Union[List[str], Dict[str, List[str]]]]]:
def analyze_pipes(
nlp: "Language", *, keys: List[str] = DEFAULT_KEYS,
) -> Dict[str, Union[List[str], Dict[str, List[str]]]]:
"""Print a formatted summary for the current nlp object's pipeline. Shows
a table with the pipeline components and why they assign and require, as
well as any problems if available.
nlp (Language): The nlp object.
pretty (bool): Pretty-print the results (color etc).
no_print (bool): Don't print anything, just return the data.
RETURNS (dict): A dict with "overview" and "problems".
keys (List[str]): The meta keys to show in the table.
RETURNS (dict): A dict with "summary" and "problems".
"""
msg = Printer(pretty=pretty, no_print=no_print)
overview = []
problems = {}
result = {"summary": {}, "problems": {}}
all_attrs = set()
for i, name in enumerate(nlp.pipe_names):
meta = nlp.get_pipe_meta(name)
overview.append((i, name, meta.requires, meta.assigns, meta.retokenizes))
problems[name] = analyze_pipes(nlp, name, i, warn=False)
all_attrs.update(meta.assigns)
all_attrs.update(meta.requires)
result["summary"][name] = {key: getattr(meta, key, None) for key in keys}
prev_pipes = nlp.pipeline[:i]
requires = {annot: False for annot in meta.requires}
if requires:
for prev_name, prev_pipe in prev_pipes:
prev_meta = nlp.get_pipe_meta(prev_name)
for annot in prev_meta.assigns:
requires[annot] = True
result["problems"][name] = []
for annot, fulfilled in requires.items():
if not fulfilled:
result["problems"][name].append(annot)
result["attrs"] = {attr: get_attr_info(nlp, attr) for attr in all_attrs}
return result
def print_pipe_analysis(
analysis: Dict[str, Union[List[str], Dict[str, List[str]]]],
*,
keys: List[str] = DEFAULT_KEYS,
) -> Optional[Dict[str, Union[List[str], Dict[str, List[str]]]]]:
"""Print a formatted version of the pipe analysis produced by analyze_pipes.
analysis (Dict[str, Union[List[str], Dict[str, List[str]]]]): The analysis.
keys (List[str]): The meta keys to show in the table.
"""
msg.divider("Pipeline Overview")
header = ("#", "Component", "Requires", "Assigns", "Retokenizes")
msg.table(overview, header=header, divider=True, multiline=True)
n_problems = sum(len(p) for p in problems.values())
if any(p for p in problems.values()):
header = ["#", "Component", *[key.capitalize() for key in keys]]
summary = analysis["summary"].items()
body = [[i, n, *[v for v in m.values()]] for i, (n, m) in enumerate(summary)]
msg.table(body, header=header, divider=True, multiline=True)
n_problems = sum(len(p) for p in analysis["problems"].values())
if any(p for p in analysis["problems"].values()):
msg.divider(f"Problems ({n_problems})")
for name, problem in problems.items():
for name, problem in analysis["problems"].items():
if problem:
msg.warn(f"'{name}' requirements not met: {', '.join(problem)}")
else:
msg.good("No problems found.")
if no_print:
return {"overview": overview, "problems": problems}
def count_pipeline_interdependencies(nlp: "Language") -> List[int]:
"""Count how many subsequent components require an annotation set by each
component in the pipeline.
nlp (Language): The current nlp object.
RETURNS (List[int]): The interdependency counts.
"""
pipe_assigns = []
pipe_requires = []
for name in nlp.pipe_names:
meta = nlp.get_pipe_meta(name)
pipe_assigns.append(set(meta.assigns))
pipe_requires.append(set(meta.requires))
counts = []
for i, assigns in enumerate(pipe_assigns):
count = 0
for requires in pipe_requires[i + 1 :]:
if assigns.intersection(requires):
count += 1
counts.append(count)
return counts

View File

@ -1,15 +1,14 @@
from libc.string cimport memcpy, memset, memmove
from libc.stdlib cimport malloc, calloc, free
from libc.string cimport memcpy, memset
from libc.stdlib cimport calloc, free
from libc.stdint cimport uint32_t, uint64_t
from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno
from murmurhash.mrmr cimport hash64
from ..vocab cimport EMPTY_LEXEME
from ..structs cimport TokenC, SpanC
from ..lexeme cimport Lexeme
from ..symbols cimport punct
from ..attrs cimport IS_SPACE
from ..typedefs cimport attr_t
from ...vocab cimport EMPTY_LEXEME
from ...structs cimport TokenC, SpanC
from ...lexeme cimport Lexeme
from ...attrs cimport IS_SPACE
from ...typedefs cimport attr_t
cdef inline bint is_space_token(const TokenC* token) nogil:

View File

@ -1,8 +1,6 @@
from cymem.cymem cimport Pool
from .stateclass cimport StateClass
from ..typedefs cimport weight_t, attr_t
from .transition_system cimport TransitionSystem, Transition
from ...typedefs cimport weight_t, attr_t
from .transition_system cimport Transition, TransitionSystem
cdef class ArcEager(TransitionSystem):

View File

@ -1,24 +1,17 @@
# cython: profile=True, cdivision=True, infer_types=True
from cpython.ref cimport Py_INCREF
from cymem.cymem cimport Pool, Address
from libc.stdint cimport int32_t
from collections import defaultdict, Counter
import json
from ..typedefs cimport hash_t, attr_t
from ..strings cimport hash_string
from ..structs cimport TokenC
from ..tokens.doc cimport Doc, set_children_from_heads
from ...typedefs cimport hash_t, attr_t
from ...strings cimport hash_string
from ...structs cimport TokenC
from ...tokens.doc cimport Doc, set_children_from_heads
from ...gold.example cimport Example
from ...errors import Errors
from .stateclass cimport StateClass
from ._state cimport StateC
from .transition_system cimport move_cost_func_t, label_cost_func_t
from ..gold.example cimport Example
from ..errors import Errors
from .nonproj import is_nonproj_tree
from . import nonproj
# Calculate cost as gold/not gold. We don't use scalar value anyway.
cdef int BINARY_COSTS = 1

View File

@ -1,6 +1,4 @@
from .transition_system cimport TransitionSystem
from .transition_system cimport Transition
from ..typedefs cimport attr_t
cdef class BiluoPushDown(TransitionSystem):

View File

@ -2,17 +2,14 @@ from collections import Counter
from libc.stdint cimport int32_t
from cymem.cymem cimport Pool
from ..typedefs cimport weight_t
from ...typedefs cimport weight_t, attr_t
from ...lexeme cimport Lexeme
from ...attrs cimport IS_SPACE
from ...gold.example cimport Example
from ...errors import Errors
from .stateclass cimport StateClass
from ._state cimport StateC
from .transition_system cimport Transition
from .transition_system cimport do_func_t
from ..lexeme cimport Lexeme
from ..attrs cimport IS_SPACE
from ..gold.iob_utils import biluo_tags_from_offsets
from ..gold.example cimport Example
from ..errors import Errors
from .transition_system cimport Transition, do_func_t
cdef enum:

View File

@ -5,9 +5,9 @@ scheme.
"""
from copy import copy
from ..tokens.doc cimport Doc, set_children_from_heads
from ...tokens.doc cimport Doc, set_children_from_heads
from ..errors import Errors
from ...errors import Errors
DELIMITER = '||'

View File

@ -1,12 +1,8 @@
from libc.string cimport memcpy, memset
from cymem.cymem cimport Pool
cimport cython
from ..structs cimport TokenC, SpanC
from ..typedefs cimport attr_t
from ...structs cimport TokenC, SpanC
from ...typedefs cimport attr_t
from ..vocab cimport EMPTY_LEXEME
from ._state cimport StateC

View File

@ -1,7 +1,7 @@
# cython: infer_types=True
import numpy
from ..tokens.doc cimport Doc
from ...tokens.doc cimport Doc
cdef class StateClass:

View File

@ -1,11 +1,11 @@
from cymem.cymem cimport Pool
from ..typedefs cimport attr_t, weight_t
from ..structs cimport TokenC
from ..strings cimport StringStore
from ...typedefs cimport attr_t, weight_t
from ...structs cimport TokenC
from ...strings cimport StringStore
from ...gold.example cimport Example
from .stateclass cimport StateClass
from ._state cimport StateC
from ..gold.example cimport Example
cdef struct Transition:

View File

@ -1,19 +1,17 @@
# cython: infer_types=True
from __future__ import print_function
from cpython.ref cimport Py_INCREF
from cymem.cymem cimport Pool
from collections import Counter
import srsly
from ..typedefs cimport weight_t
from ..tokens.doc cimport Doc
from ..structs cimport TokenC
from ...typedefs cimport weight_t, attr_t
from ...tokens.doc cimport Doc
from ...structs cimport TokenC
from .stateclass cimport StateClass
from ..typedefs cimport attr_t
from ..errors import Errors
from .. import util
from ...errors import Errors
from ... import util
cdef weight_t MIN_SCORE = -90000

View File

@ -1,13 +1,13 @@
# cython: infer_types=True, profile=True, binding=True
from typing import Optional, Iterable
from thinc.api import CosineDistance, to_categorical, get_array_module, Model, Config
from thinc.api import Model, Config
from ..syntax.nn_parser cimport Parser
from ..syntax.arc_eager cimport ArcEager
from .transition_parser cimport Parser
from ._parser_internals.arc_eager cimport ArcEager
from .functions import merge_subtokens
from ..language import Language
from ..syntax import nonproj
from ._parser_internals import nonproj
from ..scorer import Scorer
@ -34,7 +34,7 @@ DEFAULT_PARSER_MODEL = Config().from_str(default_model_config)["model"]
@Language.factory(
"parser",
assigns=["token.dep", "token.is_sent_start", "doc.sents"],
assigns=["token.dep", "token.head", "token.is_sent_start", "doc.sents"],
default_config={
"moves": None,
"update_with_oracle_cut_size": 100,
@ -120,7 +120,8 @@ cdef class DependencyParser(Parser):
return dep
results = {}
results.update(Scorer.score_spans(examples, "sents", **kwargs))
results.update(Scorer.score_deps(examples, "dep", getter=dep_getter,
ignore_labels=("p", "punct"), **kwargs))
kwargs.setdefault("getter", dep_getter)
kwargs.setdefault("ignore_label", ("p", "punct"))
results.update(Scorer.score_deps(examples, "dep", **kwargs))
del results["sents_per_type"]
return results

View File

@ -222,9 +222,9 @@ class EntityLinker(Pipe):
set_dropout_rate(self.model, drop)
if not sentence_docs:
warnings.warn(Warnings.W093.format(name="Entity Linker"))
return 0.0
return losses
sentence_encodings, bp_context = self.model.begin_update(sentence_docs)
loss, d_scores = self.get_similarity_loss(
loss, d_scores = self.get_loss(
sentence_encodings=sentence_encodings, examples=examples
)
bp_context(d_scores)
@ -235,7 +235,7 @@ class EntityLinker(Pipe):
self.set_annotations(docs, predictions)
return losses
def get_similarity_loss(self, examples: Iterable[Example], sentence_encodings):
def get_loss(self, examples: Iterable[Example], sentence_encodings):
entity_encodings = []
for eg in examples:
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
@ -247,7 +247,7 @@ class EntityLinker(Pipe):
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
if sentence_encodings.shape != entity_encodings.shape:
err = Errors.E147.format(
method="get_similarity_loss", msg="gold entities do not match up"
method="get_loss", msg="gold entities do not match up"
)
raise RuntimeError(err)
gradients = self.distance.get_grad(sentence_encodings, entity_encodings)
@ -337,13 +337,13 @@ class EntityLinker(Pipe):
final_kb_ids.append(candidates[0].entity_)
else:
random.shuffle(candidates)
# this will set all prior probabilities to 0 if they should be excluded from the model
# set all prior probabilities to 0 if incl_prior=False
prior_probs = xp.asarray(
[c.prior_prob for c in candidates]
)
if not self.cfg.get("incl_prior"):
prior_probs = xp.asarray(
[0.0 for c in candidates]
[0.0 for _ in candidates]
)
scores = prior_probs
# add in similarity from the context

View File

@ -1,7 +1,7 @@
# cython: infer_types=True, profile=True, binding=True
from typing import Optional
import numpy
from thinc.api import CosineDistance, to_categorical, to_categorical, Model, Config
from thinc.api import CosineDistance, to_categorical, Model, Config
from thinc.api import set_dropout_rate
from ..tokens.doc cimport Doc
@ -9,7 +9,7 @@ from ..tokens.doc cimport Doc
from .pipe import Pipe
from .tagger import Tagger
from ..language import Language
from ..syntax import nonproj
from ._parser_internals import nonproj
from ..attrs import POS, ID
from ..errors import Errors
@ -219,3 +219,6 @@ class ClozeMultitask(Pipe):
if losses is not None:
losses[self.name] += loss
def add_label(self, label):
raise NotImplementedError

View File

@ -1,9 +1,9 @@
# cython: infer_types=True, profile=True, binding=True
from typing import Optional, Iterable
from thinc.api import CosineDistance, to_categorical, get_array_module, Model, Config
from thinc.api import Model, Config
from ..syntax.nn_parser cimport Parser
from ..syntax.ner cimport BiluoPushDown
from .transition_parser cimport Parser
from ._parser_internals.ner cimport BiluoPushDown
from ..language import Language
from ..scorer import Scorer

2
spacy/pipeline/pipe.pxd Normal file
View File

@ -0,0 +1,2 @@
cdef class Pipe:
cdef public str name

View File

@ -8,7 +8,7 @@ from ..errors import Errors
from .. import util
class Pipe:
cdef class Pipe:
"""This class is a base class and not instantiated directly. Trainable
pipeline components like the EntityRecognizer or TextCategorizer inherit
from it and it defines the interface that components should follow to
@ -17,8 +17,6 @@ class Pipe:
DOCS: https://spacy.io/api/pipe
"""
name = None
def __init__(self, vocab, model, name, **cfg):
"""Initialize a pipeline component.

View File

@ -203,3 +203,9 @@ class Sentencizer(Pipe):
cfg = srsly.read_json(path)
self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars))
return self
def get_loss(self, examples, scores):
raise NotImplementedError
def add_label(self, label):
raise NotImplementedError

View File

@ -109,7 +109,7 @@ class SentenceRecognizer(Tagger):
for eg in examples:
eg_truth = []
for x in eg.get_aligned("sent_start"):
if x == None:
if x is None:
eg_truth.append(None)
elif x == 1:
eg_truth.append(labels[1])

View File

@ -131,8 +131,6 @@ class SimpleNER(Pipe):
return losses
def get_loss(self, examples: List[Example], scores) -> Tuple[List[Floats2d], float]:
loss = 0
d_scores = []
truths = []
for eg in examples:
tags = eg.get_aligned("TAG", as_string=True)
@ -159,7 +157,6 @@ class SimpleNER(Pipe):
if not hasattr(get_examples, "__call__"):
gold_tuples = get_examples
get_examples = lambda: gold_tuples
labels = _get_labels(get_examples())
for label in _get_labels(get_examples()):
self.add_label(label)
labels = self.labels

View File

@ -238,8 +238,11 @@ class TextCategorizer(Pipe):
DOCS: https://spacy.io/api/textcategorizer#rehearse
"""
if losses is not None:
losses.setdefault(self.name, 0.0)
if self._rehearsal_model is None:
return
return losses
try:
docs = [eg.predicted for eg in examples]
except AttributeError:
@ -250,7 +253,7 @@ class TextCategorizer(Pipe):
raise TypeError(err)
if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs.
return
return losses
set_dropout_rate(self.model, drop)
scores, bp_scores = self.model.begin_update(docs)
target = self._rehearsal_model(examples)
@ -259,7 +262,6 @@ class TextCategorizer(Pipe):
if sgd is not None:
self.model.finish_update(sgd)
if losses is not None:
losses.setdefault(self.name, 0.0)
losses[self.name] += (gradient ** 2).sum()
return losses

View File

@ -199,6 +199,9 @@ class Tok2Vec(Pipe):
docs = [Doc(self.vocab, words=["hello"])]
self.model.initialize(X=docs)
def add_label(self, label):
raise NotImplementedError
class Tok2VecListener(Model):
"""A layer that gets fed its answers from an upstream connection,

View File

@ -1,16 +1,15 @@
from .stateclass cimport StateClass
from .arc_eager cimport TransitionSystem
from cymem.cymem cimport Pool
from ..vocab cimport Vocab
from ..tokens.doc cimport Doc
from ..structs cimport TokenC
from ._state cimport StateC
from ._parser_model cimport WeightsC, ActivationsC, SizesC
from .pipe cimport Pipe
from ._parser_internals.transition_system cimport Transition, TransitionSystem
from ._parser_internals._state cimport StateC
from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC
cdef class Parser:
cdef class Parser(Pipe):
cdef readonly Vocab vocab
cdef public object model
cdef public str name
cdef public object _rehearsal_model
cdef readonly TransitionSystem moves
cdef readonly object cfg

View File

@ -1,42 +1,32 @@
# cython: infer_types=True, cdivision=True, boundscheck=False
cimport cython.parallel
from __future__ import print_function
from cymem.cymem cimport Pool
cimport numpy as np
from itertools import islice
from cpython.ref cimport PyObject, Py_XDECREF
from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno
from libc.math cimport exp
from libcpp.vector cimport vector
from libc.string cimport memset, memcpy
from libc.string cimport memset
from libc.stdlib cimport calloc, free
from cymem.cymem cimport Pool
from thinc.backends.linalg cimport Vec, VecVec
from thinc.api import chain, clone, Linear, list2array, NumpyOps, CupyOps, use_ops
from thinc.api import get_array_module, zero_init, set_dropout_rate
from itertools import islice
import srsly
from ._parser_internals.stateclass cimport StateClass
from ..ml.parser_model cimport alloc_activations, free_activations
from ..ml.parser_model cimport predict_states, arg_max_if_valid
from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss
from ..ml.parser_model cimport get_c_weights, get_c_sizes
from ..tokens.doc cimport Doc
from ..errors import Errors, Warnings
from .. import util
from ..util import create_default_optimizer
from thinc.api import set_dropout_rate
import numpy.random
import numpy
import warnings
from ..tokens.doc cimport Doc
from ..typedefs cimport weight_t, class_t, hash_t
from ._parser_model cimport alloc_activations, free_activations
from ._parser_model cimport predict_states, arg_max_if_valid
from ._parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss
from ._parser_model cimport get_c_weights, get_c_sizes
from .stateclass cimport StateClass
from ._state cimport StateC
from .transition_system cimport Transition
from ..util import create_default_optimizer, registry
from ..compat import copy_array
from ..errors import Errors, Warnings
from .. import util
from . import nonproj
cdef class Parser:
cdef class Parser(Pipe):
"""
Base class of the DependencyParser and EntityRecognizer.
"""
@ -107,7 +97,7 @@ cdef class Parser:
@property
def tok2vec(self):
'''Return the embedding and convolutional layer of the model.'''
"""Return the embedding and convolutional layer of the model."""
return self.model.get_ref("tok2vec")
@property
@ -138,13 +128,13 @@ cdef class Parser:
raise NotImplementedError
def init_multitask_objectives(self, get_examples, pipeline, **cfg):
'''Setup models for secondary objectives, to benefit from multi-task
"""Setup models for secondary objectives, to benefit from multi-task
learning. This method is intended to be overridden by subclasses.
For instance, the dependency parser can benefit from sharing
an input representation with a label prediction model. These auxiliary
models are discarded after training.
'''
"""
pass
def use_params(self, params):

View File

@ -1,55 +1,61 @@
from typing import Optional, Iterable, Dict, Any, Callable, Tuple, TYPE_CHECKING
import numpy as np
from .gold import Example
from .tokens import Token, Doc
from .errors import Errors
from .util import get_lang_class
from .morphology import Morphology
if TYPE_CHECKING:
# This lets us add type hints for mypy etc. without causing circular imports
from .language import Language # noqa: F401
DEFAULT_PIPELINE = ["senter", "tagger", "morphologizer", "parser", "ner", "textcat"]
class PRFScore:
"""
A precision / recall / F score
"""
"""A precision / recall / F score."""
def __init__(self):
def __init__(self) -> None:
self.tp = 0
self.fp = 0
self.fn = 0
def score_set(self, cand, gold):
def score_set(self, cand: set, gold: set) -> None:
self.tp += len(cand.intersection(gold))
self.fp += len(cand - gold)
self.fn += len(gold - cand)
@property
def precision(self):
def precision(self) -> float:
return self.tp / (self.tp + self.fp + 1e-100)
@property
def recall(self):
def recall(self) -> float:
return self.tp / (self.tp + self.fn + 1e-100)
@property
def fscore(self):
def fscore(self) -> float:
p = self.precision
r = self.recall
return 2 * ((p * r) / (p + r + 1e-100))
def to_dict(self):
def to_dict(self) -> Dict[str, float]:
return {"p": self.precision, "r": self.recall, "f": self.fscore}
class ROCAUCScore:
"""
An AUC ROC score.
"""
"""An AUC ROC score."""
def __init__(self):
def __init__(self) -> None:
self.golds = []
self.cands = []
self.saved_score = 0.0
self.saved_score_at_len = 0
def score_set(self, cand, gold):
def score_set(self, cand, gold) -> None:
self.cands.append(cand)
self.golds.append(gold)
@ -70,51 +76,52 @@ class ROCAUCScore:
class Scorer:
"""Compute evaluation scores."""
def __init__(self, nlp=None, **cfg):
def __init__(
self,
nlp: Optional["Language"] = None,
default_lang: str = "xx",
default_pipeline=DEFAULT_PIPELINE,
**cfg,
) -> None:
"""Initialize the Scorer.
DOCS: https://spacy.io/api/scorer#init
"""
self.nlp = nlp
self.cfg = cfg
if not nlp:
# create a default pipeline
nlp = get_lang_class("xx")()
nlp.add_pipe("senter")
nlp.add_pipe("tagger")
nlp.add_pipe("morphologizer")
nlp.add_pipe("parser")
nlp.add_pipe("ner")
nlp.add_pipe("textcat")
nlp = get_lang_class(default_lang)()
for pipe in default_pipeline:
nlp.add_pipe(pipe)
self.nlp = nlp
def score(self, examples):
def score(self, examples: Iterable[Example]) -> Dict[str, Any]:
"""Evaluate a list of Examples.
examples (Iterable[Example]): The predicted annotations + correct annotations.
RETURNS (Dict): A dictionary of scores.
DOCS: https://spacy.io/api/scorer#score
"""
scores = {}
if hasattr(self.nlp.tokenizer, "score"):
scores.update(self.nlp.tokenizer.score(examples, **self.cfg))
for name, component in self.nlp.pipeline:
if hasattr(component, "score"):
scores.update(component.score(examples, **self.cfg))
return scores
@staticmethod
def score_tokenization(examples, **cfg):
def score_tokenization(examples: Iterable[Example], **cfg) -> Dict[str, float]:
"""Returns accuracy and PRF scores for tokenization.
* token_acc: # correct tokens / # gold tokens
* token_p/r/f: PRF for token character spans
examples (Iterable[Example]): Examples to score
RETURNS (dict): A dictionary containing the scores token_acc/p/r/f.
RETURNS (Dict[str, float]): A dictionary containing the scores
token_acc/p/r/f.
DOCS: https://spacy.io/api/scorer#score_tokenization
"""
acc_score = PRFScore()
prf_score = PRFScore()
@ -145,16 +152,24 @@ class Scorer:
}
@staticmethod
def score_token_attr(examples, attr, getter=getattr, **cfg):
def score_token_attr(
examples: Iterable[Example],
attr: str,
*,
getter: Callable[[Token, str], Any] = getattr,
**cfg,
) -> Dict[str, float]:
"""Returns an accuracy score for a token-level attribute.
examples (Iterable[Example]): Examples to score
attr (str): The attribute to score.
getter (callable): Defaults to getattr. If provided,
getter (Callable[[Token, str], Any]): Defaults to getattr. If provided,
getter(token, attr) should return the value of the attribute for an
individual token.
RETURNS (dict): A dictionary containing the accuracy score under the
key attr_acc.
RETURNS (Dict[str, float]): A dictionary containing the accuracy score
under the key attr_acc.
DOCS: https://spacy.io/api/scorer#score_token_attr
"""
tag_score = PRFScore()
for example in examples:
@ -172,17 +187,21 @@ class Scorer:
gold_i = align.x2y[token.i].dataXd[0, 0]
pred_tags.add((gold_i, getter(token, attr)))
tag_score.score_set(pred_tags, gold_tags)
return {
attr + "_acc": tag_score.fscore,
}
return {f"{attr}_acc": tag_score.fscore}
@staticmethod
def score_token_attr_per_feat(examples, attr, getter=getattr, **cfg):
def score_token_attr_per_feat(
examples: Iterable[Example],
attr: str,
*,
getter: Callable[[Token, str], Any] = getattr,
**cfg,
):
"""Return PRF scores per feat for a token attribute in UFEATS format.
examples (Iterable[Example]): Examples to score
attr (str): The attribute to score.
getter (callable): Defaults to getattr. If provided,
getter (Callable[[Token, str], Any]): Defaults to getattr. If provided,
getter(token, attr) should return the value of the attribute for an
individual token.
RETURNS (dict): A dictionary containing the per-feat PRF scores unders
@ -223,20 +242,26 @@ class Scorer:
per_feat[field].score_set(
pred_per_feat.get(field, set()), gold_per_feat.get(field, set()),
)
return {
attr + "_per_feat": per_feat,
}
return {f"{attr}_per_feat": per_feat}
@staticmethod
def score_spans(examples, attr, getter=getattr, **cfg):
def score_spans(
examples: Iterable[Example],
attr: str,
*,
getter: Callable[[Doc, str], Any] = getattr,
**cfg,
) -> Dict[str, Any]:
"""Returns PRF scores for labeled spans.
examples (Iterable[Example]): Examples to score
attr (str): The attribute to score.
getter (callable): Defaults to getattr. If provided,
getter (Callable[[Doc, str], Any]): Defaults to getattr. If provided,
getter(doc, attr) should return the spans for the individual doc.
RETURNS (dict): A dictionary containing the PRF scores under the
keys attr_p/r/f and the per-type PRF scores under attr_per_type.
RETURNS (Dict[str, Any]): A dictionary containing the PRF scores under
the keys attr_p/r/f and the per-type PRF scores under attr_per_type.
DOCS: https://spacy.io/api/scorer#score_spans
"""
score = PRFScore()
score_per_type = dict()
@ -256,14 +281,12 @@ class Scorer:
# Find all predidate labels, for all and per type
gold_spans = set()
pred_spans = set()
# Special case for ents:
# If we have missing values in the gold, we can't easily tell
# whether our NER predictions are true.
# It seems bad but it's what we've always done.
if attr == "ents" and not all(token.ent_iob != 0 for token in gold_doc):
continue
for span in getter(gold_doc, attr):
gold_span = (span.label_, span.start, span.end - 1)
gold_spans.add(gold_span)
@ -279,38 +302,39 @@ class Scorer:
# Score for all labels
score.score_set(pred_spans, gold_spans)
results = {
attr + "_p": score.precision,
attr + "_r": score.recall,
attr + "_f": score.fscore,
attr + "_per_type": {k: v.to_dict() for k, v in score_per_type.items()},
f"{attr}_p": score.precision,
f"{attr}_r": score.recall,
f"{attr}_f": score.fscore,
f"{attr}_per_type": {k: v.to_dict() for k, v in score_per_type.items()},
}
return results
@staticmethod
def score_cats(
examples,
attr,
getter=getattr,
labels=[],
multi_label=True,
positive_label=None,
**cfg
):
examples: Iterable[Example],
attr: str,
*,
getter: Callable[[Doc, str], Any] = getattr,
labels: Iterable[str] = tuple(),
multi_label: bool = True,
positive_label: Optional[str] = None,
**cfg,
) -> Dict[str, Any]:
"""Returns PRF and ROC AUC scores for a doc-level attribute with a
dict with scores for each label like Doc.cats. The reported overall
score depends on the scorer settings.
examples (Iterable[Example]): Examples to score
attr (str): The attribute to score.
getter (callable): Defaults to getattr. If provided,
getter (Callable[[Doc, str], Any]): Defaults to getattr. If provided,
getter(doc, attr) should return the values for the individual doc.
labels (Iterable[str]): The set of possible labels. Defaults to [].
multi_label (bool): Whether the attribute allows multiple labels.
Defaults to True.
positive_label (str): The positive label for a binary task with
exclusive classes. Defaults to None.
RETURNS (dict): A dictionary containing the scores, with inapplicable
scores as None:
RETURNS (Dict[str, Any]): A dictionary containing the scores, with
inapplicable scores as None:
for all:
attr_score (one of attr_f / attr_macro_f / attr_macro_auc),
attr_score_desc (text description of the overall score),
@ -319,6 +343,8 @@ class Scorer:
for binary exclusive with positive label: attr_p/r/f
for 3+ exclusive classes, macro-averaged fscore: attr_macro_f
for multilabel, macro-averaged AUC: attr_macro_auc
DOCS: https://spacy.io/api/scorer#score_cats
"""
score = PRFScore()
f_per_type = dict()
@ -367,64 +393,67 @@ class Scorer:
)
)
results = {
attr + "_score": None,
attr + "_score_desc": None,
attr + "_p": None,
attr + "_r": None,
attr + "_f": None,
attr + "_macro_f": None,
attr + "_macro_auc": None,
attr + "_f_per_type": {k: v.to_dict() for k, v in f_per_type.items()},
attr + "_auc_per_type": {k: v.score for k, v in auc_per_type.items()},
f"{attr}_score": None,
f"{attr}_score_desc": None,
f"{attr}_p": None,
f"{attr}_r": None,
f"{attr}_f": None,
f"{attr}_macro_f": None,
f"{attr}_macro_auc": None,
f"{attr}_f_per_type": {k: v.to_dict() for k, v in f_per_type.items()},
f"{attr}_auc_per_type": {k: v.score for k, v in auc_per_type.items()},
}
if len(labels) == 2 and not multi_label and positive_label:
results[attr + "_p"] = score.precision
results[attr + "_r"] = score.recall
results[attr + "_f"] = score.fscore
results[attr + "_score"] = results[attr + "_f"]
results[attr + "_score_desc"] = "F (" + positive_label + ")"
results[f"{attr}_p"] = score.precision
results[f"{attr}_r"] = score.recall
results[f"{attr}_f"] = score.fscore
results[f"{attr}_score"] = results[f"{attr}_f"]
results[f"{attr}_score_desc"] = f"F ({positive_label})"
elif not multi_label:
results[attr + "_macro_f"] = sum(
results[f"{attr}_macro_f"] = sum(
[score.fscore for label, score in f_per_type.items()]
) / (len(f_per_type) + 1e-100)
results[attr + "_score"] = results[attr + "_macro_f"]
results[attr + "_score_desc"] = "macro F"
results[f"{attr}_score"] = results[f"{attr}_macro_f"]
results[f"{attr}_score_desc"] = "macro F"
else:
results[attr + "_macro_auc"] = max(
results[f"{attr}_macro_auc"] = max(
sum([score.score for label, score in auc_per_type.items()])
/ (len(auc_per_type) + 1e-100),
-1,
)
results[attr + "_score"] = results[attr + "_macro_auc"]
results[attr + "_score_desc"] = "macro AUC"
results[f"{attr}_score"] = results[f"{attr}_macro_auc"]
results[f"{attr}_score_desc"] = "macro AUC"
return results
@staticmethod
def score_deps(
examples,
attr,
getter=getattr,
head_attr="head",
head_getter=getattr,
ignore_labels=tuple(),
**cfg
):
examples: Iterable[Example],
attr: str,
*,
getter: Callable[[Token, str], Any] = getattr,
head_attr: str = "head",
head_getter: Callable[[Token, str], Any] = getattr,
ignore_labels: Tuple[str] = tuple(),
**cfg,
) -> Dict[str, Any]:
"""Returns the UAS, LAS, and LAS per type scores for dependency
parses.
examples (Iterable[Example]): Examples to score
attr (str): The attribute containing the dependency label.
getter (callable): Defaults to getattr. If provided,
getter (Callable[[Token, str], Any]): Defaults to getattr. If provided,
getter(token, attr) should return the value of the attribute for an
individual token.
head_attr (str): The attribute containing the head token. Defaults to
'head'.
head_getter (callable): Defaults to getattr. If provided,
head_getter (Callable[[Token, str], Any]): Defaults to getattr. If provided,
head_getter(token, attr) should return the value of the head for an
individual token.
ignore_labels (Tuple): Labels to ignore while scoring (e.g., punct).
RETURNS (dict): A dictionary containing the scores:
RETURNS (Dict[str, Any]): A dictionary containing the scores:
attr_uas, attr_las, and attr_las_per_type.
DOCS: https://spacy.io/api/scorer#score_deps
"""
unlabelled = PRFScore()
labelled = PRFScore()
@ -482,10 +511,11 @@ class Scorer:
set(item[:2] for item in pred_deps), set(item[:2] for item in gold_deps)
)
return {
attr + "_uas": unlabelled.fscore,
attr + "_las": labelled.fscore,
attr
+ "_las_per_type": {k: v.to_dict() for k, v in labelled_per_dep.items()},
f"{attr}_uas": unlabelled.fscore,
f"{attr}_las": labelled.fscore,
f"{attr}_las_per_type": {
k: v.to_dict() for k, v in labelled_per_dep.items()
},
}

View File

@ -4,8 +4,8 @@ from spacy import registry
from spacy.gold import Example
from spacy.pipeline import DependencyParser
from spacy.tokens import Doc
from spacy.syntax.nonproj import projectivize
from spacy.syntax.arc_eager import ArcEager
from spacy.pipeline._parser_internals.nonproj import projectivize
from spacy.pipeline._parser_internals.arc_eager import ArcEager
from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL

View File

@ -5,7 +5,7 @@ from spacy.lang.en import English
from spacy.language import Language
from spacy.lookups import Lookups
from spacy.syntax.ner import BiluoPushDown
from spacy.pipeline._parser_internals.ner import BiluoPushDown
from spacy.gold import Example
from spacy.tokens import Doc
from spacy.vocab import Vocab

View File

@ -3,8 +3,8 @@ import pytest
from spacy import registry
from spacy.gold import Example
from spacy.vocab import Vocab
from spacy.syntax.arc_eager import ArcEager
from spacy.syntax.nn_parser import Parser
from spacy.pipeline._parser_internals.arc_eager import ArcEager
from spacy.pipeline.transition_parser import Parser
from spacy.tokens.doc import Doc
from thinc.api import Model
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL

View File

@ -1,7 +1,7 @@
import pytest
from spacy.syntax.nonproj import ancestors, contains_cycle, is_nonproj_arc
from spacy.syntax.nonproj import is_nonproj_tree
from spacy.syntax import nonproj
from spacy.pipeline._parser_internals.nonproj import ancestors, contains_cycle, is_nonproj_arc
from spacy.pipeline._parser_internals.nonproj import is_nonproj_tree
from spacy.pipeline._parser_internals import nonproj
from ..util import get_doc

View File

@ -1,15 +1,10 @@
import spacy.language
from spacy.language import Language
from spacy.pipe_analysis import print_summary, validate_attrs
from spacy.pipe_analysis import get_assigns_for_attr, get_requires_for_attr
from spacy.pipe_analysis import count_pipeline_interdependencies
from spacy.pipe_analysis import get_attr_info, validate_attrs
from mock import Mock
import pytest
def test_component_decorator_assigns():
spacy.language.ENABLE_PIPELINE_ANALYSIS = True
@Language.component("c1", assigns=["token.tag", "doc.tensor"])
def test_component1(doc):
return doc
@ -32,10 +27,11 @@ def test_component_decorator_assigns():
nlp = Language()
nlp.add_pipe("c1")
with pytest.warns(UserWarning):
nlp.add_pipe("c2")
nlp.add_pipe("c2")
problems = nlp.analyze_pipes()["problems"]
assert problems["c2"] == ["token.pos"]
nlp.add_pipe("c3")
assert get_assigns_for_attr(nlp, "doc.tensor") == ["c1", "c2"]
assert get_attr_info(nlp, "doc.tensor")["assigns"] == ["c1", "c2"]
nlp.add_pipe("c1", name="c4")
test_component4_meta = nlp.get_pipe_meta("c1")
assert test_component4_meta.factory == "c1"
@ -43,9 +39,8 @@ def test_component_decorator_assigns():
assert not Language.has_factory("c4")
assert nlp.pipe_factories["c1"] == "c1"
assert nlp.pipe_factories["c4"] == "c1"
assert get_assigns_for_attr(nlp, "doc.tensor") == ["c1", "c2", "c4"]
assert get_requires_for_attr(nlp, "token.pos") == ["c2"]
assert print_summary(nlp, no_print=True)
assert get_attr_info(nlp, "doc.tensor")["assigns"] == ["c1", "c2", "c4"]
assert get_attr_info(nlp, "token.pos")["requires"] == ["c2"]
assert nlp("hello world")
@ -100,7 +95,6 @@ def test_analysis_validate_attrs_invalid(attr):
def test_analysis_validate_attrs_remove_pipe():
"""Test that attributes are validated correctly on remove."""
spacy.language.ENABLE_PIPELINE_ANALYSIS = True
@Language.component("pipe_analysis_c6", assigns=["token.tag"])
def c1(doc):
@ -112,26 +106,9 @@ def test_analysis_validate_attrs_remove_pipe():
nlp = Language()
nlp.add_pipe("pipe_analysis_c6")
with pytest.warns(UserWarning):
nlp.add_pipe("pipe_analysis_c7")
with pytest.warns(None) as record:
nlp.remove_pipe("pipe_analysis_c7")
assert not record.list
def test_pipe_interdependencies():
prefix = "test_pipe_interdependencies"
@Language.component(f"{prefix}.fancifier", assigns=("doc._.fancy",))
def fancifier(doc):
return doc
@Language.component(f"{prefix}.needer", requires=("doc._.fancy",))
def needer(doc):
return doc
nlp = Language()
nlp.add_pipe(f"{prefix}.fancifier")
nlp.add_pipe(f"{prefix}.needer")
counts = count_pipeline_interdependencies(nlp)
assert counts == [1, 0]
nlp.add_pipe("pipe_analysis_c7")
problems = nlp.analyze_pipes()["problems"]
assert problems["pipe_analysis_c7"] == ["token.pos"]
nlp.remove_pipe("pipe_analysis_c7")
problems = nlp.analyze_pipes()["problems"]
assert all(p == [] for p in problems.values())

View File

@ -118,7 +118,7 @@ def test_overfitting_IO():
# Test scoring
scores = nlp.evaluate(
train_examples, component_cfg={"scorer": {"positive_label": "POSITIVE"}}
train_examples, scorer_cfg={"positive_label": "POSITIVE"}
)
assert scores["cats_f"] == 1.0
assert scores["cats_score"] == 1.0

View File

@ -7,7 +7,7 @@ import importlib.util
import re
from pathlib import Path
import thinc
from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer, Model
from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer
import functools
import itertools
import numpy.random
@ -24,8 +24,6 @@ import tempfile
import shutil
import shlex
import inspect
from thinc.types import Unserializable
try:
import cupy.random

View File

@ -6,6 +6,7 @@ menu:
- ['Tok2Vec', 'tok2vec']
- ['Transformers', 'transformers']
- ['Parser & NER', 'parser']
- ['Tagging', 'tagger']
- ['Text Classification', 'textcat']
- ['Entity Linking', 'entitylinker']
---
@ -18,6 +19,30 @@ TODO: intro and how architectures work, link to
### spacy.HashEmbedCNN.v1 {#HashEmbedCNN}
<!-- TODO: intro -->
> #### Example Config
>
> ```ini
> [model]
> @architectures = "spacy.HashEmbedCNN.v1"
> # TODO: ...
>
> [model.tok2vec]
> # ...
> ```
| Name | Type | Description |
| -------------------- | ----- | ----------- |
| `width` | int | |
| `depth` | int | |
| `embed_size` | int | |
| `window_size` | int | |
| `maxout_pieces` | int | |
| `subword_features` | bool | |
| `dropout` | float | |
| `pretrained_vectors` | bool | |
### spacy.HashCharEmbedCNN.v1 {#HashCharEmbedCNN}
### spacy.HashCharEmbedBiLSTM.v1 {#HashCharEmbedBiLSTM}
@ -99,6 +124,28 @@ architectures into your training config.
| `use_upper` | bool | |
| `nO` | int | |
## Tagging architectures {#tagger source="spacy/ml/models/tagger.py"}
### spacy.Tagger.v1 {#Tagger}
<!-- TODO: intro -->
> #### Example Config
>
> ```ini
> [model]
> @architectures = "spacy.Tagger.v1"
> nO = null
>
> [model.tok2vec]
> # ...
> ```
| Name | Type | Description |
| --------- | ------------------------------------------ | ----------- |
| `tok2vec` | [`Model`](https://thinc.ai/docs/api-model) | |
| `nO` | int | |
## Text classification architectures {#textcat source="spacy/ml/models/textcat.py"}
### spacy.TextCatEnsemble.v1 {#TextCatEnsemble}
@ -112,3 +159,21 @@ architectures into your training config.
## Entity linking architectures {#entitylinker source="spacy/ml/models/entity_linker.py"}
### spacy.EntityLinker.v1 {#EntityLinker}
<!-- TODO: intro -->
> #### Example Config
>
> ```ini
> [model]
> @architectures = "spacy.EntityLinker.v1"
> nO = null
>
> [model.tok2vec]
> # ...
> ```
| Name | Type | Description |
| --------- | ------------------------------------------ | ----------- |
| `tok2vec` | [`Model`](https://thinc.ai/docs/api-model) | |
| `nO` | int | |

View File

@ -29,9 +29,11 @@ architectures and their arguments and hyperparameters.
> nlp.add_pipe("parser", config=config)
> ```
<!-- TODO: finish API docs -->
| Setting | Type | Description | Default |
| ------- | ------------------------------------------ | ----------------- | ----------------------------------------------------------------- |
| `moves` | list | <!-- TODO: --> | `None` |
| `moves` | list | | `None` |
| `model` | [`Model`](https://thinc.ai/docs/api-model) | The model to use. | [TransitionBasedParser](/api/architectures#TransitionBasedParser) |
```python
@ -59,17 +61,19 @@ Create a new pipeline instance. In your application, you would normally use a
shortcut for this and instantiate the component using its string name and
[`nlp.add_pipe`](/api/language#add_pipe).
<!-- TODO: finish API docs -->
| Name | Type | Description |
| ----------------------------- | ------------------------------------------ | ------------------------------------------------------------------------------------------- |
| `vocab` | `Vocab` | The shared vocabulary. |
| `model` | [`Model`](https://thinc.ai/docs/api-model) | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. |
| `name` | str | String name of the component instance. Used to add entries to the `losses` during training. |
| `moves` | list | <!-- TODO: --> |
| `moves` | list | |
| _keyword-only_ | | |
| `update_with_oracle_cut_size` | int | <!-- TODO: --> |
| `multitasks` | `Iterable` | <!-- TODO: --> |
| `learn_tokens` | bool | <!-- TODO: --> |
| `min_action_freq` | int | <!-- TODO: --> |
| `update_with_oracle_cut_size` | int | |
| `multitasks` | `Iterable` | |
| `learn_tokens` | bool | |
| `min_action_freq` | int | |
## DependencyParser.\_\_call\_\_ {#call tag="method"}

View File

@ -32,12 +32,14 @@ architectures and their arguments and hyperparameters.
> nlp.add_pipe("entity_linker", config=config)
> ```
<!-- TODO: finish API docs -->
| Setting | Type | Description | Default |
| ---------------- | ------------------------------------------ | ----------------- | ----------------------------------------------- |
| `kb` | `KnowledgeBase` | <!-- TODO: --> | `None` |
| `labels_discard` | `Iterable[str]` | <!-- TODO: --> | `[]` |
| `incl_prior` | bool | <!-- TODO: --> |  `True` |
| `incl_context` | bool | <!-- TODO: --> | `True` |
| `kb` | `KnowledgeBase` | | `None` |
| `labels_discard` | `Iterable[str]` | | `[]` |
| `incl_prior` | bool | |  `True` |
| `incl_context` | bool | | `True` |
| `model` | [`Model`](https://thinc.ai/docs/api-model) | The model to use. | [EntityLinker](/api/architectures#EntityLinker) |
```python
@ -65,16 +67,18 @@ Create a new pipeline instance. In your application, you would normally use a
shortcut for this and instantiate the component using its string name and
[`nlp.add_pipe`](/api/language#add_pipe).
<!-- TODO: finish API docs -->
| Name | Type | Description |
| ---------------- | --------------- | ------------------------------------------------------------------------------------------- |
| `vocab` | `Vocab` | The shared vocabulary. |
| `model` | `Model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. |
| `name` | str | String name of the component instance. Used to add entries to the `losses` during training. |
| _keyword-only_ | | |
| `kb` | `KnowlegeBase` | <!-- TODO: --> |
| `labels_discard` | `Iterable[str]` | <!-- TODO: --> |
| `incl_prior` | bool | <!-- TODO: --> |
| `incl_context` | bool | <!-- TODO: --> |
| `kb` | `KnowlegeBase` | |
| `labels_discard` | `Iterable[str]` | |
| `incl_prior` | bool | |
| `incl_context` | bool | |
## EntityLinker.\_\_call\_\_ {#call tag="method"}

View File

@ -29,9 +29,11 @@ architectures and their arguments and hyperparameters.
> nlp.add_pipe("ner", config=config)
> ```
<!-- TODO: finish API docs -->
| Setting | Type | Description | Default |
| ------- | ------------------------------------------ | ----------------- | ----------------------------------------------------------------- |
| `moves` | list | <!-- TODO: --> | `None` |
| `moves` | list | | `None` |
| `model` | [`Model`](https://thinc.ai/docs/api-model) | The model to use. | [TransitionBasedParser](/api/architectures#TransitionBasedParser) |
```python
@ -59,17 +61,19 @@ Create a new pipeline instance. In your application, you would normally use a
shortcut for this and instantiate the component using its string name and
[`nlp.add_pipe`](/api/language#add_pipe).
<!-- TODO: finish API docs -->
| Name | Type | Description |
| ----------------------------- | ------------------------------------------ | ------------------------------------------------------------------------------------------- |
| `vocab` | `Vocab` | The shared vocabulary. |
| `model` | [`Model`](https://thinc.ai/docs/api-model) | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. |
| `name` | str | String name of the component instance. Used to add entries to the `losses` during training. |
| `moves` | list | <!-- TODO: --> |
| `moves` | list | |
| _keyword-only_ | | |
| `update_with_oracle_cut_size` | int | <!-- TODO: --> |
| `multitasks` | `Iterable` | <!-- TODO: --> |
| `learn_tokens` | bool | <!-- TODO: --> |
| `min_action_freq` | int | <!-- TODO: --> |
| `update_with_oracle_cut_size` | int | |
| `multitasks` | `Iterable` | |
| `learn_tokens` | bool | |
| `min_action_freq` | int | |
## EntityRecognizer.\_\_call\_\_ {#call tag="method"}

View File

@ -8,9 +8,8 @@ new: 3.0
An `Example` holds the information for one training instance. It stores two
`Doc` objects: one for holding the gold-standard reference data, and one for
holding the predictions of the pipeline. An `Alignment` <!-- TODO: link? -->
object stores the alignment between these two documents, as they can differ in
tokenization.
holding the predictions of the pipeline. An `Alignment` object stores the
alignment between these two documents, as they can differ in tokenization.
## Example.\_\_init\_\_ {#init tag="method"}

View File

@ -98,10 +98,10 @@ decorator. For more details and examples, see the
| ----------------------- | -------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `name` | str | The name of the component factory. |
| _keyword-only_ | | |
| `assigns` | `Iterable[str]` | `Doc` or `Token` attributes assigned by this component, e.g. `["token.ent_id"]`. Used for pipeline analysis. <!-- TODO: link to something --> |
| `requires` | `Iterable[str]` | `Doc` or `Token` attributes required by this component, e.g. `["token.ent_id"]`. Used for pipeline analysis. <!-- TODO: link to something --> |
| `retokenizes` | bool | Whether the component changes tokenization. Used for pipeline analysis. <!-- TODO: link to something --> |
| `scores` | `Iterable[str]` | All scores set by the components if it's trainable, e.g. `["ents_f", "ents_r", "ents_p"]`. |
| `assigns` | `Iterable[str]` | `Doc` or `Token` attributes assigned by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis).. |
| `requires` | `Iterable[str]` | `Doc` or `Token` attributes required by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). |
| `retokenizes` | bool | Whether the component changes tokenization. Used for [pipe analysis](/usage/processing-pipelines#analysis). |
| `scores` | `Iterable[str]` | All scores set by the components if it's trainable, e.g. `["ents_f", "ents_r", "ents_p"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). |
| `default_score_weights` | `Dict[str, float]` | The scores to report during training, and their default weight towards the final score used to select the best model. Weights should sum to `1.0` per component and will be combined and normalized for the whole pipeline. |
| `func` | `Optional[Callable]` | Optional function if not used a a decorator. |
@ -146,10 +146,10 @@ examples, see the
| `name` | str | The name of the component factory. |
| _keyword-only_ | | |
| `default_config` | `Dict[str, any]` | The default config, describing the default values of the factory arguments. |
| `assigns` | `Iterable[str]` | `Doc` or `Token` attributes assigned by this component, e.g. `["token.ent_id"]`. Used for pipeline analysis. <!-- TODO: link to something --> |
| `requires` | `Iterable[str]` | `Doc` or `Token` attributes required by this component, e.g. `["token.ent_id"]`. Used for pipeline analysis. <!-- TODO: link to something --> |
| `retokenizes` | bool | Whether the component changes tokenization. Used for pipeline analysis. <!-- TODO: link to something --> |
| `scores` | `Iterable[str]` | All scores set by the components if it's trainable, e.g. `["ents_f", "ents_r", "ents_p"]`. |
| `assigns` | `Iterable[str]` | `Doc` or `Token` attributes assigned by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). |
| `requires` | `Iterable[str]` | `Doc` or `Token` attributes required by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). |
| `retokenizes` | bool | Whether the component changes tokenization. Used for [pipe analysis](/usage/processing-pipelines#analysis). |
| `scores` | `Iterable[str]` | All scores set by the components if it's trainable, e.g. `["ents_f", "ents_r", "ents_p"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). |
| `default_score_weights` | `Dict[str, float]` | The scores to report during training, and their default weight towards the final score used to select the best model. Weights should sum to `1.0` per component and will be combined and normalized for the whole pipeline. |
| `func` | `Optional[Callable]` | Optional function if not used a a decorator. |
@ -302,6 +302,7 @@ Evaluate a model's pipeline components.
| `batch_size` | int | The batch size to use. |
| `scorer` | `Scorer` | Optional [`Scorer`](/api/scorer) to use. If not passed in, a new one will be created. |
| `component_cfg` | `Dict[str, dict]` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. |
| `scorer_cfg` | `Dict[str, Any]` | Optional dictionary of keyword arguments for the `Scorer`. Defaults to `None`. |
| **RETURNS** | `Dict[str, Union[float, dict]]` | A dictionary of evaluation scores. |
## Language.use_params {#use_params tag="contextmanager, method"}
@ -597,6 +598,97 @@ contains the information about the component and its default provided by the
| `name` | str | The pipeline component name. |
| **RETURNS** | [`FactoryMeta`](#factorymeta) |  The factory meta. |
## Language.analyze_pipes {#analyze_pipes tag="method" new="3"}
Analyze the current pipeline components and show a summary of the attributes
they assign and require, and the scores they set. The data is based on the
information provided in the [`@Language.component`](/api/language#component) and
[`@Language.factory`](/api/language#factory) decorator. If requirements aren't
met, e.g. if a component specifies a required property that is not set by a
previous component, a warning is shown.
<Infobox variant="warning" title="Important note">
The pipeline analysis is static and does **not actually run the components**.
This means that it relies on the information provided by the components
themselves. If a custom component declares that it assigns an attribute but it
doesn't, the pipeline analysis won't catch that.
</Infobox>
> #### Example
>
> ```python
> nlp = spacy.blank("en")
> nlp.add_pipe("tagger")
> nlp.add_pipe("entity_linker")
> analysis = nlp.analyze_pipes()
> ```
<Accordion title="Example output" spaced>
```json
### Structured
{
"summary": {
"tagger": {
"assigns": ["token.tag"],
"requires": [],
"scores": ["tag_acc", "pos_acc", "lemma_acc"],
"retokenizes": false
},
"entity_linker": {
"assigns": ["token.ent_kb_id"],
"requires": ["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"],
"scores": [],
"retokenizes": false
}
},
"problems": {
"tagger": [],
"entity_linker": ["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"]
},
"attrs": {
"token.ent_iob": { "assigns": [], "requires": ["entity_linker"] },
"doc.ents": { "assigns": [], "requires": ["entity_linker"] },
"token.ent_kb_id": { "assigns": ["entity_linker"], "requires": [] },
"doc.sents": { "assigns": [], "requires": ["entity_linker"] },
"token.tag": { "assigns": ["tagger"], "requires": [] },
"token.ent_type": { "assigns": [], "requires": ["entity_linker"] }
}
}
```
```
### Pretty
============================= Pipeline Overview =============================
# Component Assigns Requires Scores Retokenizes
- ------------- --------------- -------------- --------- -----------
0 tagger token.tag tag_acc False
pos_acc
lemma_acc
1 entity_linker token.ent_kb_id doc.ents False
doc.sents
token.ent_iob
token.ent_type
================================ Problems (4) ================================
⚠ 'entity_linker' requirements not met: doc.ents, doc.sents,
token.ent_iob, token.ent_type
```
</Accordion>
| Name | Type | Description |
| -------------- | ----------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| _keyword-only_ | | |
| `keys` | `List[str]` | The values to display in the table. Corresponds to attributes of the [`FactoryMeta`](/api/language#factorymeta). Defaults to `["assigns", "requires", "scores", "retokenizes"]`. |
| `pretty` | bool | Pretty-print the results as a table. Defaults to `False`. |
| **RETURNS** | dict | Dictionary containing the pipe analysis, keyed by `"summary"` (component meta by pipe), `"problems"` (attribute names by pipe) and `"attrs"` (pipes that assign and require an attribute, keyed by attribute). |
## Language.meta {#meta tag="property"}
Custom meta data for the Language class. If a model is loaded, contains meta
@ -832,8 +924,8 @@ instance and factory instance.
| ----------------------- | ------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `factory` | str | The name of the registered component factory. |
| `default_config` | `Dict[str, Any]` | The default config, describing the default values of the factory arguments. |
| `assigns` | `Iterable[str]` | `Doc` or `Token` attributes assigned by this component, e.g. `["token.ent_id"]`. Used for pipeline analysis. <!-- TODO: link to something --> |
| `requires` | `Iterable[str]` | `Doc` or `Token` attributes required by this component, e.g. `["token.ent_id"]`. Used for pipeline analysis. <!-- TODO: link to something -->  |
| `retokenizes` | bool | Whether the component changes tokenization. Used for pipeline analysis. <!-- TODO: link to something -->  |
| `scores` | `Iterable[str]` | All scores set by the components if it's trainable, e.g. `["ents_f", "ents_r", "ents_p"]`. |
| `assigns` | `Iterable[str]` | `Doc` or `Token` attributes assigned by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). |
| `requires` | `Iterable[str]` | `Doc` or `Token` attributes required by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis).  |
| `retokenizes` | bool | Whether the component changes tokenization. Used for [pipe analysis](/usage/processing-pipelines#analysis).  |
| `scores` | `Iterable[str]` | All scores set by the components if it's trainable, e.g. `["ents_f", "ents_r", "ents_p"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). |
| `default_score_weights` | `Dict[str, float]` | The scores to report during training, and their default weight towards the final score used to select the best model. Weights should sum to `1.0` per component and will be combined and normalized for the whole pipeline. |

View File

@ -63,14 +63,16 @@ Create a new pipeline instance. In your application, you would normally use a
shortcut for this and instantiate the component using its string name and
[`nlp.add_pipe`](/api/language#add_pipe).
<!-- TODO: finish API docs -->
| Name | Type | Description |
| -------------- | ------- | ------------------------------------------------------------------------------------------- |
| `vocab` | `Vocab` | The shared vocabulary. |
| `model` | `Model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. |
| `name` | str | String name of the component instance. Used to add entries to the `losses` during training. |
| _keyword-only_ | | |
| `labels_morph` | dict | <!-- TODO: --> |
| `labels_pos` | dict | <!-- TODO: --> |
| `labels_morph` | dict | |
| `labels_pos` | dict | |
## Morphologizer.\_\_call\_\_ {#call tag="method"}

View File

@ -6,10 +6,9 @@ source: spacy/scorer.py
---
The `Scorer` computes evaluation scores. It's typically created by
[`Language.evaluate`](/api/language#evaluate).
In addition, the `Scorer` provides a number of evaluation methods for evaluating
`Token` and `Doc` attributes.
[`Language.evaluate`](/api/language#evaluate). In addition, the `Scorer`
provides a number of evaluation methods for evaluating [`Token`](/api/token) and
[`Doc`](/api/doc) attributes.
## Scorer.\_\_init\_\_ {#init tag="method"}
@ -20,10 +19,10 @@ Create a new `Scorer`.
> ```python
> from spacy.scorer import Scorer
>
> # default scoring pipeline
> # Default scoring pipeline
> scorer = Scorer()
>
> # provided scoring pipeline
> # Provided scoring pipeline
> nlp = spacy.load("en_core_web_sm")
> scorer = Scorer(nlp)
> ```
@ -40,16 +39,20 @@ scoring methods provided by the components in the pipeline.
The returned `Dict` contains the scores provided by the individual pipeline
components. For the scoring methods provided by the `Scorer` and use by the core
pipeline components, the individual score names start with the `Token` or `Doc`
attribute being scored: `token_acc`, `token_p/r/f`, `sents_p/r/f`, `tag_acc`,
`pos_acc`, `morph_acc`, `morph_per_feat`, `lemma_acc`, `dep_uas`, `dep_las`,
`dep_las_per_type`, `ents_p/r/f`, `ents_per_type`, `textcat_macro_auc`,
`textcat_macro_f`.
attribute being scored:
- `token_acc`, `token_p`, `token_r`, `token_f`,
- `sents_p`, `sents_r`, `sents_f`
- `tag_acc`, `pos_acc`, `morph_acc`, `morph_per_feat`, `lemma_acc`
- `dep_uas`, `dep_las`, `dep_las_per_type`
- `ents_p`, `ents_r` `ents_f`, `ents_per_type`
- `textcat_macro_auc`, `textcat_macro_f`
> #### Example
>
> ```python
> scorer = Scorer()
> scorer.score(examples)
> scores = scorer.score(examples)
> ```
| Name | Type | Description |
@ -57,78 +60,148 @@ attribute being scored: `token_acc`, `token_p/r/f`, `sents_p/r/f`, `tag_acc`,
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| **RETURNS** | `Dict` | A dictionary of scores. |
## Scorer.score_tokenization {#score_tokenization tag="staticmethod"}
## Scorer.score_tokenization {#score_tokenization tag="staticmethod" new="3"}
Scores the tokenization:
- `token_acc`: # correct tokens / # gold tokens
- `token_p/r/f`: PRF for token character spans
- `token_acc`: number of correct tokens / number of gold tokens
- `token_p`, `token_r`, `token_f`: precision, recall and F-score for token
character spans
> #### Example
>
> ```python
> scores = Scorer.score_tokenization(examples)
> ```
| Name | Type | Description |
| ----------- | ------------------- | --------------------------------------------------------------------------------------------- |
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| **RETURNS** | `Dict` | A dictionary containing the scores `token_acc/p/r/f`. |
| **RETURNS** | `Dict` | A dictionary containing the scores `token_acc`, `token_p`, `token_r`, `token_f`. |
## Scorer.score_token_attr {#score_token_attr tag="staticmethod"}
## Scorer.score_token_attr {#score_token_attr tag="staticmethod" new="3"}
Scores a single token attribute.
| Name | Type | Description |
| ----------- | ------------------- | ----------------------------------------------------------------------------------------------------------------------------- |
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| `attr` | `str` | The attribute to score. |
| `getter` | `callable` | Defaults to `getattr`. If provided, `getter(token, attr)` should return the value of the attribute for an individual `Token`. |
| **RETURNS** | `Dict` | A dictionary containing the score `attr_acc`. |
> #### Example
>
> ```python
> scores = Scorer.score_token_attr(examples, "pos")
> print(scores["pos_acc"])
> ```
## Scorer.score_token_attr_per_feat {#score_token_attr_per_feat tag="staticmethod"}
| Name | Type | Description |
| -------------- | ------------------- | ----------------------------------------------------------------------------------------------------------------------------- |
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| `attr` | `str` | The attribute to score. |
| _keyword-only_ | | |
| `getter` | `Callable` | Defaults to `getattr`. If provided, `getter(token, attr)` should return the value of the attribute for an individual `Token`. |
| **RETURNS** | `Dict[str, float]` | A dictionary containing the score `{attr}_acc`. |
Scores a single token attribute per feature for a token attribute in UFEATS
## Scorer.score_token_attr_per_feat {#score_token_attr_per_feat tag="staticmethod" new="3"}
Scores a single token attribute per feature for a token attribute in
[UFEATS](https://universaldependencies.org/format.html#morphological-annotation)
format.
| Name | Type | Description |
| ----------- | ------------------- | ----------------------------------------------------------------------------------------------------------------------------- |
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| `attr` | `str` | The attribute to score. |
| `getter` | `callable` | Defaults to `getattr`. If provided, `getter(token, attr)` should return the value of the attribute for an individual `Token`. |
| **RETURNS** | `Dict` | A dictionary containing the per-feature PRF scores unders the key `attr_per_feat`. |
> #### Example
>
> ```python
> scores = Scorer.score_token_attr_per_feat(examples, "morph")
> print(scores["morph_per_feat"])
> ```
## Scorer.score_spans {#score_spans tag="staticmethod"}
| Name | Type | Description |
| -------------- | ------------------- | ----------------------------------------------------------------------------------------------------------------------------- |
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| `attr` | `str` | The attribute to score. |
| _keyword-only_ | | |
| `getter` | `Callable` | Defaults to `getattr`. If provided, `getter(token, attr)` should return the value of the attribute for an individual `Token`. |
| **RETURNS** | `Dict` | A dictionary containing the per-feature PRF scores under the key `{attr}_per_feat`. |
## Scorer.score_spans {#score_spans tag="staticmethod" new="3"}
Returns PRF scores for labeled or unlabeled spans.
| Name | Type | Description |
| ----------- | ------------------- | --------------------------------------------------------------------------------------------------------------------- |
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| `attr` | `str` | The attribute to score. |
| `getter` | `callable` | Defaults to `getattr`. If provided, `getter(doc, attr)` should return the `Span` objects for an individual `Doc`. |
| **RETURNS** | `Dict` | A dictionary containing the PRF scores under the keys `attr_p/r/f` and the per-type PRF scores under `attr_per_type`. |
> #### Example
>
> ```python
> scores = Scorer.score_spans(examples, "ents")
> print(scores["ents_f"])
> ```
## Scorer.score_deps {#score_deps tag="staticmethod"}
| Name | Type | Description |
| -------------- | ------------------- | --------------------------------------------------------------------------------------------------------------------------------------------- |
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| `attr` | `str` | The attribute to score. |
| _keyword-only_ | | |
| `getter` | `Callable` | Defaults to `getattr`. If provided, `getter(doc, attr)` should return the `Span` objects for an individual `Doc`. |
| **RETURNS** | `Dict` | A dictionary containing the PRF scores under the keys `{attr}_p`, `{attr}_r`, `{attr}_f` and the per-type PRF scores under `{attr}_per_type`. |
## Scorer.score_deps {#score_deps tag="staticmethod" new="3"}
Calculate the UAS, LAS, and LAS per type scores for dependency parses.
> #### Example
>
> ```python
> def dep_getter(token, attr):
> dep = getattr(token, attr)
> dep = token.vocab.strings.as_string(dep).lower()
> return dep
>
> scores = Scorer.score_deps(
> examples,
> "dep",
> getter=dep_getter,
> ignore_labels=("p", "punct")
> )
> print(scores["dep_uas"], scores["dep_las"])
> ```
| Name | Type | Description |
| --------------- | ------------------- | ----------------------------------------------------------------------------------------------------------------------------- |
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| `attr` | `str` | The attribute containing the dependency label. |
| `getter` | `callable` | Defaults to `getattr`. If provided, `getter(token, attr)` should return the value of the attribute for an individual `Token`. |
| _keyword-only_ | | |
| `getter` | `Callable` | Defaults to `getattr`. If provided, `getter(token, attr)` should return the value of the attribute for an individual `Token`. |
| `head_attr` | `str` | The attribute containing the head token. |
| `head_getter` | `callable` | Defaults to `getattr`. If provided, `head_getter(token, attr)` should return the head for an individual `Token`. |
| `ignore_labels` | `Tuple` | Labels to ignore while scoring (e.g., `punct`). |
| **RETURNS** | `Dict` | A dictionary containing the scores: `attr_uas`, `attr_las`, and `attr_las_per_type`. |
| **RETURNS** | `Dict` | A dictionary containing the scores: `{attr}_uas`, `{attr}_las`, and `{attr}_las_per_type`. |
## Scorer.score_cats {#score_cats tag="staticmethod"}
## Scorer.score_cats {#score_cats tag="staticmethod" new="3"}
Calculate PRF and ROC AUC scores for a doc-level attribute that is a dict
containing scores for each label like `Doc.cats`. The reported overall score
depends on the scorer settings.
depends on the scorer settings:
| Name | Type | Description |
| ---------------- | ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| `attr` | `str` | The attribute to score. |
| `getter` | `callable` | Defaults to `getattr`. If provided, `getter(doc, attr)` should return the cats for an individual `Doc`. |
| labels | `Iterable[str]` | The set of possible labels. Defaults to `[]`. |
| `multi_label` | `bool` | Whether the attribute allows multiple labels. Defaults to `True`. |
| `positive_label` | `str` | The positive label for a binary task with exclusive classes. Defaults to `None`. |
| **RETURNS** | `Dict` | A dictionary containing the scores, with inapplicable scores as `None`: 1) for all: `attr_score` (one of `attr_f` / `attr_macro_f` / `attr_macro_auc`), `attr_score_desc` (text description of the overall score), `attr_f_per_type`, `attr_auc_per_type`; 2) for binary exclusive with positive label: `attr_p/r/f`; 3) for 3+ exclusive classes, macro-averaged fscore: `attr_macro_f`; 4) for multilabel, macro-averaged AUC: `attr_macro_auc` |
1. **all:** `{attr}_score` (one of `{attr}_f` / `{attr}_macro_f` /
`{attr}_macro_auc`), `{attr}_score_desc` (text description of the overall
score), `{attr}_f_per_type`, `{attr}_auc_per_type`
2. **binary exclusive with positive label:** `{attr}_p`, `{attr}_r`, `{attr}_f`
3. **3+ exclusive classes**, macro-averaged F-score: `{attr}_macro_f`;
4. **multilabel**, macro-averaged AUC: `{attr}_macro_auc`
> #### Example
>
> ```python
> labels = ["LABEL_A", "LABEL_B", "LABEL_C"]
> scores = Scorer.score_cats(
> examples,
> "cats",
> labels=labels
> )
> print(scores["cats_macro_auc"])
> ```
| Name | Type | Description |
| ---------------- | ------------------- | ------------------------------------------------------------------------------------------------------- |
| `examples` | `Iterable[Example]` | The `Example` objects holding both the predictions and the correct gold-standard annotations. |
| `attr` | `str` | The attribute to score. |
| _keyword-only_ | | |
| `getter` | `Callable` | Defaults to `getattr`. If provided, `getter(doc, attr)` should return the cats for an individual `Doc`. |
| labels | `Iterable[str]` | The set of possible labels. Defaults to `[]`. |
| `multi_label` | `bool` | Whether the attribute allows multiple labels. Defaults to `True`. |
| `positive_label` | `str` | The positive label for a binary task with exclusive classes. Defaults to `None`. |
| **RETURNS** | `Dict` | A dictionary containing the scores, with inapplicable scores as `None`. |

View File

@ -290,6 +290,8 @@ factories.
> return Model("custom", forward, dims={"nO": nO})
> ```
<!-- TODO: finish table -->
| Registry name | Description |
| ----------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `architectures` | Registry for functions that create [model architectures](/api/architectures). Can be used to register custom model architectures and reference them in the `config.cfg`. |
@ -297,7 +299,7 @@ factories.
| `languages` | Registry for language-specific `Language` subclasses. Automatically reads from [entry points](/usage/saving-loading#entry-points). |
| `lookups` | Registry for large lookup tables available via `vocab.lookups`. |
| `displacy_colors` | Registry for custom color scheme for the [`displacy` NER visualizer](/usage/visualizers). Automatically reads from [entry points](/usage/saving-loading#entry-points). |
| `assets` | <!-- TODO: what is this used for again?--> |
| `assets` | |
| `optimizers` | Registry for functions that create [optimizers](https://thinc.ai/docs/api-optimizers). |
| `schedules` | Registry for functions that create [schedules](https://thinc.ai/docs/api-schedules). |
| `layers` | Registry for functions that create [layers](https://thinc.ai/docs/api-layers). |

View File

@ -347,50 +347,52 @@ serialization by passing in the string names via the `exclude` argument.
Transformer tokens and outputs for one `Doc` object.
| Name | Type | Description |
| --------- | -------------------------------------------------- | ----------------------------------------- |
| `tokens` | `Dict` | <!-- TODO: --> |
| `tensors` | `List[FloatsXd]` | <!-- TODO: --> |
| `align` | [`Ragged`](https://thinc.ai/docs/api-types#ragged) | <!-- TODO: --> |
| `width` | int | <!-- TODO: also mention it's property --> |
<!-- TODO: finish API docs, also mention "width" is property -->
| Name | Type | Description |
| --------- | -------------------------------------------------- | ----------- |
| `tokens` | `Dict` | |
| `tensors` | `List[FloatsXd]` | |
| `align` | [`Ragged`](https://thinc.ai/docs/api-types#ragged) | |
| `width` | int | |
### TransformerData.empty {#transformerdata-emoty tag="classmethod"}
<!-- TODO: -->
<!-- TODO: finish API docs -->
| Name | Type | Description |
| ----------- | ----------------- | -------------- |
| **RETURNS** | `TransformerData` | <!-- TODO: --> |
| Name | Type | Description |
| ----------- | ----------------- | ----------- |
| **RETURNS** | `TransformerData` | |
## FullTransformerBatch {#fulltransformerbatch tag="dataclass"}
<!-- TODO: -->
<!-- TODO: write, also mention doc_data is property -->
| Name | Type | Description |
| ---------- | -------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------- |
| `spans` | `List[List[Span]]` | <!-- TODO: --> |
| `tokens` | [`transformers.BatchEncoding`](https://huggingface.co/transformers/main_classes/tokenizer.html#transformers.BatchEncoding) | <!-- TODO: --> |
| `tensors` | `List[torch.Tensor]` | <!-- TODO: --> |
| `align` | [`Ragged`](https://thinc.ai/docs/api-types#ragged) | <!-- TODO: --> |
| `doc_data` | `List[TransformerData]` | <!-- TODO: also mention it's property --> |
| Name | Type | Description |
| ---------- | -------------------------------------------------------------------------------------------------------------------------- | ----------- |
| `spans` | `List[List[Span]]` | |
| `tokens` | [`transformers.BatchEncoding`](https://huggingface.co/transformers/main_classes/tokenizer.html#transformers.BatchEncoding) | |
| `tensors` | `List[torch.Tensor]` | |
| `align` | [`Ragged`](https://thinc.ai/docs/api-types#ragged) | |
| `doc_data` | `List[TransformerData]` | |
### FullTransformerBatch.unsplit_by_doc {#fulltransformerbatch-unsplit_by_doc tag="method"}
<!-- TODO: -->
<!-- TODO: write -->
| Name | Type | Description |
| ----------- | ---------------------- | -------------- |
| `arrays` | `List[List[Floats3d]]` | <!-- TODO: --> |
| **RETURNS** | `FullTransformerBatch` | <!-- TODO: --> |
| Name | Type | Description |
| ----------- | ---------------------- | ----------- |
| `arrays` | `List[List[Floats3d]]` | |
| **RETURNS** | `FullTransformerBatch` | |
### FullTransformerBatch.split_by_doc {#fulltransformerbatch-split_by_doc tag="method"}
Split a `TransformerData` object that represents a batch into a list with one
`TransformerData` per `Doc`.
| Name | Type | Description |
| ----------- | ----------------------- | -------------- |
| **RETURNS** | `List[TransformerData]` | <!-- TODO: --> |
| Name | Type | Description |
| ----------- | ----------------------- | ----------- |
| **RETURNS** | `List[TransformerData]` | |
## Span getters {#span_getters tag="registered functions" source="github.com/explosion/spacy-transformers/blob/master/spacy_transformers/span_getters.py"}
@ -421,11 +423,13 @@ getters using the `@registry.span_getters` decorator.
The following built-in functions are available:
<!-- TODO: finish API docs -->
| Name | Description |
| ------------------ | ------------------------------------------------------------------ |
| `doc_spans.v1` | Create a span for each doc (no transformation, process each text). |
| `sent_spans.v1` | Create a span for each sentence if sentence boundaries are set. |
| `strided_spans.v1` | <!-- TODO: --> |
| `strided_spans.v1` | |
## Annotation setters {#annotation_setters tag="registered functions" source="github.com/explosion/spacy-transformers/blob/master/spacy_transformers/annotation_setters.py"}

View File

@ -231,10 +231,10 @@ available pipeline components and component functions.
| `morphologizer` | [`Morphologizer`](/api/morphologizer) | Assign morphological features and coarse-grained POS tags. |
| `senter` | [`SentenceRecognizer`](/api/sentencerecognizer) | Assign sentence boundaries. |
| `sentencizer` | [`Sentencizer`](/api/sentencizer) | Add rule-based sentence segmentation without the dependency parse. |
| `tok2vec` | [`Tok2Vec`](/api/tok2vec) | <!-- TODO: --> |
| `tok2vec` | [`Tok2Vec`](/api/tok2vec) | |
| `transformer` | [`Transformer`](/api/transformer) | Assign the tokens and outputs of a transformer model. |
<!-- TODO: update with more components -->
<!-- TODO: finish and update with more components -->
<!-- TODO: explain default config and factories -->
@ -311,6 +311,99 @@ nlp.rename_pipe("ner", "entityrecognizer")
nlp.replace_pipe("tagger", my_custom_tagger)
```
### Analyzing pipeline components {#analysis new="3"}
The [`nlp.analyze_pipes`](/api/language#analyze_pipes) method analyzes the
components in the current pipeline and outputs information about them, like the
attributes they set on the [`Doc`](/api/doc) and [`Token`](/api/token), whether
they retokenize the `Doc` and which scores they produce during training. It will
also show warnings if components require values that aren't set by previous
component for instance, if the entity linker is used but no component that
runs before it sets named entities. Setting `pretty=True` will pretty-print a
table instead of only returning the structured data.
> #### ✏️ Things to try
>
> 1. Add the components `"ner"` and `"sentencizer"` _before_ the entity linker.
> The analysis should now show no problems, because requirements are met.
```python
### {executable="true"}
import spacy
nlp = spacy.blank("en")
nlp.add_pipe("tagger")
# This is a problem because it needs entities and sentence boundaries
nlp.add_pipe("entity_linker")
analysis = nlp.analyze_pipes(pretty=True)
```
<Accordion title="Example output">
```json
### Structured
{
"summary": {
"tagger": {
"assigns": ["token.tag"],
"requires": [],
"scores": ["tag_acc", "pos_acc", "lemma_acc"],
"retokenizes": false
},
"entity_linker": {
"assigns": ["token.ent_kb_id"],
"requires": ["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"],
"scores": [],
"retokenizes": false
}
},
"problems": {
"tagger": [],
"entity_linker": ["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"]
},
"attrs": {
"token.ent_iob": { "assigns": [], "requires": ["entity_linker"] },
"doc.ents": { "assigns": [], "requires": ["entity_linker"] },
"token.ent_kb_id": { "assigns": ["entity_linker"], "requires": [] },
"doc.sents": { "assigns": [], "requires": ["entity_linker"] },
"token.tag": { "assigns": ["tagger"], "requires": [] },
"token.ent_type": { "assigns": [], "requires": ["entity_linker"] }
}
}
```
```
### Pretty
============================= Pipeline Overview =============================
# Component Assigns Requires Scores Retokenizes
- ------------- --------------- -------------- --------- -----------
0 tagger token.tag tag_acc False
pos_acc
lemma_acc
1 entity_linker token.ent_kb_id doc.ents False
doc.sents
token.ent_iob
token.ent_type
================================ Problems (4) ================================
⚠ 'entity_linker' requirements not met: doc.ents, doc.sents,
token.ent_iob, token.ent_type
```
</Accordion>
<Infobox variant="warning" title="Important note">
The pipeline analysis is static and does **not actually run the components**.
This means that it relies on the information provided by the components
themselves. If a custom component declares that it assigns an attribute but it
doesn't, the pipeline analysis won't catch that.
</Infobox>
## Creating custom pipeline components {#custom-components}
A pipeline component is a function that receives a `Doc` object, modifies it and
@ -489,6 +582,8 @@ All other settings can be passed in by the user via the `config` argument on
[`@Language.factory`](/api/language#factory) decorator also lets you define a
`default_config` that's used as a fallback.
<!-- TODO: add example of passing in a custom Python object via the config based on a registered function -->
```python
### With config {highlight="4,9"}
import spacy

View File

@ -15,8 +15,6 @@ import Serialization101 from 'usage/101/\_serialization.md'
### Serializing the pipeline {#pipeline}
<!-- TODO: update this -->
When serializing the pipeline, keep in mind that this will only save out the
**binary data for the individual components** to allow spaCy to restore them
not the entire objects. This is a good thing, because it makes serialization

View File

@ -3,7 +3,8 @@ title: Training Models
next: /usage/projects
menu:
- ['Introduction', 'basics']
- ['CLI & Config', 'cli-config']
- ['Quickstart', 'quickstart']
- ['Config System', 'config']
- ['Transfer Learning', 'transfer-learning']
- ['Custom Models', 'custom-models']
- ['Parallel Training', 'parallel-training']
@ -29,12 +30,13 @@ ready-to-use spaCy models.
</Infobox>
## Training CLI & config {#cli-config}
### Training CLI & config {#cli-config}
<!-- TODO: intro describing the new v3 training philosophy -->
The recommended way to train your spaCy models is via the
[`spacy train`](/api/cli#train) command on the command line.
[`spacy train`](/api/cli#train) command on the command line. You can pass in the
following data and information:
1. The **training and evaluation data** in spaCy's
[binary `.spacy` format](/api/data-formats#binary-training) created using
@ -68,38 +70,22 @@ workflows, from data preprocessing to training and packaging your model.
</Project>
<Accordion title="Understanding the training output">
## Quickstart {#quickstart}
When you train a model using the [`spacy train`](/api/cli#train) command, you'll
see a table showing metrics after each pass over the data. Here's what those
metrics means:
> #### Instructions
>
> 1. Select your requirements and settings. The quickstart widget will
> auto-generate a recommended starter config for you.
> 2. Use the buttons at the bottom to save the result to your clipboard or a
> file `config.cfg`.
> 3. TOOD: recommended approach for filling config
> 4. Run [`spacy train`](/api/cli#train) with your config and data.
<!-- TODO: update table below and include note about scores in config -->
import QuickstartTraining from 'widgets/quickstart-training.js'
| Name | Description |
| ---------- | ------------------------------------------------------------------------------------------------- |
| `Dep Loss` | Training loss for dependency parser. Should decrease, but usually not to 0. |
| `NER Loss` | Training loss for named entity recognizer. Should decrease, but usually not to 0. |
| `UAS` | Unlabeled attachment score for parser. The percentage of unlabeled correct arcs. Should increase. |
| `NER P.` | NER precision on development data. Should increase. |
| `NER R.` | NER recall on development data. Should increase. |
| `NER F.` | NER F-score on development data. Should increase. |
| `Tag %` | Fine-grained part-of-speech tag accuracy on development data. Should increase. |
| `Token %` | Tokenization accuracy on development data. |
| `CPU WPS` | Prediction speed on CPU in words per second, if available. Should stay stable. |
| `GPU WPS` | Prediction speed on GPU in words per second, if available. Should stay stable. |
<QuickstartTraining />
Note that if the development data has raw text, some of the gold-standard
entities might not align to the predicted tokenization. These tokenization
errors are **excluded from the NER evaluation**. If your tokenization makes it
impossible for the model to predict 50% of your entities, your NER F-score might
still look good.
</Accordion>
---
### Training config files {#config}
## Training config {#config}
> #### Migration from spaCy v2.x
>
@ -237,7 +223,70 @@ compound = 1.001
<!-- TODO: refer to architectures API: /api/architectures. This should document the architectures in spacy/ml/models -->
<!-- TODO: how do we document the default configs? -->
### Metrics, training output and weighted scores {#metrics}
When you train a model using the [`spacy train`](/api/cli#train) command, you'll
see a table showing the metrics after each pass over the data. The available
metrics **depend on the pipeline components**. Pipeline components also define
which scores are shown and how they should be **weighted in the final score**
that decides about the best model.
The `training.score_weights` setting in your `config.cfg` lets you customize the
scores shown in the table and how they should be weighted. In this example, the
labeled dependency accuracy and NER F-score count towards the final score with
40% each and the tagging accuracy makes up the remaining 20%. The tokenization
accuracy and speed are both shown in the table, but not counted towards the
score.
> #### Why do I need score weights?
>
> At the end of your training process, you typically want to select the **best
> model** but what "best" means depends on the available components and your
> specific use case. For instance, you may prefer a model with higher NER and
> lower POS tagging accuracy over a model with lower NER and higher POS
> accuracy. You can express this preference in the score weights, e.g. by
> assigning `ents_f` (NER F-score) a higher weight.
```ini
[training.score_weights]
dep_las = 0.4
ents_f = 0.4
tag_acc = 0.2
token_acc = 0.0
speed = 0.0
```
The `score_weights` don't _have to_ sum to `1.0` but it's recommended. When
you generate a config for a given pipeline, the score weights are generated by
combining and normalizing the default score weights of the pipeline components.
The default score weights are defined by each pipeline component via the
`default_score_weights` setting on the
[`@Language.component`](/api/language#component) or
[`@Language.factory`](/api/language#factory). By default, all pipeline
components are weighted equally.
<Accordion title="Understanding the training output and score types" spaced>
<!-- TODO: come up with good short explanation of precision and recall -->
| Name | Description |
| -------------------------- | ----------------------------------------------------------------------------------------------------------------------- |
| **Loss** | The training loss representing the amount of work left for the optimizer. Should decrease, but usually not to `0`. |
| **Precision** (P) | Should increase. |
| **Recall** (R) | Should increase. |
| **F-Score** (F) | The weighted average of precision and recall. Should increase. |
| **UAS** / **LAS** | Unlabeled and labeled attachment score for the dependency parser, i.e. the percentage of correct arcs. Should increase. |
| **Words per second** (WPS) | Prediction speed in words per second. Should stay stable. |
<!-- TODO: is this still relevant? -->
Note that if the development data has raw text, some of the gold-standard
entities might not align to the predicted tokenization. These tokenization
errors are **excluded from the NER evaluation**. If your tokenization makes it
impossible for the model to predict 50% of your entities, your NER F-score might
still look good.
</Accordion>
## Transfer learning {#transfer-learning}

View File

@ -88,7 +88,8 @@ The recommended workflow for training is to use spaCy's
[`spacy train`](/api/cli#train) command. The training config defines all
component settings and hyperparameters in one place and lets you describe a tree
of objects by referring to creation functions, including functions you register
yourself.
yourself. For details on how to get started with training your own model, check
out the [training quickstart](/usage/training#quickstart).
<Project id="en_core_bert">

View File

@ -3,21 +3,23 @@ import React, { useState, useRef } from 'react'
import Icon from './icon'
import classes from '../styles/copy.module.sass'
export function copyToClipboard(ref, callback) {
const isClient = typeof window !== 'undefined'
if (ref.current && isClient) {
ref.current.select()
document.execCommand('copy')
callback(true)
ref.current.blur()
setTimeout(() => callback(false), 1000)
}
}
const CopyInput = ({ text, prefix }) => {
const isClient = typeof window !== 'undefined'
const supportsCopy = isClient && document.queryCommandSupported('copy')
const textareaRef = useRef()
const [copySuccess, setCopySuccess] = useState(false)
function copyToClipboard() {
if (textareaRef.current && isClient) {
textareaRef.current.select()
document.execCommand('copy')
setCopySuccess(true)
textareaRef.current.blur()
setTimeout(() => setCopySuccess(false), 1000)
}
}
const onClick = () => copyToClipboard(textareaRef, setCopySuccess)
function selectText() {
if (textareaRef.current && isClient) {
@ -37,7 +39,7 @@ const CopyInput = ({ text, prefix }) => {
onClick={selectText}
/>
{supportsCopy && (
<button title="Copy to clipboard" onClick={copyToClipboard}>
<button title="Copy to clipboard" onClick={onClick}>
<Icon width={16} name={copySuccess ? 'accept' : 'clipboard'} />
</button>
)}

View File

@ -22,6 +22,7 @@ import { ReactComponent as SearchIcon } from '../images/icons/search.svg'
import { ReactComponent as MoonIcon } from '../images/icons/moon.svg'
import { ReactComponent as ClipboardIcon } from '../images/icons/clipboard.svg'
import { ReactComponent as NetworkIcon } from '../images/icons/network.svg'
import { ReactComponent as DownloadIcon } from '../images/icons/download.svg'
import classes from '../styles/icon.module.sass'
@ -46,7 +47,8 @@ const icons = {
search: SearchIcon,
moon: MoonIcon,
clipboard: ClipboardIcon,
network: NetworkIcon
network: NetworkIcon,
download: DownloadIcon,
}
const Icon = ({ name, width, height, inline, variant, className }) => {

View File

@ -1,4 +1,4 @@
import React, { Fragment, useState, useEffect } from 'react'
import React, { Fragment, useState, useEffect, useRef } from 'react'
import PropTypes from 'prop-types'
import classNames from 'classnames'
import { window } from 'browser-monads'
@ -6,6 +6,7 @@ import { window } from 'browser-monads'
import Section from './section'
import Icon from './icon'
import { H2 } from './typography'
import { copyToClipboard } from './copy'
import classes from '../styles/quickstart.module.sass'
function getNewChecked(optionId, checkedForId, multiple) {
@ -14,10 +15,41 @@ function getNewChecked(optionId, checkedForId, multiple) {
return [...checkedForId, optionId]
}
const Quickstart = ({ data, title, description, id, children }) => {
function getRawContent(ref) {
if (ref.current && ref.current.childNodes) {
// Select all currently visible nodes (spans and text nodes)
const result = [...ref.current.childNodes].filter(el => el.offsetParent !== null)
return result.map(el => el.textContent).join('\n')
}
return ''
}
const Quickstart = ({
data,
title,
description,
copy,
download,
id,
setters = {},
hidePrompts,
children,
}) => {
const contentRef = useRef()
const copyAreaRef = useRef()
const isClient = typeof window !== 'undefined'
const supportsCopy = isClient && document.queryCommandSupported('copy')
const showCopy = supportsCopy && copy
const [styles, setStyles] = useState({})
const [checked, setChecked] = useState({})
const [initialized, setInitialized] = useState(false)
const [copySuccess, setCopySuccess] = useState(false)
const [otherState, setOtherState] = useState({})
const setOther = (id, value) => setOtherState({ ...otherState, [id]: value })
const onClickCopy = () => {
copyAreaRef.current.value = getRawContent(contentRef)
copyToClipboard(copyAreaRef, setCopySuccess)
}
const getCss = (id, checkedOptions) => {
const checkedForId = checkedOptions[id] || []
@ -32,7 +64,7 @@ const Quickstart = ({ data, title, description, id, children }) => {
if (!initialized) {
const initialChecked = Object.assign(
{},
...data.map(({ id, options }) => ({
...data.map(({ id, options = [] }) => ({
[id]: options.filter(option => option.checked).map(({ id }) => id),
}))
)
@ -48,7 +80,7 @@ const Quickstart = ({ data, title, description, id, children }) => {
return !data.length ? null : (
<Section id={id}>
<div className={classes.root}>
<div className={classNames(classes.root, { [classes.hidePrompts]: !!hidePrompts })}>
{title && (
<H2 className={classes.title} name={id}>
<a href={`#${id}`}>{title}</a>
@ -57,82 +89,154 @@ const Quickstart = ({ data, title, description, id, children }) => {
{description && <p className={classes.description}>{description}</p>}
{data.map(({ id, title, options = [], multiple, help }) => (
<div key={id} data-quickstart-group={id} className={classes.group}>
<style data-quickstart-style={id}>
{styles[id] ||
`[data-quickstart-results]>[data-quickstart-${id}] { display: none }`}
</style>
<div className={classes.legend}>
{title}
{help && (
<span data-tooltip={help} className={classes.help}>
{' '}
<Icon name="help" width={16} spaced />
</span>
)}
</div>
<div className={classes.fields}>
{options.map(option => {
const optionType = multiple ? 'checkbox' : 'radio'
const checkedForId = checked[id] || []
return (
<Fragment key={option.id}>
<input
onChange={() => {
const newChecked = {
...checked,
[id]: getNewChecked(
option.id,
checkedForId,
multiple
),
{data.map(
({
id,
title,
options = [],
dropdown = [],
defaultValue,
multiple,
other,
help,
}) => {
// Optional function that's called with the value
const setterFunc = setters[id] || (() => {})
return (
<div key={id} data-quickstart-group={id} className={classes.group}>
<style data-quickstart-style={id} scoped>
{styles[id] ||
`[data-quickstart-results]>[data-quickstart-${id}] { display: none }`}
</style>
<div className={classes.legend}>
{title}
{help && (
<span data-tooltip={help} className={classes.help}>
{' '}
<Icon name="help" width={16} spaced />
</span>
)}
</div>
<div className={classes.fields}>
{!!dropdown.length && (
<select
defaultValue={defaultValue}
className={classes.select}
onChange={({ target }) => {
const value = target.value
if (value != other) {
setterFunc(value)
setOther(id, false)
} else {
setterFunc('')
setOther(id, true)
}
setChecked(newChecked)
setStyles({
...styles,
[id]: getCss(id, newChecked),
})
}}
type={optionType}
className={classNames(
classes.input,
classes[optionType]
)}
name={id}
id={`quickstart-${option.id}`}
value={option.id}
checked={checkedForId.includes(option.id)}
/>
<label
className={classes.label}
htmlFor={`quickstart-${option.id}`}
>
{option.title}
{option.meta && (
<span className={classes.meta}>{option.meta}</span>
)}
{option.help && (
<span
data-tooltip={option.help}
className={classes.help}
{dropdown.map(({ id, title }) => (
<option key={id} value={id}>
{title}
</option>
))}
{other && <option value={other}>{other}</option>}
</select>
)}
{other && otherState[id] && (
<input
type="text"
className={classes.textInput}
placeholder="Type here..."
onChange={({ target }) => setterFunc(target.value)}
/>
)}
{options.map(option => {
const optionType = multiple ? 'checkbox' : 'radio'
const checkedForId = checked[id] || []
return (
<Fragment key={option.id}>
<input
onChange={() => {
const newChecked = {
...checked,
[id]: getNewChecked(
option.id,
checkedForId,
multiple
),
}
setChecked(newChecked)
setStyles({
...styles,
[id]: getCss(id, newChecked),
})
setterFunc(newChecked[id])
}}
type={optionType}
className={classNames(
classes.input,
classes[optionType]
)}
name={id}
id={`quickstart-${option.id}`}
value={option.id}
checked={checkedForId.includes(option.id)}
/>
<label
className={classes.label}
htmlFor={`quickstart-${option.id}`}
>
{' '}
<Icon name="help" width={16} spaced />
</span>
)}
</label>
</Fragment>
)
})}
</div>
</div>
))}
{option.title}
{option.meta && (
<span className={classes.meta}>
{option.meta}
</span>
)}
{option.help && (
<span
data-tooltip={option.help}
className={classes.help}
>
{' '}
<Icon name="help" width={16} spaced />
</span>
)}
</label>
</Fragment>
)
})}
</div>
</div>
)
}
)}
<pre className={classes.code}>
<code className={classes.results} data-quickstart-results="">
<code className={classes.results} data-quickstart-results="" ref={contentRef}>
{children}
</code>
<menu className={classes.menu}>
{showCopy && (
<button
title="Copy to clipboard"
onClick={onClickCopy}
className={classes.iconButton}
>
<Icon width={18} name={copySuccess ? 'accept' : 'clipboard'} />
</button>
)}
{download && (
<a
href={`data:application/octet-stream,${getRawContent(contentRef)}`}
title="Download file"
download={download}
className={classes.iconButton}
>
<Icon width={18} name="download" />
</a>
)}
</menu>
</pre>
{showCopy && <textarea ref={copyAreaRef} className={classes.copyArea} rows={1} />}
</div>
</Section>
)
@ -141,6 +245,7 @@ const Quickstart = ({ data, title, description, id, children }) => {
Quickstart.defaultProps = {
data: [],
id: 'quickstart',
copy: true,
}
Quickstart.propTypes = {
@ -164,12 +269,13 @@ Quickstart.propTypes = {
),
}
const QS = ({ children, prompt = 'bash', divider = false, ...props }) => {
const QS = ({ children, prompt = 'bash', divider = false, comment = false, ...props }) => {
const qsClassNames = classNames({
[classes.prompt]: !!prompt && !divider,
[classes.bash]: prompt === 'bash' && !divider,
[classes.python]: prompt === 'python' && !divider,
[classes.divider]: !!divider,
[classes.comment]: !!comment,
})
const attrs = Object.assign(
{},

View File

@ -0,0 +1,4 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24">
<path d="M16.707 7.404c-0.189-0.188-0.448-0.283-0.707-0.283s-0.518 0.095-0.707 0.283l-2.293 2.293v-6.697c0-0.552-0.448-1-1-1s-1 0.448-1 1v6.697l-2.293-2.293c-0.189-0.188-0.44-0.293-0.707-0.293s-0.518 0.105-0.707 0.293c-0.39 0.39-0.39 1.024 0 1.414l4.707 4.682 4.709-4.684c0.388-0.387 0.388-1.022-0.002-1.412z"></path>
<path d="M20.987 16c0-0.105-0.004-0.211-0.039-0.316l-2-6c-0.136-0.409-0.517-0.684-0.948-0.684h-0.219c-0.094 0.188-0.21 0.368-0.367 0.525l-1.482 1.475h1.348l1.667 5h-13.893l1.667-5h1.348l-1.483-1.475c-0.157-0.157-0.274-0.337-0.367-0.525h-0.219c-0.431 0-0.812 0.275-0.948 0.684l-2 6c-0.035 0.105-0.039 0.211-0.039 0.316-0.013 0-0.013 5-0.013 5 0 0.553 0.447 1 1 1h16c0.553 0 1-0.447 1-1 0 0 0-5-0.013-5z"></path>
</svg>

After

Width:  |  Height:  |  Size: 821 B

View File

@ -24,7 +24,7 @@
.code,
.juniper-input pre
display: block
padding: 1.75em 2em
padding: 1.75em 1.5em
.code
&[data-prompt]:before,

View File

@ -370,9 +370,9 @@ body [id]:target
background-color: var(--color-dark-secondary)
border-left: 0.35em solid var(--color-theme)
display: block
margin-right: -2em
margin-left: -2em
padding-right: 2em
margin-right: -1.5em
margin-left: -1.5em
padding-right: 1.5em
padding-left: 1.65em
&:empty:before

View File

@ -83,6 +83,24 @@
.fields
flex: 100%
.select
cursor: pointer
border: 1px solid var(--color-subtle)
border-radius: var(--border-radius)
display: inline-block
padding: 0.35rem 1.25rem
margin: 0 1rem 0.75rem 0
font-size: var(--font-size-sm)
background: var(--color-back)
.text-input
border: 1px solid var(--color-subtle)
border-radius: var(--border-radius)
display: inline-block
padding: 0.35rem 0.75rem
font-size: var(--font-size-sm)
background: var(--color-back)
.code
background: var(--color-front)
color: var(--color-back)
@ -95,6 +113,7 @@
border-bottom-right-radius: var(--border-radius)
-webkit-font-smoothing: subpixel-antialiased
-moz-osx-font-smoothing: auto
position: relative
.results
display: block
@ -105,6 +124,9 @@
& > span
display: block
.hide-prompts .prompt:before
content: initial !important
.prompt:before
color: var(--color-theme)
margin-right: 1em
@ -115,6 +137,9 @@
.python:before
content: ">>>"
.comment
color: var(--syntax-comment)
.divider
padding: 1.5rem 0
@ -123,3 +148,29 @@
.input:checked + .label &
color: inherit
.copy-area
width: 1px
height: 1px
opacity: 0
position: absolute
.menu
color: var(--color-subtle)
padding-right: 1.5rem
display: inline-block
position: absolute
bottom: var(--spacing-xs)
right: 0
.icon-button
display: inline-block
color: inherit
cursor: pointer
transition: transform 0.05s ease
&:not(:last-child)
margin-right: 1.5rem
&:hover
transform: scale(1.1)

View File

@ -92,7 +92,7 @@ const QuickstartInstall = ({ id, title }) => (
</QS>
<QS package="source">pip install -r requirements.txt</QS>
<QS addition="transformers" package="pip">
pip install -U spacy-lookups-transformers
pip install -U spacy-transformers
</QS>
<QS addition="transformers" package="source">
pip install -U spacy-transformers

View File

@ -0,0 +1,118 @@
import React, { useState } from 'react'
import { StaticQuery, graphql } from 'gatsby'
import { Quickstart, QS } from '../components/quickstart'
const DEFAULT_LANG = 'en'
const MODELS_SMALL = { en: 'roberta-base-small' }
const MODELS_LARGE = { en: 'roberta-base' }
const COMPONENTS = ['tagger', 'parser', 'ner', 'textcat']
const COMMENT = `# This is an auto-generated partial config for training a model.
# TODO: intructions for how to fill and use it`
const DATA = [
{
id: 'lang',
title: 'Language',
defaultValue: DEFAULT_LANG,
},
{
id: 'components',
title: 'Components',
help: 'Pipeline components to train. Requires training data for those annotations.',
options: COMPONENTS.map(id => ({ id, title: id })),
multiple: true,
},
{
id: 'hardware',
title: 'Hardware',
options: [
{ id: 'cpu-only', title: 'CPU only' },
{ id: 'cpu', title: 'CPU preferred' },
{ id: 'gpu', title: 'GPU', checked: true },
],
},
{
id: 'optimize',
title: 'Optimize for',
help: '...',
options: [
{ id: 'efficiency', title: 'efficiency', checked: true },
{ id: 'accuracy', title: 'accuracy' },
],
},
{
id: 'config',
title: 'Configuration',
options: [
{
id: 'independent',
title: 'independent components',
help: "Make components independent and don't share weights",
},
],
multiple: true,
},
]
const QuickstartTraining = ({ id, title, download = 'config.cfg' }) => {
const [lang, setLang] = useState(DEFAULT_LANG)
const [pipeline, setPipeline] = useState([])
const setters = { lang: setLang, components: setPipeline }
return (
<StaticQuery
query={query}
render={({ site }) => {
const langs = site.siteMetadata.languages
DATA[0].dropdown = langs.map(({ name, code }) => ({
id: code,
title: name,
}))
return (
<Quickstart
download={download}
data={DATA}
title={title}
id={id}
setters={setters}
hidePrompts
>
<QS comment>{COMMENT}</QS>
<span>[nlp]</span>
<span>lang = "{lang}"</span>
<span>pipeline = {JSON.stringify(pipeline).replace(/,/g, ', ')}</span>
<br />
<span>[components]</span>
<br />
<span>[components.transformer]</span>
<QS optimize="efficiency">name = "{MODELS_SMALL[lang]}"</QS>
<QS optimize="accuracy">name = "{MODELS_LARGE[lang]}"</QS>
{!!pipeline.length && <br />}
{pipeline.map((pipe, i) => (
<>
{i !== 0 && <br />}
<span>[components.{pipe}]</span>
<span>factory = "{pipe}"</span>
</>
))}
</Quickstart>
)
}}
/>
)
}
const query = graphql`
query QuickstartTrainingQuery {
site {
siteMetadata {
languages {
code
name
}
}
}
}
`
export default QuickstartTraining