mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 20:30:24 +03:00
Make minibatch take iterables again as well
This commit is contained in:
parent
a880e017ca
commit
b93f1758a6
|
@ -24,7 +24,9 @@ def test_issue4348():
|
||||||
optimizer = nlp.initialize()
|
optimizer = nlp.initialize()
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
losses = {}
|
losses = {}
|
||||||
batches = util.minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
|
batches = util.minibatch(
|
||||||
|
TRAIN_DATA, size=compounding(4.0, 32.0, 1.001).to_generator()
|
||||||
|
)
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
nlp.update(batch, sgd=optimizer, losses=losses)
|
nlp.update(batch, sgd=optimizer, losses=losses)
|
||||||
|
|
||||||
|
|
|
@ -91,7 +91,9 @@ def test_issue3611():
|
||||||
optimizer = nlp.initialize()
|
optimizer = nlp.initialize()
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
losses = {}
|
losses = {}
|
||||||
batches = util.minibatch(train_data, size=compounding(4.0, 32.0, 1.001))
|
batches = util.minibatch(
|
||||||
|
train_data, size=compounding(4.0, 32.0, 1.001).to_generator()
|
||||||
|
)
|
||||||
|
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
nlp.update(examples=batch, sgd=optimizer, drop=0.1, losses=losses)
|
nlp.update(examples=batch, sgd=optimizer, drop=0.1, losses=losses)
|
||||||
|
@ -128,7 +130,9 @@ def test_issue4030():
|
||||||
optimizer = nlp.initialize()
|
optimizer = nlp.initialize()
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
losses = {}
|
losses = {}
|
||||||
batches = util.minibatch(train_data, size=compounding(4.0, 32.0, 1.001))
|
batches = util.minibatch(
|
||||||
|
train_data, size=compounding(4.0, 32.0, 1.001).to_generator()
|
||||||
|
)
|
||||||
|
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
nlp.update(examples=batch, sgd=optimizer, drop=0.1, losses=losses)
|
nlp.update(examples=batch, sgd=optimizer, drop=0.1, losses=losses)
|
||||||
|
|
|
@ -905,7 +905,9 @@ def _train_tuples(train_data):
|
||||||
optimizer = nlp.initialize()
|
optimizer = nlp.initialize()
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
losses = {}
|
losses = {}
|
||||||
batches = minibatch(train_examples, size=compounding(4.0, 32.0, 1.001))
|
batches = minibatch(
|
||||||
|
train_examples, size=compounding(4.0, 32.0, 1.001).to_generator()
|
||||||
|
)
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
nlp.update(batch, sgd=optimizer, losses=losses)
|
nlp.update(batch, sgd=optimizer, losses=losses)
|
||||||
|
|
||||||
|
|
|
@ -1582,12 +1582,12 @@ def minibatch(items, size):
|
||||||
so that batch-size can vary on each step.
|
so that batch-size can vary on each step.
|
||||||
"""
|
"""
|
||||||
if isinstance(size, int):
|
if isinstance(size, int):
|
||||||
size_ = constant_schedule(size)
|
size_ = itertools.repeat(size)
|
||||||
else:
|
else:
|
||||||
size_ = size
|
size_ = iter(size)
|
||||||
items = iter(items)
|
items = iter(items)
|
||||||
for step in itertools.count():
|
while True:
|
||||||
batch_size = size_(step)
|
batch_size = next(size_)
|
||||||
batch = list(itertools.islice(items, int(batch_size)))
|
batch = list(itertools.islice(items, int(batch_size)))
|
||||||
if len(batch) == 0:
|
if len(batch) == 0:
|
||||||
break
|
break
|
||||||
|
|
|
@ -752,7 +752,7 @@ themselves, or be discarded if `discard_oversize` is set to `True`. The argument
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| ------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| ------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `seqs` | The sequences to minibatch. ~~Iterable[Any]~~ |
|
| `seqs` | The sequences to minibatch. ~~Iterable[Any]~~ |
|
||||||
| `size` | The target number of words per batch. Can also be a block referencing a schedule, e.g. [`compounding`](https://thinc.ai/docs/api-schedules/#compounding). ~~Union[int, Iterable[int], Schedule]~~ |
|
| `size` | The target number of words per batch. Can also be a block referencing a schedule, e.g. [`compounding`](https://thinc.ai/docs/api-schedules/#compounding). ~~Union[int, Iterable[int], Schedule]~~ |
|
||||||
| `tolerance` | What percentage of the size to allow batches to exceed. ~~float~~ |
|
| `tolerance` | What percentage of the size to allow batches to exceed. ~~float~~ |
|
||||||
|
@ -774,7 +774,7 @@ themselves, or be discarded if `discard_oversize` is set to `True`. The argument
|
||||||
Create a batcher that creates batches of the specified size.
|
Create a batcher that creates batches of the specified size.
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| ------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| ------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `size` | The target number of items per batch. Can also be a block referencing a schedule, e.g. [`compounding`](https://thinc.ai/docs/api-schedules/#compounding). ~~Union[int, Iterable[int], Schedule]~~ |
|
| `size` | The target number of items per batch. Can also be a block referencing a schedule, e.g. [`compounding`](https://thinc.ai/docs/api-schedules/#compounding). ~~Union[int, Iterable[int], Schedule]~~ |
|
||||||
| `get_length` | Optional function that receives a sequence item and returns its length. Defaults to the built-in `len()` if not set. ~~Optional[Callable[[Any], int]]~~ |
|
| `get_length` | Optional function that receives a sequence item and returns its length. Defaults to the built-in `len()` if not set. ~~Optional[Callable[[Any], int]]~~ |
|
||||||
| **CREATES** | The batcher that takes an iterable of items and returns batches. ~~Callable[[Iterable[Any]], Iterable[List[Any]]]~~ |
|
| **CREATES** | The batcher that takes an iterable of items and returns batches. ~~Callable[[Iterable[Any]], Iterable[List[Any]]]~~ |
|
||||||
|
@ -1352,7 +1352,7 @@ vary on each step.
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| ---------- | ------------------------------------------------ |
|
| ---------- | ------------------------------------------------ |
|
||||||
| `items` | The items to batch up. ~~Iterable[Any]~~ |
|
| `items` | The items to batch up. ~~Iterable[Any]~~ |
|
||||||
| `size` | The batch size(s). ~~Union[int, Schedule]~~ |
|
| `size` | The batch size(s). ~~Union[int, Iterable[int]]~~ |
|
||||||
| **YIELDS** | The batches. |
|
| **YIELDS** | The batches. |
|
||||||
|
|
||||||
### util.filter_spans {#util.filter_spans tag="function" new="2.1.4"}
|
### util.filter_spans {#util.filter_spans tag="function" new="2.1.4"}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user