Fix batching regression (#12094)

* Fix batching regression

Some time ago, the spaCy v4 branch switched to the new Thinc v9
schedule. However, this introduced an error in how batching is handed.

In the PR, the batchers were changed to keep track of their step,
so that the step can be passed to the schedule. However, the issue
is that the training loop repeatedly calls the batching functions
(rather than using an infinite generator/iterator). So, the step and
therefore the schedule would be reset each epoch. Before the schedule
switch we didn't have this issue, because the old schedules were
stateful.

This PR fixes this issue by reverting the batching functions to use
a (stateful) generator. Their registry functions do accept a `Schedule`
and we convert `Schedule`s to generators.

* Update batcher docs

* Docstring fixes

* Make minibatch take iterables again as well

* Bump thinc requirement to 9.0.0.dev2

* Use type declaration

* Convert another comment into a proper type declaration
This commit is contained in:
Daniël de Kok 2023-01-18 18:28:30 +01:00 committed by GitHub
parent a183db3cef
commit b052b1b47f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 64 additions and 54 deletions

View File

@ -5,7 +5,7 @@ requires = [
"cymem>=2.0.2,<2.1.0",
"preshed>=3.0.2,<3.1.0",
"murmurhash>=0.28.0,<1.1.0",
"thinc>=9.0.0.dev1,<9.1.0",
"thinc>=9.0.0.dev2,<9.1.0",
"numpy>=1.15.0",
]
build-backend = "setuptools.build_meta"

View File

@ -3,7 +3,7 @@ spacy-legacy>=3.0.11,<3.1.0
spacy-loggers>=1.0.0,<2.0.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
thinc>=9.0.0.dev1,<9.1.0
thinc>=9.0.0.dev2,<9.1.0
ml_datasets>=0.2.0,<0.3.0
murmurhash>=0.28.0,<1.1.0
wasabi>=0.9.1,<1.2.0

View File

@ -39,7 +39,7 @@ install_requires =
murmurhash>=0.28.0,<1.1.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
thinc>=9.0.0.dev1,<9.1.0
thinc>=9.0.0.dev2,<9.1.0
wasabi>=0.9.1,<1.2.0
srsly>=2.4.3,<3.0.0
catalogue>=2.0.6,<2.1.0

View File

@ -24,7 +24,9 @@ def test_issue4348():
optimizer = nlp.initialize()
for i in range(5):
losses = {}
batches = util.minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
batches = util.minibatch(
TRAIN_DATA, size=compounding(4.0, 32.0, 1.001).to_generator()
)
for batch in batches:
nlp.update(batch, sgd=optimizer, losses=losses)

View File

@ -91,7 +91,9 @@ def test_issue3611():
optimizer = nlp.initialize()
for i in range(3):
losses = {}
batches = util.minibatch(train_data, size=compounding(4.0, 32.0, 1.001))
batches = util.minibatch(
train_data, size=compounding(4.0, 32.0, 1.001).to_generator()
)
for batch in batches:
nlp.update(examples=batch, sgd=optimizer, drop=0.1, losses=losses)
@ -128,7 +130,9 @@ def test_issue4030():
optimizer = nlp.initialize()
for i in range(3):
losses = {}
batches = util.minibatch(train_data, size=compounding(4.0, 32.0, 1.001))
batches = util.minibatch(
train_data, size=compounding(4.0, 32.0, 1.001).to_generator()
)
for batch in batches:
nlp.update(examples=batch, sgd=optimizer, drop=0.1, losses=losses)

View File

@ -918,7 +918,9 @@ def _train_tuples(train_data):
optimizer = nlp.initialize()
for i in range(5):
losses = {}
batches = minibatch(train_examples, size=compounding(4.0, 32.0, 1.001))
batches = minibatch(
train_examples, size=compounding(4.0, 32.0, 1.001).to_generator()
)
for batch in batches:
nlp.update(batch, sgd=optimizer, losses=losses)

View File

@ -2,12 +2,13 @@ from typing import Union, Iterable, Sequence, TypeVar, List, Callable, Iterator
from typing import Optional, Any
from functools import partial
import itertools
from thinc.schedules import Schedule, constant as constant_schedule
from thinc.schedules import Schedule
from ..util import registry, minibatch
Sizing = Union[Sequence[int], int, Schedule[int]]
SizingSchedule = Union[Iterable[int], int, Schedule]
Sizing = Union[Iterable[int], int]
ItemT = TypeVar("ItemT")
BatcherT = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]]
@ -15,7 +16,7 @@ BatcherT = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]]
@registry.batchers("spacy.batch_by_padded.v1")
def configure_minibatch_by_padded_size(
*,
size: Sizing,
size: SizingSchedule,
buffer: int,
discard_oversize: bool,
get_length: Optional[Callable[[ItemT], int]] = None
@ -25,8 +26,8 @@ def configure_minibatch_by_padded_size(
The padded size is defined as the maximum length of sequences within the
batch multiplied by the number of sequences in the batch.
size (int or Sequence[int]): The largest padded size to batch sequences into.
Can be a single integer, or a sequence, allowing for variable batch sizes.
size (int, Iterable[int] or Schedule): The largest padded size to batch sequences
into. Can be a single integer, or a sequence, allowing for variable batch sizes.
buffer (int): The number of sequences to accumulate before sorting by length.
A larger buffer will result in more even sizing, but if the buffer is
very large, the iteration order will be less random, which can result
@ -40,7 +41,7 @@ def configure_minibatch_by_padded_size(
optionals = {"get_length": get_length} if get_length is not None else {}
return partial(
minibatch_by_padded_size,
size=size,
size=_schedule_to_sizing(size),
buffer=buffer,
discard_oversize=discard_oversize,
**optionals
@ -50,14 +51,14 @@ def configure_minibatch_by_padded_size(
@registry.batchers("spacy.batch_by_words.v1")
def configure_minibatch_by_words(
*,
size: Sizing,
size: SizingSchedule,
tolerance: float,
discard_oversize: bool,
get_length: Optional[Callable[[ItemT], int]] = None
) -> BatcherT:
"""Create a batcher that uses the "minibatch by words" strategy.
size (int or Sequence[int]): The target number of words per batch.
size (int, Iterable[int] or Schedule): The target number of words per batch.
Can be a single integer, or a sequence, allowing for variable batch sizes.
tolerance (float): What percentage of the size to allow batches to exceed.
discard_oversize (bool): Whether to discard sequences that by themselves
@ -68,7 +69,7 @@ def configure_minibatch_by_words(
optionals = {"get_length": get_length} if get_length is not None else {}
return partial(
minibatch_by_words,
size=size,
size=_schedule_to_sizing(size),
tolerance=tolerance,
discard_oversize=discard_oversize,
**optionals
@ -77,15 +78,15 @@ def configure_minibatch_by_words(
@registry.batchers("spacy.batch_by_sequence.v1")
def configure_minibatch(
size: Sizing, get_length: Optional[Callable[[ItemT], int]] = None
size: SizingSchedule, get_length: Optional[Callable[[ItemT], int]] = None
) -> BatcherT:
"""Create a batcher that creates batches of the specified size.
size (int or Sequence[int]): The target number of items per batch.
size (int, Iterable[int] or Schedule): The target number of items per batch.
Can be a single integer, or a sequence, allowing for variable batch sizes.
"""
optionals = {"get_length": get_length} if get_length is not None else {}
return partial(minibatch, size=size, **optionals)
return partial(minibatch, size=_schedule_to_sizing(size), **optionals)
def minibatch_by_padded_size(
@ -101,7 +102,7 @@ def minibatch_by_padded_size(
The padded size is defined as the maximum length of sequences within the
batch multiplied by the number of sequences in the batch.
size (int or Sequence[int]): The largest padded size to batch sequences into.
size (int or Iterable[int]): The largest padded size to batch sequences into.
buffer (int): The number of sequences to accumulate before sorting by length.
A larger buffer will result in more even sizing, but if the buffer is
very large, the iteration order will be less random, which can result
@ -112,13 +113,12 @@ def minibatch_by_padded_size(
The `len` function is used by default.
"""
if isinstance(size, int):
size_ = constant_schedule(size)
size_: Iterator[int] = itertools.repeat(size)
else:
assert isinstance(size, Schedule)
size_ = size
for step, outer_batch in enumerate(minibatch(seqs, size=buffer)):
size_ = iter(size)
for outer_batch in minibatch(seqs, size=buffer):
outer_batch = list(outer_batch)
target_size = size_(step)
target_size = next(size_)
for indices in _batch_by_length(outer_batch, target_size, get_length):
subbatch = [outer_batch[i] for i in indices]
padded_size = max(len(seq) for seq in subbatch) * len(subbatch)
@ -140,7 +140,7 @@ def minibatch_by_words(
themselves, or be discarded if discard_oversize=True.
seqs (Iterable[Sequence]): The sequences to minibatch.
size (int or Sequence[int]): The target number of words per batch.
size (int or Iterable[int]): The target number of words per batch.
Can be a single integer, or a sequence, allowing for variable batch sizes.
tolerance (float): What percentage of the size to allow batches to exceed.
discard_oversize (bool): Whether to discard sequences that by themselves
@ -149,12 +149,10 @@ def minibatch_by_words(
item. The `len` function is used by default.
"""
if isinstance(size, int):
size_ = constant_schedule(size)
size_: Iterator[int] = itertools.repeat(size)
else:
assert isinstance(size, Schedule)
size_ = size
step = 0
target_size = size_(step)
size_ = iter(size)
target_size = next(size_)
tol_size = target_size * tolerance
batch = []
overflow = []
@ -179,8 +177,7 @@ def minibatch_by_words(
else:
if batch:
yield batch
step += 1
target_size = size_(step)
target_size = next(size_)
tol_size = target_size * tolerance
batch = overflow
batch_size = overflow_size
@ -198,8 +195,7 @@ def minibatch_by_words(
else:
if batch:
yield batch
step += 1
target_size = size_(step)
target_size = next(size_)
tol_size = target_size * tolerance
batch = [seq]
batch_size = n_words
@ -236,3 +232,9 @@ def _batch_by_length(
batches = [list(sorted(batch)) for batch in batches]
batches.reverse()
return batches
def _schedule_to_sizing(size: SizingSchedule) -> Sizing:
if isinstance(size, Schedule):
return size.to_generator()
return size

View File

@ -1583,12 +1583,12 @@ def minibatch(items, size):
so that batch-size can vary on each step.
"""
if isinstance(size, int):
size_ = constant_schedule(size)
size_ = itertools.repeat(size)
else:
size_ = size
size_ = iter(size)
items = iter(items)
for step in itertools.count():
batch_size = size_(step)
while True:
batch_size = next(size_)
batch = list(itertools.islice(items, int(batch_size)))
if len(batch) == 0:
break

View File

@ -752,9 +752,9 @@ themselves, or be discarded if `discard_oversize` is set to `True`. The argument
> ```
| Name | Description |
| ------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| ------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `seqs` | The sequences to minibatch. ~~Iterable[Any]~~ |
| `size` | The target number of words per batch. Can also be a block referencing a schedule, e.g. [`compounding`](https://thinc.ai/docs/api-schedules/#compounding). ~~Union[int, Sequence[int]]~~ |
| `size` | The target number of words per batch. Can also be a block referencing a schedule, e.g. [`compounding`](https://thinc.ai/docs/api-schedules/#compounding). ~~Union[int, Iterable[int], Schedule]~~ |
| `tolerance` | What percentage of the size to allow batches to exceed. ~~float~~ |
| `discard_oversize` | Whether to discard sequences that by themselves exceed the tolerated size. ~~bool~~ |
| `get_length` | Optional function that receives a sequence item and returns its length. Defaults to the built-in `len()` if not set. ~~Optional[Callable[[Any], int]]~~ |
@ -774,8 +774,8 @@ themselves, or be discarded if `discard_oversize` is set to `True`. The argument
Create a batcher that creates batches of the specified size.
| Name | Description |
| ------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `size` | The target number of items per batch. Can also be a block referencing a schedule, e.g. [`compounding`](https://thinc.ai/docs/api-schedules/#compounding). ~~Union[int, Sequence[int]]~~ |
| ------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `size` | The target number of items per batch. Can also be a block referencing a schedule, e.g. [`compounding`](https://thinc.ai/docs/api-schedules/#compounding). ~~Union[int, Iterable[int], Schedule]~~ |
| `get_length` | Optional function that receives a sequence item and returns its length. Defaults to the built-in `len()` if not set. ~~Optional[Callable[[Any], int]]~~ |
| **CREATES** | The batcher that takes an iterable of items and returns batches. ~~Callable[[Iterable[Any]], Iterable[List[Any]]]~~ |
@ -799,7 +799,7 @@ sequences in the batch.
| Name | Description |
| ------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `size` | The largest padded size to batch sequences into. Can also be a block referencing a schedule, e.g. [`compounding`](https://thinc.ai/docs/api-schedules/#compounding). ~~Union[int, Sequence[int]]~~ |
| `size` | The largest padded size to batch sequences into. Can also be a block referencing a schedule, e.g. [`compounding`](https://thinc.ai/docs/api-schedules/#compounding). ~~Union[int, Iterable[int], Schedule]~~ |
| `buffer` | The number of sequences to accumulate before sorting by length. A larger buffer will result in more even sizing, but if the buffer is very large, the iteration order will be less random, which can result in suboptimal training. ~~int~~ |
| `discard_oversize` | Whether to discard sequences that are by themselves longer than the largest padded batch size. ~~bool~~ |
| `get_length` | Optional function that receives a sequence item and returns its length. Defaults to the built-in `len()` if not set. ~~Optional[Callable[[Any], int]]~~ |
@ -1401,7 +1401,7 @@ vary on each step.
| Name | Description |
| ---------- | ------------------------------------------------ |
| `items` | The items to batch up. ~~Iterable[Any]~~ |
| `size` | The batch size(s). ~~Union[int, Sequence[int]]~~ |
| `size` | The batch size(s). ~~Union[int, Iterable[int]]~~ |
| **YIELDS** | The batches. |
### util.filter_spans {id="util.filter_spans",tag="function",version="2.1.4"}