Correctly determine sort key in subdivide_batch

This commit is contained in:
Daniël de Kok 2023-04-19 20:07:37 +02:00
parent add1a21657
commit 962c2972e4

View File

@ -511,11 +511,14 @@ def train_while_improving(
break break
def subdivide_batch(batch, accumulate_gradient): def subdivide_batch(
batch: Union[Iterable[Doc], Iterable[Example]], accumulate_gradient
):
batch = list(batch) batch = list(batch)
if isinstance(batch, Example): if len(batch):
if isinstance(batch[0], Example):
batch.sort(key=lambda eg: len(eg.predicted)) batch.sort(key=lambda eg: len(eg.predicted))
elif isinstance(batch, Doc): else:
batch.sort(key=lambda doc: len(doc)) batch.sort(key=lambda doc: len(doc))
sub_len = len(batch) // accumulate_gradient sub_len = len(batch) // accumulate_gradient
start = 0 start = 0