minibatch utiltiy can deal with strings, docs or examples

This commit is contained in:
svlandeg 2020-06-16 21:35:55 +02:00
parent 8b66c11ff2
commit 4ed399c848
7 changed files with 28 additions and 36 deletions

View File

@ -166,8 +166,7 @@ def pretrain(
skip_counter = 0
loss_func = pretrain_config["loss_func"]
for epoch in range(epoch_resume, pretrain_config["max_epochs"]):
examples = [Example(doc=text) for text in texts]
batches = util.minibatch_by_words(examples, size=pretrain_config["batch_size"])
batches = util.minibatch_by_words(texts, size=pretrain_config["batch_size"])
for batch_id, batch in enumerate(batches):
docs, count = make_docs(
nlp,

View File

@ -275,7 +275,7 @@ def _fix_legacy_dict_data(example_dict):
}
def _parse_ner_tags(biluo_or_offsets, vocab, words, spaces=None):
def _parse_ner_tags(biluo_or_offsets, vocab, words, spaces):
if isinstance(biluo_or_offsets[0], (list, tuple)):
# Convert to biluo if necessary
# This is annoying but to convert the offsets we need a Doc

View File

@ -677,7 +677,7 @@ class Language(object):
# Populate vocab
else:
for example in get_examples():
for word in example.token_annotation.words:
for word in [t.text for t in example.reference]:
_ = self.vocab[word] # noqa: F841
if cfg.get("device", -1) >= 0:

View File

@ -5,7 +5,7 @@ from ..gold import Example
from ..tokens import Doc
from ..vocab import Vocab
from ..language import component
from ..util import link_vectors_to_models, minibatch, eg2doc
from ..util import link_vectors_to_models, minibatch
from .defaults import default_tok2vec
@ -51,19 +51,15 @@ class Tok2Vec(Pipe):
self.set_annotations([doc], tokvecses)
return doc
def pipe(self, stream, batch_size=128, n_threads=-1, as_example=False):
def pipe(self, stream, batch_size=128, n_threads=-1):
"""Process `Doc` objects as a stream.
stream (iterator): A sequence of `Doc` objects to process.
batch_size (int): Number of `Doc` objects to group.
n_threads (int): Number of threads.
YIELDS (iterator): A sequence of `Doc` objects, in order of input.
"""
for batch in minibatch(stream, batch_size):
for docs in minibatch(stream, batch_size):
batch = list(batch)
if as_example:
docs = [eg2doc(doc) for doc in batch]
else:
docs = batch
tokvecses = self.predict(docs)
self.set_annotations(docs, tokvecses)
yield from batch

View File

@ -430,7 +430,6 @@ def test_tuple_format_implicit():
_train(train_data)
@pytest.mark.xfail # TODO
def test_tuple_format_implicit_invalid():
"""Test that an error is thrown for an implicit invalid field"""
@ -443,7 +442,7 @@ def test_tuple_format_implicit_invalid():
("Google rebrands its business apps", {"entities": [(0, 6, "ORG")]}),
]
with pytest.raises(TypeError):
with pytest.raises(KeyError):
_train(train_data)

View File

@ -25,15 +25,14 @@ from spacy.util import minibatch_by_words
)
def test_util_minibatch(doc_sizes, expected_batches):
docs = [get_random_doc(doc_size) for doc_size in doc_sizes]
examples = [Example(doc=doc) for doc in docs]
tol = 0.2
batch_size = 1000
batches = list(minibatch_by_words(examples=examples, size=batch_size, tolerance=tol, discard_oversize=True))
batches = list(minibatch_by_words(docs, size=batch_size, tolerance=tol, discard_oversize=True))
assert [len(batch) for batch in batches] == expected_batches
max_size = batch_size + batch_size * tol
for batch in batches:
assert sum([len(example.doc) for example in batch]) < max_size
assert sum([len(doc) for doc in batch]) < max_size
@pytest.mark.parametrize(
@ -50,10 +49,9 @@ def test_util_minibatch(doc_sizes, expected_batches):
def test_util_minibatch_oversize(doc_sizes, expected_batches):
""" Test that oversized documents are returned in their own batch"""
docs = [get_random_doc(doc_size) for doc_size in doc_sizes]
examples = [Example(doc=doc) for doc in docs]
tol = 0.2
batch_size = 1000
batches = list(minibatch_by_words(examples=examples, size=batch_size, tolerance=tol, discard_oversize=False))
batches = list(minibatch_by_words(docs, size=batch_size, tolerance=tol, discard_oversize=False))
assert [len(batch) for batch in batches] == expected_batches

View File

@ -471,14 +471,6 @@ def get_async(stream, numpy_array):
return array
def eg2doc(example):
"""Get a Doc object from an Example (or if it's a Doc, use it directly)"""
# Put the import here to avoid circular import problems
from .tokens.doc import Doc
return example if isinstance(example, Doc) else example.doc
def env_opt(name, default=None):
if type(default) is float:
type_convert = float
@ -697,10 +689,13 @@ def decaying(start, stop, decay):
curr -= decay
def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_oversize=False):
def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
"""Create minibatches of roughly a given number of words. If any examples
are longer than the specified batch length, they will appear in a batch by
themselves, or be discarded if discard_oversize=True."""
themselves, or be discarded if discard_oversize=True.
The argument 'docs' can be a list of strings, Doc's or Example's. """
from .gold import Example
if isinstance(size, int):
size_ = itertools.repeat(size)
elif isinstance(size, List):
@ -715,22 +710,27 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o
batch_size = 0
overflow_size = 0
for example in examples:
n_words = count_words(example.doc)
for doc in docs:
if isinstance(doc, Example):
n_words = len(doc.reference)
elif isinstance(doc, str):
n_words = len(doc.split())
else:
n_words = len(doc)
# if the current example exceeds the maximum batch size, it is returned separately
# but only if discard_oversize=False.
if n_words > target_size + tol_size:
if not discard_oversize:
yield [example]
yield [doc]
# add the example to the current batch if there's no overflow yet and it still fits
elif overflow_size == 0 and (batch_size + n_words) <= target_size:
batch.append(example)
batch.append(doc)
batch_size += n_words
# add the example to the overflow buffer if it fits in the tolerance margin
elif (batch_size + overflow_size + n_words) <= (target_size + tol_size):
overflow.append(example)
overflow.append(doc)
overflow_size += n_words
# yield the previous batch and start a new one. The new one gets the overflow examples.
@ -745,12 +745,12 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o
# this example still fits
if (batch_size + n_words) <= target_size:
batch.append(example)
batch.append(doc)
batch_size += n_words
# this example fits in overflow
elif (batch_size + n_words) <= (target_size + tol_size):
overflow.append(example)
overflow.append(doc)
overflow_size += n_words
# this example does not fit with the previous overflow: start another new batch
@ -758,7 +758,7 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o
yield batch
target_size = next(size_)
tol_size = target_size * tolerance
batch = [example]
batch = [doc]
batch_size = n_words
# yield the final batch