Shuffle on first epoch of train

This commit is contained in:
Matthw Honnibal 2020-08-31 19:55:22 +02:00
parent ec14744ee4
commit fe298fa50a

View File

@ -186,18 +186,12 @@ def train(
def create_train_batches(iterator, batcher, max_epochs: int): def create_train_batches(iterator, batcher, max_epochs: int):
epoch = 1 epoch = 0
examples = [] examples = list(iterator)
# Stream the first epoch, so we start training faster and support
# infinite streams.
for batch in batcher(iterator):
yield epoch, batch
if max_epochs != 1:
examples.extend(batch)
if not examples: if not examples:
# Raise error if no data # Raise error if no data
raise ValueError(Errors.E986) raise ValueError(Errors.E986)
while epoch != max_epochs: while max_epochs < 1 or epoch != max_epochs:
random.shuffle(examples) random.shuffle(examples)
for batch in batcher(examples): for batch in batcher(examples):
yield epoch, batch yield epoch, batch