mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-22 01:51:58 +03:00
Correctly determine sort key in subdivide_batch
This commit is contained in:
parent
add1a21657
commit
962c2972e4
|
@ -511,12 +511,15 @@ def train_while_improving(
|
|||
break
|
||||
|
||||
|
||||
def subdivide_batch(batch, accumulate_gradient):
|
||||
def subdivide_batch(
|
||||
batch: Union[Iterable[Doc], Iterable[Example]], accumulate_gradient
|
||||
):
|
||||
batch = list(batch)
|
||||
if isinstance(batch, Example):
|
||||
batch.sort(key=lambda eg: len(eg.predicted))
|
||||
elif isinstance(batch, Doc):
|
||||
batch.sort(key=lambda doc: len(doc))
|
||||
if len(batch):
|
||||
if isinstance(batch[0], Example):
|
||||
batch.sort(key=lambda eg: len(eg.predicted))
|
||||
else:
|
||||
batch.sort(key=lambda doc: len(doc))
|
||||
sub_len = len(batch) // accumulate_gradient
|
||||
start = 0
|
||||
for i in range(accumulate_gradient):
|
||||
|
|
Loading…
Reference in New Issue
Block a user