Correctly determine sort key in subdivide_batch

This commit is contained in:
Daniël de Kok 2023-04-19 20:07:37 +02:00
parent add1a21657
commit 962c2972e4

View File

@ -511,12 +511,15 @@ def train_while_improving(
break
def subdivide_batch(batch, accumulate_gradient):
def subdivide_batch(
batch: Union[Iterable[Doc], Iterable[Example]], accumulate_gradient
):
batch = list(batch)
if isinstance(batch, Example):
batch.sort(key=lambda eg: len(eg.predicted))
elif isinstance(batch, Doc):
batch.sort(key=lambda doc: len(doc))
if len(batch):
if isinstance(batch[0], Example):
batch.sort(key=lambda eg: len(eg.predicted))
else:
batch.sort(key=lambda doc: len(doc))
sub_len = len(batch) // accumulate_gradient
start = 0
for i in range(accumulate_gradient):