From 962c2972e465ce882e7a1f814c31a432bfcaae17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Wed, 19 Apr 2023 20:07:37 +0200 Subject: [PATCH] Correctly determine sort key in subdivide_batch --- spacy/training/loop.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/spacy/training/loop.py b/spacy/training/loop.py index 7f67aa2cf..71fcc65cb 100644 --- a/spacy/training/loop.py +++ b/spacy/training/loop.py @@ -511,12 +511,15 @@ def train_while_improving( break -def subdivide_batch(batch, accumulate_gradient): +def subdivide_batch( + batch: Union[Iterable[Doc], Iterable[Example]], accumulate_gradient +): batch = list(batch) - if isinstance(batch, Example): - batch.sort(key=lambda eg: len(eg.predicted)) - elif isinstance(batch, Doc): - batch.sort(key=lambda doc: len(doc)) + if len(batch): + if isinstance(batch[0], Example): + batch.sort(key=lambda eg: len(eg.predicted)) + else: + batch.sort(key=lambda doc: len(doc)) sub_len = len(batch) // accumulate_gradient start = 0 for i in range(accumulate_gradient):