generalize corpora, dot notation for dev and train corpus

This commit is contained in:
svlandeg 2020-09-17 11:38:59 +02:00
parent 8cedb2f380
commit 0c35885751
16 changed files with 261 additions and 143 deletions

View File

@ -8,6 +8,22 @@ init_tok2vec = null
seed = 0 seed = 0
use_pytorch_for_gpu_memory = false use_pytorch_for_gpu_memory = false
[corpora]
[corpora.train]
@readers = "spacy.Corpus.v1"
path = ${paths:train}
gold_preproc = true
max_length = 0
limit = 0
[corpora.dev]
@readers = "spacy.Corpus.v1"
path = ${paths:dev}
gold_preproc = ${corpora.train.gold_preproc}
max_length = 0
limit = 0
[training] [training]
seed = ${system:seed} seed = ${system:seed}
dropout = 0.1 dropout = 0.1
@ -20,22 +36,8 @@ patience = 10000
eval_frequency = 200 eval_frequency = 200
score_weights = {"dep_las": 0.4, "ents_f": 0.4, "tag_acc": 0.2} score_weights = {"dep_las": 0.4, "ents_f": 0.4, "tag_acc": 0.2}
frozen_components = [] frozen_components = []
dev_corpus = "corpora.dev"
[training.corpus] train_corpus = "corpora.train"
[training.corpus.train]
@readers = "spacy.Corpus.v1"
path = ${paths:train}
gold_preproc = true
max_length = 0
limit = 0
[training.corpus.dev]
@readers = "spacy.Corpus.v1"
path = ${paths:dev}
gold_preproc = ${training.read_train:gold_preproc}
max_length = 0
limit = 0
[training.batcher] [training.batcher]
@batchers = "spacy.batch_by_words.v1" @batchers = "spacy.batch_by_words.v1"

View File

@ -8,6 +8,22 @@ init_tok2vec = null
seed = 0 seed = 0
use_pytorch_for_gpu_memory = false use_pytorch_for_gpu_memory = false
[corpora]
[corpora.train]
@readers = "spacy.Corpus.v1"
path = ${paths:train}
gold_preproc = true
max_length = 0
limit = 0
[corpora.dev]
@readers = "spacy.Corpus.v1"
path = ${paths:dev}
gold_preproc = ${corpora.train.gold_preproc}
max_length = 0
limit = 0
[training] [training]
seed = ${system:seed} seed = ${system:seed}
dropout = 0.2 dropout = 0.2
@ -20,22 +36,6 @@ patience = 10000
eval_frequency = 200 eval_frequency = 200
score_weights = {"dep_las": 0.8, "tag_acc": 0.2} score_weights = {"dep_las": 0.8, "tag_acc": 0.2}
[training.corpus]
[training.corpus.train]
@readers = "spacy.Corpus.v1"
path = ${paths:train}
gold_preproc = true
max_length = 0
limit = 0
[training.corpus.dev]
@readers = "spacy.Corpus.v1"
path = ${paths:dev}
gold_preproc = ${training.read_train:gold_preproc}
max_length = 0
limit = 0
[training.batcher] [training.batcher]
@batchers = "spacy.batch_by_words.v1" @batchers = "spacy.batch_by_words.v1"
discard_oversize = false discard_oversize = false

View File

