minibatch utiltiy can deal with strings, docs or examples

This commit is contained in:
svlandeg 2020-06-16 21:35:55 +02:00
parent 8b66c11ff2
commit 4ed399c848
7 changed files with 28 additions and 36 deletions

View File

@ -166,8 +166,7 @@ def pretrain(
skip_counter = 0 skip_counter = 0
loss_func = pretrain_config["loss_func"] loss_func = pretrain_config["loss_func"]
for epoch in range(epoch_resume, pretrain_config["max_epochs"]): for epoch in range(epoch_resume, pretrain_config["max_epochs"]):
examples = [Example(doc=text) for text in texts] batches = util.minibatch_by_words(texts, size=pretrain_config["batch_size"])
batches = util.minibatch_by_words(examples, size=pretrain_config["batch_size"])
for batch_id, batch in enumerate(batches): for batch_id, batch in enumerate(batches):
docs, count = make_docs( docs, count = make_docs(
nlp, nlp,

View File

@ -275,7 +275,7 @@ def _fix_legacy_dict_data(example_dict):
} }
def _parse_ner_tags(biluo_or_offsets, vocab, words, spaces=None): def _parse_ner_tags(biluo_or_offsets, vocab, words, spaces):
if isinstance(biluo_or_offsets[0], (list, tuple)): if isinstance(biluo_or_offsets[0], (list, tuple)):
# Convert to biluo if necessary # Convert to biluo if necessary
# This is annoying but to convert the offsets we need a Doc # This is annoying but to convert the offsets we need a Doc

View File

@ -677,7 +677,7 @@ class Language(object):
# Populate vocab # Populate vocab
else: else:
for example in get_examples(): for example in get_examples():
for word in example.token_annotation.words: for word in [t.text for t in example.reference]:
_ = self.vocab[word] # noqa: F841 _ = self.vocab[word] # noqa: F841
if cfg.get("device", -1) >= 0: if cfg.get("device", -1) >= 0:

View File

@ -5,7 +5,7 @@ from ..gold import Example
from ..tokens import Doc from ..tokens import Doc
from ..vocab import Vocab from ..vocab import Vocab
from ..language import component from ..language import component
from ..util import link_vectors_to_models, minibatch, eg2doc from ..util import link_vectors_to_models, minibatch
from .defaults import default_tok2vec from .defaults import default_tok2vec
@ -51,19 +51,15 @@ class Tok2Vec(Pipe):
self.set_annotations([doc], tokvecses) self.set_annotations([doc], tokvecses)
return doc return doc
def pipe(self, stream, batch_size=128, n_threads=-1, as_example=False): def pipe(self, stream, batch_size=128, n_threads=-1):
"""Process `Doc` objects as a stream. """Process `Doc` objects as a stream.
stream (iterator): A sequence of `Doc` objects to process. stream (iterator): A sequence of `Doc` objects to process.
batch_size (int): Number of `Doc` objects to group. batch_size (int): Number of `Doc` objects to group.
n_threads (int): Number of threads. n_threads (int): Number of threads.
YIELDS (iterator): A sequence of `Doc` objects, in order of input. YIELDS (iterator): A sequence of `Doc` objects, in order of input.
""" """
for batch in minibatch(stream, batch_size): for docs in minibatch(stream, batch_size):
batch = list(batch) batch = list(batch)
if as_example:
docs = [eg2doc(doc) for doc in batch]
else:
docs = batch
tokvecses = self.predict(docs) tokvecses = self.predict(docs)
self.set_annotations(docs, tokvecses) self.set_annotations(docs, tokvecses)
yield from batch yield from batch

View File

@ -430,7 +430,6 @@ def test_tuple_format_implicit():
_train(train_data) _train(train_data)
@pytest.mark.xfail # TODO
def test_tuple_format_implicit_invalid(): def test_tuple_format_implicit_invalid():
"""Test that an error is thrown for an implicit invalid field""" """Test that an error is thrown for an implicit invalid field"""
@ -443,7 +442,7 @@ def test_tuple_format_implicit_invalid():
("Google rebrands its business apps", {"entities": [(0, 6, "ORG")]}), ("Google rebrands its business apps", {"entities": [(0, 6, "ORG")]}),
] ]
with pytest.raises(TypeError): with pytest.raises(KeyError):
_train(train_data) _train(train_data)

View File

