diff --git a/spacy/training/loop.py b/spacy/training/loop.py index 7f67aa2cf..71fcc65cb 100644 --- a/spacy/training/loop.py +++ b/spacy/training/loop.py @@ -511,12 +511,15 @@ def train_while_improving( break -def subdivide_batch(batch, accumulate_gradient): +def subdivide_batch( + batch: Union[Iterable[Doc], Iterable[Example]], accumulate_gradient +): batch = list(batch) - if isinstance(batch, Example): - batch.sort(key=lambda eg: len(eg.predicted)) - elif isinstance(batch, Doc): - batch.sort(key=lambda doc: len(doc)) + if len(batch): + if isinstance(batch[0], Example): + batch.sort(key=lambda eg: len(eg.predicted)) + else: + batch.sort(key=lambda doc: len(doc)) sub_len = len(batch) // accumulate_gradient start = 0 for i in range(accumulate_gradient):