Merge pull request #6078 from svlandeg/fix/corpus

This commit is contained in:
Ines Montani 2020-09-18 14:44:21 +02:00 committed by GitHub
commit a127fa475e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 280 additions and 103 deletions

View File

@ -6,7 +6,7 @@ requires = [
"cymem>=2.0.2,<2.1.0",
"preshed>=3.0.2,<3.1.0",
"murmurhash>=0.28.0,<1.1.0",
"thinc>=8.0.0a31,<8.0.0a40",
"thinc>=8.0.0a33,<8.0.0a40",
"blis>=0.4.0,<0.5.0",
"pytokenizations",
"pathy"

View File

@ -1,9 +1,9 @@
# Our libraries
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
thinc>=8.0.0a31,<8.0.0a40
thinc>=8.0.0a33,<8.0.0a40
blis>=0.4.0,<0.5.0
ml_datasets>=0.1.1
ml_datasets==0.2.0a0
murmurhash>=0.28.0,<1.1.0
wasabi>=0.8.0,<1.1.0
srsly>=2.1.0,<3.0.0

View File

@ -34,13 +34,13 @@ setup_requires =
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
murmurhash>=0.28.0,<1.1.0
thinc>=8.0.0a31,<8.0.0a40
thinc>=8.0.0a33,<8.0.0a40
install_requires =
# Our libraries
murmurhash>=0.28.0,<1.1.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
thinc>=8.0.0a31,<8.0.0a40
thinc>=8.0.0a33,<8.0.0a40
blis>=0.4.0,<0.5.0
wasabi>=0.8.0,<1.1.0
srsly>=2.1.0,<3.0.0

View File

@ -20,6 +20,7 @@ from ..ml.models.multi_task import build_cloze_characters_multi_task_model
from ..tokens import Doc
from ..attrs import ID
from .. import util
from ..util import dot_to_object
@app.command(
@ -70,9 +71,7 @@ def pretrain_cli(
with show_validation_error(config_path):
config = util.load_config(
config_path,
overrides=config_overrides,
interpolate=True
config_path, overrides=config_overrides, interpolate=True
)
if not config.get("pretraining"):
# TODO: What's the solution here? How do we handle optional blocks?
@ -83,7 +82,7 @@ def pretrain_cli(
config.to_disk(output_dir / "config.cfg")
msg.good("Saved config file in the output directory")
pretrain(
config,
output_dir,
@ -98,7 +97,7 @@ def pretrain(
output_dir: Path,
resume_path: Optional[Path] = None,
epoch_resume: Optional[int] = None,
use_gpu: int=-1
use_gpu: int = -1,
):
if config["system"].get("seed") is not None:
fix_random_seed(config["system"]["seed"])
@ -106,7 +105,7 @@ def pretrain(
use_pytorch_for_gpu_memory()
nlp, config = util.load_model_from_config(config)
P_cfg = config["pretraining"]
corpus = P_cfg["corpus"]
corpus = dot_to_object(config, P_cfg["corpus"])
batcher = P_cfg["batcher"]
model = create_pretraining_model(nlp, config["pretraining"])
optimizer = config["pretraining"]["optimizer"]
@ -147,9 +146,7 @@ def pretrain(
progress = tracker.update(epoch, loss, docs)
if progress:
msg.row(progress, **row_settings)
if P_cfg["n_save_every"] and (
batch_id % P_cfg["n_save_every"] == 0
):
if P_cfg["n_save_every"] and (batch_id % P_cfg["n_save_every"] == 0):
_save_model(epoch, is_temp=True)
_save_model(epoch)
tracker.epoch_loss = 0.0

View File

@ -173,6 +173,18 @@ factory = "{{ pipe }}"
{% endif %}
{% endfor %}
[corpora]
[corpora.train]
@readers = "spacy.Corpus.v1"
path = ${paths.train}
max_length = {{ 500 if hardware == "gpu" else 2000 }}
[corpora.dev]
@readers = "spacy.Corpus.v1"
path = ${paths.dev}
max_length = 0
[training]
{% if use_transformer or optimize == "efficiency" or not word_vectors -%}
vectors = null
@ -182,11 +194,12 @@ vectors = "{{ word_vectors }}"
{% if use_transformer -%}
accumulate_gradient = {{ transformer["size_factor"] }}
{% endif %}
dev_corpus = "corpora.dev"
train_corpus = "corpora.train"
[training.optimizer]
@optimizers = "Adam.v1"
{% if use_transformer -%}
[training.optimizer.learn_rate]
@schedules = "warmup_linear.v1"
@ -195,16 +208,6 @@ total_steps = 20000
initial_rate = 5e-5
{% endif %}
[training.train_corpus]
@readers = "spacy.Corpus.v1"
path = ${paths.train}
max_length = {{ 500 if hardware == "gpu" else 2000 }}
[training.dev_corpus]
@readers = "spacy.Corpus.v1"
path = ${paths.dev}
max_length = 0
{% if use_transformer %}
[training.batcher]
@batchers = "spacy.batch_by_padded.v1"

View File

@ -18,6 +18,7 @@ from ..language import Language
from .. import util
from ..training.example import Example
from ..errors import Errors
from ..util import dot_to_object
@app.command(
@ -92,8 +93,8 @@ def train(
raw_text, tag_map, morph_rules, weights_data = load_from_paths(config)
T_cfg = config["training"]
optimizer = T_cfg["optimizer"]
train_corpus = T_cfg["train_corpus"]
dev_corpus = T_cfg["dev_corpus"]
train_corpus = dot_to_object(config, T_cfg["train_corpus"])
dev_corpus = dot_to_object(config, T_cfg["dev_corpus"])
batcher = T_cfg["batcher"]
train_logger = T_cfg["logger"]
# Components that shouldn't be updated during training

View File

@ -22,6 +22,33 @@ after_pipeline_creation = null
[components]
# Readers for corpora like dev and train.
[corpora]
[corpora.train]
@readers = "spacy.Corpus.v1"
path = ${paths.train}
# Whether to train on sequences with 'gold standard' sentence boundaries
# and tokens. If you set this to true, take care to ensure your run-time
# data is passed in sentence-by-sentence via some prior preprocessing.
gold_preproc = false
# Limitations on training document length
max_length = 0
# Limitation on number of training examples
limit = 0
[corpora.dev]
@readers = "spacy.Corpus.v1"
path = ${paths.dev}
# Whether to train on sequences with 'gold standard' sentence boundaries
# and tokens. If you set this to true, take care to ensure your run-time
# data is passed in sentence-by-sentence via some prior preprocessing.
gold_preproc = false
# Limitations on training document length
max_length = 0
# Limitation on number of training examples
limit = 0
# Training hyper-parameters and additional features.
[training]
seed = ${system.seed}
@ -40,33 +67,14 @@ eval_frequency = 200
score_weights = {}
# Names of pipeline components that shouldn't be updated during training
frozen_components = []
# Location in the config where the dev corpus is defined
dev_corpus = "corpora.dev"
# Location in the config where the train corpus is defined
train_corpus = "corpora.train"
[training.logger]
@loggers = "spacy.ConsoleLogger.v1"
[training.train_corpus]
@readers = "spacy.Corpus.v1"
path = ${paths.train}
# Whether to train on sequences with 'gold standard' sentence boundaries
# and tokens. If you set this to true, take care to ensure your run-time
# data is passed in sentence-by-sentence via some prior preprocessing.
gold_preproc = false
# Limitations on training document length
max_length = 0
# Limitation on number of training examples
limit = 0
[training.dev_corpus]
@readers = "spacy.Corpus.v1"
path = ${paths.dev}
# Whether to train on sequences with 'gold standard' sentence boundaries
# and tokens. If you set this to true, take care to ensure your run-time
# data is passed in sentence-by-sentence via some prior preprocessing.
gold_preproc = false
# Limitations on training document length
max_length = 0
# Limitation on number of training examples
limit = 0
[training.batcher]
@batchers = "spacy.batch_by_words.v1"

View File

@ -4,6 +4,7 @@ dropout = 0.2
n_save_every = null
component = "tok2vec"
layer = ""
corpus = "corpora.pretrain"
[pretraining.batcher]
@batchers = "spacy.batch_by_words.v1"
@ -12,13 +13,6 @@ discard_oversize = false
tolerance = 0.2
get_length = null
[pretraining.corpus]
@readers = "spacy.JsonlReader.v1"
path = ${paths.raw}
min_length = 5
max_length = 500
limit = 0
[pretraining.objective]
type = "characters"
n_characters = 4
@ -33,3 +27,12 @@ grad_clip = 1.0
use_averages = true
eps = 1e-8
learn_rate = 0.001
[corpora]
[corpora.pretrain]
@readers = "spacy.JsonlReader.v1"
path = ${paths.raw}
min_length = 5
max_length = 500
limit = 0

View File

@ -181,9 +181,9 @@ class TextCategorizer(Pipe):
DOCS: https://nightly.spacy.io/api/textcategorizer#predict
"""
tensors = [doc.tensor for doc in docs]
if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs.
tensors = [doc.tensor for doc in docs]
xp = get_array_module(tensors)
scores = xp.zeros((len(docs), len(self.labels)))
return scores

View File

@ -104,7 +104,7 @@ class TokenPatternOperator(str, Enum):
StringValue = Union[TokenPatternString, StrictStr]
NumberValue = Union[TokenPatternNumber, StrictInt, StrictFloat]
UnderscoreValue = Union[
TokenPatternString, TokenPatternNumber, str, int, float, list, bool,
TokenPatternString, TokenPatternNumber, str, int, float, list, bool
]
@ -198,8 +198,8 @@ class ModelMetaSchema(BaseModel):
class ConfigSchemaTraining(BaseModel):
# fmt: off
vectors: Optional[StrictStr] = Field(..., title="Path to vectors")
train_corpus: Reader = Field(..., title="Reader for the training data")
dev_corpus: Reader = Field(..., title="Reader for the dev data")
dev_corpus: StrictStr = Field(..., title="Path in the config to the dev data")
train_corpus: StrictStr = Field(..., title="Path in the config to the training data")
batcher: Batcher = Field(..., title="Batcher for the training data")
dropout: StrictFloat = Field(..., title="Dropout rate")
patience: StrictInt = Field(..., title="How many steps to continue without improvement in evaluation score")
@ -249,7 +249,7 @@ class ConfigSchemaPretrain(BaseModel):
dropout: StrictFloat = Field(..., title="Dropout rate")
n_save_every: Optional[StrictInt] = Field(..., title="Saving frequency")
optimizer: Optimizer = Field(..., title="The optimizer to use")
corpus: Reader = Field(..., title="Reader for the training data")
corpus: StrictStr = Field(..., title="Path in the config to the training data")
batcher: Batcher = Field(..., title="Batcher for the training data")
component: str = Field(..., title="Component to find the layer to pretrain")
layer: str = Field(..., title="Layer to pretrain. Whole model if empty.")
@ -268,6 +268,7 @@ class ConfigSchema(BaseModel):
nlp: ConfigSchemaNlp
pretraining: Union[ConfigSchemaPretrain, ConfigSchemaPretrainEmpty] = {}
components: Dict[str, Dict[str, Any]]
corpora: Dict[str, Reader]
@root_validator(allow_reuse=True)
def validate_config(cls, values):

View File

@ -9,7 +9,7 @@ from spacy.tokens import Doc
from spacy.training import Example
from spacy import util
from spacy.lang.en import English
from .util import get_batch
from ..util import get_batch
from thinc.api import Config

View File

@ -17,16 +17,18 @@ nlp_config_string = """
train = ""
dev = ""
[training]
[corpora]
[training.train_corpus]
[corpora.train]
@readers = "spacy.Corpus.v1"
path = ${paths.train}
[training.dev_corpus]
[corpora.dev]
@readers = "spacy.Corpus.v1"
path = ${paths.dev}
[training]
[training.batcher]
@batchers = "spacy.batch_by_words.v1"
size = 666
@ -300,20 +302,20 @@ def test_config_overrides():
def test_config_interpolation():
config = Config().from_str(nlp_config_string, interpolate=False)
assert config["training"]["train_corpus"]["path"] == "${paths.train}"
assert config["corpora"]["train"]["path"] == "${paths.train}"
interpolated = config.interpolate()
assert interpolated["training"]["train_corpus"]["path"] == ""
assert interpolated["corpora"]["train"]["path"] == ""
nlp = English.from_config(config)
assert nlp.config["training"]["train_corpus"]["path"] == "${paths.train}"
assert nlp.config["corpora"]["train"]["path"] == "${paths.train}"
# Ensure that variables are preserved in nlp config
width = "${components.tok2vec.model.width}"
assert config["components"]["tagger"]["model"]["tok2vec"]["width"] == width
assert nlp.config["components"]["tagger"]["model"]["tok2vec"]["width"] == width
interpolated2 = nlp.config.interpolate()
assert interpolated2["training"]["train_corpus"]["path"] == ""
assert interpolated2["corpora"]["train"]["path"] == ""
assert interpolated2["components"]["tagger"]["model"]["tok2vec"]["width"] == 342
nlp2 = English.from_config(interpolated)
assert nlp2.config["training"]["train_corpus"]["path"] == ""
assert nlp2.config["corpora"]["train"]["path"] == ""
assert nlp2.config["components"]["tagger"]["model"]["tok2vec"]["width"] == 342

View File

View File

@ -0,0 +1,112 @@
from typing import Dict, Iterable, Callable
import pytest
from thinc.api import Config
from spacy import Language
from spacy.util import load_model_from_config, registry, dot_to_object
from spacy.training import Example
def test_readers():
config_string = """
[training]
[corpora]
@readers = "myreader.v1"
[nlp]
lang = "en"
pipeline = ["tok2vec", "textcat"]
[components]
[components.tok2vec]
factory = "tok2vec"
[components.textcat]
factory = "textcat"
"""
@registry.readers.register("myreader.v1")
def myreader() -> Dict[str, Callable[[Language, str], Iterable[Example]]]:
annots = {"cats": {"POS": 1.0, "NEG": 0.0}}
def reader(nlp: Language):
doc = nlp.make_doc(f"This is an example")
return [Example.from_dict(doc, annots)]
return {"train": reader, "dev": reader, "extra": reader, "something": reader}
config = Config().from_str(config_string)
nlp, resolved = load_model_from_config(config, auto_fill=True)
train_corpus = dot_to_object(resolved, resolved["training"]["train_corpus"])
assert isinstance(train_corpus, Callable)
optimizer = resolved["training"]["optimizer"]
# simulate a training loop
nlp.begin_training(lambda: train_corpus(nlp), sgd=optimizer)
for example in train_corpus(nlp):
nlp.update([example], sgd=optimizer)
dev_corpus = dot_to_object(resolved, resolved["training"]["dev_corpus"])
scores = nlp.evaluate(list(dev_corpus(nlp)))
assert scores["cats_score"]
# ensure the pipeline runs
doc = nlp("Quick test")
assert doc.cats
extra_corpus = resolved["corpora"]["extra"]
assert isinstance(extra_corpus, Callable)
@pytest.mark.slow
@pytest.mark.parametrize(
"reader,additional_config",
[
("ml_datasets.imdb_sentiment.v1", {"train_limit": 10, "dev_limit": 2}),
("ml_datasets.dbpedia.v1", {"train_limit": 10, "dev_limit": 2}),
("ml_datasets.cmu_movies.v1", {"limit": 10, "freq_cutoff": 200, "split": 0.8}),
],
)
def test_cat_readers(reader, additional_config):
nlp_config_string = """
[training]
[corpora]
@readers = "PLACEHOLDER"
[nlp]
lang = "en"
pipeline = ["tok2vec", "textcat"]
[components]
[components.tok2vec]
factory = "tok2vec"
[components.textcat]
factory = "textcat"
"""
config = Config().from_str(nlp_config_string)
config["corpora"]["@readers"] = reader
config["corpora"].update(additional_config)
nlp, resolved = load_model_from_config(config, auto_fill=True)
train_corpus = dot_to_object(resolved, resolved["training"]["train_corpus"])
optimizer = resolved["training"]["optimizer"]
# simulate a training loop
nlp.begin_training(lambda: train_corpus(nlp), sgd=optimizer)
for example in train_corpus(nlp):
assert example.y.cats
# this shouldn't fail if each training example has at least one positive label
assert sorted(list(set(example.y.cats.values()))) == [0.0, 1.0]
nlp.update([example], sgd=optimizer)
# simulate performance benchmark on dev corpus
dev_corpus = dot_to_object(resolved, resolved["training"]["dev_corpus"])
dev_examples = list(dev_corpus(nlp))
for example in dev_examples:
# this shouldn't fail if each dev example has at least one positive label
assert sorted(list(set(example.y.cats.values()))) == [0.0, 1.0]
scores = nlp.evaluate(dev_examples)
assert scores["cats_score"]
# ensure the pipeline runs
doc = nlp("Quick test")
assert doc.cats

View File

@ -12,7 +12,7 @@ from thinc.api import compounding
import pytest
import srsly
from .util import make_tempdir, get_doc
from ..util import make_tempdir, get_doc
@pytest.fixture

View File

@ -274,7 +274,7 @@ training -> dropout field required
training -> optimizer field required
training -> optimize extra fields not permitted
{'vectors': 'en_vectors_web_lg', 'seed': 0, 'accumulate_gradient': 1, 'init_tok2vec': None, 'raw_text': None, 'patience': 1600, 'max_epochs': 0, 'max_steps': 20000, 'eval_frequency': 200, 'frozen_components': [], 'optimize': None, 'batcher': {'@batchers': 'spacy.batch_by_words.v1', 'discard_oversize': False, 'tolerance': 0.2, 'get_length': None, 'size': {'@schedules': 'compounding.v1', 'start': 100, 'stop': 1000, 'compound': 1.001, 't': 0.0}}, 'dev_corpus': {'@readers': 'spacy.Corpus.v1', 'path': '', 'max_length': 0, 'gold_preproc': False, 'limit': 0}, 'score_weights': {'tag_acc': 0.5, 'dep_uas': 0.25, 'dep_las': 0.25, 'sents_f': 0.0}, 'train_corpus': {'@readers': 'spacy.Corpus.v1', 'path': '', 'max_length': 0, 'gold_preproc': False, 'limit': 0}}
{'vectors': 'en_vectors_web_lg', 'seed': 0, 'accumulate_gradient': 1, 'init_tok2vec': None, 'raw_text': None, 'patience': 1600, 'max_epochs': 0, 'max_steps': 20000, 'eval_frequency': 200, 'frozen_components': [], 'optimize': None, 'batcher': {'@batchers': 'spacy.batch_by_words.v1', 'discard_oversize': False, 'tolerance': 0.2, 'get_length': None, 'size': {'@schedules': 'compounding.v1', 'start': 100, 'stop': 1000, 'compound': 1.001, 't': 0.0}}, 'corpus': {'train': {'@readers': 'spacy.Corpus.v1', 'path': '', 'max_length': 0, 'gold_preproc': False, 'limit': 0}, 'dev': {'@readers': 'spacy.Corpus.v1', 'path': '', 'max_length': 0, 'gold_preproc': False, 'limit': 0}} 'score_weights': {'tag_acc': 0.5, 'dep_uas': 0.25, 'dep_las': 0.25, 'sents_f': 0.0}}
If your config contains missing values, you can run the 'init fill-config'
command to fill in all the defaults, if possible:
@ -357,6 +357,16 @@ Registry @architectures
Name spacy.MaxoutWindowEncoder.v1
Module spacy.ml.models.tok2vec
File /path/to/spacy/ml/models/tok2vec.py (line 207)
[corpora.dev]
Registry @readers
Name spacy.Corpus.v1
Module spacy.training.corpus
File /path/to/spacy/training/corpus.py (line 18)
[corpora.train]
Registry @readers
Name spacy.Corpus.v1
Module spacy.training.corpus
File /path/to/spacy/training/corpus.py (line 18)
[training.logger]
Registry @loggers
Name spacy.ConsoleLogger.v1
@ -372,11 +382,6 @@ Registry @schedules
Name compounding.v1
Module thinc.schedules
File /path/to/thinc/thinc/schedules.py (line 43)
[training.dev_corpus]
Registry @readers
Name spacy.Corpus.v1
Module spacy.training.corpus
File /path/to/spacy/training/corpus.py (line 18)
[training.optimizer]
Registry @optimizers
Name Adam.v1
@ -387,11 +392,6 @@ Registry @schedules
Name warmup_linear.v1
Module thinc.schedules
File /path/to/thinc/thinc/schedules.py (line 91)
[training.train_corpus]
Registry @readers
Name spacy.Corpus.v1
Module spacy.training.corpus
File /path/to/spacy/training/corpus.py (line 18)
```
</Accordion>

View File

@ -26,7 +26,7 @@ streaming.
> [paths]
> train = "corpus/train.spacy"
>
> [training.train_corpus]
> [corpora.train]
> @readers = "spacy.Corpus.v1"
> path = ${paths.train}
> gold_preproc = false
@ -135,7 +135,7 @@ Initialize the reader.
>
> ```ini
> ### Example config
> [pretraining.corpus]
> [corpora.pretrain]
> @readers = "spacy.JsonlReader.v1"
> path = "corpus/raw_text.jsonl"
> min_length = 0

View File

@ -121,6 +121,55 @@ that you don't want to hard-code in your config file.
$ python -m spacy train config.cfg --paths.train ./corpus/train.spacy
```
### corpora {#config-corpora tag="section"}
This section defines a dictionary mapping of string keys to `Callable`
functions. Each callable takes an `nlp` object and yields
[`Example`](/api/example) objects. By default, the two keys `train` and `dev`
are specified and each refer to a [`Corpus`](/api/top-level#Corpus). When
pretraining, an additional pretrain section is added that defaults to a
[`JsonlReader`](/api/top-level#JsonlReader).
These subsections can be expanded with additional subsections, each referring to
a callback of type `Callable[[Language], Iterator[Example]]`:
> #### Example
>
> ```ini
> [corpora]
> [corpora.train]
> @readers = "spacy.Corpus.v1"
> path = ${paths:train}
>
> [corpora.dev]
> @readers = "spacy.Corpus.v1"
> path = ${paths:dev}
>
> [corpora.pretrain]
> @readers = "spacy.JsonlReader.v1"
> path = ${paths.raw}
> min_length = 5
> max_length = 500
>
> [corpora.mydata]
> @readers = "my_reader.v1"
> shuffle = true
> ```
Alternatively, the `corpora` block could refer to one function with return type
`Dict[str, Callable[[Language], Iterator[Example]]]`:
> #### Example
>
> ```ini
> [corpora]
> @readers = "my_dict_reader.v1"
> train_path = ${paths:train}
> dev_path = ${paths:dev}
> shuffle = true
>
> ```
### training {#config-training tag="section"}
This section defines settings and controls for the training and evaluation
@ -130,7 +179,7 @@ process that are used when you run [`spacy train`](/api/cli#train).
| --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `accumulate_gradient` | Whether to divide the batch up into substeps. Defaults to `1`. ~~int~~ |
| `batcher` | Callable that takes an iterator of [`Doc`](/api/doc) objects and yields batches of `Doc`s. Defaults to [`batch_by_words`](/api/top-level#batch_by_words). ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ |
| `dev_corpus` | Callable that takes the current `nlp` object and yields [`Example`](/api/example) objects. Defaults to [`Corpus`](/api/top-level#Corpus). ~~Callable[[Language], Iterator[Example]]~~ |
| `dev_corpus` | Dot notation of the config location defining the dev corpus. Defaults to `corpora.dev`. ~~str~~ |
| `dropout` | The dropout rate. Defaults to `0.1`. ~~float~~ |
| `eval_frequency` | How often to evaluate during training (steps). Defaults to `200`. ~~int~~ |
| `frozen_components` | Pipeline component names that are "frozen" and shouldn't be updated during training. See [here](/usage/training#config-components) for details. Defaults to `[]`. ~~List[str]~~ |
@ -142,7 +191,7 @@ process that are used when you run [`spacy train`](/api/cli#train).
| `raw_text` | Optional path to a jsonl file with unlabelled text documents for a [rehearsal](/api/language#rehearse) step. Defaults to variable `${paths.raw}`. ~~Optional[str]~~ |
| `score_weights` | Score names shown in metrics mapped to their weight towards the final weighted score. See [here](/usage/training#metrics) for details. Defaults to `{}`. ~~Dict[str, float]~~ |
| `seed` | The random seed. Defaults to variable `${system.seed}`. ~~int~~ |
| `train_corpus` | Callable that takes the current `nlp` object and yields [`Example`](/api/example) objects. Defaults to [`Corpus`](/api/top-level#Corpus). ~~Callable[[Language], Iterator[Example]]~~ |
| `train_corpus` | Dot notation of the config location defining the train corpus. Defaults to `corpora.train`. ~~str~~ |
| `vectors` | Name or path of pipeline containing pretrained word vectors to use, e.g. created with [`init vocab`](/api/cli#init-vocab). Defaults to `null`. ~~Optional[str]~~ |
### pretraining {#config-pretraining tag="section,optional"}
@ -151,17 +200,18 @@ This section is optional and defines settings and controls for
[language model pretraining](/usage/embeddings-transformers#pretraining). It's
used when you run [`spacy pretrain`](/api/cli#pretrain).
| Name | Description |
| -------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `max_epochs` | Maximum number of epochs. Defaults to `1000`. ~~int~~ |
| `dropout` | The dropout rate. Defaults to `0.2`. ~~float~~ |
| `n_save_every` | Saving frequency. Defaults to `null`. ~~Optional[int]~~ |
| `objective` | The pretraining objective. Defaults to `{"type": "characters", "n_characters": 4}`. ~~Dict[str, Any]~~ |
| `optimizer` | The optimizer. Defaults to [`Adam`](https://thinc.ai/docs/api-optimizers#adam). ~~Optimizer~~ |
| `corpus` | Callable that takes the current `nlp` object and yields [`Doc`](/api/doc) objects. Defaults to [`JsonlReader`](/api/top-level#JsonlReader). ~~Callable[[Language, str], Iterable[Example]]~~ |
| `batcher` | Batcher for the training data. ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ |
| `component` | Component to find the layer to pretrain. Defaults to `"tok2vec"`. ~~str~~ |
| `layer` | The layer to pretrain. If empty, the whole component model will be used. ~~str~~ |
| Name | Description |
| -------------- | ------------------------------------------------------------------------------------------------------ |
| `max_epochs` | Maximum number of epochs. Defaults to `1000`. ~~int~~ |
| `dropout` | The dropout rate. Defaults to `0.2`. ~~float~~ |
| `n_save_every` | Saving frequency. Defaults to `null`. ~~Optional[int]~~ |
| `objective` | The pretraining objective. Defaults to `{"type": "characters", "n_characters": 4}`. ~~Dict[str, Any]~~ |
| `optimizer` | The optimizer. Defaults to [`Adam`](https://thinc.ai/docs/api-optimizers#adam). ~~Optimizer~~ |
| `corpus` | Dot notation of the config location defining the train corpus. Defaults to `corpora.pretrain`. ~~str~~ |
| `batcher` | Batcher for the training data. ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ |
| `component` | Component to find the layer to pretrain. Defaults to `"tok2vec"`. ~~str~~ |
| `layer` | The layer to pretrain. If empty, the whole component model will be used. ~~str~~ |
| |
## Training data {#training}

View File

@ -448,7 +448,7 @@ remain in the config file stored on your local system.
> [training.logger]
> @loggers = "spacy.WandbLogger.v1"
> project_name = "monitor_spacy_training"
> remove_config_values = ["paths.train", "paths.dev", "training.dev_corpus.path", "training.train_corpus.path"]
> remove_config_values = ["paths.train", "paths.dev", "corpora.train.path", "corpora.dev.path"]
> ```
| Name | Description |
@ -478,7 +478,7 @@ the [`Corpus`](/api/corpus) class.
> [paths]
> train = "corpus/train.spacy"
>
> [training.train_corpus]
> [corpora.train]
> @readers = "spacy.Corpus.v1"
> path = ${paths.train}
> gold_preproc = false
@ -506,7 +506,7 @@ JSONL file. Also see the [`JsonlReader`](/api/corpus#jsonlreader) class.
> [paths]
> pretrain = "corpus/raw_text.jsonl"
>
> [pretraining.corpus]
> [corpora.pretrain]
> @readers = "spacy.JsonlReader.v1"
> path = ${paths.pretrain}
> min_length = 0

View File

@ -969,7 +969,7 @@ your results.
> [training.logger]
> @loggers = "spacy.WandbLogger.v1"
> project_name = "monitor_spacy_training"
> remove_config_values = ["paths.train", "paths.dev", "training.dev_corpus.path", "training.train_corpus.path"]
> remove_config_values = ["paths.train", "paths.dev", "corpora.train.path", "corpora.dev.path"]
> ```
![Screenshot: Visualized training results](../images/wandb1.jpg)

View File

@ -746,7 +746,7 @@ as **config settings** in this case, `source`.
> #### config.cfg
>
> ```ini
> [training.train_corpus]
> [corpora.train]
> @readers = "corpus_variants.v1"
> source = "s3://your_bucket/path/data.csv"
> ```