Adjust to new Schedule class and pass scores to Optimizer

Requires https://github.com/explosion/thinc/pull/804
This commit is contained in:
Daniël de Kok 2022-12-21 10:14:39 +01:00
parent f9308aae13
commit f62cd9782e
3 changed files with 22 additions and 15 deletions

View File

@ -2,11 +2,12 @@ from typing import Union, Iterable, Sequence, TypeVar, List, Callable, Iterator
from typing import Optional, Any
from functools import partial
import itertools
from thinc.schedules import Schedule, constant as constant_schedule
from ..util import registry, minibatch
Sizing = Union[Sequence[int], int]
Sizing = Union[Sequence[int], int, Schedule[int]]
ItemT = TypeVar("ItemT")
BatcherT = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]]
@ -111,12 +112,13 @@ def minibatch_by_padded_size(
The `len` function is used by default.
"""
if isinstance(size, int):
size_ = itertools.repeat(size) # type: Iterator[int]
size_ = constant_schedule(size)
else:
size_ = iter(size)
for outer_batch in minibatch(seqs, size=buffer):
assert isinstance(size, Schedule)
size_ = size
for step, outer_batch in enumerate(minibatch(seqs, size=buffer)):
outer_batch = list(outer_batch)
target_size = next(size_)
target_size = size_(step)
for indices in _batch_by_length(outer_batch, target_size, get_length):
subbatch = [outer_batch[i] for i in indices]
padded_size = max(len(seq) for seq in subbatch) * len(subbatch)
@ -147,10 +149,12 @@ def minibatch_by_words(
item. The `len` function is used by default.
"""
if isinstance(size, int):
size_ = itertools.repeat(size) # type: Iterator[int]
size_ = constant_schedule(size)
else:
size_ = iter(size)
target_size = next(size_)
assert isinstance(size, Schedule)
size_ = size
step = 0
target_size = size_(step)
tol_size = target_size * tolerance
batch = []
overflow = []
@ -175,7 +179,8 @@ def minibatch_by_words(
else:
if batch:
yield batch
target_size = next(size_)
step += 1
target_size = size_(step)
tol_size = target_size * tolerance
batch = overflow
batch_size = overflow_size
@ -193,7 +198,8 @@ def minibatch_by_words(
else:
if batch:
yield batch
target_size = next(size_)
step += 1
target_size = size_(step)
tol_size = target_size * tolerance
batch = [seq]
batch_size = n_words

View File

@ -204,7 +204,7 @@ def train_while_improving(
if before_update:
before_update_args = {"step": step, "epoch": epoch}
before_update(nlp, before_update_args)
dropout = next(dropouts) # type: ignore
dropout = dropouts(optimizer.step) # type: ignore
for subbatch in subdivide_batch(batch, accumulate_gradient):
nlp.update(
subbatch,
@ -230,6 +230,7 @@ def train_while_improving(
score, other_scores = evaluate()
else:
score, other_scores = evaluate()
optimizer.last_score = score
results.append((score, step))
is_best_checkpoint = score == max(results)[0]
else:

View File

@ -9,7 +9,7 @@ import re
from pathlib import Path
import thinc
from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer
from thinc.api import ConfigValidationError, Model
from thinc.api import ConfigValidationError, Model, constant as constant_schedule
import functools
import itertools
import numpy
@ -1582,12 +1582,12 @@ def minibatch(items, size):
so that batch-size can vary on each step.
"""
if isinstance(size, int):
size_ = itertools.repeat(size)
size_ = constant_schedule(size)
else:
size_ = size
items = iter(items)
while True:
batch_size = next(size_)
for step in itertools.count():
batch_size = size_(step)
batch = list(itertools.islice(items, int(batch_size)))
if len(batch) == 0:
break