mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 18:06:29 +03:00
Shuffle on first epoch of train
This commit is contained in:
parent
ec14744ee4
commit
fe298fa50a
|
@ -186,18 +186,12 @@ def train(
|
||||||
|
|
||||||
|
|
||||||
def create_train_batches(iterator, batcher, max_epochs: int):
|
def create_train_batches(iterator, batcher, max_epochs: int):
|
||||||
epoch = 1
|
epoch = 0
|
||||||
examples = []
|
examples = list(iterator)
|
||||||
# Stream the first epoch, so we start training faster and support
|
|
||||||
# infinite streams.
|
|
||||||
for batch in batcher(iterator):
|
|
||||||
yield epoch, batch
|
|
||||||
if max_epochs != 1:
|
|
||||||
examples.extend(batch)
|
|
||||||
if not examples:
|
if not examples:
|
||||||
# Raise error if no data
|
# Raise error if no data
|
||||||
raise ValueError(Errors.E986)
|
raise ValueError(Errors.E986)
|
||||||
while epoch != max_epochs:
|
while max_epochs < 1 or epoch != max_epochs:
|
||||||
random.shuffle(examples)
|
random.shuffle(examples)
|
||||||
for batch in batcher(examples):
|
for batch in batcher(examples):
|
||||||
yield epoch, batch
|
yield epoch, batch
|
||||||
|
|
Loading…
Reference in New Issue
Block a user