Adjust to new Schedule class and pass scores to Optimizer (#12008)

* Adjust to new `Schedule` class and pass scores to `Optimizer`

Requires https://github.com/explosion/thinc/pull/804

* Bump minimum Thinc requirement to 9.0.0.dev1
This commit is contained in:
Daniël de Kok 2022-12-29 08:03:24 +01:00 committed by GitHub
parent d30ba9b7b8
commit 20b63943f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 25 additions and 18 deletions

View File

@ -5,7 +5,7 @@ requires = [
"cymem>=2.0.2,<2.1.0",
"preshed>=3.0.2,<3.1.0",
"murmurhash>=0.28.0,<1.1.0",
"thinc>=9.0.0.dev0,<9.1.0",
"thinc>=9.0.0.dev1,<9.1.0",
"numpy>=1.15.0",
]
build-backend = "setuptools.build_meta"

View File

@ -3,7 +3,7 @@ spacy-legacy>=3.0.10,<3.1.0
spacy-loggers>=1.0.0,<2.0.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
thinc>=9.0.0.dev0,<9.1.0
thinc>=9.0.0.dev1,<9.1.0
ml_datasets>=0.2.0,<0.3.0
murmurhash>=0.28.0,<1.1.0
wasabi>=0.9.1,<1.2.0

View File

@ -38,7 +38,7 @@ install_requires =
murmurhash>=0.28.0,<1.1.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
thinc>=9.0.0.dev0,<9.1.0
thinc>=9.0.0.dev1,<9.1.0
wasabi>=0.9.1,<1.2.0
srsly>=2.4.3,<3.0.0
catalogue>=2.0.6,<2.1.0

View File

@ -2,11 +2,12 @@ from typing import Union, Iterable, Sequence, TypeVar, List, Callable, Iterator
from typing import Optional, Any
from functools import partial
import itertools
from thinc.schedules import Schedule, constant as constant_schedule
from ..util import registry, minibatch
Sizing = Union[Sequence[int], int]
Sizing = Union[Sequence[int], int, Schedule[int]]
ItemT = TypeVar("ItemT")
BatcherT = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]]
@ -111,12 +112,13 @@ def minibatch_by_padded_size(
The `len` function is used by default.
"""
if isinstance(size, int):
size_ = itertools.repeat(size) # type: Iterator[int]
size_ = constant_schedule(size)
else:
size_ = iter(size)
for outer_batch in minibatch(seqs, size=buffer):
assert isinstance(size, Schedule)
size_ = size
for step, outer_batch in enumerate(minibatch(seqs, size=buffer)):
outer_batch = list(outer_batch)
target_size = next(size_)
target_size = size_(step)
for indices in _batch_by_length(outer_batch, target_size, get_length):
subbatch = [outer_batch[i] for i in indices]
padded_size = max(len(seq) for seq in subbatch) * len(subbatch)
@ -147,10 +149,12 @@ def minibatch_by_words(
item. The `len` function is used by default.
"""
if isinstance(size, int):
size_ = itertools.repeat(size) # type: Iterator[int]
size_ = constant_schedule(size)
else:
size_ = iter(size)
target_size = next(size_)
assert isinstance(size, Schedule)
size_ = size
step = 0
target_size = size_(step)
tol_size = target_size * tolerance
batch = []
overflow = []
@ -175,7 +179,8 @@ def minibatch_by_words(
else:
if batch:
yield batch
target_size = next(size_)
step += 1
target_size = size_(step)
tol_size = target_size * tolerance
batch = overflow
batch_size = overflow_size
@ -193,7 +198,8 @@ def minibatch_by_words(
else:
if batch:
yield batch
target_size = next(size_)
step += 1
target_size = size_(step)
tol_size = target_size * tolerance
batch = [seq]
batch_size = n_words

View File

@ -204,7 +204,7 @@ def train_while_improving(
if before_update:
before_update_args = {"step": step, "epoch": epoch}
before_update(nlp, before_update_args)
dropout = next(dropouts) # type: ignore
dropout = dropouts(optimizer.step) # type: ignore
for subbatch in subdivide_batch(batch, accumulate_gradient):
nlp.update(
subbatch,
@ -230,6 +230,7 @@ def train_while_improving(
score, other_scores = evaluate()
else:
score, other_scores = evaluate()
optimizer.last_score = score
results.append((score, step))
is_best_checkpoint = score == max(results)[0]
else:

View File

@ -9,7 +9,7 @@ import re
from pathlib import Path
import thinc
from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer
from thinc.api import ConfigValidationError, Model
from thinc.api import ConfigValidationError, Model, constant as constant_schedule
import functools
import itertools
import numpy
@ -1582,12 +1582,12 @@ def minibatch(items, size):
so that batch-size can vary on each step.
"""
if isinstance(size, int):
size_ = itertools.repeat(size)
size_ = constant_schedule(size)
else:
size_ = size
items = iter(items)
while True:
batch_size = next(size_)
for step in itertools.count():
batch_size = size_(step)
batch = list(itertools.islice(items, int(batch_size)))
if len(batch) == 0:
break