@ -25,15 +25,14 @@ from spacy.util import minibatch_by_words
) )
def test_util_minibatch(doc_sizes, expected_batches): def test_util_minibatch(doc_sizes, expected_batches):
docs = [get_random_doc(doc_size) for doc_size in doc_sizes] docs = [get_random_doc(doc_size) for doc_size in doc_sizes]
examples = [Example(doc=doc) for doc in docs]
tol = 0.2 tol = 0.2
batch_size = 1000 batch_size = 1000
batches = list(minibatch_by_words(examples=examples, size=batch_size, tolerance=tol, discard_oversize=True)) batches = list(minibatch_by_words(docs, size=batch_size, tolerance=tol, discard_oversize=True))
assert [len(batch) for batch in batches] == expected_batches assert [len(batch) for batch in batches] == expected_batches
max_size = batch_size + batch_size * tol max_size = batch_size + batch_size * tol
for batch in batches: for batch in batches:
assert sum([len(example.doc) for example in batch]) < max_size assert sum([len(doc) for doc in batch]) < max_size
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -50,10 +49,9 @@ def test_util_minibatch(doc_sizes, expected_batches):
def test_util_minibatch_oversize(doc_sizes, expected_batches): def test_util_minibatch_oversize(doc_sizes, expected_batches):
""" Test that oversized documents are returned in their own batch""" """ Test that oversized documents are returned in their own batch"""
docs = [get_random_doc(doc_size) for doc_size in doc_sizes] docs = [get_random_doc(doc_size) for doc_size in doc_sizes]
examples = [Example(doc=doc) for doc in docs]
tol = 0.2 tol = 0.2
batch_size = 1000 batch_size = 1000
batches = list(minibatch_by_words(examples=examples, size=batch_size, tolerance=tol, discard_oversize=False)) batches = list(minibatch_by_words(docs, size=batch_size, tolerance=tol, discard_oversize=False))
assert [len(batch) for batch in batches] == expected_batches assert [len(batch) for batch in batches] == expected_batches

View File

@ -471,14 +471,6 @@ def get_async(stream, numpy_array):
return array return array
def eg2doc(example):
"""Get a Doc object from an Example (or if it's a Doc, use it directly)"""
# Put the import here to avoid circular import problems
from .tokens.doc import Doc
return example if isinstance(example, Doc) else example.doc
def env_opt(name, default=None): def env_opt(name, default=None):
if type(default) is float: if type(default) is float:
type_convert = float type_convert = float
@ -697,10 +689,13 @@ def decaying(start, stop, decay):
curr -= decay curr -= decay
def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_oversize=False): def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
"""Create minibatches of roughly a given number of words. If any examples """Create minibatches of roughly a given number of words. If any examples
are longer than the specified batch length, they will appear in a batch by are longer than the specified batch length, they will appear in a batch by
themselves, or be discarded if discard_oversize=True.""" themselves, or be discarded if discard_oversize=True.
The argument 'docs' can be a list of strings, Doc's or Example's. """
from .gold import Example
if isinstance(size, int): if isinstance(size, int):
size_ = itertools.repeat(size) size_ = itertools.repeat(size)
elif isinstance(size, List): elif isinstance(size, List):
@ -715,22 +710,27 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o
batch_size = 0 batch_size = 0
overflow_size = 0 overflow_size = 0
for example in examples: for doc in docs:
n_words = count_words(example.doc) if isinstance(doc, Example):
n_words = len(doc.reference)
elif isinstance(doc, str):
n_words = len(doc.split())
else:
n_words = len(doc)
# if the current example exceeds the maximum batch size, it is returned separately # if the current example exceeds the maximum batch size, it is returned separately
# but only if discard_oversize=False. # but only if discard_oversize=False.
if n_words > target_size + tol_size: if n_words > target_size + tol_size:
if not discard_oversize: if not discard_oversize:
yield [example] yield [doc]
# add the example to the current batch if there's no overflow yet and it still fits # add the example to the current batch if there's no overflow yet and it still fits
elif overflow_size == 0 and (batch_size + n_words) <= target_size: elif overflow_size == 0 and (batch_size + n_words) <= target_size:
batch.append(example) batch.append(doc)
batch_size += n_words batch_size += n_words
# add the example to the overflow buffer if it fits in the tolerance margin # add the example to the overflow buffer if it fits in the tolerance margin
elif (batch_size + overflow_size + n_words) <= (target_size + tol_size): elif (batch_size + overflow_size + n_words) <= (target_size + tol_size):
overflow.append(example) overflow.append(doc)
overflow_size += n_words overflow_size += n_words
# yield the previous batch and start a new one. The new one gets the overflow examples. # yield the previous batch and start a new one. The new one gets the overflow examples.
@ -745,12 +745,12 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o
# this example still fits # this example still fits
if (batch_size + n_words) <= target_size: if (batch_size + n_words) <= target_size:
batch.append(example) batch.append(doc)
batch_size += n_words batch_size += n_words
# this example fits in overflow # this example fits in overflow
elif (batch_size + n_words) <= (target_size + tol_size): elif (batch_size + n_words) <= (target_size + tol_size):
overflow.append(example) overflow.append(doc)
overflow_size += n_words overflow_size += n_words
# this example does not fit with the previous overflow: start another new batch # this example does not fit with the previous overflow: start another new batch
@ -758,7 +758,7 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o
yield batch yield batch
target_size = next(size_) target_size = next(size_)
tol_size = target_size * tolerance tol_size = target_size * tolerance
batch = [example] batch = [doc]
batch_size = n_words batch_size = n_words
# yield the final batch # yield the final batch