mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 16:22:29 +03:00
Correctly determine sort key in subdivide_batch
This commit is contained in:
parent
add1a21657
commit
962c2972e4
|
@ -511,11 +511,14 @@ def train_while_improving(
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
def subdivide_batch(batch, accumulate_gradient):
|
def subdivide_batch(
|
||||||
|
batch: Union[Iterable[Doc], Iterable[Example]], accumulate_gradient
|
||||||
|
):
|
||||||
batch = list(batch)
|
batch = list(batch)
|
||||||
if isinstance(batch, Example):
|
if len(batch):
|
||||||
|
if isinstance(batch[0], Example):
|
||||||
batch.sort(key=lambda eg: len(eg.predicted))
|
batch.sort(key=lambda eg: len(eg.predicted))
|
||||||
elif isinstance(batch, Doc):
|
else:
|
||||||
batch.sort(key=lambda doc: len(doc))
|
batch.sort(key=lambda doc: len(doc))
|
||||||
sub_len = len(batch) // accumulate_gradient
|
sub_len = len(batch) // accumulate_gradient
|
||||||
start = 0
|
start = 0
|
||||||
|
|
Loading…
Reference in New Issue
Block a user