mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-11 00:50:33 +03:00
minibatch utiltiy can deal with strings, docs or examples
This commit is contained in:
parent
8b66c11ff2
commit
4ed399c848
|
@ -166,8 +166,7 @@ def pretrain(
|
|||
skip_counter = 0
|
||||
loss_func = pretrain_config["loss_func"]
|
||||
for epoch in range(epoch_resume, pretrain_config["max_epochs"]):
|
||||
examples = [Example(doc=text) for text in texts]
|
||||
batches = util.minibatch_by_words(examples, size=pretrain_config["batch_size"])
|
||||
batches = util.minibatch_by_words(texts, size=pretrain_config["batch_size"])
|
||||
for batch_id, batch in enumerate(batches):
|
||||
docs, count = make_docs(
|
||||
nlp,
|
||||
|
|
|
@ -275,7 +275,7 @@ def _fix_legacy_dict_data(example_dict):
|
|||
}
|
||||
|
||||
|
||||
def _parse_ner_tags(biluo_or_offsets, vocab, words, spaces=None):
|
||||
def _parse_ner_tags(biluo_or_offsets, vocab, words, spaces):
|
||||
if isinstance(biluo_or_offsets[0], (list, tuple)):
|
||||
# Convert to biluo if necessary
|
||||
# This is annoying but to convert the offsets we need a Doc
|
||||
|
|
|
@ -677,7 +677,7 @@ class Language(object):
|
|||
# Populate vocab
|
||||
else:
|
||||
for example in get_examples():
|
||||
for word in example.token_annotation.words:
|
||||
for word in [t.text for t in example.reference]:
|
||||
_ = self.vocab[word] # noqa: F841
|
||||
|
||||
if cfg.get("device", -1) >= 0:
|
||||
|
|
|
@ -5,7 +5,7 @@ from ..gold import Example
|
|||
from ..tokens import Doc
|
||||
from ..vocab import Vocab
|
||||
from ..language import component
|
||||
from ..util import link_vectors_to_models, minibatch, eg2doc
|
||||
from ..util import link_vectors_to_models, minibatch
|
||||
from .defaults import default_tok2vec
|
||||
|
||||
|
||||
|
@ -51,19 +51,15 @@ class Tok2Vec(Pipe):
|
|||
self.set_annotations([doc], tokvecses)
|
||||
return doc
|
||||
|
||||
def pipe(self, stream, batch_size=128, n_threads=-1, as_example=False):
|
||||
def pipe(self, stream, batch_size=128, n_threads=-1):
|
||||
"""Process `Doc` objects as a stream.
|
||||
stream (iterator): A sequence of `Doc` objects to process.
|
||||
batch_size (int): Number of `Doc` objects to group.
|
||||
n_threads (int): Number of threads.
|
||||
YIELDS (iterator): A sequence of `Doc` objects, in order of input.
|
||||
"""
|
||||
for batch in minibatch(stream, batch_size):
|
||||
for docs in minibatch(stream, batch_size):
|
||||
batch = list(batch)
|
||||
if as_example:
|
||||
docs = [eg2doc(doc) for doc in batch]
|
||||
else:
|
||||
docs = batch
|
||||
tokvecses = self.predict(docs)
|
||||
self.set_annotations(docs, tokvecses)
|
||||
yield from batch
|
||||
|
|
|
@ -430,7 +430,6 @@ def test_tuple_format_implicit():
|
|||
_train(train_data)
|
||||
|
||||
|
||||
@pytest.mark.xfail # TODO
|
||||
def test_tuple_format_implicit_invalid():
|
||||
"""Test that an error is thrown for an implicit invalid field"""
|
||||
|
||||
|
@ -443,7 +442,7 @@ def test_tuple_format_implicit_invalid():
|
|||
("Google rebrands its business apps", {"entities": [(0, 6, "ORG")]}),
|
||||
]
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
with pytest.raises(KeyError):
|
||||
_train(train_data)
|
||||
|
||||
|
||||
|
|
|
@ -25,15 +25,14 @@ from spacy.util import minibatch_by_words
|
|||
)
|
||||
def test_util_minibatch(doc_sizes, expected_batches):
|
||||
docs = [get_random_doc(doc_size) for doc_size in doc_sizes]
|
||||
examples = [Example(doc=doc) for doc in docs]
|
||||
tol = 0.2
|
||||
batch_size = 1000
|
||||
batches = list(minibatch_by_words(examples=examples, size=batch_size, tolerance=tol, discard_oversize=True))
|
||||
batches = list(minibatch_by_words(docs, size=batch_size, tolerance=tol, discard_oversize=True))
|
||||
assert [len(batch) for batch in batches] == expected_batches
|
||||
|
||||
max_size = batch_size + batch_size * tol
|
||||
for batch in batches:
|
||||
assert sum([len(example.doc) for example in batch]) < max_size
|
||||
assert sum([len(doc) for doc in batch]) < max_size
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -50,10 +49,9 @@ def test_util_minibatch(doc_sizes, expected_batches):
|
|||
def test_util_minibatch_oversize(doc_sizes, expected_batches):
|
||||
""" Test that oversized documents are returned in their own batch"""
|
||||
docs = [get_random_doc(doc_size) for doc_size in doc_sizes]
|
||||
examples = [Example(doc=doc) for doc in docs]
|
||||
tol = 0.2
|
||||
batch_size = 1000
|
||||
batches = list(minibatch_by_words(examples=examples, size=batch_size, tolerance=tol, discard_oversize=False))
|
||||
batches = list(minibatch_by_words(docs, size=batch_size, tolerance=tol, discard_oversize=False))
|
||||
assert [len(batch) for batch in batches] == expected_batches
|
||||
|
||||
|
||||
|
|
|
@ -471,14 +471,6 @@ def get_async(stream, numpy_array):
|
|||
return array
|
||||
|
||||
|
||||
def eg2doc(example):
|
||||
"""Get a Doc object from an Example (or if it's a Doc, use it directly)"""
|
||||
# Put the import here to avoid circular import problems
|
||||
from .tokens.doc import Doc
|
||||
|
||||
return example if isinstance(example, Doc) else example.doc
|
||||
|
||||
|
||||
def env_opt(name, default=None):
|
||||
if type(default) is float:
|
||||
type_convert = float
|
||||
|
@ -697,10 +689,13 @@ def decaying(start, stop, decay):
|
|||
curr -= decay
|
||||
|
||||
|
||||
def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_oversize=False):
|
||||
def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
|
||||
"""Create minibatches of roughly a given number of words. If any examples
|
||||
are longer than the specified batch length, they will appear in a batch by
|
||||
themselves, or be discarded if discard_oversize=True."""
|
||||
themselves, or be discarded if discard_oversize=True.
|
||||
The argument 'docs' can be a list of strings, Doc's or Example's. """
|
||||
from .gold import Example
|
||||
|
||||
if isinstance(size, int):
|
||||
size_ = itertools.repeat(size)
|
||||
elif isinstance(size, List):
|
||||
|
@ -715,22 +710,27 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o
|
|||
batch_size = 0
|
||||
overflow_size = 0
|
||||
|
||||
for example in examples:
|
||||
n_words = count_words(example.doc)
|
||||
for doc in docs:
|
||||
if isinstance(doc, Example):
|
||||
n_words = len(doc.reference)
|
||||
elif isinstance(doc, str):
|
||||
n_words = len(doc.split())
|
||||
else:
|
||||
n_words = len(doc)
|
||||
# if the current example exceeds the maximum batch size, it is returned separately
|
||||
# but only if discard_oversize=False.
|
||||
if n_words > target_size + tol_size:
|
||||
if not discard_oversize:
|
||||
yield [example]
|
||||
yield [doc]
|
||||
|
||||
# add the example to the current batch if there's no overflow yet and it still fits
|
||||
elif overflow_size == 0 and (batch_size + n_words) <= target_size:
|
||||
batch.append(example)
|
||||
batch.append(doc)
|
||||
batch_size += n_words
|
||||
|
||||
# add the example to the overflow buffer if it fits in the tolerance margin
|
||||
elif (batch_size + overflow_size + n_words) <= (target_size + tol_size):
|
||||
overflow.append(example)
|
||||
overflow.append(doc)
|
||||
overflow_size += n_words
|
||||
|
||||
# yield the previous batch and start a new one. The new one gets the overflow examples.
|
||||
|
@ -745,12 +745,12 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o
|
|||
|
||||
# this example still fits
|
||||
if (batch_size + n_words) <= target_size:
|
||||
batch.append(example)
|
||||
batch.append(doc)
|
||||
batch_size += n_words
|
||||
|
||||
# this example fits in overflow
|
||||
elif (batch_size + n_words) <= (target_size + tol_size):
|
||||
overflow.append(example)
|
||||
overflow.append(doc)
|
||||
overflow_size += n_words
|
||||
|
||||
# this example does not fit with the previous overflow: start another new batch
|
||||
|
@ -758,7 +758,7 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o
|
|||
yield batch
|
||||
target_size = next(size_)
|
||||
tol_size = target_size * tolerance
|
||||
batch = [example]
|
||||
batch = [doc]
|
||||
batch_size = n_words
|
||||
|
||||
# yield the final batch
|
||||
|
|
Loading…
Reference in New Issue
Block a user