mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-24 22:14:56 +03:00
Make minibatch take iterables again as well
This commit is contained in:
parent
a880e017ca
commit
b93f1758a6
|
@ -24,7 +24,9 @@ def test_issue4348():
|
|||
optimizer = nlp.initialize()
|
||||
for i in range(5):
|
||||
losses = {}
|
||||
batches = util.minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
|
||||
batches = util.minibatch(
|
||||
TRAIN_DATA, size=compounding(4.0, 32.0, 1.001).to_generator()
|
||||
)
|
||||
for batch in batches:
|
||||
nlp.update(batch, sgd=optimizer, losses=losses)
|
||||
|
||||
|
|
|
@ -91,7 +91,9 @@ def test_issue3611():
|
|||
optimizer = nlp.initialize()
|
||||
for i in range(3):
|
||||
losses = {}
|
||||
batches = util.minibatch(train_data, size=compounding(4.0, 32.0, 1.001))
|
||||
batches = util.minibatch(
|
||||
train_data, size=compounding(4.0, 32.0, 1.001).to_generator()
|
||||
)
|
||||
|
||||
for batch in batches:
|
||||
nlp.update(examples=batch, sgd=optimizer, drop=0.1, losses=losses)
|
||||
|
@ -128,7 +130,9 @@ def test_issue4030():
|
|||
optimizer = nlp.initialize()
|
||||
for i in range(3):
|
||||
losses = {}
|
||||
batches = util.minibatch(train_data, size=compounding(4.0, 32.0, 1.001))
|
||||
batches = util.minibatch(
|
||||
train_data, size=compounding(4.0, 32.0, 1.001).to_generator()
|
||||
)
|
||||
|
||||
for batch in batches:
|
||||
nlp.update(examples=batch, sgd=optimizer, drop=0.1, losses=losses)
|
||||
|
|
|
@ -905,7 +905,9 @@ def _train_tuples(train_data):
|
|||
optimizer = nlp.initialize()
|
||||
for i in range(5):
|
||||
losses = {}
|
||||
batches = minibatch(train_examples, size=compounding(4.0, 32.0, 1.001))
|
||||
batches = minibatch(
|
||||
train_examples, size=compounding(4.0, 32.0, 1.001).to_generator()
|
||||
)
|
||||
for batch in batches:
|
||||
nlp.update(batch, sgd=optimizer, losses=losses)
|
||||
|
||||
|
|
|
@ -1582,12 +1582,12 @@ def minibatch(items, size):
|
|||
so that batch-size can vary on each step.
|
||||
"""
|
||||
if isinstance(size, int):
|
||||
size_ = constant_schedule(size)
|
||||
size_ = itertools.repeat(size)
|
||||
else:
|
||||
size_ = size
|
||||
size_ = iter(size)
|
||||
items = iter(items)
|
||||
for step in itertools.count():
|
||||
batch_size = size_(step)
|
||||
while True:
|
||||
batch_size = next(size_)
|
||||
batch = list(itertools.islice(items, int(batch_size)))
|
||||
if len(batch) == 0:
|
||||
break
|
||||
|
|
|
@ -751,14 +751,14 @@ themselves, or be discarded if `discard_oversize` is set to `True`. The argument
|
|||
> get_length = null
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `seqs` | The sequences to minibatch. ~~Iterable[Any]~~ |
|
||||
| Name | Description |
|
||||
| ------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `seqs` | The sequences to minibatch. ~~Iterable[Any]~~ |
|
||||
| `size` | The target number of words per batch. Can also be a block referencing a schedule, e.g. [`compounding`](https://thinc.ai/docs/api-schedules/#compounding). ~~Union[int, Iterable[int], Schedule]~~ |
|
||||
| `tolerance` | What percentage of the size to allow batches to exceed. ~~float~~ |
|
||||
| `discard_oversize` | Whether to discard sequences that by themselves exceed the tolerated size. ~~bool~~ |
|
||||
| `get_length` | Optional function that receives a sequence item and returns its length. Defaults to the built-in `len()` if not set. ~~Optional[Callable[[Any], int]]~~ |
|
||||
| **CREATES** | The batcher that takes an iterable of items and returns batches. ~~Callable[[Iterable[Any]], Iterable[List[Any]]]~~ |
|
||||
| `tolerance` | What percentage of the size to allow batches to exceed. ~~float~~ |
|
||||
| `discard_oversize` | Whether to discard sequences that by themselves exceed the tolerated size. ~~bool~~ |
|
||||
| `get_length` | Optional function that receives a sequence item and returns its length. Defaults to the built-in `len()` if not set. ~~Optional[Callable[[Any], int]]~~ |
|
||||
| **CREATES** | The batcher that takes an iterable of items and returns batches. ~~Callable[[Iterable[Any]], Iterable[List[Any]]]~~ |
|
||||
|
||||
### spacy.batch_by_sequence.v1 {#batch_by_sequence tag="registered function"}
|
||||
|
||||
|
@ -773,11 +773,11 @@ themselves, or be discarded if `discard_oversize` is set to `True`. The argument
|
|||
|
||||
Create a batcher that creates batches of the specified size.
|
||||
|
||||
| Name | Description |
|
||||
| ------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Name | Description |
|
||||
| ------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `size` | The target number of items per batch. Can also be a block referencing a schedule, e.g. [`compounding`](https://thinc.ai/docs/api-schedules/#compounding). ~~Union[int, Iterable[int], Schedule]~~ |
|
||||
| `get_length` | Optional function that receives a sequence item and returns its length. Defaults to the built-in `len()` if not set. ~~Optional[Callable[[Any], int]]~~ |
|
||||
| **CREATES** | The batcher that takes an iterable of items and returns batches. ~~Callable[[Iterable[Any]], Iterable[List[Any]]]~~ |
|
||||
| `get_length` | Optional function that receives a sequence item and returns its length. Defaults to the built-in `len()` if not set. ~~Optional[Callable[[Any], int]]~~ |
|
||||
| **CREATES** | The batcher that takes an iterable of items and returns batches. ~~Callable[[Iterable[Any]], Iterable[List[Any]]]~~ |
|
||||
|
||||
### spacy.batch_by_padded.v1 {#batch_by_padded tag="registered function"}
|
||||
|
||||
|
@ -799,7 +799,7 @@ sequences in the batch.
|
|||
|
||||
| Name | Description |
|
||||
| ------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `size` | The largest padded size to batch sequences into. Can also be a block referencing a schedule, e.g. [`compounding`](https://thinc.ai/docs/api-schedules/#compounding). ~~Union[int, Iterable[int], Schedule]~~ |
|
||||
| `size` | The largest padded size to batch sequences into. Can also be a block referencing a schedule, e.g. [`compounding`](https://thinc.ai/docs/api-schedules/#compounding). ~~Union[int, Iterable[int], Schedule]~~ |
|
||||
| `buffer` | The number of sequences to accumulate before sorting by length. A larger buffer will result in more even sizing, but if the buffer is very large, the iteration order will be less random, which can result in suboptimal training. ~~int~~ |
|
||||
| `discard_oversize` | Whether to discard sequences that are by themselves longer than the largest padded batch size. ~~bool~~ |
|
||||
| `get_length` | Optional function that receives a sequence item and returns its length. Defaults to the built-in `len()` if not set. ~~Optional[Callable[[Any], int]]~~ |
|
||||
|
@ -1352,7 +1352,7 @@ vary on each step.
|
|||
| Name | Description |
|
||||
| ---------- | ------------------------------------------------ |
|
||||
| `items` | The items to batch up. ~~Iterable[Any]~~ |
|
||||
| `size` | The batch size(s). ~~Union[int, Schedule]~~ |
|
||||
| `size` | The batch size(s). ~~Union[int, Iterable[int]]~~ |
|
||||
| **YIELDS** | The batches. |
|
||||
|
||||
### util.filter_spans {#util.filter_spans tag="function" new="2.1.4"}
|
||||
|
|
Loading…
Reference in New Issue
Block a user