mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-23 15:54:13 +03:00
Adjust to new Schedule
class and pass scores to Optimizer
(#12008)
* Adjust to new `Schedule` class and pass scores to `Optimizer` Requires https://github.com/explosion/thinc/pull/804 * Bump minimum Thinc requirement to 9.0.0.dev1
This commit is contained in:
parent
d30ba9b7b8
commit
20b63943f5
|
@ -5,7 +5,7 @@ requires = [
|
|||
"cymem>=2.0.2,<2.1.0",
|
||||
"preshed>=3.0.2,<3.1.0",
|
||||
"murmurhash>=0.28.0,<1.1.0",
|
||||
"thinc>=9.0.0.dev0,<9.1.0",
|
||||
"thinc>=9.0.0.dev1,<9.1.0",
|
||||
"numpy>=1.15.0",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
|
|
@ -3,7 +3,7 @@ spacy-legacy>=3.0.10,<3.1.0
|
|||
spacy-loggers>=1.0.0,<2.0.0
|
||||
cymem>=2.0.2,<2.1.0
|
||||
preshed>=3.0.2,<3.1.0
|
||||
thinc>=9.0.0.dev0,<9.1.0
|
||||
thinc>=9.0.0.dev1,<9.1.0
|
||||
ml_datasets>=0.2.0,<0.3.0
|
||||
murmurhash>=0.28.0,<1.1.0
|
||||
wasabi>=0.9.1,<1.2.0
|
||||
|
|
|
@ -38,7 +38,7 @@ install_requires =
|
|||
murmurhash>=0.28.0,<1.1.0
|
||||
cymem>=2.0.2,<2.1.0
|
||||
preshed>=3.0.2,<3.1.0
|
||||
thinc>=9.0.0.dev0,<9.1.0
|
||||
thinc>=9.0.0.dev1,<9.1.0
|
||||
wasabi>=0.9.1,<1.2.0
|
||||
srsly>=2.4.3,<3.0.0
|
||||
catalogue>=2.0.6,<2.1.0
|
||||
|
|
|
@ -2,11 +2,12 @@ from typing import Union, Iterable, Sequence, TypeVar, List, Callable, Iterator
|
|||
from typing import Optional, Any
|
||||
from functools import partial
|
||||
import itertools
|
||||
from thinc.schedules import Schedule, constant as constant_schedule
|
||||
|
||||
from ..util import registry, minibatch
|
||||
|
||||
|
||||
Sizing = Union[Sequence[int], int]
|
||||
Sizing = Union[Sequence[int], int, Schedule[int]]
|
||||
ItemT = TypeVar("ItemT")
|
||||
BatcherT = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]]
|
||||
|
||||
|
@ -111,12 +112,13 @@ def minibatch_by_padded_size(
|
|||
The `len` function is used by default.
|
||||
"""
|
||||
if isinstance(size, int):
|
||||
size_ = itertools.repeat(size) # type: Iterator[int]
|
||||
size_ = constant_schedule(size)
|
||||
else:
|
||||
size_ = iter(size)
|
||||
for outer_batch in minibatch(seqs, size=buffer):
|
||||
assert isinstance(size, Schedule)
|
||||
size_ = size
|
||||
for step, outer_batch in enumerate(minibatch(seqs, size=buffer)):
|
||||
outer_batch = list(outer_batch)
|
||||
target_size = next(size_)
|
||||
target_size = size_(step)
|
||||
for indices in _batch_by_length(outer_batch, target_size, get_length):
|
||||
subbatch = [outer_batch[i] for i in indices]
|
||||
padded_size = max(len(seq) for seq in subbatch) * len(subbatch)
|
||||
|
@ -147,10 +149,12 @@ def minibatch_by_words(
|
|||
item. The `len` function is used by default.
|
||||
"""
|
||||
if isinstance(size, int):
|
||||
size_ = itertools.repeat(size) # type: Iterator[int]
|
||||
size_ = constant_schedule(size)
|
||||
else:
|
||||
size_ = iter(size)
|
||||
target_size = next(size_)
|
||||
assert isinstance(size, Schedule)
|
||||
size_ = size
|
||||
step = 0
|
||||
target_size = size_(step)
|
||||
tol_size = target_size * tolerance
|
||||
batch = []
|
||||
overflow = []
|
||||
|
@ -175,7 +179,8 @@ def minibatch_by_words(
|
|||
else:
|
||||
if batch:
|
||||
yield batch
|
||||
target_size = next(size_)
|
||||
step += 1
|
||||
target_size = size_(step)
|
||||
tol_size = target_size * tolerance
|
||||
batch = overflow
|
||||
batch_size = overflow_size
|
||||
|
@ -193,7 +198,8 @@ def minibatch_by_words(
|
|||
else:
|
||||
if batch:
|
||||
yield batch
|
||||
target_size = next(size_)
|
||||
step += 1
|
||||
target_size = size_(step)
|
||||
tol_size = target_size * tolerance
|
||||
batch = [seq]
|
||||
batch_size = n_words
|
||||
|
|
|
@ -204,7 +204,7 @@ def train_while_improving(
|
|||
if before_update:
|
||||
before_update_args = {"step": step, "epoch": epoch}
|
||||
before_update(nlp, before_update_args)
|
||||
dropout = next(dropouts) # type: ignore
|
||||
dropout = dropouts(optimizer.step) # type: ignore
|
||||
for subbatch in subdivide_batch(batch, accumulate_gradient):
|
||||
nlp.update(
|
||||
subbatch,
|
||||
|
@ -230,6 +230,7 @@ def train_while_improving(
|
|||
score, other_scores = evaluate()
|
||||
else:
|
||||
score, other_scores = evaluate()
|
||||
optimizer.last_score = score
|
||||
results.append((score, step))
|
||||
is_best_checkpoint = score == max(results)[0]
|
||||
else:
|
||||
|
|
|
@ -9,7 +9,7 @@ import re
|
|||
from pathlib import Path
|
||||
import thinc
|
||||
from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer
|
||||
from thinc.api import ConfigValidationError, Model
|
||||
from thinc.api import ConfigValidationError, Model, constant as constant_schedule
|
||||
import functools
|
||||
import itertools
|
||||
import numpy
|
||||
|
@ -1582,12 +1582,12 @@ def minibatch(items, size):
|
|||
so that batch-size can vary on each step.
|
||||
"""
|
||||
if isinstance(size, int):
|
||||
size_ = itertools.repeat(size)
|
||||
size_ = constant_schedule(size)
|
||||
else:
|
||||
size_ = size
|
||||
items = iter(items)
|
||||
while True:
|
||||
batch_size = next(size_)
|
||||
for step in itertools.count():
|
||||
batch_size = size_(step)
|
||||
batch = list(itertools.islice(items, int(batch_size)))
|
||||
if len(batch) == 0:
|
||||
break
|
||||
|
|
Loading…
Reference in New Issue
Block a user