Make minibatch take iterables again as well

This commit is contained in:
Daniël de Kok 2023-01-11 17:51:17 +01:00
parent a880e017ca
commit b93f1758a6
5 changed files with 29 additions and 21 deletions

View File

@ -24,7 +24,9 @@ def test_issue4348():
optimizer = nlp.initialize() optimizer = nlp.initialize()
for i in range(5): for i in range(5):
losses = {} losses = {}
batches = util.minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001)) batches = util.minibatch(
TRAIN_DATA, size=compounding(4.0, 32.0, 1.001).to_generator()
)
for batch in batches: for batch in batches:
nlp.update(batch, sgd=optimizer, losses=losses) nlp.update(batch, sgd=optimizer, losses=losses)

View File

@ -91,7 +91,9 @@ def test_issue3611():
optimizer = nlp.initialize() optimizer = nlp.initialize()
for i in range(3): for i in range(3):
losses = {} losses = {}
batches = util.minibatch(train_data, size=compounding(4.0, 32.0, 1.001)) batches = util.minibatch(
train_data, size=compounding(4.0, 32.0, 1.001).to_generator()
)
for batch in batches: for batch in batches:
nlp.update(examples=batch, sgd=optimizer, drop=0.1, losses=losses) nlp.update(examples=batch, sgd=optimizer, drop=0.1, losses=losses)
@ -128,7 +130,9 @@ def test_issue4030():
optimizer = nlp.initialize() optimizer = nlp.initialize()
for i in range(3): for i in range(3):
losses = {} losses = {}
batches = util.minibatch(train_data, size=compounding(4.0, 32.0, 1.001)) batches = util.minibatch(
train_data, size=compounding(4.0, 32.0, 1.001).to_generator()
)
for batch in batches: for batch in batches:
nlp.update(examples=batch, sgd=optimizer, drop=0.1, losses=losses) nlp.update(examples=batch, sgd=optimizer, drop=0.1, losses=losses)

View File

@ -905,7 +905,9 @@ def _train_tuples(train_data):
optimizer = nlp.initialize() optimizer = nlp.initialize()
for i in range(5): for i in range(5):
losses = {} losses = {}
batches = minibatch(train_examples, size=compounding(4.0, 32.0, 1.001)) batches = minibatch(
train_examples, size=compounding(4.0, 32.0, 1.001).to_generator()
)
for batch in batches: for batch in batches:
nlp.update(batch, sgd=optimizer, losses=losses) nlp.update(batch, sgd=optimizer, losses=losses)

View File

@ -1582,12 +1582,12 @@ def minibatch(items, size):
so that batch-size can vary on each step. so that batch-size can vary on each step.
""" """
if isinstance(size, int): if isinstance(size, int):
size_ = constant_schedule(size) size_ = itertools.repeat(size)
else: else:
size_ = size size_ = iter(size)
items = iter(items) items = iter(items)
for step in itertools.count(): while True:
batch_size = size_(step) batch_size = next(size_)
batch = list(itertools.islice(items, int(batch_size))) batch = list(itertools.islice(items, int(batch_size)))
if len(batch) == 0: if len(batch) == 0:
break break

View File

