Adjust to new Schedule class and pass scores to Optimizer (#12008)

* Adjust to new `Schedule` class and pass scores to `Optimizer`

Requires https://github.com/explosion/thinc/pull/804

* Bump minimum Thinc requirement to 9.0.0.dev1
This commit is contained in:
Daniël de Kok 2022-12-29 08:03:24 +01:00 committed by GitHub
parent d30ba9b7b8
commit 20b63943f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 25 additions and 18 deletions

View File

@ -5,7 +5,7 @@ requires = [
"cymem>=2.0.2,<2.1.0", "cymem>=2.0.2,<2.1.0",
"preshed>=3.0.2,<3.1.0", "preshed>=3.0.2,<3.1.0",
"murmurhash>=0.28.0,<1.1.0", "murmurhash>=0.28.0,<1.1.0",
"thinc>=9.0.0.dev0,<9.1.0", "thinc>=9.0.0.dev1,<9.1.0",
"numpy>=1.15.0", "numpy>=1.15.0",
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"

View File

@ -3,7 +3,7 @@ spacy-legacy>=3.0.10,<3.1.0
spacy-loggers>=1.0.0,<2.0.0 spacy-loggers>=1.0.0,<2.0.0
cymem>=2.0.2,<2.1.0 cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0 preshed>=3.0.2,<3.1.0
thinc>=9.0.0.dev0,<9.1.0 thinc>=9.0.0.dev1,<9.1.0
ml_datasets>=0.2.0,<0.3.0 ml_datasets>=0.2.0,<0.3.0
murmurhash>=0.28.0,<1.1.0 murmurhash>=0.28.0,<1.1.0
wasabi>=0.9.1,<1.2.0 wasabi>=0.9.1,<1.2.0

View File

@ -38,7 +38,7 @@ install_requires =
murmurhash>=0.28.0,<1.1.0 murmurhash>=0.28.0,<1.1.0
cymem>=2.0.2,<2.1.0 cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0 preshed>=3.0.2,<3.1.0
thinc>=9.0.0.dev0,<9.1.0 thinc>=9.0.0.dev1,<9.1.0
wasabi>=0.9.1,<1.2.0 wasabi>=0.9.1,<1.2.0
srsly>=2.4.3,<3.0.0 srsly>=2.4.3,<3.0.0
catalogue>=2.0.6,<2.1.0 catalogue>=2.0.6,<2.1.0

View File

@ -2,11 +2,12 @@ from typing import Union, Iterable, Sequence, TypeVar, List, Callable, Iterator
from typing import Optional, Any from typing import Optional, Any
from functools import partial from functools import partial
import itertools import itertools
from thinc.schedules import Schedule, constant as constant_schedule
from ..util import registry, minibatch from ..util import registry, minibatch
Sizing = Union[Sequence[int], int] Sizing = Union[Sequence[int], int, Schedule[int]]
ItemT = TypeVar("ItemT") ItemT = TypeVar("ItemT")
BatcherT = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]] BatcherT = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]]
@ -111,12 +112,13 @@ def minibatch_by_padded_size(
The `len` function is used by default. The `len` function is used by default.
""" """
if isinstance(size, int): if isinstance(size, int):
size_ = itertools.repeat(size) # type: Iterator[int] size_ = constant_schedule(size)
else: else:
size_ = iter(size) assert isinstance(size, Schedule)
for outer_batch in minibatch(seqs, size=buffer): size_ = size
for step, outer_batch in enumerate(minibatch(seqs, size=buffer)):
outer_batch = list(outer_batch) outer_batch = list(outer_batch)
target_size = next(size_) target_size = size_(step)
for indices in _batch_by_length(outer_batch, target_size, get_length): for indices in _batch_by_length(outer_batch, target_size, get_length):
subbatch = [outer_batch[i] for i in indices] subbatch = [outer_batch[i] for i in indices]
padded_size = max(len(seq) for seq in subbatch) * len(subbatch) padded_size = max(len(seq) for seq in subbatch) * len(subbatch)
@ -147,10 +149,12 @@ def minibatch_by_words(
item. The `len` function is used by default. item. The `len` function is used by default.
""" """
if isinstance(size, int): if isinstance(size, int):
size_ = itertools.repeat(size) # type: Iterator[int] size_ = constant_schedule(size)
else: else:
size_ = iter(size) assert isinstance(size, Schedule)
target_size = next(size_) size_ = size
step = 0
target_size = size_(step)
tol_size = target_size * tolerance tol_size = target_size * tolerance
batch = [] batch = []
overflow = [] overflow = []
@ -175,7 +179,8 @@ def minibatch_by_words(
else: else:
if batch: if batch:
yield batch yield batch
target_size = next(size_) step += 1
target_size = size_(step)
tol_size = target_size * tolerance tol_size = target_size * tolerance
batch = overflow batch = overflow
batch_size = overflow_size batch_size = overflow_size
@ -193,7 +198,8 @@ def minibatch_by_words(
else: else:
if batch: if batch:
yield batch yield batch
target_size = next(size_) step += 1
target_size = size_(step)
tol_size = target_size * tolerance tol_size = target_size * tolerance
batch = [seq] batch = [seq]
batch_size = n_words batch_size = n_words

View File

@ -204,7 +204,7 @@ def train_while_improving(
if before_update: if before_update:
before_update_args = {"step": step, "epoch": epoch} before_update_args = {"step": step, "epoch": epoch}
before_update(nlp, before_update_args) before_update(nlp, before_update_args)
dropout = next(dropouts) # type: ignore dropout = dropouts(optimizer.step) # type: ignore
for subbatch in subdivide_batch(batch, accumulate_gradient): for subbatch in subdivide_batch(batch, accumulate_gradient):
nlp.update( nlp.update(
subbatch, subbatch,
@ -230,6 +230,7 @@ def train_while_improving(
score, other_scores = evaluate() score, other_scores = evaluate()
else: else:
score, other_scores = evaluate() score, other_scores = evaluate()
optimizer.last_score = score
results.append((score, step)) results.append((score, step))
is_best_checkpoint = score == max(results)[0] is_best_checkpoint = score == max(results)[0]
else: else:

View File

@ -9,7 +9,7 @@ import re
from pathlib import Path from pathlib import Path
import thinc import thinc
from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer
from thinc.api import ConfigValidationError, Model from thinc.api import ConfigValidationError, Model, constant as constant_schedule
import functools import functools
import itertools import itertools
import numpy import numpy
@ -1582,12 +1582,12 @@ def minibatch(items, size):
so that batch-size can vary on each step. so that batch-size can vary on each step.
""" """
if isinstance(size, int): if isinstance(size, int):
size_ = itertools.repeat(size) size_ = constant_schedule(size)
else: else:
size_ = size size_ = size
items = iter(items) items = iter(items)
while True: for step in itertools.count():
batch_size = next(size_) batch_size = size_(step)
batch = list(itertools.islice(items, int(batch_size))) batch = list(itertools.islice(items, int(batch_size)))
if len(batch) == 0: if len(batch) == 0:
break break