Merge remote-tracking branch 'upstream/develop' into feature/el-docs

# Conflicts:
#	website/docs/usage/training.md
This commit is contained in:
svlandeg 2020-08-06 19:48:52 +02:00
commit b17db0e994
20 changed files with 394 additions and 326 deletions

View File

@ -89,6 +89,21 @@ class AttributeRuler(Pipe):
set_token_attrs(token, attrs)
return doc
def pipe(self, stream, *, batch_size=128):
"""Apply the pipe to a stream of documents. This usually happens under
the hood when the nlp object is called on a text and all components are
applied to the Doc.
stream (Iterable[Doc]): A stream of documents.
batch_size (int): The number of documents to buffer.
YIELDS (Doc): Processed documents in order.
DOCS: https://spacy.io/attributeruler/pipe#pipe
"""
for doc in stream:
doc = self(doc)
yield doc
def load_from_tag_map(
self, tag_map: Dict[str, Dict[Union[int, str], Union[int, str]]]
) -> None:

View File

@ -30,8 +30,8 @@ bow_model_config = """
[model]
@architectures = "spacy.TextCatBOW.v1"
exclusive_classes = false
ngram_size: 1
no_output_layer: false
ngram_size = 1
no_output_layer = false
"""
cnn_model_config = """

View File

@ -242,7 +242,8 @@ class Scorer:
per_feat[field].score_set(
pred_per_feat.get(field, set()), gold_per_feat.get(field, set()),
)
return {f"{attr}_per_feat": per_feat}
result = {k: v.to_dict() for k, v in per_feat.items()}
return {f"{attr}_per_feat": result}
@staticmethod
def score_spans(
@ -318,6 +319,7 @@ class Scorer:
labels: Iterable[str] = tuple(),
multi_label: bool = True,
positive_label: Optional[str] = None,
threshold: Optional[float] = None,
**cfg,
) -> Dict[str, Any]:
"""Returns PRF and ROC AUC scores for a doc-level attribute with a
@ -333,94 +335,104 @@ class Scorer:
Defaults to True.
positive_label (str): The positive label for a binary task with
exclusive classes. Defaults to None.
threshold (float): Cutoff to consider a prediction "positive". Defaults
to 0.5 for multi-label, and 0.0 (i.e. whatever's highest scoring)
otherwise.
RETURNS (Dict[str, Any]): A dictionary containing the scores, with
inapplicable scores as None:
for all:
attr_score (one of attr_f / attr_macro_f / attr_macro_auc),
attr_score (one of attr_micro_f / attr_macro_f / attr_macro_auc),
attr_score_desc (text description of the overall score),
attr_micro_f,
attr_macro_f,
attr_auc,
attr_f_per_type,
attr_auc_per_type
for binary exclusive with positive label: attr_p/r/f
for 3+ exclusive classes, macro-averaged fscore: attr_macro_f
for multilabel, macro-averaged AUC: attr_macro_auc
DOCS: https://spacy.io/api/scorer#score_cats
"""
score = PRFScore()
f_per_type = dict()
auc_per_type = dict()
for label in labels:
f_per_type[label] = PRFScore()
auc_per_type[label] = ROCAUCScore()
if threshold is None:
threshold = 0.5 if multi_label else 0.0
f_per_type = {label: PRFScore() for label in labels}
auc_per_type = {label: ROCAUCScore() for label in labels}
labels = set(labels)
if labels:
for eg in examples:
labels.update(eg.predicted.cats.keys())
labels.update(eg.reference.cats.keys())
for example in examples:
gold_doc = example.reference
pred_doc = example.predicted
gold_values = getter(gold_doc, attr)
pred_values = getter(pred_doc, attr)
if (
len(gold_values) > 0
and set(f_per_type) == set(auc_per_type) == set(gold_values)
and set(gold_values) == set(pred_values)
):
gold_val = max(gold_values, key=gold_values.get)
pred_val = max(pred_values, key=pred_values.get)
if positive_label:
score.score_set(
set([positive_label]) & set([pred_val]),
set([positive_label]) & set([gold_val]),
)
for label in set(gold_values):
auc_per_type[label].score_set(
pred_values[label], gold_values[label]
)
f_per_type[label].score_set(
set([label]) & set([pred_val]), set([label]) & set([gold_val])
)
elif len(f_per_type) > 0:
model_labels = set(f_per_type)
eval_labels = set(gold_values)
raise ValueError(
Errors.E162.format(
model_labels=model_labels, eval_labels=eval_labels
)
)
elif len(auc_per_type) > 0:
model_labels = set(auc_per_type)
eval_labels = set(gold_values)
raise ValueError(
Errors.E162.format(
model_labels=model_labels, eval_labels=eval_labels
)
)
# Through this loop, None in the gold_cats indicates missing label.
pred_cats = getter(example.predicted, attr)
gold_cats = getter(example.reference, attr)
# I think the AUC metric is applicable regardless of whether we're
# doing multi-label classification? Unsure. If not, move this into
# the elif pred_cats and gold_cats block below.
for label in labels:
pred_score = pred_cats.get(label, 0.0)
gold_score = gold_cats.get(label, 0.0)
if gold_score is not None:
auc_per_type[label].score_set(pred_score, gold_score)
if multi_label:
for label in labels:
pred_score = pred_cats.get(label, 0.0)
gold_score = gold_cats.get(label, 0.0)
if gold_score is not None:
if pred_score >= threshold and gold_score > 0:
f_per_type[label].tp += 1
elif pred_score >= threshold and gold_score == 0:
f_per_type[label].fp += 1
elif pred_score < threshold and gold_score > 0:
f_per_type[label].fn += 1
elif pred_cats and gold_cats:
# Get the highest-scoring for each.
pred_label, pred_score = max(pred_cats.items(), key=lambda it: it[1])
gold_label, gold_score = max(gold_cats.items(), key=lambda it: it[1])
if gold_score is not None:
if pred_label == gold_label and pred_score >= threshold:
f_per_type[pred_label].tp += 1
else:
f_per_type[gold_label].fn += 1
if pred_score >= threshold:
f_per_type[pred_label].fp += 1
elif gold_cats:
gold_label, gold_score = max(gold_cats, key=lambda it: it[1])
if gold_score is not None and gold_score > 0:
f_per_type[gold_label].fn += 1
else:
pred_label, pred_score = max(pred_cats, key=lambda it: it[1])
if pred_score >= threshold:
f_per_type[pred_label].fp += 1
micro_prf = PRFScore()
for label_prf in f_per_type.values():
micro_prf.tp = label_prf.tp
micro_prf.fn = label_prf.fn
micro_prf.fp = label_prf.fp
n_cats = len(f_per_type) + 1e-100
macro_p = sum(prf.precision for prf in f_per_type.values()) / n_cats
macro_r = sum(prf.recall for prf in f_per_type.values()) / n_cats
macro_f = sum(prf.fscore for prf in f_per_type.values()) / n_cats
results = {
f"{attr}_score": None,
f"{attr}_score_desc": None,
f"{attr}_p": None,
f"{attr}_r": None,
f"{attr}_f": None,
f"{attr}_macro_f": None,
f"{attr}_micro_p": micro_prf.precision,
f"{attr}_micro_r": micro_prf.recall,
f"{attr}_micro_f": micro_prf.fscore,
f"{attr}_macro_p": macro_p,
f"{attr}_macro_r": macro_r,
f"{attr}_macro_f": macro_f,
f"{attr}_macro_auc": None,
f"{attr}_f_per_type": {k: v.to_dict() for k, v in f_per_type.items()},
f"{attr}_auc_per_type": {k: v.score for k, v in auc_per_type.items()},
}
if len(labels) == 2 and not multi_label and positive_label:
results[f"{attr}_p"] = score.precision
results[f"{attr}_r"] = score.recall
results[f"{attr}_f"] = score.fscore
results[f"{attr}_score"] = results[f"{attr}_f"]
positive_label_f = results[f"{attr}_f_per_type"][positive_label]['f']
results[f"{attr}_score"] = positive_label_f
results[f"{attr}_score_desc"] = f"F ({positive_label})"
elif not multi_label:
results[f"{attr}_macro_f"] = sum(
[score.fscore for label, score in f_per_type.items()]
) / (len(f_per_type) + 1e-100)
results[f"{attr}_score"] = results[f"{attr}_macro_f"]
results[f"{attr}_score_desc"] = "macro F"
else:
results[f"{attr}_macro_auc"] = max(
sum([score.score for label, score in auc_per_type.items()])
/ (len(auc_per_type) + 1e-100),
-1,
)
results[f"{attr}_score"] = results[f"{attr}_macro_auc"]
results[f"{attr}_score_desc"] = "macro AUC"
return results

View File

@ -117,8 +117,10 @@ def test_overfitting_IO():
assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.1)
# Test scoring
scores = nlp.evaluate(train_examples, scorer_cfg={"positive_label": "POSITIVE"})
assert scores["cats_f"] == 1.0
scores = nlp.evaluate(
train_examples, scorer_cfg={"positive_label": "POSITIVE"}
)
assert scores["cats_micro_f"] == 1.0
assert scores["cats_score"] == 1.0
assert "cats_score_desc" in scores

View File

@ -259,7 +259,7 @@ def test_tag_score(tagged_doc):
assert results["tag_acc"] == 1.0
assert results["pos_acc"] == 1.0
assert results["morph_acc"] == 1.0
assert results["morph_per_feat"]["NounType"].fscore == 1.0
assert results["morph_per_feat"]["NounType"]["f"] == 1.0
# Gold annotation is modified
scorer = Scorer()
@ -282,9 +282,9 @@ def test_tag_score(tagged_doc):
assert results["tag_acc"] == 0.9
assert results["pos_acc"] == 0.9
assert results["morph_acc"] == approx(0.8)
assert results["morph_per_feat"]["NounType"].fscore == 1.0
assert results["morph_per_feat"]["Poss"].fscore == 0.0
assert results["morph_per_feat"]["Number"].fscore == approx(0.72727272)
assert results["morph_per_feat"]["NounType"]["f"] == 1.0
assert results["morph_per_feat"]["Poss"]["f"] == 0.0
assert results["morph_per_feat"]["Number"]["f"] == approx(0.72727272)
def test_roc_auc_score():

View File

@ -9,7 +9,6 @@ from cymem.cymem cimport Pool
from preshed.maps cimport PreshMap
cimport cython
from typing import Dict, List, Union, Pattern, Optional, Any
import re
import warnings

View File

@ -202,7 +202,7 @@ $ python -m spacy convert [input_file] [output_dir] [--converter]
| ID | Description |
| ------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `auto` | Automatically pick converter based on file extension and file content (default). |
| `json` | JSON-formatted training data used in spaCy v2.x and produced by [`docs2json`](/api/top-level#docs_to_json). |
| `json` | JSON-formatted training data used in spaCy v2.x. |
| `conll` | Universal Dependencies `.conllu` or `.conll` format. |
| `ner` | NER with IOB/IOB2 tags, one token per line with columns separated by whitespace. The first column is the token and the final column is the IOB tag. Sentences are separated by blank lines and documents are separated by the line `-DOCSTART- -X- O O`. Supports CoNLL 2003 NER format. See [sample data](https://github.com/explosion/spaCy/tree/master/examples/training/ner_example_data). |
| `iob` | NER with IOB/IOB2 tags, one sentence per line with tokens separated by whitespace and annotation separated by `|`, either `word|B-ENT` or `word|POS|B-ENT`. See [sample data](https://github.com/explosion/spaCy/tree/master/examples/training/ner_example_data). |
@ -602,13 +602,15 @@ $ python -m spacy train [config_path] [--output] [--code] [--verbose] [overrides
<!-- TODO: document new pretrain command and link to new pretraining docs -->
Pre-train the "token to vector" (`tok2vec`) layer of pipeline components on
[raw text](/api/data-formats#pretrain), using an approximate language-modeling
objective. Specifically, we load pretrained vectors, and train a component like
a CNN, BiLSTM, etc to predict vectors which match the pretrained ones. The
weights are saved to a directory after each epoch. You can then pass a path to
one of these pretrained weights files to the `spacy train` command. This
technique may be especially helpful if you have little labelled data.
Pre-train the "token to vector" ([`Tok2vec`](/api/tok2vec)) layer of pipeline
components on [raw text](/api/data-formats#pretrain), using an approximate
language-modeling objective. Specifically, we load pretrained vectors, and train
a component like a CNN, BiLSTM, etc to predict vectors which match the
pretrained ones. The weights are saved to a directory after each epoch. You can
then include a **path to one of these pretrained weights files** in your
[training config](/usage/training#config) as the `init_tok2vec` setting when you
train your model. This technique may be especially helpful if you have little
labelled data.
<Infobox title="Changed in v3.0" variant="warning">

View File

@ -9,7 +9,41 @@ new: 3
This class manages annotated corpora and can be used for training and
development datasets in the [DocBin](/api/docbin) (`.spacy`) format. To
customize the data loading during training, you can register your own
[data readers and batchers](/usage/training#custom-code-readers-batchers)
[data readers and batchers](/usage/training#custom-code-readers-batchers).
## Config and implementation {#config}
`spacy.Corpus.v1` is a registered function that creates a `Corpus` of training
or evaluation data. It takes the same arguments as the `Corpus` class and
returns a callable that yields [`Example`](/api/example) objects. You can
replace it with your own registered function in the
[`@readers` registry](/api/top-level#regsitry) to customize the data loading and
streaming.
> #### Example config
>
> ```ini
> [paths]
> train = "corpus/train.spacy"
>
> [training.train_corpus]
> @readers = "spacy.Corpus.v1"
> path = ${paths:train}
> gold_preproc = false
> max_length = 0
> limit = 0
> ```
| Name | Type | Description |
| --------------- | ------ | ----------------------------------------------------------------------------------------------------------------------------------------------- |
| `path` | `Path` | The directory or filename to read from. Expects data in spaCy's binary [`.spacy` format](/api/data-formats#binary-training). |
|  `gold_preproc` | bool | Whether to set up the Example object with gold-standard sentences and tokens for the predictions. See [`Corpus`](/api/corpus#init) for details. |
| `max_length` | int | Maximum document length. Longer documents will be split into sentences, if sentence boundaries are available. Defaults to `0` for no limit. |
| `limit` | int | Limit corpus to a subset of examples, e.g. for debugging. Defaults to `0` for no limit. |
```python
https://github.com/explosion/spaCy/blob/develop/spacy/gold/corpus.py
```
## Corpus.\_\_init\_\_ {#init tag="method"}

View File

@ -2,16 +2,173 @@
title: Data formats
teaser: Details on spaCy's input and output data formats
menu:
- ['Training Config', 'config']
- ['Training Data', 'training']
- ['Pretraining Data', 'pretraining']
- ['Training Config', 'config']
- ['Vocabulary', 'vocab']
---
This section documents input and output formats of data used by spaCy, including
training data and lexical vocabulary data. For an overview of label schemes used
by the models, see the [models directory](/models). Each model documents the
label schemes used in its components, depending on the data it was trained on.
the [training config](/usage/training#config), training data and lexical
vocabulary data. For an overview of label schemes used by the models, see the
[models directory](/models). Each model documents the label schemes used in its
components, depending on the data it was trained on.
## Training config {#config new="3"}
Config files define the training process and model pipeline and can be passed to
[`spacy train`](/api/cli#train). They use
[Thinc's configuration system](https://thinc.ai/docs/usage-config) under the
hood. For details on how to use training configs, see the
[usage documentation](/usage/training#config).
<!-- TODO: add details on getting started and init config -->
> #### What does the @ mean?
>
> The `@` syntax lets you refer to function names registered in the
> [function registry](/api/top-level#registry). For example,
> `@architectures = "spacy.HashEmbedCNN.v1"` refers to a registered function of
> the name [spacy.HashEmbedCNN.v1](/api/architectures#HashEmbedCNN) and all
> other values defined in its block will be passed into that function as
> arguments. Those arguments depend on the registered function. See the usage
> guide on [registered functions](/usage/training#config-functions) for details.
```ini
https://github.com/explosion/spaCy/blob/develop/spacy/default_config.cfg
```
<Infobox title="Notes on data validation" emoji="💡">
Under the hood, spaCy's configs are powered by our machine learning library
[Thinc's config system](https://thinc.ai/docs/usage-config), which uses
[`pydantic`](https://github.com/samuelcolvin/pydantic/) for data validation
based on type hints. See
[`spacy/schemas.py`](https://github.com/explosion/spaCy/blob/develop/spacy/schemas.py)
for the schemas used to validate the default config. Arguments of registered
functions are validated against their type annotations, if available. To debug
your config and check that it's valid, you can run the
[`spacy debug config`](/api/cli#debug-config) command.
</Infobox>
<!-- TODO: once we know how we want to implement "starter config" workflow or outputting a full default config for the user, update this section with the command -->
### nlp {#config-nlp tag="section"}
> #### Example
>
> ```ini
> [nlp]
> lang = "en"
> pipeline = ["tagger", "parser", "ner"]
> load_vocab_data = true
> before_creation = null
> after_creation = null
> after_pipeline_creation = null
>
> [nlp.tokenizer]
> @tokenizers = "spacy.Tokenizer.v1"
> ```
Defines the `nlp` object, its tokenizer and
[processing pipeline](/usage/processing-pipelines) component names.
| Name | Type | Description | Default |
| ------------------------- | ----------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----------------------------- |
| `lang` | str | The language code to use. | `null` |
| `pipeline` | `List[str]` | Names of pipeline components in order. Should correspond to sections in the `[components]` block, e.g. `[components.ner]`. See docs on [defining components](/usage/training#config-components). | `[]` |
| `load_vocab_data` | bool | Whether to load additional lexeme and vocab data from [`spacy-lookups-data`](https://github.com/explosion/spacy-lookups-data) if available. | `true` |
| `before_creation` | callable | Optional [callback](/usage/training#custom-code-nlp-callbacks) to modify `Language` subclass before it's initialized. | `null` |
| `after_creation` | callable | Optional [callback](/usage/training#custom-code-nlp-callbacks) to modify `nlp` object right after it's initialized. | `null` |
| `after_pipeline_creation` | callable | Optional [callback](/usage/training#custom-code-nlp-callbacks) to modify `nlp` object after the pipeline components have been added. | `null` |
| `tokenizer` | callable | The tokenizer to use. | [`Tokenizer`](/api/tokenizer) |
### components {#config-components tag="section"}
> #### Example
>
> ```ini
> [components.textcat]
> factory = "textcat"
> labels = ["POSITIVE", "NEGATIVE"]
>
> [components.textcat.model]
> @architectures = "spacy.TextCatBOW.v1"
> exclusive_classes = false
> ngram_size = 1
> no_output_layer = false
> ```
This section includes definitions of the
[pipeline components](/usage/processing-pipelines) and their models, if
available. Components in this section can be referenced in the `pipeline` of the
`[nlp]` block. Component blocks need to specify either a `factory` (named
function to use to create component) or a `source` (name of path of pretrained
model to copy components from). See the docs on
[defining pipeline components](/usage/training#config-components) for details.
### paths, system {#config-variables tag="variables"}
These sections define variables that can be referenced across the other sections
as variables. For example `${paths:train}` uses the value of `train` defined in
the block `[paths]`. If your config includes custom registered functions that
need paths, you can define them here. All config values can also be
[overwritten](/usage/training#config-overrides) on the CLI when you run
[`spacy train`](/api/cli#train), which is especially relevant for data paths
that you don't want to hard-code in your config file.
```bash
$ python -m spacy train ./config.cfg --paths.train ./corpus/train.spacy
```
### training {#config-training tag="section"}
This section defines settings and controls for the training and evaluation
process that are used when you run [`spacy train`](/api/cli#train).
<!-- TODO: complete -->
| Name | Type | Description | Default |
| --------------------- | --------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------- |
| `seed` | int | The random seed. | `${system:seed}` |
| `dropout` | float | The dropout rate. | `0.1` |
| `accumulate_gradient` | int | Whether to divide the batch up into substeps. | `1` |
| `init_tok2vec` | str | Optional path to pretrained tok2vec weights created with [`spacy pretrain`](/api/cli#pretrain). | `${paths:init_tok2vec}` |
| `raw_text` | str | | `${paths:raw}` |
| `vectors` | str | | `null` |
| `patience` | int | How many steps to continue without improvement in evaluation score. | `1600` |
| `max_epochs` | int | Maximum number of epochs to train for. | `0` |
| `max_steps` | int | Maximum number of update steps to train for. | `20000` |
| `eval_frequency` | int | How often to evaluate during training (steps). | `200` |
| `score_weights` | `Dict[str, float]` | Score names shown in metrics mapped to their weight towards the final weighted score. See [here](/usage/training#metrics) for details. | `{}` |
| `frozen_components` | `List[str]` | Pipeline component names that are "frozen" and shouldn't be updated during training. See [here](/usage/training#config-components) for details. | `[]` |
| `train_corpus` | callable | Callable that takes the current `nlp` object and yields [`Example`](/api/example) objects. | [`Corpus`](/api/corpus) |
| `dev_corpus` | callable | Callable that takes the current `nlp` object and yields [`Example`](/api/example) objects. | [`Corpus`](/api/corpus) |
| `batcher` | callable | Callable that takes an iterator of [`Doc`](/api/doc) objects and yields batches of `Doc`s. | [`batch_by_words`](/api/top-level#batch_by_words) |
| `optimizer` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. The learning rate schedule and other settings can be configured as part of the optimizer. | [`Adam`](https://thinc.ai/docs/api-optimizers#adam) |
### pretraining {#config-pretraining tag="section,optional"}
This section is optional and defines settings and controls for
[language model pretraining](/usage/training#pretraining). It's used when you
run [`spacy pretrain`](/api/cli#pretrain).
<!-- TODO: complete -->
| Name | Type | Description | Default |
| ---------------------------- | --------------------------------------------------- | ----------------------------------------------------------------------------- | --------------------------------------------------- |
| `max_epochs` | int | Maximum number of epochs. | `1000` |
| `min_length` | int | Minimum length of examples. | `5` |
| `max_length` | int | Maximum length of examples. | `500` |
| `dropout` | float | The dropout rate. | `0.2` |
| `n_save_every` | int | Saving frequency. | `null` |
| `batch_size` | int / `Sequence[int]` | The batch size or batch size [schedule](https://thinc.ai/docs/api-schedules). | `3000` |
| `seed` | int | The random seed. | `${system.seed}` |
| `use_pytorch_for_gpu_memory` | bool | Allocate memory via PyTorch. | `${system:use_pytorch_for_gpu_memory}` |
| `tok2vec_model` | str | tok2vec model section in the config. | `"components.tok2vec.model"` |
| `objective` | dict | The pretraining objective. | `{"type": "characters", "n_characters": 4}` |
| `optimizer` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. | [`Adam`](https://thinc.ai/docs/api-optimizers#adam) |
## Training data {#training}
@ -120,9 +277,7 @@ instance. It stores two [`Doc`](/api/doc) objects: one for holding the
gold-standard reference data, and one for holding the predictions of the
pipeline. Examples can be created using the
[`Example.from_dict`](/api/example#from_dict) method with a reference `Doc` and
a dictionary of gold-standard annotations. There are currently two formats
supported for this dictionary of annotations: one with a simple, **flat
structure** of keywords, and one with a more **hierarchical structure**.
a dictionary of gold-standard annotations.
> #### Example
>
@ -142,8 +297,6 @@ to keep track of your settings and hyperparameters and your own
</Infobox>
#### Flat structure {#dict-flat}
> #### Example
>
> ```python
@ -177,12 +330,12 @@ to keep track of your settings and hyperparameters and your own
| `sent_starts` | `List[bool]` | List of boolean values indicating whether each token is the first of a sentence or not. |
| `deps` | `List[str]` | List of string values indicating the [dependency relation](/usage/linguistic-features#dependency-parse) of a token to its head. |
| `heads` | `List[int]` | List of integer values indicating the dependency head of each token, referring to the absolute index of each token in the text. |
| `entities` | `List[str]` | Option 1: List of [BILUO tags](#biluo) per token of the format `"{action}-{label}"`, or `None` for unannotated tokens. |
| `entities` | `List[Tuple[int, int, str]]` | Option 2: List of `"(start, end, label)"` tuples defining all entities in the text. |
| `entities` | `List[str]` | **Option 1:** List of [BILUO tags](#biluo) per token of the format `"{action}-{label}"`, or `None` for unannotated tokens. |
| `entities` | `List[Tuple[int, int, str]]` | **Option 2:** List of `"(start, end, label)"` tuples defining all entities in the text. |
| `cats` | `Dict[str, float]` | Dictionary of `label`/`value` pairs indicating how relevant a certain [text category](/api/textcategorizer) is for the text. |
| `links` | `Dict[(int, int), Dict]` | Dictionary of `offset`/`dict` pairs defining [named entity links](/usage/linguistic-features#entity-linking). The character offsets are linked to a dictionary of relevant knowledge base IDs. |
<Infobox variant="warning" title="Important notes and caveats">
<Infobox title="Notes and caveats">
- Multiple formats are possible for the "entities" entry, but you have to pick
one.
@ -194,76 +347,34 @@ to keep track of your settings and hyperparameters and your own
</Infobox>
<!-- TODO: finish reformatting below -->
##### Examples
```python
### Examples
# Training data for a part-of-speech tagger
doc = Doc(vocab, words=["I", "like", "stuff"])
example = Example.from_dict(doc, {"tags": ["NOUN", "VERB", "NOUN"]})
gold_dict = {"tags": ["NOUN", "VERB", "NOUN"]}
example = Example.from_dict(doc, gold_dict)
# Training data for an entity recognizer (option 1)
doc = nlp("Laura flew to Silicon Valley.")
biluo_tags = ["U-PERS", "O", "O", "B-LOC", "L-LOC"]
example = Example.from_dict(doc, {"entities": biluo_tags})
gold_dict = {"entities": ["U-PERS", "O", "O", "B-LOC", "L-LOC"]}
example = Example.from_dict(doc, gold_dict)
# Training data for an entity recognizer (option 2)
doc = nlp("Laura flew to Silicon Valley.")
entity_tuples = [
(0, 5, "PERSON"),
(14, 28, "LOC"),
]
example = Example.from_dict(doc, {"entities": entity_tuples})
gold_dict = {"entities": [(0, 5, "PERSON"), (14, 28, "LOC")]}
example = Example.from_dict(doc, gold_dict)
# Training data for text categorization
doc = nlp("I'm pretty happy about that!")
example = Example.from_dict(doc, {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}})
gold_dict = {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}
example = Example.from_dict(doc, gold_dict)
# Training data for an Entity Linking component
doc = nlp("Russ Cochran his reprints include EC Comics.")
example = Example.from_dict(doc, {"links": {(0, 12): {"Q7381115": 1.0, "Q2146908": 0.0}}})
gold_dict = {"links": {(0, 12): {"Q7381115": 1.0, "Q2146908": 0.0}}}
example = Example.from_dict(doc, gold_dict)
```
#### Hierachical structure {#dict-hierarch}
Internally, a more hierarchical dictionary structure is used to store
gold-standard annotations. Its format is similar to the structure described in
the previous section, but there are two main sections `token_annotation` and
`doc_annotation`, and the keys for token annotations should be uppercase
[`Token` attributes](/api/token#attributes) such as "ORTH" and "TAG".
```python
### Hierarchical dictionary
{
"text": string, # Raw text.
"token_annotation": {
"ORTH": List[string], # List of gold tokens.
"LEMMA": List[string], # List of lemmas.
"SPACY": List[bool], # List of boolean values indicating whether the corresponding tokens is followed by a space or not.
"TAG": List[string], # List of fine-grained [POS tags](/usage/linguistic-features#pos-tagging).
"POS": List[string], # List of coarse-grained [POS tags](/usage/linguistic-features#pos-tagging).
"MORPH": List[string], # List of [morphological features](/usage/linguistic-features#rule-based-morphology).
"SENT_START": List[bool], # List of boolean values indicating whether each token is the first of a sentence or not.
"DEP": List[string], # List of string values indicating the [dependency relation](/usage/linguistic-features#dependency-parse) of a token to its head.
"HEAD": List[int], # List of integer values indicating the dependency head of each token, referring to the absolute index of each token in the text.
},
"doc_annotation": {
"entities": List[(int, int, string)], # List of [BILUO tags](#biluo) per token of the format `"{action}-{label}"`, or `None` for unannotated tokens.
"cats": Dict[str, float], # Dictionary of `label:value` pairs indicating how relevant a certain [category](/api/textcategorizer) is for the text.
"links": Dict[(int, int), Dict], # Dictionary of `offset:dict` pairs defining [named entity links](/usage/linguistic-features#entity-linking). The charachter offsets are linked to a dictionary of relevant knowledge base IDs.
}
}
```
There are a few caveats to take into account:
- Any values for sentence starts will be ignored if there are annotations for
dependency relations.
- If the dictionary contains values for "text" and "ORTH", but not "SPACY", the
latter are inferred automatically. If "ORTH" is not provided either, the
values are inferred from the `doc` argument.
## Pretraining data {#pretraining}
The [`spacy pretrain`](/api/cli#pretrain) command lets you pretrain the tok2vec
@ -297,29 +408,6 @@ provided.
{"tokens": ["If", "tokens", "are", "provided", "then", "we", "can", "skip", "the", "raw", "input", "text"]}
```
## Training config {#config new="3"}
Config files define the training process and model pipeline and can be passed to
[`spacy train`](/api/cli#train). They use
[Thinc's configuration system](https://thinc.ai/docs/usage-config) under the
hood. For details on how to use training configs, see the
[usage documentation](/usage/training#config).
<Infobox variant="warning">
The `@` syntax lets you refer to function names registered in the
[function registry](/api/top-level#registry). For example,
`@architectures = "spacy.HashEmbedCNN.v1"` refers to a registered function of
the name `"spacy.HashEmbedCNN.v1"` and all other values defined in its block
will be passed into that function as arguments. Those arguments depend on the
registered function. See the [model architectures](/api/architectures) docs for
API details.
</Infobox>
<!-- TODO: we need to come up with a good way to present the sections and their expected values visually? -->
<!-- TODO: once we know how we want to implement "starter config" workflow or outputting a full default config for the user, update this section with the command -->
## Lexical data for vocabulary {#vocab-jsonl new="2"}
To populate a model's vocabulary, you can use the

View File

@ -265,37 +265,6 @@ ancestor is found, e.g. if span excludes a necessary ancestor.
| ----------- | -------------------------------------- | ----------------------------------------------- |
| **RETURNS** | `numpy.ndarray[ndim=2, dtype="int32"]` | The lowest common ancestor matrix of the `Doc`. |
## Doc.to_json {#to_json tag="method" new="2.1"}
Convert a Doc to JSON. The format it produces will be the new format for the
[`spacy train`](/api/cli#train) command (not implemented yet). If custom
underscore attributes are specified, their values need to be JSON-serializable.
They'll be added to an `"_"` key in the data, e.g. `"_": {"foo": "bar"}`.
> #### Example
>
> ```python
> doc = nlp("Hello")
> json_doc = doc.to_json()
> ```
>
> #### Result
>
> ```python
> {
> "text": "Hello",
> "ents": [],
> "sents": [{"start": 0, "end": 5}],
> "tokens": [{"id": 0, "start": 0, "end": 5, "pos": "INTJ", "tag": "UH", "dep": "ROOT", "head": 0}
> ]
> }
> ```
| Name | Type | Description |
| ------------ | ---- | ------------------------------------------------------------------------------ |
| `underscore` | list | Optional list of string names of custom JSON-serializable `doc._.` attributes. |
| **RETURNS** | dict | The JSON-formatted data. |
## Doc.to_array {#to_array tag="method"}
Export given token attributes to a numpy `ndarray`. If `attr_ids` is a sequence

View File

@ -244,8 +244,7 @@ accuracy of predicted entities against the original gold-standard annotation.
## Example.to_dict {#to_dict tag="method"}
Return a
[hierarchical dictionary representation](/api/data-formats#dict-hierarch) of the
Return a [dictionary representation](/api/data-formats#dict-input) of the
reference annotation contained in this `Example`.
> #### Example
@ -256,7 +255,7 @@ reference annotation contained in this `Example`.
| Name | Type | Description |
| ----------- | ---------------- | ------------------------------------------------------ |
| **RETURNS** | `Dict[str, obj]` | Dictionary representation of the reference annotation. |
| **RETURNS** | `Dict[str, Any]` | Dictionary representation of the reference annotation. |
## Example.split_sents {#split_sents tag="method"}

View File

@ -5,9 +5,20 @@ tag: class
source: spacy/tokenizer.pyx
---
> #### Default config
>
> ```ini
> [nlp.tokenizer]
> @tokenizers = "spacy.Tokenizer.v1"
> ```
Segment text, and create `Doc` objects with the discovered segment boundaries.
For a deeper understanding, see the docs on
[how spaCy's tokenizer works](/usage/linguistic-features#how-tokenizer-works).
The tokenizer is typically created automatically when the a
[`Language`](/api/language) subclass is initialized and it reads its settings
like punctuation and special case rules from the
[`Language.Defaults`](/api/language#defaults) provided by the language subclass.
## Tokenizer.\_\_init\_\_ {#init tag="method"}

View File

@ -4,7 +4,7 @@ menu:
- ['spacy', 'spacy']
- ['displacy', 'displacy']
- ['registry', 'registry']
- ['Readers & Batchers', 'readers-batchers']
- ['Batchers', 'batchers']
- ['Data & Alignment', 'gold']
- ['Utility Functions', 'util']
---
@ -299,13 +299,14 @@ factories.
| ----------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `architectures` | Registry for functions that create [model architectures](/api/architectures). Can be used to register custom model architectures and reference them in the `config.cfg`. |
| `factories` | Registry for functions that create [pipeline components](/usage/processing-pipelines#custom-components). Added automatically when you use the `@spacy.component` decorator and also reads from [entry points](/usage/saving-loading#entry-points) |
| `tokenizers` | Registry for tokenizer factories. Registered functions should return a callback that receives the `nlp` object and returns a [`Tokenizer`](/api/tokenizer) or a custom callable. |
| `languages` | Registry for language-specific `Language` subclasses. Automatically reads from [entry points](/usage/saving-loading#entry-points). |
| `lookups` | Registry for large lookup tables available via `vocab.lookups`. |
| `displacy_colors` | Registry for custom color scheme for the [`displacy` NER visualizer](/usage/visualizers). Automatically reads from [entry points](/usage/saving-loading#entry-points). |
| `assets` | |
| `callbacks` | Registry for custom callbacks to [modify the `nlp` object](/usage/training#custom-code-nlp-callbacks) before training. |
| `readers` | Registry for training and evaluation [data readers](#readers-batchers). |
| `batchers` | Registry for training and evaluation [data batchers](#readers-batchers). |
| `readers` | Registry for training and evaluation data readers like [`Corpus`](/api/corpus). |
| `batchers` | Registry for training and evaluation [data batchers](#batchers). |
| `optimizers` | Registry for functions that create [optimizers](https://thinc.ai/docs/api-optimizers). |
| `schedules` | Registry for functions that create [schedules](https://thinc.ai/docs/api-schedules). |
| `layers` | Registry for functions that create [layers](https://thinc.ai/docs/api-layers). |
@ -337,42 +338,9 @@ See the [`Transformer`](/api/transformer) API reference and
| [`span_getters`](/api/transformer#span_getters) | Registry for functions that take a batch of `Doc` objects and return a list of `Span` objects to process by the transformer, e.g. sentences. |
| [`annotation_setters`](/api/transformer#annotation_setters) | Registry for functions that create annotation setters. Annotation setters are functions that take a batch of `Doc` objects and a [`FullTransformerBatch`](/api/transformer#fulltransformerbatch) and can set additional annotations on the `Doc`. |
## Data readers and batchers {#readers-batchers new="3"}
## Batchers {#batchers source="spacy/gold/batchers.py" new="3"}
<!-- TODO: -->
### spacy.Corpus.v1 {#corpus tag="registered function" source="spacy/gold/corpus.py"}
Registered function that creates a [`Corpus`](/api/corpus) of training or
evaluation data. It takes the same arguments as the `Corpus` class and returns a
callable that yields [`Example`](/api/example) objects. You can replace it with
your own registered function in the [`@readers` registry](#regsitry) to
customize the data loading and streaming.
> #### Example config
>
> ```ini
> [paths]
> train = "corpus/train.spacy"
>
> [training.train_corpus]
> @readers = "spacy.Corpus.v1"
> path = ${paths:train}
> gold_preproc = false
> max_length = 0
> limit = 0
> ```
| Name | Type | Description |
| --------------- | ------ | ----------------------------------------------------------------------------------------------------------------------------------------------- |
| `path` | `Path` | The directory or filename to read from. Expects data in spaCy's binary [`.spacy` format](/api/data-formats#binary-training). |
|  `gold_preproc` | bool | Whether to set up the Example object with gold-standard sentences and tokens for the predictions. See [`Corpus`](/api/corpus#init) for details. |
| `max_length` | int | Maximum document length. Longer documents will be split into sentences, if sentence boundaries are available. Defaults to `0` for no limit. |
| `limit` | int | Limit corpus to a subset of examples, e.g. for debugging. Defaults to `0` for no limit. |
### Batchers {#batchers source="spacy/gold/batchers.py"}
<!-- TODO: -->
<!-- TODO: intro and also describe signature of functions -->
#### batch_by_words.v1 {#batch_by_words tag="registered function"}
@ -446,28 +414,6 @@ themselves, or be discarded if `discard_oversize` is set to `True`. The argument
## Training data and alignment {#gold source="spacy/gold"}
### gold.docs_to_json {#docs_to_json tag="function"}
Convert a list of Doc objects into the
[JSON-serializable format](/api/data-formats#json-input) used by the
[`spacy train`](/api/cli#train) command. Each input doc will be treated as a
'paragraph' in the output doc.
> #### Example
>
> ```python
> from spacy.gold import docs_to_json
>
> doc = nlp("I like London")
> json_data = docs_to_json([doc])
> ```
| Name | Type | Description |
| ----------- | ---------------- | ------------------------------------------ |
| `docs` | iterable / `Doc` | The `Doc` object(s) to convert. |
| `id` | int | ID to assign to the JSON. Defaults to `0`. |
| **RETURNS** | dict | The data in spaCy's JSON format. |
### gold.biluo_tags_from_offsets {#biluo_tags_from_offsets tag="function"}
Encode labelled spans into per-token tags, using the

View File

@ -1089,9 +1089,10 @@ In situations like that, you often want to align the tokenization so that you
can merge annotations from different sources together, or take vectors predicted
by a
[pretrained BERT model](https://github.com/huggingface/pytorch-transformers) and
apply them to spaCy tokens. spaCy's [`Alignment`](/api/example#alignment-object) object
allows the one-to-one mappings of token indices in both directions as well as
taking into account indices where multiple tokens align to one single token.
apply them to spaCy tokens. spaCy's [`Alignment`](/api/example#alignment-object)
object allows the one-to-one mappings of token indices in both directions as
well as taking into account indices where multiple tokens align to one single
token.
> #### ✏️ Things to try
>
@ -1490,7 +1491,7 @@ language name, and even train models with it and refer to it in your
>
> ```bash
> ### {wrap="true"}
> $ python -m spacy train train.spacy dev.spacy config.cfg --code code.py
> $ python -m spacy train config.cfg --code code.py
> ```
```python

View File

@ -210,7 +210,7 @@ commands:
- name: train
help: 'Train a spaCy model using the specified corpus and config'
script:
- 'python -m spacy train ./corpus/training.spacy ./corpus/evaluation.spacy ./configs/config.cfg -o training/'
- 'python -m spacy train ./configs/config.cfg -o training/ --paths.train ./corpus/training.spacy --paths.dev ./corpus/evaluation.spacy'
deps:
- 'configs/config.cfg'
- 'corpus/training.spacy'

View File

@ -30,35 +30,14 @@ ready-to-use spaCy models.
</Infobox>
### Training CLI & config {#cli-config}
<!-- TODO: intro describing the new v3 training philosophy -->
## Quickstart {#quickstart}
The recommended way to train your spaCy models is via the
[`spacy train`](/api/cli#train) command on the command line. You can pass in the
following data and information:
1. The **training and evaluation data** in spaCy's
[binary `.spacy` format](/api/data-formats#binary-training) created using
[`spacy convert`](/api/cli#convert).
2. A [`config.cfg`](#config) **configuration file** with all settings and
hyperparameters.
3. An optional **Python file** to register
[custom functions and architectures](#custom-code).
```bash
$ python -m spacy train train.spacy dev.spacy config.cfg --output ./output
```
<Project id="some_example_project">
The easiest way to get started with an end-to-end training process is to clone a
[project](/usage/projects) template. Projects let you manage multi-step
workflows, from data preprocessing to training and packaging your model.
</Project>
## Quickstart {#quickstart}
[`spacy train`](/api/cli#train) command on the command line. It only needs a
single [`config.cfg`](#config) **configuration file** that includes all settings
and hyperparameters. You can optionally [overwritten](#config-overrides)
settings on the command line, and load in a Python file to register
[custom functions](#custom-code) and architectures.
> #### Instructions
>
@ -88,17 +67,26 @@ $ python -m spacy init config config.cfg --base base_config.cfg
> invalid entity annotations, cyclic dependencies, low data labels and more.
>
> ```bash
> $ python -m spacy debug data en train.spacy dev.spacy --verbose
> $ python -m spacy debug data config.cfg --verbose
> ```
You can now run [`train`](/api/cli#train) with your training and development
data and the training config. See the [`convert`](/api/cli#convert) command for
details on how to convert your data to spaCy's binary `.spacy` format.
You can now add your data and run [`train`](/api/cli#train) with your config.
See the [`convert`](/api/cli#convert) command for details on how to convert your
data to spaCy's binary `.spacy` format. You can either include the data paths in
the `[paths]` section of your config, or pass them in via the command line.
```bash
$ python -m spacy train train.spacy dev.spacy config.cfg --output ./output
$ python -m spacy train config.cfg --output ./output --paths.train ./train.spacy --paths.dev ./dev.spacy
```
<Project id="some_example_project">
The easiest way to get started with an end-to-end training process is to clone a
[project](/usage/projects) template. Projects let you manage multi-step
workflows, from data preprocessing to training and packaging your model.
</Project>
## Training config {#config}
> #### Migration from spaCy v2.x
@ -149,19 +137,19 @@ not just define static settings, but also construct objects like architectures,
schedules, optimizers or any other custom components. The main top-level
sections of a config file are:
| Section | Description |
| ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `nlp` | Definition of the `nlp` object, its tokenizer and [processing pipeline](/usage/processing-pipelines) component names. |
| `components` | Definitions of the [pipeline components](/usage/processing-pipelines) and their models. |
| `paths` | Paths to data and other assets. Can be re-used across the config as variables, e.g. `${paths:train}`, and [overwritten](#config-overrides) on the CLI. |
| `system` | Settings related to system and hardware. |
| `training` | Settings and controls for the training and evaluation process. |
| `pretraining` | Optional settings and controls for the [language model pretraining](#pretraining). |
| Section | Description |
| ------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `nlp` | Definition of the `nlp` object, its tokenizer and [processing pipeline](/usage/processing-pipelines) component names. |
| `components` | Definitions of the [pipeline components](/usage/processing-pipelines) and their models. |
| `paths` | Paths to data and other assets. Re-used across the config as variables, e.g. `${paths:train}`, and can be [overwritten](#config-overrides) on the CLI. |
| `system` | Settings related to system and hardware. Re-used across the config as variables, e.g. `${system.seed}`, and can be [overwritten](#config-overrides) on the CLI. |
| `training` | Settings and controls for the training and evaluation process. |
| `pretraining` | Optional settings and controls for the [language model pretraining](#pretraining). |
<Infobox title="Config format and settings" emoji="📖">
For a full overview of spaCy's config format and settings, see the
[training format documentation](/api/data-formats#config) and
[data format documentation](/api/data-formats#config) and
[Thinc's config system docs](https://thinc.ai/usage/config). The settings
available for the different architectures are documented with the
[model architectures API](/api/architectures). See the Thinc documentation for
@ -172,8 +160,6 @@ available for the different architectures are documented with the
### Overwriting config settings on the command line {#config-overrides}
<!-- TODO: change example to use file path overrides -->
The config system means that you can define all settings **in one place** and in
a consistent format. There are no command-line arguments that need to be set,
and no hidden defaults. However, there can still be scenarios where you may want
@ -183,18 +169,20 @@ hard-code in a config file, or **system-dependent settings**.
For cases like this, you can set additional command-line options starting with
`--` that correspond to the config section and value to override. For example,
`--training.batch_size 128` sets the `batch_size` value in the `[training]`
block to `128`.
`--paths.train ./corpus/train.spacy` sets the `train` value in the `[paths]`
block.
```bash
$ python -m spacy train train.spacy dev.spacy config.cfg
--training.batch_size 128 --nlp.vectors /path/to/vectors
$ python -m spacy train config.cfg --paths.train ./corpus/train.spacy
--paths.dev ./corpus/dev.spacy --training.batch_size 128
```
Only existing sections and values in the config can be overwritten. At the end
of the training, the final filled `config.cfg` is exported with your model, so
you'll always have a record of the settings that were used, including your
overrides.
overrides. Overrides are added before [variables](#config-interpolation) are
resolved, by the way  so if you need to use a value in multiple places,
reference it across your config and override it on the CLI once.
### Defining pipeline components {#config-components}
@ -398,7 +386,7 @@ still look good.
> ```bash
> ### Example {wrap="true"}
> $ python -m spacy train train.spacy dev.spacy config.cfg --code functions.py
> $ python -m spacy train config.cfg --code functions.py
> ```
The [`spacy train`](/api/cli#train) recipe lets you specify an optional argument
@ -517,7 +505,7 @@ to your Python file. Before loading the config, spaCy will import the
```bash
### Training with custom code {wrap="true"}
python -m spacy train train.spacy dev.spacy config.cfg --output ./output --code ./functions.py
python -m spacy train config.cfg --output ./output --code ./functions.py
```
#### Example: Custom batch size schedule {#custom-code-schedule}
@ -610,7 +598,7 @@ config and customize the implementations, see the usage guide on
### Pretraining with spaCy {#pretraining}
<!-- TODO: document spacy pretrain -->
<!-- TODO: document spacy pretrain, objectives etc. -->
## Parallel Training with Ray {#parallel-training}

View File

@ -204,7 +204,7 @@ a Python file. For more details on training with custom code, see the
[training documentation](/usage/training#custom-code).
```bash
$ python -m spacy train ./train.spacy ./dev.spacy ./config.cfg --code ./code.py
$ python -m spacy train ./config.cfg --code ./code.py
```
### Customizing the model implementations {#training-custom-model}

View File

@ -182,7 +182,7 @@ $ python -m spacy convert ./training.json ./output
```diff
### {wrap="true"}
- python -m spacy train en ./output ./train.json ./dev.json --pipeline tagger,parser --cnn-window 1 --bilstm-depth 0
+ python -m spacy train ./train.spacy ./dev.spacy ./config.cfg --output ./output
+ python -m spacy train ./config.cfg --output ./output
```
<Project id="some_example_project">

View File

@ -11,6 +11,8 @@ import Link from './link'
import GitHubCode from './github'
import classes from '../styles/code.module.sass'
const WRAP_THRESHOLD = 15
export default props => (
<Pre>
<Code {...props} />
@ -23,7 +25,7 @@ export const Pre = props => {
export const InlineCode = ({ wrap = false, className, children, ...props }) => {
const codeClassNames = classNames(classes.inlineCode, className, {
[classes.wrap]: wrap || (isString(children) && children.length >= 20),
[classes.wrap]: wrap || (isString(children) && children.length >= WRAP_THRESHOLD),
})
return (
<code className={codeClassNames} {...props}>

View File

@ -58,7 +58,7 @@
.wrap
white-space: pre-wrap
word-wrap: break-word
word-wrap: anywhere
.title,
.juniper-button