@ -751,14 +751,14 @@ themselves, or be discarded if `discard_oversize` is set to `True`. The argument
> get_length = null > get_length = null
> ``` > ```
| Name | Description | | Name | Description |
| ------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `seqs` | The sequences to minibatch. ~~Iterable[Any]~~ | | `seqs` | The sequences to minibatch. ~~Iterable[Any]~~ |
| `size` | The target number of words per batch. Can also be a block referencing a schedule, e.g. [`compounding`](https://thinc.ai/docs/api-schedules/#compounding). ~~Union[int, Iterable[int], Schedule]~~ | | `size` | The target number of words per batch. Can also be a block referencing a schedule, e.g. [`compounding`](https://thinc.ai/docs/api-schedules/#compounding). ~~Union[int, Iterable[int], Schedule]~~ |
| `tolerance` | What percentage of the size to allow batches to exceed. ~~float~~ | | `tolerance` | What percentage of the size to allow batches to exceed. ~~float~~ |
| `discard_oversize` | Whether to discard sequences that by themselves exceed the tolerated size. ~~bool~~ | | `discard_oversize` | Whether to discard sequences that by themselves exceed the tolerated size. ~~bool~~ |
| `get_length` | Optional function that receives a sequence item and returns its length. Defaults to the built-in `len()` if not set. ~~Optional[Callable[[Any], int]]~~ | | `get_length` | Optional function that receives a sequence item and returns its length. Defaults to the built-in `len()` if not set. ~~Optional[Callable[[Any], int]]~~ |
| **CREATES** | The batcher that takes an iterable of items and returns batches. ~~Callable[[Iterable[Any]], Iterable[List[Any]]]~~ | | **CREATES** | The batcher that takes an iterable of items and returns batches. ~~Callable[[Iterable[Any]], Iterable[List[Any]]]~~ |
### spacy.batch_by_sequence.v1 {#batch_by_sequence tag="registered function"} ### spacy.batch_by_sequence.v1 {#batch_by_sequence tag="registered function"}
@ -773,11 +773,11 @@ themselves, or be discarded if `discard_oversize` is set to `True`. The argument
Create a batcher that creates batches of the specified size. Create a batcher that creates batches of the specified size.
| Name | Description | | Name | Description |
| ------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `size` | The target number of items per batch. Can also be a block referencing a schedule, e.g. [`compounding`](https://thinc.ai/docs/api-schedules/#compounding). ~~Union[int, Iterable[int], Schedule]~~ | | `size` | The target number of items per batch. Can also be a block referencing a schedule, e.g. [`compounding`](https://thinc.ai/docs/api-schedules/#compounding). ~~Union[int, Iterable[int], Schedule]~~ |
| `get_length` | Optional function that receives a sequence item and returns its length. Defaults to the built-in `len()` if not set. ~~Optional[Callable[[Any], int]]~~ | | `get_length` | Optional function that receives a sequence item and returns its length. Defaults to the built-in `len()` if not set. ~~Optional[Callable[[Any], int]]~~ |
| **CREATES** | The batcher that takes an iterable of items and returns batches. ~~Callable[[Iterable[Any]], Iterable[List[Any]]]~~ | | **CREATES** | The batcher that takes an iterable of items and returns batches. ~~Callable[[Iterable[Any]], Iterable[List[Any]]]~~ |
### spacy.batch_by_padded.v1 {#batch_by_padded tag="registered function"} ### spacy.batch_by_padded.v1 {#batch_by_padded tag="registered function"}
@ -799,7 +799,7 @@ sequences in the batch.
| Name | Description | | Name | Description |
| ------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `size` | The largest padded size to batch sequences into. Can also be a block referencing a schedule, e.g. [`compounding`](https://thinc.ai/docs/api-schedules/#compounding). ~~Union[int, Iterable[int], Schedule]~~ | | `size` | The largest padded size to batch sequences into. Can also be a block referencing a schedule, e.g. [`compounding`](https://thinc.ai/docs/api-schedules/#compounding). ~~Union[int, Iterable[int], Schedule]~~ |
| `buffer` | The number of sequences to accumulate before sorting by length. A larger buffer will result in more even sizing, but if the buffer is very large, the iteration order will be less random, which can result in suboptimal training. ~~int~~ | | `buffer` | The number of sequences to accumulate before sorting by length. A larger buffer will result in more even sizing, but if the buffer is very large, the iteration order will be less random, which can result in suboptimal training. ~~int~~ |
| `discard_oversize` | Whether to discard sequences that are by themselves longer than the largest padded batch size. ~~bool~~ | | `discard_oversize` | Whether to discard sequences that are by themselves longer than the largest padded batch size. ~~bool~~ |
| `get_length` | Optional function that receives a sequence item and returns its length. Defaults to the built-in `len()` if not set. ~~Optional[Callable[[Any], int]]~~ | | `get_length` | Optional function that receives a sequence item and returns its length. Defaults to the built-in `len()` if not set. ~~Optional[Callable[[Any], int]]~~ |
@ -1352,7 +1352,7 @@ vary on each step.
| Name | Description | | Name | Description |
| ---------- | ------------------------------------------------ | | ---------- | ------------------------------------------------ |
| `items` | The items to batch up. ~~Iterable[Any]~~ | | `items` | The items to batch up. ~~Iterable[Any]~~ |
| `size` | The batch size(s). ~~Union[int, Schedule]~~ | | `size` | The batch size(s). ~~Union[int, Iterable[int]]~~ |
| **YIELDS** | The batches. | | **YIELDS** | The batches. |
### util.filter_spans {#util.filter_spans tag="function" new="2.1.4"} ### util.filter_spans {#util.filter_spans tag="function" new="2.1.4"}