mirror of
https://github.com/explosion/spaCy.git
synced 2024-09-21 03:19:13 +03:00
Support large/infinite training corpora (#7208)
* Support infinite generators for training corpora Support a training corpus with an infinite generator in the `spacy train` training loop: * Revert `create_train_batches` to the state where an infinite generator can be used as the in the first epoch of exactly one epoch without resulting in a memory leak (`max_epochs != 1` will still result in a memory leak) * Move the shuffling for the first epoch into the corpus reader, renaming it to `spacy.Corpus.v2`. * Switch to training option for shuffling in memory Training loop: * Add option `training.shuffle_train_corpus_in_memory` that controls whether the corpus is loaded in memory once and shuffled in the training loop * Revert changes to `create_train_batches` and rename to `create_train_batches_with_shuffling` for use with `spacy.Corpus.v1` and a corpus that should be loaded in memory * Add `create_train_batches_without_shuffling` for a corpus that should not be shuffled in the training loop: the corpus is merely batched during training Corpus readers: * Restore `spacy.Corpus.v1` * Add `spacy.ShuffledCorpus.v1` for a corpus shuffled in memory in the reader instead of the training loop * In combination with `shuffle_train_corpus_in_memory = False`, each epoch could result in a different augmentation * Refactor create_train_batches, validation * Rename config setting to `training.shuffle_train_corpus` * Refactor to use a single `create_train_batches` method with a `shuffle` option * Only validate `get_examples` in initialize step if: * labels are required * labels are not provided * Switch back to max_epochs=-1 for streaming train corpus * Use first 100 examples for stream train corpus init * Always check validate_get_examples in initialize
This commit is contained in:
parent
81fd595223
commit
ff84075839
|
@ -70,6 +70,9 @@ dropout = 0.1
|
||||||
accumulate_gradient = 1
|
accumulate_gradient = 1
|
||||||
# Controls early-stopping. 0 disables early stopping.
|
# Controls early-stopping. 0 disables early stopping.
|
||||||
patience = 1600
|
patience = 1600
|
||||||
|
# Number of epochs. 0 means unlimited. If >= 0, train corpus is loaded once in
|
||||||
|
# memory and shuffled within the training loop. -1 means stream train corpus
|
||||||
|
# rather than loading in memory with no shuffling within the training loop.
|
||||||
max_epochs = 0
|
max_epochs = 0
|
||||||
max_steps = 20000
|
max_steps = 20000
|
||||||
eval_frequency = 200
|
eval_frequency = 200
|
||||||
|
|
|
@ -2,6 +2,7 @@ import warnings
|
||||||
from typing import Union, List, Iterable, Iterator, TYPE_CHECKING, Callable
|
from typing import Union, List, Iterable, Iterator, TYPE_CHECKING, Callable
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import random
|
||||||
import srsly
|
import srsly
|
||||||
|
|
||||||
from .. import util
|
from .. import util
|
||||||
|
@ -96,6 +97,7 @@ class Corpus:
|
||||||
Defaults to 0, which indicates no limit.
|
Defaults to 0, which indicates no limit.
|
||||||
augment (Callable[Example, Iterable[Example]]): Optional data augmentation
|
augment (Callable[Example, Iterable[Example]]): Optional data augmentation
|
||||||
function, to extrapolate additional examples from your annotations.
|
function, to extrapolate additional examples from your annotations.
|
||||||
|
shuffle (bool): Whether to shuffle the examples.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/corpus
|
DOCS: https://spacy.io/api/corpus
|
||||||
"""
|
"""
|
||||||
|
@ -108,12 +110,14 @@ class Corpus:
|
||||||
gold_preproc: bool = False,
|
gold_preproc: bool = False,
|
||||||
max_length: int = 0,
|
max_length: int = 0,
|
||||||
augmenter: Optional[Callable] = None,
|
augmenter: Optional[Callable] = None,
|
||||||
|
shuffle: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.path = util.ensure_path(path)
|
self.path = util.ensure_path(path)
|
||||||
self.gold_preproc = gold_preproc
|
self.gold_preproc = gold_preproc
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.limit = limit
|
self.limit = limit
|
||||||
self.augmenter = augmenter if augmenter is not None else dont_augment
|
self.augmenter = augmenter if augmenter is not None else dont_augment
|
||||||
|
self.shuffle = shuffle
|
||||||
|
|
||||||
def __call__(self, nlp: "Language") -> Iterator[Example]:
|
def __call__(self, nlp: "Language") -> Iterator[Example]:
|
||||||
"""Yield examples from the data.
|
"""Yield examples from the data.
|
||||||
|
@ -124,6 +128,10 @@ class Corpus:
|
||||||
DOCS: https://spacy.io/api/corpus#call
|
DOCS: https://spacy.io/api/corpus#call
|
||||||
"""
|
"""
|
||||||
ref_docs = self.read_docbin(nlp.vocab, walk_corpus(self.path, FILE_TYPE))
|
ref_docs = self.read_docbin(nlp.vocab, walk_corpus(self.path, FILE_TYPE))
|
||||||
|
if self.shuffle:
|
||||||
|
ref_docs = list(ref_docs)
|
||||||
|
random.shuffle(ref_docs)
|
||||||
|
|
||||||
if self.gold_preproc:
|
if self.gold_preproc:
|
||||||
examples = self.make_examples_gold_preproc(nlp, ref_docs)
|
examples = self.make_examples_gold_preproc(nlp, ref_docs)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -8,6 +8,7 @@ import tarfile
|
||||||
import gzip
|
import gzip
|
||||||
import zipfile
|
import zipfile
|
||||||
import tqdm
|
import tqdm
|
||||||
|
from itertools import islice
|
||||||
|
|
||||||
from .pretrain import get_tok2vec_ref
|
from .pretrain import get_tok2vec_ref
|
||||||
from ..lookups import Lookups
|
from ..lookups import Lookups
|
||||||
|
@ -68,7 +69,11 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
||||||
# Make sure that listeners are defined before initializing further
|
# Make sure that listeners are defined before initializing further
|
||||||
nlp._link_components()
|
nlp._link_components()
|
||||||
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
|
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
|
||||||
nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer)
|
if T["max_epochs"] == -1:
|
||||||
|
logger.debug("Due to streamed train corpus, using only first 100 examples for initialization. If necessary, provide all labels in [initialize]. More info: https://spacy.io/api/cli#init_labels")
|
||||||
|
nlp.initialize(lambda: islice(train_corpus(nlp), 100), sgd=optimizer)
|
||||||
|
else:
|
||||||
|
nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer)
|
||||||
logger.info(f"Initialized pipeline components: {nlp.pipe_names}")
|
logger.info(f"Initialized pipeline components: {nlp.pipe_names}")
|
||||||
# Detect components with listeners that are not frozen consistently
|
# Detect components with listeners that are not frozen consistently
|
||||||
for name, proc in nlp.pipeline:
|
for name, proc in nlp.pipeline:
|
||||||
|
|
|
@ -78,7 +78,7 @@ def train(
|
||||||
training_step_iterator = train_while_improving(
|
training_step_iterator = train_while_improving(
|
||||||
nlp,
|
nlp,
|
||||||
optimizer,
|
optimizer,
|
||||||
create_train_batches(train_corpus(nlp), batcher, T["max_epochs"]),
|
create_train_batches(nlp, train_corpus, batcher, T["max_epochs"]),
|
||||||
create_evaluation_callback(nlp, dev_corpus, score_weights),
|
create_evaluation_callback(nlp, dev_corpus, score_weights),
|
||||||
dropout=T["dropout"],
|
dropout=T["dropout"],
|
||||||
accumulate_gradient=T["accumulate_gradient"],
|
accumulate_gradient=T["accumulate_gradient"],
|
||||||
|
@ -290,17 +290,22 @@ def create_evaluation_callback(
|
||||||
|
|
||||||
|
|
||||||
def create_train_batches(
|
def create_train_batches(
|
||||||
iterator: Iterator[Example],
|
nlp: "Language",
|
||||||
|
corpus: Callable[["Language"], Iterable[Example]],
|
||||||
batcher: Callable[[Iterable[Example]], Iterable[Example]],
|
batcher: Callable[[Iterable[Example]], Iterable[Example]],
|
||||||
max_epochs: int,
|
max_epochs: int,
|
||||||
):
|
):
|
||||||
epoch = 0
|
epoch = 0
|
||||||
examples = list(iterator)
|
if max_epochs >= 0:
|
||||||
if not examples:
|
examples = list(corpus(nlp))
|
||||||
# Raise error if no data
|
if not examples:
|
||||||
raise ValueError(Errors.E986)
|
# Raise error if no data
|
||||||
|
raise ValueError(Errors.E986)
|
||||||
while max_epochs < 1 or epoch != max_epochs:
|
while max_epochs < 1 or epoch != max_epochs:
|
||||||
random.shuffle(examples)
|
if max_epochs >= 0:
|
||||||
|
random.shuffle(examples)
|
||||||
|
else:
|
||||||
|
examples = corpus(nlp)
|
||||||
for batch in batcher(examples):
|
for batch in batcher(examples):
|
||||||
yield epoch, batch
|
yield epoch, batch
|
||||||
epoch += 1
|
epoch += 1
|
||||||
|
|
Loading…
Reference in New Issue
Block a user