diff --git a/extra/experiments/onto-joint/defaults.cfg b/extra/experiments/onto-joint/defaults.cfg index 97eebe6b4..90101281c 100644 --- a/extra/experiments/onto-joint/defaults.cfg +++ b/extra/experiments/onto-joint/defaults.cfg @@ -8,6 +8,22 @@ init_tok2vec = null seed = 0 use_pytorch_for_gpu_memory = false +[corpora] + +[corpora.train] +@readers = "spacy.Corpus.v1" +path = ${paths:train} +gold_preproc = true +max_length = 0 +limit = 0 + +[corpora.dev] +@readers = "spacy.Corpus.v1" +path = ${paths:dev} +gold_preproc = ${corpora.train.gold_preproc} +max_length = 0 +limit = 0 + [training] seed = ${system:seed} dropout = 0.1 @@ -20,22 +36,8 @@ patience = 10000 eval_frequency = 200 score_weights = {"dep_las": 0.4, "ents_f": 0.4, "tag_acc": 0.2} frozen_components = [] - -[training.corpus] - -[training.corpus.train] -@readers = "spacy.Corpus.v1" -path = ${paths:train} -gold_preproc = true -max_length = 0 -limit = 0 - -[training.corpus.dev] -@readers = "spacy.Corpus.v1" -path = ${paths:dev} -gold_preproc = ${training.read_train:gold_preproc} -max_length = 0 -limit = 0 +dev_corpus = "corpora.dev" +train_corpus = "corpora.train" [training.batcher] @batchers = "spacy.batch_by_words.v1" diff --git a/extra/experiments/ptb-joint-pos-dep/defaults.cfg b/extra/experiments/ptb-joint-pos-dep/defaults.cfg index 03e2f5bd7..55fb52b99 100644 --- a/extra/experiments/ptb-joint-pos-dep/defaults.cfg +++ b/extra/experiments/ptb-joint-pos-dep/defaults.cfg @@ -8,6 +8,22 @@ init_tok2vec = null seed = 0 use_pytorch_for_gpu_memory = false +[corpora] + +[corpora.train] +@readers = "spacy.Corpus.v1" +path = ${paths:train} +gold_preproc = true +max_length = 0 +limit = 0 + +[corpora.dev] +@readers = "spacy.Corpus.v1" +path = ${paths:dev} +gold_preproc = ${corpora.train.gold_preproc} +max_length = 0 +limit = 0 + [training] seed = ${system:seed} dropout = 0.2 @@ -20,22 +36,6 @@ patience = 10000 eval_frequency = 200 score_weights = {"dep_las": 0.8, "tag_acc": 0.2} -[training.corpus] - -[training.corpus.train] -@readers = "spacy.Corpus.v1" -path = ${paths:train} -gold_preproc = true -max_length = 0 -limit = 0 - -[training.corpus.dev] -@readers = "spacy.Corpus.v1" -path = ${paths:dev} -gold_preproc = ${training.read_train:gold_preproc} -max_length = 0 -limit = 0 - [training.batcher] @batchers = "spacy.batch_by_words.v1" discard_oversize = false diff --git a/spacy/cli/pretrain.py b/spacy/cli/pretrain.py index 70858123d..3567e7339 100644 --- a/spacy/cli/pretrain.py +++ b/spacy/cli/pretrain.py @@ -20,6 +20,7 @@ from ..ml.models.multi_task import build_cloze_characters_multi_task_model from ..tokens import Doc from ..attrs import ID from .. import util +from ..util import dot_to_object @app.command( @@ -106,7 +107,7 @@ def pretrain( use_pytorch_for_gpu_memory() nlp, config = util.load_model_from_config(config) P_cfg = config["pretraining"] - corpus = P_cfg["corpus"] + corpus = dot_to_object(config, config["pretraining"]["corpus"]) batcher = P_cfg["batcher"] model = create_pretraining_model(nlp, config["pretraining"]) optimizer = config["pretraining"]["optimizer"] diff --git a/spacy/cli/templates/quickstart_training.jinja b/spacy/cli/templates/quickstart_training.jinja index 39d4d875d..00b77af4d 100644 --- a/spacy/cli/templates/quickstart_training.jinja +++ b/spacy/cli/templates/quickstart_training.jinja @@ -173,6 +173,18 @@ factory = "{{ pipe }}" {% endif %} {% endfor %} +[corpora] + +[corpora.train] +@readers = "spacy.Corpus.v1" +path = ${paths.train} +max_length = {{ 500 if hardware == "gpu" else 2000 }} + +[corpora.dev] +@readers = "spacy.Corpus.v1" +path = ${paths.dev} +max_length = 0 + [training] {% if use_transformer or optimize == "efficiency" or not word_vectors -%} vectors = null @@ -182,11 +194,12 @@ vectors = "{{ word_vectors }}" {% if use_transformer -%} accumulate_gradient = {{ transformer["size_factor"] }} {% endif %} +dev_corpus = "corpora.dev" +train_corpus = "corpora.train" [training.optimizer] @optimizers = "Adam.v1" - {% if use_transformer -%} [training.optimizer.learn_rate] @schedules = "warmup_linear.v1" @@ -195,18 +208,6 @@ total_steps = 20000 initial_rate = 5e-5 {% endif %} -[training.corpus] - -[training.corpus.train] -@readers = "spacy.Corpus.v1" -path = ${paths.train} -max_length = {{ 500 if hardware == "gpu" else 2000 }} - -[training.corpus.dev] -@readers = "spacy.Corpus.v1" -path = ${paths.dev} -max_length = 0 - {% if use_transformer %} [training.batcher] @batchers = "spacy.batch_by_padded.v1" diff --git a/spacy/cli/train.py b/spacy/cli/train.py index 2c2eeb88b..15c745b69 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -18,6 +18,7 @@ from ..language import Language from .. import util from ..training.example import Example from ..errors import Errors +from ..util import dot_to_object @app.command( @@ -92,8 +93,8 @@ def train( raw_text, tag_map, morph_rules, weights_data = load_from_paths(config) T_cfg = config["training"] optimizer = T_cfg["optimizer"] - train_corpus = T_cfg["corpus"]["train"] - dev_corpus = T_cfg["corpus"]["dev"] + train_corpus = dot_to_object(config, config["training"]["train_corpus"]) + dev_corpus = dot_to_object(config, config["training"]["dev_corpus"]) batcher = T_cfg["batcher"] train_logger = T_cfg["logger"] # Components that shouldn't be updated during training diff --git a/spacy/default_config.cfg b/spacy/default_config.cfg index 61f3dfe25..c7c9593d7 100644 --- a/spacy/default_config.cfg +++ b/spacy/default_config.cfg @@ -22,6 +22,33 @@ after_pipeline_creation = null [components] +# Readers for corpora like dev and train. +[corpora] + +[corpora.train] +@readers = "spacy.Corpus.v1" +path = ${paths.train} +# Whether to train on sequences with 'gold standard' sentence boundaries +# and tokens. If you set this to true, take care to ensure your run-time +# data is passed in sentence-by-sentence via some prior preprocessing. +gold_preproc = false +# Limitations on training document length +max_length = 0 +# Limitation on number of training examples +limit = 0 + +[corpora.dev] +@readers = "spacy.Corpus.v1" +path = ${paths.dev} +# Whether to train on sequences with 'gold standard' sentence boundaries +# and tokens. If you set this to true, take care to ensure your run-time +# data is passed in sentence-by-sentence via some prior preprocessing. +gold_preproc = false +# Limitations on training document length +max_length = 0 +# Limitation on number of training examples +limit = 0 + # Training hyper-parameters and additional features. [training] seed = ${system.seed} @@ -40,35 +67,14 @@ eval_frequency = 200 score_weights = {} # Names of pipeline components that shouldn't be updated during training frozen_components = [] +# Location in the config where the dev corpus is defined +dev_corpus = "corpora.dev" +# Location in the config where the train corpus is defined +train_corpus = "corpora.train" [training.logger] @loggers = "spacy.ConsoleLogger.v1" -[training.corpus] - -[training.corpus.train] -@readers = "spacy.Corpus.v1" -path = ${paths.train} -# Whether to train on sequences with 'gold standard' sentence boundaries -# and tokens. If you set this to true, take care to ensure your run-time -# data is passed in sentence-by-sentence via some prior preprocessing. -gold_preproc = false -# Limitations on training document length -max_length = 0 -# Limitation on number of training examples -limit = 0 - -[training.corpus.dev] -@readers = "spacy.Corpus.v1" -path = ${paths.dev} -# Whether to train on sequences with 'gold standard' sentence boundaries -# and tokens. If you set this to true, take care to ensure your run-time -# data is passed in sentence-by-sentence via some prior preprocessing. -gold_preproc = false -# Limitations on training document length -max_length = 0 -# Limitation on number of training examples -limit = 0 [training.batcher] @batchers = "spacy.batch_by_words.v1" diff --git a/spacy/default_config_pretraining.cfg b/spacy/default_config_pretraining.cfg index 9120db338..bbd595308 100644 --- a/spacy/default_config_pretraining.cfg +++ b/spacy/default_config_pretraining.cfg @@ -4,6 +4,7 @@ dropout = 0.2 n_save_every = null component = "tok2vec" layer = "" +corpus = "corpora.pretrain" [pretraining.batcher] @batchers = "spacy.batch_by_words.v1" @@ -12,13 +13,6 @@ discard_oversize = false tolerance = 0.2 get_length = null -[pretraining.corpus] -@readers = "spacy.JsonlReader.v1" -path = ${paths.raw} -min_length = 5 -max_length = 500 -limit = 0 - [pretraining.objective] type = "characters" n_characters = 4 @@ -33,3 +27,12 @@ grad_clip = 1.0 use_averages = true eps = 1e-8 learn_rate = 0.001 + +[corpora] + +[corpora.pretrain] +@readers = "spacy.JsonlReader.v1" +path = ${paths.raw} +min_length = 5 +max_length = 500 +limit = 0 diff --git a/spacy/schemas.py b/spacy/schemas.py index 2030048d8..a530db3d0 100644 --- a/spacy/schemas.py +++ b/spacy/schemas.py @@ -198,7 +198,8 @@ class ModelMetaSchema(BaseModel): class ConfigSchemaTraining(BaseModel): # fmt: off vectors: Optional[StrictStr] = Field(..., title="Path to vectors") - corpus: Dict[str, Reader] = Field(..., title="Reader for the training and dev data") + dev_corpus: StrictStr = Field(..., title="Path in the config to the dev data") + train_corpus: StrictStr = Field(..., title="Path in the config to the training data") batcher: Batcher = Field(..., title="Batcher for the training data") dropout: StrictFloat = Field(..., title="Dropout rate") patience: StrictInt = Field(..., title="How many steps to continue without improvement in evaluation score") @@ -248,7 +249,7 @@ class ConfigSchemaPretrain(BaseModel): dropout: StrictFloat = Field(..., title="Dropout rate") n_save_every: Optional[StrictInt] = Field(..., title="Saving frequency") optimizer: Optimizer = Field(..., title="The optimizer to use") - corpus: Reader = Field(..., title="Reader for the training data") + corpus: StrictStr = Field(..., title="Path in the config to the training data") batcher: Batcher = Field(..., title="Batcher for the training data") component: str = Field(..., title="Component to find the layer to pretrain") layer: str = Field(..., title="Layer to pretrain. Whole model if empty.") @@ -267,6 +268,7 @@ class ConfigSchema(BaseModel): nlp: ConfigSchemaNlp pretraining: Union[ConfigSchemaPretrain, ConfigSchemaPretrainEmpty] = {} components: Dict[str, Dict[str, Any]] + corpora: Dict[str, Reader] @root_validator(allow_reuse=True) def validate_config(cls, values): diff --git a/spacy/tests/serialize/test_serialize_config.py b/spacy/tests/serialize/test_serialize_config.py index d113ac2a5..1e17b3212 100644 --- a/spacy/tests/serialize/test_serialize_config.py +++ b/spacy/tests/serialize/test_serialize_config.py @@ -17,18 +17,18 @@ nlp_config_string = """ train = "" dev = "" -[training] +[corpora] -[training.corpus] - -[training.corpus.train] +[corpora.train] @readers = "spacy.Corpus.v1" path = ${paths.train} -[training.corpus.dev] +[corpora.dev] @readers = "spacy.Corpus.v1" path = ${paths.dev} +[training] + [training.batcher] @batchers = "spacy.batch_by_words.v1" size = 666 @@ -302,20 +302,20 @@ def test_config_overrides(): def test_config_interpolation(): config = Config().from_str(nlp_config_string, interpolate=False) - assert config["training"]["corpus"]["train"]["path"] == "${paths.train}" + assert config["corpora"]["train"]["path"] == "${paths.train}" interpolated = config.interpolate() - assert interpolated["training"]["corpus"]["train"]["path"] == "" + assert interpolated["corpora"]["train"]["path"] == "" nlp = English.from_config(config) - assert nlp.config["training"]["corpus"]["train"]["path"] == "${paths.train}" + assert nlp.config["corpora"]["train"]["path"] == "${paths.train}" # Ensure that variables are preserved in nlp config width = "${components.tok2vec.model.width}" assert config["components"]["tagger"]["model"]["tok2vec"]["width"] == width assert nlp.config["components"]["tagger"]["model"]["tok2vec"]["width"] == width interpolated2 = nlp.config.interpolate() - assert interpolated2["training"]["corpus"]["train"]["path"] == "" + assert interpolated2["corpora"]["train"]["path"] == "" assert interpolated2["components"]["tagger"]["model"]["tok2vec"]["width"] == 342 nlp2 = English.from_config(interpolated) - assert nlp2.config["training"]["corpus"]["train"]["path"] == "" + assert nlp2.config["corpora"]["train"]["path"] == "" assert nlp2.config["components"]["tagger"]["model"]["tok2vec"]["width"] == 342 diff --git a/spacy/tests/training/test_readers.py b/spacy/tests/training/test_readers.py index c81ec0897..52a4abecc 100644 --- a/spacy/tests/training/test_readers.py +++ b/spacy/tests/training/test_readers.py @@ -1,6 +1,57 @@ +from typing import Dict, Iterable, Callable import pytest from thinc.api import Config -from spacy.util import load_model_from_config + +from spacy import Language +from spacy.util import load_model_from_config, registry, dot_to_object +from spacy.training import Example + + +def test_readers(): + config_string = """ + [training] + + [corpora] + @readers = "myreader.v1" + + [nlp] + lang = "en" + pipeline = ["tok2vec", "textcat"] + + [components] + + [components.tok2vec] + factory = "tok2vec" + + [components.textcat] + factory = "textcat" + """ + @registry.readers.register("myreader.v1") + def myreader() -> Dict[str, Callable[[Language, str], Iterable[Example]]]: + annots = {"cats": {"POS": 1.0, "NEG": 0.0}} + def reader(nlp: Language): + doc = nlp.make_doc(f"This is an example") + return [Example.from_dict(doc, annots)] + return {"train": reader, "dev": reader, "extra": reader, "something": reader} + + config = Config().from_str(config_string) + nlp, resolved = load_model_from_config(config, auto_fill=True) + + train_corpus = dot_to_object(resolved, resolved["training"]["train_corpus"]) + assert isinstance(train_corpus, Callable) + optimizer = resolved["training"]["optimizer"] + # simulate a training loop + nlp.begin_training(lambda: train_corpus(nlp), sgd=optimizer) + for example in train_corpus(nlp): + nlp.update([example], sgd=optimizer) + dev_corpus = dot_to_object(resolved, resolved["training"]["dev_corpus"]) + scores = nlp.evaluate(list(dev_corpus(nlp))) + assert scores["cats_score"] + # ensure the pipeline runs + doc = nlp("Quick test") + assert doc.cats + extra_corpus = resolved["corpora"]["extra"] + assert isinstance(extra_corpus, Callable) @pytest.mark.slow @@ -16,7 +67,7 @@ def test_cat_readers(reader, additional_config): nlp_config_string = """ [training] - [training.corpus] + [corpora] @readers = "PLACEHOLDER" [nlp] @@ -32,11 +83,11 @@ def test_cat_readers(reader, additional_config): factory = "textcat" """ config = Config().from_str(nlp_config_string) - config["training"]["corpus"]["@readers"] = reader - config["training"]["corpus"].update(additional_config) + config["corpora"]["@readers"] = reader + config["corpora"].update(additional_config) nlp, resolved = load_model_from_config(config, auto_fill=True) - train_corpus = resolved["training"]["corpus"]["train"] + train_corpus = dot_to_object(resolved, resolved["training"]["train_corpus"]) optimizer = resolved["training"]["optimizer"] # simulate a training loop nlp.begin_training(lambda: train_corpus(nlp), sgd=optimizer) @@ -46,7 +97,7 @@ def test_cat_readers(reader, additional_config): assert sorted(list(set(example.y.cats.values()))) == [0.0, 1.0] nlp.update([example], sgd=optimizer) # simulate performance benchmark on dev corpus - dev_corpus = resolved["training"]["corpus"]["dev"] + dev_corpus = dot_to_object(resolved, resolved["training"]["dev_corpus"]) dev_examples = list(dev_corpus(nlp)) for example in dev_examples: # this shouldn't fail if each dev example has at least one positive label diff --git a/website/docs/api/cli.md b/website/docs/api/cli.md index 7dd6e6184..5c5eb6486 100644 --- a/website/docs/api/cli.md +++ b/website/docs/api/cli.md @@ -355,6 +355,16 @@ Registry @architectures Name spacy.MaxoutWindowEncoder.v1 Module spacy.ml.models.tok2vec File /path/to/spacy/ml/models/tok2vec.py (line 207) +ℹ [corpora.dev] +Registry @readers +Name spacy.Corpus.v1 +Module spacy.training.corpus +File /path/to/spacy/training/corpus.py (line 18) +ℹ [corpora.train] +Registry @readers +Name spacy.Corpus.v1 +Module spacy.training.corpus +File /path/to/spacy/training/corpus.py (line 18) ℹ [training.logger] Registry @loggers Name spacy.ConsoleLogger.v1 @@ -370,16 +380,6 @@ Registry @schedules Name compounding.v1 Module thinc.schedules File /path/to/thinc/thinc/schedules.py (line 43) -ℹ [training.corpus.dev] -Registry @readers -Name spacy.Corpus.v1 -Module spacy.training.corpus -File /path/to/spacy/training/corpus.py (line 18) -ℹ [training.corpus.train] -Registry @readers -Name spacy.Corpus.v1 -Module spacy.training.corpus -File /path/to/spacy/training/corpus.py (line 18) ℹ [training.optimizer] Registry @optimizers Name Adam.v1 diff --git a/website/docs/api/corpus.md b/website/docs/api/corpus.md index c25ce1651..2b308d618 100644 --- a/website/docs/api/corpus.md +++ b/website/docs/api/corpus.md @@ -26,7 +26,7 @@ streaming. > [paths] > train = "corpus/train.spacy" > -> [training.corpus.train] +> [corpora.train] > @readers = "spacy.Corpus.v1" > path = ${paths.train} > gold_preproc = false @@ -135,7 +135,7 @@ Initialize the reader. > > ```ini > ### Example config -> [pretraining.corpus] +> [corpora.pretrain] > @readers = "spacy.JsonlReader.v1" > path = "corpus/raw_text.jsonl" > min_length = 0 diff --git a/website/docs/api/data-formats.md b/website/docs/api/data-formats.md index cf091e16c..f868233c7 100644 --- a/website/docs/api/data-formats.md +++ b/website/docs/api/data-formats.md @@ -121,28 +121,78 @@ that you don't want to hard-code in your config file. $ python -m spacy train config.cfg --paths.train ./corpus/train.spacy ``` +### corpora {#config-corpora tag="section"} + +This section defines a dictionary mapping of string keys to `Callable` +functions. Each callable takes an `nlp` object and yields +[`Example`](/api/example) objects. By default, the two keys `train` and `dev` +are specified and each refer to a [`Corpus`](/api/top-level#Corpus). When +pretraining, an additional pretrain section is added that defaults to a +[`JsonlReader`](/api/top-level#JsonlReader). + +These subsections can be expanded with additional subsections, each referring to +a callback of type `Callable[[Language], Iterator[Example]]`: + +> #### Example +> +> ```ini +> [corpora] +> [corpora.train] +> @readers = "spacy.Corpus.v1" +> path = ${paths:train} +> +> [corpora.dev] +> @readers = "spacy.Corpus.v1" +> path = ${paths:dev} +> +> [corpora.pretrain] +> @readers = "spacy.JsonlReader.v1" +> path = ${paths.raw} +> min_length = 5 +> max_length = 500 +> +> [corpora.mydata] +> @readers = "my_reader.v1" +> shuffle = true +> ``` + +Alternatively, the `corpora` block could refer to one function with return type +`Dict[str, Callable[[Language], Iterator[Example]]]`: + +> #### Example +> +> ```ini +> [corpora] +> @readers = "my_dict_reader.v1" +> train_path = ${paths:train} +> dev_path = ${paths:dev} +> shuffle = true +> +> ``` + ### training {#config-training tag="section"} This section defines settings and controls for the training and evaluation process that are used when you run [`spacy train`](/api/cli#train). -| Name | Description | -| --------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `accumulate_gradient` | Whether to divide the batch up into substeps. Defaults to `1`. ~~int~~ | -| `batcher` | Callable that takes an iterator of [`Doc`](/api/doc) objects and yields batches of `Doc`s. Defaults to [`batch_by_words`](/api/top-level#batch_by_words). ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ | -| `corpus` | Dictionary with `train` and `dev` keys, each referring to a callable that takes the current `nlp` object and yields [`Example`](/api/example) objects. Defaults to [`Corpus`](/api/top-level#Corpus). ~~Callable[[Language], Iterator[Example]]~~ | -| `dropout` | The dropout rate. Defaults to `0.1`. ~~float~~ | -| `eval_frequency` | How often to evaluate during training (steps). Defaults to `200`. ~~int~~ | -| `frozen_components` | Pipeline component names that are "frozen" and shouldn't be updated during training. See [here](/usage/training#config-components) for details. Defaults to `[]`. ~~List[str]~~ | -| `init_tok2vec` | Optional path to pretrained tok2vec weights created with [`spacy pretrain`](/api/cli#pretrain). Defaults to variable `${paths.init_tok2vec}`. ~~Optional[str]~~ | -| `max_epochs` | Maximum number of epochs to train for. Defaults to `0`. ~~int~~ | -| `max_steps` | Maximum number of update steps to train for. Defaults to `20000`. ~~int~~ | -| `optimizer` | The optimizer. The learning rate schedule and other settings can be configured as part of the optimizer. Defaults to [`Adam`](https://thinc.ai/docs/api-optimizers#adam). ~~Optimizer~~ | -| `patience` | How many steps to continue without improvement in evaluation score. Defaults to `1600`. ~~int~~ | -| `raw_text` | Optional path to a jsonl file with unlabelled text documents for a [rehearsal](/api/language#rehearse) step. Defaults to variable `${paths.raw}`. ~~Optional[str]~~ | -| `score_weights` | Score names shown in metrics mapped to their weight towards the final weighted score. See [here](/usage/training#metrics) for details. Defaults to `{}`. ~~Dict[str, float]~~ | -| `seed` | The random seed. Defaults to variable `${system.seed}`. ~~int~~ | -| `vectors` | Name or path of pipeline containing pretrained word vectors to use, e.g. created with [`init vocab`](/api/cli#init-vocab). Defaults to `null`. ~~Optional[str]~~ | +| Name | Description | +| --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `accumulate_gradient` | Whether to divide the batch up into substeps. Defaults to `1`. ~~int~~ | +| `batcher` | Callable that takes an iterator of [`Doc`](/api/doc) objects and yields batches of `Doc`s. Defaults to [`batch_by_words`](/api/top-level#batch_by_words). ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ | +| `dev_corpus` | Dot notation of the config location defining the dev corpus. Defaults to `corpora.dev`. ~~str~~ | +| `dropout` | The dropout rate. Defaults to `0.1`. ~~float~~ | +| `eval_frequency` | How often to evaluate during training (steps). Defaults to `200`. ~~int~~ | +| `frozen_components` | Pipeline component names that are "frozen" and shouldn't be updated during training. See [here](/usage/training#config-components) for details. Defaults to `[]`. ~~List[str]~~ | +| `init_tok2vec` | Optional path to pretrained tok2vec weights created with [`spacy pretrain`](/api/cli#pretrain). Defaults to variable `${paths.init_tok2vec}`. ~~Optional[str]~~ | +| `max_epochs` | Maximum number of epochs to train for. Defaults to `0`. ~~int~~ | +| `max_steps` | Maximum number of update steps to train for. Defaults to `20000`. ~~int~~ | +| `optimizer` | The optimizer. The learning rate schedule and other settings can be configured as part of the optimizer. Defaults to [`Adam`](https://thinc.ai/docs/api-optimizers#adam). ~~Optimizer~~ | +| `patience` | How many steps to continue without improvement in evaluation score. Defaults to `1600`. ~~int~~ | +| `raw_text` | Optional path to a jsonl file with unlabelled text documents for a [rehearsal](/api/language#rehearse) step. Defaults to variable `${paths.raw}`. ~~Optional[str]~~ | +| `score_weights` | Score names shown in metrics mapped to their weight towards the final weighted score. See [here](/usage/training#metrics) for details. Defaults to `{}`. ~~Dict[str, float]~~ | +| `seed` | The random seed. Defaults to variable `${system.seed}`. ~~int~~ | +| `corpus` | Dot notation of the config location defining the train corpus. Defaults to `corpora.train`. ~~str~~ | +| `vectors` | Name or path of pipeline containing pretrained word vectors to use, e.g. created with [`init vocab`](/api/cli#init-vocab). Defaults to `null`. ~~Optional[str]~~ | ### pretraining {#config-pretraining tag="section,optional"} @@ -150,17 +200,18 @@ This section is optional and defines settings and controls for [language model pretraining](/usage/embeddings-transformers#pretraining). It's used when you run [`spacy pretrain`](/api/cli#pretrain). -| Name | Description | -| -------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `max_epochs` | Maximum number of epochs. Defaults to `1000`. ~~int~~ | -| `dropout` | The dropout rate. Defaults to `0.2`. ~~float~~ | -| `n_save_every` | Saving frequency. Defaults to `null`. ~~Optional[int]~~ | -| `objective` | The pretraining objective. Defaults to `{"type": "characters", "n_characters": 4}`. ~~Dict[str, Any]~~ | -| `optimizer` | The optimizer. Defaults to [`Adam`](https://thinc.ai/docs/api-optimizers#adam). ~~Optimizer~~ | -| `corpus` | Callable that takes the current `nlp` object and yields [`Doc`](/api/doc) objects. Defaults to [`JsonlReader`](/api/top-level#JsonlReader). ~~Callable[[Language, str], Iterable[Example]]~~ | -| `batcher` | Batcher for the training data. ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ | -| `component` | Component to find the layer to pretrain. Defaults to `"tok2vec"`. ~~str~~ | -| `layer` | The layer to pretrain. If empty, the whole component model will be used. ~~str~~ | +| Name | Description | +| -------------- | ------------------------------------------------------------------------------------------------------ | +| `max_epochs` | Maximum number of epochs. Defaults to `1000`. ~~int~~ | +| `dropout` | The dropout rate. Defaults to `0.2`. ~~float~~ | +| `n_save_every` | Saving frequency. Defaults to `null`. ~~Optional[int]~~ | +| `objective` | The pretraining objective. Defaults to `{"type": "characters", "n_characters": 4}`. ~~Dict[str, Any]~~ | +| `optimizer` | The optimizer. Defaults to [`Adam`](https://thinc.ai/docs/api-optimizers#adam). ~~Optimizer~~ | +| `corpus` | Dot notation of the config location defining the train corpus. Defaults to `corpora.train`. ~~str~~ | +| `batcher` | Batcher for the training data. ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ | +| `component` | Component to find the layer to pretrain. Defaults to `"tok2vec"`. ~~str~~ | +| `layer` | The layer to pretrain. If empty, the whole component model will be used. ~~str~~ | +| | ## Training data {#training} diff --git a/website/docs/api/top-level.md b/website/docs/api/top-level.md index be7994d5d..72b79de48 100644 --- a/website/docs/api/top-level.md +++ b/website/docs/api/top-level.md @@ -448,7 +448,7 @@ remain in the config file stored on your local system. > [training.logger] > @loggers = "spacy.WandbLogger.v1" > project_name = "monitor_spacy_training" -> remove_config_values = ["paths.train", "paths.dev", "training.corpus.train.path", "training.corpus.dev.path"] +> remove_config_values = ["paths.train", "paths.dev", "corpora.train.path", "corpora.dev.path"] > ``` | Name | Description | @@ -478,7 +478,7 @@ the [`Corpus`](/api/corpus) class. > [paths] > train = "corpus/train.spacy" > -> [training.corpus.train] +> [corpora.train] > @readers = "spacy.Corpus.v1" > path = ${paths.train} > gold_preproc = false @@ -506,7 +506,7 @@ JSONL file. Also see the [`JsonlReader`](/api/corpus#jsonlreader) class. > [paths] > pretrain = "corpus/raw_text.jsonl" > -> [pretraining.corpus] +> [corpora.pretrain] > @readers = "spacy.JsonlReader.v1" > path = ${paths.pretrain} > min_length = 0 diff --git a/website/docs/usage/projects.md b/website/docs/usage/projects.md index 3a6bd4551..665caa15b 100644 --- a/website/docs/usage/projects.md +++ b/website/docs/usage/projects.md @@ -969,7 +969,7 @@ your results. > [training.logger] > @loggers = "spacy.WandbLogger.v1" > project_name = "monitor_spacy_training" -> remove_config_values = ["paths.train", "paths.dev", "training.corpus.train.path", "training.corpus.dev.path"] +> remove_config_values = ["paths.train", "paths.dev", "corpora.train.path", "corpora.dev.path"] > ``` ![Screenshot: Visualized training results](../images/wandb1.jpg) diff --git a/website/docs/usage/training.md b/website/docs/usage/training.md index bba2e2853..c0f4caad7 100644 --- a/website/docs/usage/training.md +++ b/website/docs/usage/training.md @@ -746,7 +746,7 @@ as **config settings** – in this case, `source`. > #### config.cfg > > ```ini -> [training.corpus.train] +> [corpora.train] > @readers = "corpus_variants.v1" > source = "s3://your_bucket/path/data.csv" > ```