diff --git a/spacy/tests/pipeline/test_tagger.py b/spacy/tests/pipeline/test_tagger.py index a0c71198e..613ce57f9 100644 --- a/spacy/tests/pipeline/test_tagger.py +++ b/spacy/tests/pipeline/test_tagger.py @@ -24,7 +24,9 @@ def test_issue4348(): optimizer = nlp.initialize() for i in range(5): losses = {} - batches = util.minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001)) + batches = util.minibatch( + TRAIN_DATA, size=compounding(4.0, 32.0, 1.001).to_generator() + ) for batch in batches: nlp.update(batch, sgd=optimizer, losses=losses) diff --git a/spacy/tests/pipeline/test_textcat.py b/spacy/tests/pipeline/test_textcat.py index 942062d1d..569f1b429 100644 --- a/spacy/tests/pipeline/test_textcat.py +++ b/spacy/tests/pipeline/test_textcat.py @@ -91,7 +91,9 @@ def test_issue3611(): optimizer = nlp.initialize() for i in range(3): losses = {} - batches = util.minibatch(train_data, size=compounding(4.0, 32.0, 1.001)) + batches = util.minibatch( + train_data, size=compounding(4.0, 32.0, 1.001).to_generator() + ) for batch in batches: nlp.update(examples=batch, sgd=optimizer, drop=0.1, losses=losses) @@ -128,7 +130,9 @@ def test_issue4030(): optimizer = nlp.initialize() for i in range(3): losses = {} - batches = util.minibatch(train_data, size=compounding(4.0, 32.0, 1.001)) + batches = util.minibatch( + train_data, size=compounding(4.0, 32.0, 1.001).to_generator() + ) for batch in batches: nlp.update(examples=batch, sgd=optimizer, drop=0.1, losses=losses) diff --git a/spacy/tests/training/test_training.py b/spacy/tests/training/test_training.py index 7933ea31f..a187885b9 100644 --- a/spacy/tests/training/test_training.py +++ b/spacy/tests/training/test_training.py @@ -905,7 +905,9 @@ def _train_tuples(train_data): optimizer = nlp.initialize() for i in range(5): losses = {} - batches = minibatch(train_examples, size=compounding(4.0, 32.0, 1.001)) + batches = minibatch( + train_examples, size=compounding(4.0, 32.0, 1.001).to_generator() + ) for batch in batches: nlp.update(batch, sgd=optimizer, losses=losses) diff --git a/spacy/util.py b/spacy/util.py index aafbbb5de..ab334002b 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -1582,12 +1582,12 @@ def minibatch(items, size): so that batch-size can vary on each step. """ if isinstance(size, int): - size_ = constant_schedule(size) + size_ = itertools.repeat(size) else: - size_ = size + size_ = iter(size) items = iter(items) - for step in itertools.count(): - batch_size = size_(step) + while True: + batch_size = next(size_) batch = list(itertools.islice(items, int(batch_size))) if len(batch) == 0: break diff --git a/website/docs/api/top-level.md b/website/docs/api/top-level.md index ba9f18902..7ff899098 100644 --- a/website/docs/api/top-level.md +++ b/website/docs/api/top-level.md @@ -751,14 +751,14 @@ themselves, or be discarded if `discard_oversize` is set to `True`. The argument > get_length = null > ``` -| Name | Description | -| ------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `seqs` | The sequences to minibatch. ~~Iterable[Any]~~ | +| Name | Description | +| ------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `seqs` | The sequences to minibatch. ~~Iterable[Any]~~ | | `size` | The target number of words per batch. Can also be a block referencing a schedule, e.g. [`compounding`](https://thinc.ai/docs/api-schedules/#compounding). ~~Union[int, Iterable[int], Schedule]~~ | -| `tolerance` | What percentage of the size to allow batches to exceed. ~~float~~ | -| `discard_oversize` | Whether to discard sequences that by themselves exceed the tolerated size. ~~bool~~ | -| `get_length` | Optional function that receives a sequence item and returns its length. Defaults to the built-in `len()` if not set. ~~Optional[Callable[[Any], int]]~~ | -| **CREATES** | The batcher that takes an iterable of items and returns batches. ~~Callable[[Iterable[Any]], Iterable[List[Any]]]~~ | +| `tolerance` | What percentage of the size to allow batches to exceed. ~~float~~ | +| `discard_oversize` | Whether to discard sequences that by themselves exceed the tolerated size. ~~bool~~ | +| `get_length` | Optional function that receives a sequence item and returns its length. Defaults to the built-in `len()` if not set. ~~Optional[Callable[[Any], int]]~~ | +| **CREATES** | The batcher that takes an iterable of items and returns batches. ~~Callable[[Iterable[Any]], Iterable[List[Any]]]~~ | ### spacy.batch_by_sequence.v1 {#batch_by_sequence tag="registered function"} @@ -773,11 +773,11 @@ themselves, or be discarded if `discard_oversize` is set to `True`. The argument Create a batcher that creates batches of the specified size. -| Name | Description | -| ------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| Name | Description | +| ------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | `size` | The target number of items per batch. Can also be a block referencing a schedule, e.g. [`compounding`](https://thinc.ai/docs/api-schedules/#compounding). ~~Union[int, Iterable[int], Schedule]~~ | -| `get_length` | Optional function that receives a sequence item and returns its length. Defaults to the built-in `len()` if not set. ~~Optional[Callable[[Any], int]]~~ | -| **CREATES** | The batcher that takes an iterable of items and returns batches. ~~Callable[[Iterable[Any]], Iterable[List[Any]]]~~ | +| `get_length` | Optional function that receives a sequence item and returns its length. Defaults to the built-in `len()` if not set. ~~Optional[Callable[[Any], int]]~~ | +| **CREATES** | The batcher that takes an iterable of items and returns batches. ~~Callable[[Iterable[Any]], Iterable[List[Any]]]~~ | ### spacy.batch_by_padded.v1 {#batch_by_padded tag="registered function"} @@ -799,7 +799,7 @@ sequences in the batch. | Name | Description | | ------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `size` | The largest padded size to batch sequences into. Can also be a block referencing a schedule, e.g. [`compounding`](https://thinc.ai/docs/api-schedules/#compounding). ~~Union[int, Iterable[int], Schedule]~~ | +| `size` | The largest padded size to batch sequences into. Can also be a block referencing a schedule, e.g. [`compounding`](https://thinc.ai/docs/api-schedules/#compounding). ~~Union[int, Iterable[int], Schedule]~~ | | `buffer` | The number of sequences to accumulate before sorting by length. A larger buffer will result in more even sizing, but if the buffer is very large, the iteration order will be less random, which can result in suboptimal training. ~~int~~ | | `discard_oversize` | Whether to discard sequences that are by themselves longer than the largest padded batch size. ~~bool~~ | | `get_length` | Optional function that receives a sequence item and returns its length. Defaults to the built-in `len()` if not set. ~~Optional[Callable[[Any], int]]~~ | @@ -1352,7 +1352,7 @@ vary on each step. | Name | Description | | ---------- | ------------------------------------------------ | | `items` | The items to batch up. ~~Iterable[Any]~~ | -| `size` | The batch size(s). ~~Union[int, Schedule]~~ | +| `size` | The batch size(s). ~~Union[int, Iterable[int]]~~ | | **YIELDS** | The batches. | ### util.filter_spans {#util.filter_spans tag="function" new="2.1.4"}