Shuffle on first epoch of train

This commit is contained in:
Matthw Honnibal 2020-08-31 19:55:22 +02:00
parent ec14744ee4
commit fe298fa50a

View File

@ -186,18 +186,12 @@ def train(
def create_train_batches(iterator, batcher, max_epochs: int):
epoch = 1
examples = []
# Stream the first epoch, so we start training faster and support
# infinite streams.
for batch in batcher(iterator):
yield epoch, batch
if max_epochs != 1:
examples.extend(batch)
epoch = 0
examples = list(iterator)
if not examples:
# Raise error if no data
raise ValueError(Errors.E986)
while epoch != max_epochs:
while max_epochs < 1 or epoch != max_epochs:
random.shuffle(examples)
for batch in batcher(examples):
yield epoch, batch