@ -20,6 +20,7 @@ from ..ml.models.multi_task import build_cloze_characters_multi_task_model
from ..tokens import Doc from ..tokens import Doc
from ..attrs import ID from ..attrs import ID
from .. import util from .. import util
from ..util import dot_to_object
@app.command( @app.command(
@ -106,7 +107,7 @@ def pretrain(
use_pytorch_for_gpu_memory() use_pytorch_for_gpu_memory()
nlp, config = util.load_model_from_config(config) nlp, config = util.load_model_from_config(config)
P_cfg = config["pretraining"] P_cfg = config["pretraining"]
corpus = P_cfg["corpus"] corpus = dot_to_object(config, config["pretraining"]["corpus"])
batcher = P_cfg["batcher"] batcher = P_cfg["batcher"]
model = create_pretraining_model(nlp, config["pretraining"]) model = create_pretraining_model(nlp, config["pretraining"])
optimizer = config["pretraining"]["optimizer"] optimizer = config["pretraining"]["optimizer"]

View File

@ -173,6 +173,18 @@ factory = "{{ pipe }}"
{% endif %} {% endif %}
{% endfor %} {% endfor %}
[corpora]
[corpora.train]
@readers = "spacy.Corpus.v1"
path = ${paths.train}
max_length = {{ 500 if hardware == "gpu" else 2000 }}
[corpora.dev]
@readers = "spacy.Corpus.v1"
path = ${paths.dev}
max_length = 0
[training] [training]
{% if use_transformer or optimize == "efficiency" or not word_vectors -%} {% if use_transformer or optimize == "efficiency" or not word_vectors -%}
vectors = null vectors = null
@ -182,11 +194,12 @@ vectors = "{{ word_vectors }}"
{% if use_transformer -%} {% if use_transformer -%}
accumulate_gradient = {{ transformer["size_factor"] }} accumulate_gradient = {{ transformer["size_factor"] }}
{% endif %} {% endif %}
dev_corpus = "corpora.dev"
train_corpus = "corpora.train"
[training.optimizer] [training.optimizer]
@optimizers = "Adam.v1" @optimizers = "Adam.v1"
{% if use_transformer -%} {% if use_transformer -%}
[training.optimizer.learn_rate] [training.optimizer.learn_rate]
@schedules = "warmup_linear.v1" @schedules = "warmup_linear.v1"
@ -195,18 +208,6 @@ total_steps = 20000
initial_rate = 5e-5 initial_rate = 5e-5
{% endif %} {% endif %}
[training.corpus]
[training.corpus.train]
@readers = "spacy.Corpus.v1"
path = ${paths.train}
max_length = {{ 500 if hardware == "gpu" else 2000 }}
[training.corpus.dev]
@readers = "spacy.Corpus.v1"
path = ${paths.dev}
max_length = 0
{% if use_transformer %} {% if use_transformer %}
[training.batcher] [training.batcher]
@batchers = "spacy.batch_by_padded.v1" @batchers = "spacy.batch_by_padded.v1"

View File

@ -18,6 +18,7 @@ from ..language import Language
from .. import util from .. import util
from ..training.example import Example from ..training.example import Example
from ..errors import Errors from ..errors import Errors
from ..util import dot_to_object
@app.command( @app.command(
@ -92,8 +93,8 @@ def train(
raw_text, tag_map, morph_rules, weights_data = load_from_paths(config) raw_text, tag_map, morph_rules, weights_data = load_from_paths(config)
T_cfg = config["training"] T_cfg = config["training"]
optimizer = T_cfg["optimizer"] optimizer = T_cfg["optimizer"]
train_corpus = T_cfg["corpus"]["train"] train_corpus = dot_to_object(config, config["training"]["train_corpus"])
dev_corpus = T_cfg["corpus"]["dev"] dev_corpus = dot_to_object(config, config["training"]["dev_corpus"])
batcher = T_cfg["batcher"] batcher = T_cfg["batcher"]
train_logger = T_cfg["logger"] train_logger = T_cfg["logger"]
# Components that shouldn't be updated during training # Components that shouldn't be updated during training

View File

@ -22,6 +22,33 @@ after_pipeline_creation = null
[components] [components]
# Readers for corpora like dev and train.
[corpora]
[corpora.train]
@readers = "spacy.Corpus.v1"
path = ${paths.train}
# Whether to train on sequences with 'gold standard' sentence boundaries
# and tokens. If you set this to true, take care to ensure your run-time
# data is passed in sentence-by-sentence via some prior preprocessing.
gold_preproc = false
# Limitations on training document length
max_length = 0
# Limitation on number of training examples
limit = 0
[corpora.dev]
@readers = "spacy.Corpus.v1"
path = ${paths.dev}
# Whether to train on sequences with 'gold standard' sentence boundaries
# and tokens. If you set this to true, take care to ensure your run-time
# data is passed in sentence-by-sentence via some prior preprocessing.
gold_preproc = false
# Limitations on training document length
max_length = 0
# Limitation on number of training examples
limit = 0
# Training hyper-parameters and additional features. # Training hyper-parameters and additional features.
[training] [training]
seed = ${system.seed} seed = ${system.seed}
@ -40,35 +67,14 @@ eval_frequency = 200
score_weights = {} score_weights = {}
# Names of pipeline components that shouldn't be updated during training # Names of pipeline components that shouldn't be updated during training
frozen_components = [] frozen_components = []
# Location in the config where the dev corpus is defined
dev_corpus = "corpora.dev"
# Location in the config where the train corpus is defined
train_corpus = "corpora.train"
[training.logger] [training.logger]
@loggers = "spacy.ConsoleLogger.v1" @loggers = "spacy.ConsoleLogger.v1"
[training.corpus]
[training.corpus.train]
@readers = "spacy.Corpus.v1"
path = ${paths.train}
# Whether to train on sequences with 'gold standard' sentence boundaries
# and tokens. If you set this to true, take care to ensure your run-time
# data is passed in sentence-by-sentence via some prior preprocessing.
gold_preproc = false
# Limitations on training document length
max_length = 0
# Limitation on number of training examples
limit = 0
[training.corpus.dev]
@readers = "spacy.Corpus.v1"
path = ${paths.dev}
# Whether to train on sequences with 'gold standard' sentence boundaries
# and tokens. If you set this to true, take care to ensure your run-time
# data is passed in sentence-by-sentence via some prior preprocessing.
gold_preproc = false
# Limitations on training document length
max_length = 0
# Limitation on number of training examples
limit = 0
[training.batcher] [training.batcher]
@batchers = "spacy.batch_by_words.v1" @batchers = "spacy.batch_by_words.v1"

View File

@ -4,6 +4,7 @@ dropout = 0.2
n_save_every = null n_save_every = null
component = "tok2vec" component = "tok2vec"
layer = "" layer = ""
corpus = "corpora.pretrain"
[pretraining.batcher] [pretraining.batcher]
@batchers = "spacy.batch_by_words.v1" @batchers = "spacy.batch_by_words.v1"
@ -12,13 +13,6 @@ discard_oversize = false
tolerance = 0.2 tolerance = 0.2
get_length = null get_length = null
[pretraining.corpus]
@readers = "spacy.JsonlReader.v1"
path = ${paths.raw}
min_length = 5
max_length = 500
limit = 0
[pretraining.objective] [pretraining.objective]
type = "characters" type = "characters"
n_characters = 4 n_characters = 4
@ -33,3 +27,12 @@ grad_clip = 1.0
use_averages = true use_averages = true
eps = 1e-8 eps = 1e-8
learn_rate = 0.001 learn_rate = 0.001
[corpora]
[corpora.pretrain]
@readers = "spacy.JsonlReader.v1"
path = ${paths.raw}
min_length = 5
max_length = 500
limit = 0

View File

@ -198,7 +198,8 @@ class ModelMetaSchema(BaseModel):
class ConfigSchemaTraining(BaseModel): class ConfigSchemaTraining(BaseModel):
# fmt: off # fmt: off
vectors: Optional[StrictStr] = Field(..., title="Path to vectors") vectors: Optional[StrictStr] = Field(..., title="Path to vectors")
corpus: Dict[str, Reader] = Field(..., title="Reader for the training and dev data") dev_corpus: StrictStr = Field(..., title="Path in the config to the dev data")
train_corpus: StrictStr = Field(..., title="Path in the config to the training data")
batcher: Batcher = Field(..., title="Batcher for the training data") batcher: Batcher = Field(..., title="Batcher for the training data")
dropout: StrictFloat = Field(..., title="Dropout rate") dropout: StrictFloat = Field(..., title="Dropout rate")
patience: StrictInt = Field(..., title="How many steps to continue without improvement in evaluation score") patience: StrictInt = Field(..., title="How many steps to continue without improvement in evaluation score")
@ -248,7 +249,7 @@ class ConfigSchemaPretrain(BaseModel):
dropout: StrictFloat = Field(..., title="Dropout rate") dropout: StrictFloat = Field(..., title="Dropout rate")
n_save_every: Optional[StrictInt] = Field(..., title="Saving frequency") n_save_every: Optional[StrictInt] = Field(..., title="Saving frequency")
optimizer: Optimizer = Field(..., title="The optimizer to use") optimizer: Optimizer = Field(..., title="The optimizer to use")
corpus: Reader = Field(..., title="Reader for the training data") corpus: StrictStr = Field(..., title="Path in the config to the training data")
batcher: Batcher = Field(..., title="Batcher for the training data") batcher: Batcher = Field(..., title="Batcher for the training data")
component: str = Field(..., title="Component to find the layer to pretrain") component: str = Field(..., title="Component to find the layer to pretrain")
layer: str = Field(..., title="Layer to pretrain. Whole model if empty.") layer: str = Field(..., title="Layer to pretrain. Whole model if empty.")
@ -267,6 +268,7 @@ class ConfigSchema(BaseModel):
nlp: ConfigSchemaNlp nlp: ConfigSchemaNlp
pretraining: Union[ConfigSchemaPretrain, ConfigSchemaPretrainEmpty] = {} pretraining: Union[ConfigSchemaPretrain, ConfigSchemaPretrainEmpty] = {}
components: Dict[str, Dict[str, Any]] components: Dict[str, Dict[str, Any]]
corpora: Dict[str, Reader]
@root_validator(allow_reuse=True) @root_validator(allow_reuse=True)
def validate_config(cls, values): def validate_config(cls, values):

View File

@ -17,18 +17,18 @@ nlp_config_string = """
train = "" train = ""
dev = "" dev = ""
[training] [corpora]
[training.corpus] [corpora.train]
[training.corpus.train]
@readers = "spacy.Corpus.v1" @readers = "spacy.Corpus.v1"
path = ${paths.train} path = ${paths.train}
[training.corpus.dev] [corpora.dev]
@readers = "spacy.Corpus.v1" @readers = "spacy.Corpus.v1"
path = ${paths.dev} path = ${paths.dev}
[training]
[training.batcher] [training.batcher]
@batchers = "spacy.batch_by_words.v1" @batchers = "spacy.batch_by_words.v1"
size = 666 size = 666
@ -302,20 +302,20 @@ def test_config_overrides():
def test_config_interpolation(): def test_config_interpolation():
config = Config().from_str(nlp_config_string, interpolate=False) config = Config().from_str(nlp_config_string, interpolate=False)
assert config["training"]["corpus"]["train"]["path"] == "${paths.train}" assert config["corpora"]["train"]["path"] == "${paths.train}"
interpolated = config.interpolate() interpolated = config.interpolate()
assert interpolated["training"]["corpus"]["train"]["path"] == "" assert interpolated["corpora"]["train"]["path"] == ""
nlp = English.from_config(config) nlp = English.from_config(config)
assert nlp.config["training"]["corpus"]["train"]["path"] == "${paths.train}" assert nlp.config["corpora"]["train"]["path"] == "${paths.train}"
# Ensure that variables are preserved in nlp config # Ensure that variables are preserved in nlp config
width = "${components.tok2vec.model.width}" width = "${components.tok2vec.model.width}"
assert config["components"]["tagger"]["model"]["tok2vec"]["width"] == width assert config["components"]["tagger"]["model"]["tok2vec"]["width"] == width
assert nlp.config["components"]["tagger"]["model"]["tok2vec"]["width"] == width assert nlp.config["components"]["tagger"]["model"]["tok2vec"]["width"] == width
interpolated2 = nlp.config.interpolate() interpolated2 = nlp.config.interpolate()
assert interpolated2["training"]["corpus"]["train"]["path"] == "" assert interpolated2["corpora"]["train"]["path"] == ""
assert interpolated2["components"]["tagger"]["model"]["tok2vec"]["width"] == 342 assert interpolated2["components"]["tagger"]["model"]["tok2vec"]["width"] == 342
nlp2 = English.from_config(interpolated) nlp2 = English.from_config(interpolated)
assert nlp2.config["training"]["corpus"]["train"]["path"] == "" assert nlp2.config["corpora"]["train"]["path"] == ""
assert nlp2.config["components"]["tagger"]["model"]["tok2vec"]["width"] == 342 assert nlp2.config["components"]["tagger"]["model"]["tok2vec"]["width"] == 342

View File

@ -1,6 +1,57 @@
from typing import Dict, Iterable, Callable
import pytest import pytest
from thinc.api import Config from thinc.api import Config
from spacy.util import load_model_from_config
from spacy import Language
from spacy.util import load_model_from_config, registry, dot_to_object
from spacy.training import Example
def test_readers():
config_string = """
[training]
[corpora]
@readers = "myreader.v1"
[nlp]
lang = "en"
pipeline = ["tok2vec", "textcat"]
[components]
[components.tok2vec]
factory = "tok2vec"
[components.textcat]
factory = "textcat"
"""
@registry.readers.register("myreader.v1")
def myreader() -> Dict[str, Callable[[Language, str], Iterable[Example]]]:
annots = {"cats": {"POS": 1.0, "NEG": 0.0}}
def reader(nlp: Language):
doc = nlp.make_doc(f"This is an example")
return [Example.from_dict(doc, annots)]
return {"train": reader, "dev": reader, "extra": reader, "something": reader}
config = Config().from_str(config_string)
nlp, resolved = load_model_from_config(config, auto_fill=True)
train_corpus = dot_to_object(resolved, resolved["training"]["train_corpus"])
assert isinstance(train_corpus, Callable)
optimizer = resolved["training"]["optimizer"]
# simulate a training loop
nlp.begin_training(lambda: train_corpus(nlp), sgd=optimizer)
for example in train_corpus(nlp):
nlp.update([example], sgd=optimizer)
dev_corpus = dot_to_object(resolved, resolved["training"]["dev_corpus"])
scores = nlp.evaluate(list(dev_corpus(nlp)))
assert scores["cats_score"]
# ensure the pipeline runs
doc = nlp("Quick test")
assert doc.cats
extra_corpus = resolved["corpora"]["extra"]
assert isinstance(extra_corpus, Callable)
@pytest.mark.slow @pytest.mark.slow
@ -16,7 +67,7 @@ def test_cat_readers(reader, additional_config):
nlp_config_string = """ nlp_config_string = """
[training] [training]
[training.corpus] [corpora]
@readers = "PLACEHOLDER" @readers = "PLACEHOLDER"
[nlp] [nlp]
@ -32,11 +83,11 @@ def test_cat_readers(reader, additional_config):
factory = "textcat" factory = "textcat"
""" """
config = Config().from_str(nlp_config_string) config = Config().from_str(nlp_config_string)
config["training"]["corpus"]["@readers"] = reader config["corpora"]["@readers"] = reader
config["training"]["corpus"].update(additional_config) config["corpora"].update(additional_config)
nlp, resolved = load_model_from_config(config, auto_fill=True) nlp, resolved = load_model_from_config(config, auto_fill=True)
train_corpus = resolved["training"]["corpus"]["train"] train_corpus = dot_to_object(resolved, resolved["training"]["train_corpus"])
optimizer = resolved["training"]["optimizer"] optimizer = resolved["training"]["optimizer"]
# simulate a training loop # simulate a training loop
nlp.begin_training(lambda: train_corpus(nlp), sgd=optimizer) nlp.begin_training(lambda: train_corpus(nlp), sgd=optimizer)
@ -46,7 +97,7 @@ def test_cat_readers(reader, additional_config):
assert sorted(list(set(example.y.cats.values()))) == [0.0, 1.0] assert sorted(list(set(example.y.cats.values()))) == [0.0, 1.0]
nlp.update([example], sgd=optimizer) nlp.update([example], sgd=optimizer)
# simulate performance benchmark on dev corpus # simulate performance benchmark on dev corpus
dev_corpus = resolved["training"]["corpus"]["dev"] dev_corpus = dot_to_object(resolved, resolved["training"]["dev_corpus"])
dev_examples = list(dev_corpus(nlp)) dev_examples = list(dev_corpus(nlp))
for example in dev_examples: for example in dev_examples:
# this shouldn't fail if each dev example has at least one positive label # this shouldn't fail if each dev example has at least one positive label

View File

@ -355,6 +355,16 @@ Registry @architectures
Name spacy.MaxoutWindowEncoder.v1 Name spacy.MaxoutWindowEncoder.v1
Module spacy.ml.models.tok2vec Module spacy.ml.models.tok2vec
File /path/to/spacy/ml/models/tok2vec.py (line 207) File /path/to/spacy/ml/models/tok2vec.py (line 207)
[corpora.dev]
Registry @readers
Name spacy.Corpus.v1
Module spacy.training.corpus
File /path/to/spacy/training/corpus.py (line 18)
[corpora.train]
Registry @readers
Name spacy.Corpus.v1
Module spacy.training.corpus
File /path/to/spacy/training/corpus.py (line 18)
[training.logger] [training.logger]
Registry @loggers Registry @loggers
Name spacy.ConsoleLogger.v1 Name spacy.ConsoleLogger.v1
@ -370,16 +380,6 @@ Registry @schedules
Name compounding.v1 Name compounding.v1
Module thinc.schedules Module thinc.schedules
File /path/to/thinc/thinc/schedules.py (line 43) File /path/to/thinc/thinc/schedules.py (line 43)
[training.corpus.dev]
Registry @readers
Name spacy.Corpus.v1
Module spacy.training.corpus
File /path/to/spacy/training/corpus.py (line 18)
[training.corpus.train]
Registry @readers
Name spacy.Corpus.v1
Module spacy.training.corpus
File /path/to/spacy/training/corpus.py (line 18)
[training.optimizer] [training.optimizer]
Registry @optimizers Registry @optimizers
Name Adam.v1 Name Adam.v1

View File

@ -26,7 +26,7 @@ streaming.
> [paths] > [paths]
> train = "corpus/train.spacy" > train = "corpus/train.spacy"
> >
> [training.corpus.train] > [corpora.train]
> @readers = "spacy.Corpus.v1" > @readers = "spacy.Corpus.v1"
> path = ${paths.train} > path = ${paths.train}
> gold_preproc = false > gold_preproc = false
@ -135,7 +135,7 @@ Initialize the reader.
> >
> ```ini > ```ini
> ### Example config > ### Example config
> [pretraining.corpus] > [corpora.pretrain]
> @readers = "spacy.JsonlReader.v1" > @readers = "spacy.JsonlReader.v1"
> path = "corpus/raw_text.jsonl" > path = "corpus/raw_text.jsonl"
> min_length = 0 > min_length = 0

View File

@ -121,28 +121,78 @@ that you don't want to hard-code in your config file.
$ python -m spacy train config.cfg --paths.train ./corpus/train.spacy $ python -m spacy train config.cfg --paths.train ./corpus/train.spacy
``` ```
### corpora {#config-corpora tag="section"}
This section defines a dictionary mapping of string keys to `Callable`
functions. Each callable takes an `nlp` object and yields
[`Example`](/api/example) objects. By default, the two keys `train` and `dev`
are specified and each refer to a [`Corpus`](/api/top-level#Corpus). When
pretraining, an additional pretrain section is added that defaults to a
[`JsonlReader`](/api/top-level#JsonlReader).
These subsections can be expanded with additional subsections, each referring to
a callback of type `Callable[[Language], Iterator[Example]]`:
> #### Example
>
> ```ini
> [corpora]
> [corpora.train]
> @readers = "spacy.Corpus.v1"
> path = ${paths:train}
>
> [corpora.dev]
> @readers = "spacy.Corpus.v1"
> path = ${paths:dev}
>
> [corpora.pretrain]
> @readers = "spacy.JsonlReader.v1"
> path = ${paths.raw}
> min_length = 5
> max_length = 500
>
> [corpora.mydata]
> @readers = "my_reader.v1"
> shuffle = true
> ```
Alternatively, the `corpora` block could refer to one function with return type
`Dict[str, Callable[[Language], Iterator[Example]]]`:
> #### Example
>
> ```ini
> [corpora]
> @readers = "my_dict_reader.v1"
> train_path = ${paths:train}
> dev_path = ${paths:dev}
> shuffle = true
>
> ```
### training {#config-training tag="section"} ### training {#config-training tag="section"}
This section defines settings and controls for the training and evaluation This section defines settings and controls for the training and evaluation
process that are used when you run [`spacy train`](/api/cli#train). process that are used when you run [`spacy train`](/api/cli#train).
| Name | Description | | Name | Description |
| --------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `accumulate_gradient` | Whether to divide the batch up into substeps. Defaults to `1`. ~~int~~ | | `accumulate_gradient` | Whether to divide the batch up into substeps. Defaults to `1`. ~~int~~ |
| `batcher` | Callable that takes an iterator of [`Doc`](/api/doc) objects and yields batches of `Doc`s. Defaults to [`batch_by_words`](/api/top-level#batch_by_words). ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ | | `batcher` | Callable that takes an iterator of [`Doc`](/api/doc) objects and yields batches of `Doc`s. Defaults to [`batch_by_words`](/api/top-level#batch_by_words). ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ |
| `corpus` | Dictionary with `train` and `dev` keys, each referring to a callable that takes the current `nlp` object and yields [`Example`](/api/example) objects. Defaults to [`Corpus`](/api/top-level#Corpus). ~~Callable[[Language], Iterator[Example]]~~ | | `dev_corpus` | Dot notation of the config location defining the dev corpus. Defaults to `corpora.dev`. ~~str~~ |
| `dropout` | The dropout rate. Defaults to `0.1`. ~~float~~ | | `dropout` | The dropout rate. Defaults to `0.1`. ~~float~~ |
| `eval_frequency` | How often to evaluate during training (steps). Defaults to `200`. ~~int~~ | | `eval_frequency` | How often to evaluate during training (steps). Defaults to `200`. ~~int~~ |
| `frozen_components` | Pipeline component names that are "frozen" and shouldn't be updated during training. See [here](/usage/training#config-components) for details. Defaults to `[]`. ~~List[str]~~ | | `frozen_components` | Pipeline component names that are "frozen" and shouldn't be updated during training. See [here](/usage/training#config-components) for details. Defaults to `[]`. ~~List[str]~~ |
| `init_tok2vec` | Optional path to pretrained tok2vec weights created with [`spacy pretrain`](/api/cli#pretrain). Defaults to variable `${paths.init_tok2vec}`. ~~Optional[str]~~ | | `init_tok2vec` | Optional path to pretrained tok2vec weights created with [`spacy pretrain`](/api/cli#pretrain). Defaults to variable `${paths.init_tok2vec}`. ~~Optional[str]~~ |
| `max_epochs` | Maximum number of epochs to train for. Defaults to `0`. ~~int~~ | | `max_epochs` | Maximum number of epochs to train for. Defaults to `0`. ~~int~~ |
| `max_steps` | Maximum number of update steps to train for. Defaults to `20000`. ~~int~~ | | `max_steps` | Maximum number of update steps to train for. Defaults to `20000`. ~~int~~ |
| `optimizer` | The optimizer. The learning rate schedule and other settings can be configured as part of the optimizer. Defaults to [`Adam`](https://thinc.ai/docs/api-optimizers#adam). ~~Optimizer~~ | | `optimizer` | The optimizer. The learning rate schedule and other settings can be configured as part of the optimizer. Defaults to [`Adam`](https://thinc.ai/docs/api-optimizers#adam). ~~Optimizer~~ |
| `patience` | How many steps to continue without improvement in evaluation score. Defaults to `1600`. ~~int~~ | | `patience` | How many steps to continue without improvement in evaluation score. Defaults to `1600`. ~~int~~ |
| `raw_text` | Optional path to a jsonl file with unlabelled text documents for a [rehearsal](/api/language#rehearse) step. Defaults to variable `${paths.raw}`. ~~Optional[str]~~ | | `raw_text` | Optional path to a jsonl file with unlabelled text documents for a [rehearsal](/api/language#rehearse) step. Defaults to variable `${paths.raw}`. ~~Optional[str]~~ |
| `score_weights` | Score names shown in metrics mapped to their weight towards the final weighted score. See [here](/usage/training#metrics) for details. Defaults to `{}`. ~~Dict[str, float]~~ | | `score_weights` | Score names shown in metrics mapped to their weight towards the final weighted score. See [here](/usage/training#metrics) for details. Defaults to `{}`. ~~Dict[str, float]~~ |
| `seed` | The random seed. Defaults to variable `${system.seed}`. ~~int~~ | | `seed` | The random seed. Defaults to variable `${system.seed}`. ~~int~~ |
| `vectors` | Name or path of pipeline containing pretrained word vectors to use, e.g. created with [`init vocab`](/api/cli#init-vocab). Defaults to `null`. ~~Optional[str]~~ | | `corpus` | Dot notation of the config location defining the train corpus. Defaults to `corpora.train`. ~~str~~ |
| `vectors` | Name or path of pipeline containing pretrained word vectors to use, e.g. created with [`init vocab`](/api/cli#init-vocab). Defaults to `null`. ~~Optional[str]~~ |
### pretraining {#config-pretraining tag="section,optional"} ### pretraining {#config-pretraining tag="section,optional"}
@ -150,17 +200,18 @@ This section is optional and defines settings and controls for
[language model pretraining](/usage/embeddings-transformers#pretraining). It's [language model pretraining](/usage/embeddings-transformers#pretraining). It's
used when you run [`spacy pretrain`](/api/cli#pretrain). used when you run [`spacy pretrain`](/api/cli#pretrain).
| Name | Description | | Name | Description |
| -------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | -------------- | ------------------------------------------------------------------------------------------------------ |
| `max_epochs` | Maximum number of epochs. Defaults to `1000`. ~~int~~ | | `max_epochs` | Maximum number of epochs. Defaults to `1000`. ~~int~~ |
| `dropout` | The dropout rate. Defaults to `0.2`. ~~float~~ | | `dropout` | The dropout rate. Defaults to `0.2`. ~~float~~ |
| `n_save_every` | Saving frequency. Defaults to `null`. ~~Optional[int]~~ | | `n_save_every` | Saving frequency. Defaults to `null`. ~~Optional[int]~~ |
| `objective` | The pretraining objective. Defaults to `{"type": "characters", "n_characters": 4}`. ~~Dict[str, Any]~~ | | `objective` | The pretraining objective. Defaults to `{"type": "characters", "n_characters": 4}`. ~~Dict[str, Any]~~ |
| `optimizer` | The optimizer. Defaults to [`Adam`](https://thinc.ai/docs/api-optimizers#adam). ~~Optimizer~~ | | `optimizer` | The optimizer. Defaults to [`Adam`](https://thinc.ai/docs/api-optimizers#adam). ~~Optimizer~~ |
| `corpus` | Callable that takes the current `nlp` object and yields [`Doc`](/api/doc) objects. Defaults to [`JsonlReader`](/api/top-level#JsonlReader). ~~Callable[[Language, str], Iterable[Example]]~~ | | `corpus` | Dot notation of the config location defining the train corpus. Defaults to `corpora.train`. ~~str~~ |
| `batcher` | Batcher for the training data. ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ | | `batcher` | Batcher for the training data. ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ |
| `component` | Component to find the layer to pretrain. Defaults to `"tok2vec"`. ~~str~~ | | `component` | Component to find the layer to pretrain. Defaults to `"tok2vec"`. ~~str~~ |
| `layer` | The layer to pretrain. If empty, the whole component model will be used. ~~str~~ | | `layer` | The layer to pretrain. If empty, the whole component model will be used. ~~str~~ |
| |
## Training data {#training} ## Training data {#training}

View File

@ -448,7 +448,7 @@ remain in the config file stored on your local system.
> [training.logger] > [training.logger]
> @loggers = "spacy.WandbLogger.v1" > @loggers = "spacy.WandbLogger.v1"
> project_name = "monitor_spacy_training" > project_name = "monitor_spacy_training"
> remove_config_values = ["paths.train", "paths.dev", "training.corpus.train.path", "training.corpus.dev.path"] > remove_config_values = ["paths.train", "paths.dev", "corpora.train.path", "corpora.dev.path"]
> ``` > ```
| Name | Description | | Name | Description |
@ -478,7 +478,7 @@ the [`Corpus`](/api/corpus) class.
> [paths] > [paths]
> train = "corpus/train.spacy" > train = "corpus/train.spacy"
> >
> [training.corpus.train] > [corpora.train]
> @readers = "spacy.Corpus.v1" > @readers = "spacy.Corpus.v1"
> path = ${paths.train} > path = ${paths.train}
> gold_preproc = false > gold_preproc = false
@ -506,7 +506,7 @@ JSONL file. Also see the [`JsonlReader`](/api/corpus#jsonlreader) class.
> [paths] > [paths]
> pretrain = "corpus/raw_text.jsonl" > pretrain = "corpus/raw_text.jsonl"
> >
> [pretraining.corpus] > [corpora.pretrain]
> @readers = "spacy.JsonlReader.v1" > @readers = "spacy.JsonlReader.v1"
> path = ${paths.pretrain} > path = ${paths.pretrain}
> min_length = 0 > min_length = 0

View File

@ -969,7 +969,7 @@ your results.
> [training.logger] > [training.logger]
> @loggers = "spacy.WandbLogger.v1" > @loggers = "spacy.WandbLogger.v1"
> project_name = "monitor_spacy_training" > project_name = "monitor_spacy_training"
> remove_config_values = ["paths.train", "paths.dev", "training.corpus.train.path", "training.corpus.dev.path"] > remove_config_values = ["paths.train", "paths.dev", "corpora.train.path", "corpora.dev.path"]
> ``` > ```
![Screenshot: Visualized training results](../images/wandb1.jpg) ![Screenshot: Visualized training results](../images/wandb1.jpg)

View File

@ -746,7 +746,7 @@ as **config settings** in this case, `source`.
> #### config.cfg > #### config.cfg
> >
> ```ini > ```ini
> [training.corpus.train] > [corpora.train]
> @readers = "corpus_variants.v1" > @readers = "corpus_variants.v1"
> source = "s3://your_bucket/path/data.csv" > source = "s3://your_bucket/path/data.csv"
> ``` > ```