mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-11 17:10:36 +03:00
minibatch utiltiy can deal with strings, docs or examples
This commit is contained in:
parent
8b66c11ff2
commit
4ed399c848
|
@ -166,8 +166,7 @@ def pretrain(
|
||||||
skip_counter = 0
|
skip_counter = 0
|
||||||
loss_func = pretrain_config["loss_func"]
|
loss_func = pretrain_config["loss_func"]
|
||||||
for epoch in range(epoch_resume, pretrain_config["max_epochs"]):
|
for epoch in range(epoch_resume, pretrain_config["max_epochs"]):
|
||||||
examples = [Example(doc=text) for text in texts]
|
batches = util.minibatch_by_words(texts, size=pretrain_config["batch_size"])
|
||||||
batches = util.minibatch_by_words(examples, size=pretrain_config["batch_size"])
|
|
||||||
for batch_id, batch in enumerate(batches):
|
for batch_id, batch in enumerate(batches):
|
||||||
docs, count = make_docs(
|
docs, count = make_docs(
|
||||||
nlp,
|
nlp,
|
||||||
|
|
|
@ -275,7 +275,7 @@ def _fix_legacy_dict_data(example_dict):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _parse_ner_tags(biluo_or_offsets, vocab, words, spaces=None):
|
def _parse_ner_tags(biluo_or_offsets, vocab, words, spaces):
|
||||||
if isinstance(biluo_or_offsets[0], (list, tuple)):
|
if isinstance(biluo_or_offsets[0], (list, tuple)):
|
||||||
# Convert to biluo if necessary
|
# Convert to biluo if necessary
|
||||||
# This is annoying but to convert the offsets we need a Doc
|
# This is annoying but to convert the offsets we need a Doc
|
||||||
|
|
|
@ -677,7 +677,7 @@ class Language(object):
|
||||||
# Populate vocab
|
# Populate vocab
|
||||||
else:
|
else:
|
||||||
for example in get_examples():
|
for example in get_examples():
|
||||||
for word in example.token_annotation.words:
|
for word in [t.text for t in example.reference]:
|
||||||
_ = self.vocab[word] # noqa: F841
|
_ = self.vocab[word] # noqa: F841
|
||||||
|
|
||||||
if cfg.get("device", -1) >= 0:
|
if cfg.get("device", -1) >= 0:
|
||||||
|
|
|
@ -5,7 +5,7 @@ from ..gold import Example
|
||||||
from ..tokens import Doc
|
from ..tokens import Doc
|
||||||
from ..vocab import Vocab
|
from ..vocab import Vocab
|
||||||
from ..language import component
|
from ..language import component
|
||||||
from ..util import link_vectors_to_models, minibatch, eg2doc
|
from ..util import link_vectors_to_models, minibatch
|
||||||
from .defaults import default_tok2vec
|
from .defaults import default_tok2vec
|
||||||
|
|
||||||
|
|
||||||
|
@ -51,19 +51,15 @@ class Tok2Vec(Pipe):
|
||||||
self.set_annotations([doc], tokvecses)
|
self.set_annotations([doc], tokvecses)
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
def pipe(self, stream, batch_size=128, n_threads=-1, as_example=False):
|
def pipe(self, stream, batch_size=128, n_threads=-1):
|
||||||
"""Process `Doc` objects as a stream.
|
"""Process `Doc` objects as a stream.
|
||||||
stream (iterator): A sequence of `Doc` objects to process.
|
stream (iterator): A sequence of `Doc` objects to process.
|
||||||
batch_size (int): Number of `Doc` objects to group.
|
batch_size (int): Number of `Doc` objects to group.
|
||||||
n_threads (int): Number of threads.
|
n_threads (int): Number of threads.
|
||||||
YIELDS (iterator): A sequence of `Doc` objects, in order of input.
|
YIELDS (iterator): A sequence of `Doc` objects, in order of input.
|
||||||
"""
|
"""
|
||||||
for batch in minibatch(stream, batch_size):
|
for docs in minibatch(stream, batch_size):
|
||||||
batch = list(batch)
|
batch = list(batch)
|
||||||
if as_example:
|
|
||||||
docs = [eg2doc(doc) for doc in batch]
|
|
||||||
else:
|
|
||||||
docs = batch
|
|
||||||
tokvecses = self.predict(docs)
|
tokvecses = self.predict(docs)
|
||||||
self.set_annotations(docs, tokvecses)
|
self.set_annotations(docs, tokvecses)
|
||||||
yield from batch
|
yield from batch
|
||||||
|
|
|
@ -430,7 +430,6 @@ def test_tuple_format_implicit():
|
||||||
_train(train_data)
|
_train(train_data)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail # TODO
|
|
||||||
def test_tuple_format_implicit_invalid():
|
def test_tuple_format_implicit_invalid():
|
||||||
"""Test that an error is thrown for an implicit invalid field"""
|
"""Test that an error is thrown for an implicit invalid field"""
|
||||||
|
|
||||||
|
@ -443,7 +442,7 @@ def test_tuple_format_implicit_invalid():
|
||||||
("Google rebrands its business apps", {"entities": [(0, 6, "ORG")]}),
|
("Google rebrands its business apps", {"entities": [(0, 6, "ORG")]}),
|
||||||
]
|
]
|
||||||
|
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(KeyError):
|
||||||
_train(train_data)
|
_train(train_data)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -25,15 +25,14 @@ from spacy.util import minibatch_by_words
|
||||||
)
|
)
|
||||||
def test_util_minibatch(doc_sizes, expected_batches):
|
def test_util_minibatch(doc_sizes, expected_batches):
|
||||||
docs = [get_random_doc(doc_size) for doc_size in doc_sizes]
|
docs = [get_random_doc(doc_size) for doc_size in doc_sizes]
|
||||||
examples = [Example(doc=doc) for doc in docs]
|
|
||||||
tol = 0.2
|
tol = 0.2
|
||||||
batch_size = 1000
|
batch_size = 1000
|
||||||
batches = list(minibatch_by_words(examples=examples, size=batch_size, tolerance=tol, discard_oversize=True))
|
batches = list(minibatch_by_words(docs, size=batch_size, tolerance=tol, discard_oversize=True))
|
||||||
assert [len(batch) for batch in batches] == expected_batches
|
assert [len(batch) for batch in batches] == expected_batches
|
||||||
|
|
||||||
max_size = batch_size + batch_size * tol
|
max_size = batch_size + batch_size * tol
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
assert sum([len(example.doc) for example in batch]) < max_size
|
assert sum([len(doc) for doc in batch]) < max_size
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -50,10 +49,9 @@ def test_util_minibatch(doc_sizes, expected_batches):
|
||||||
def test_util_minibatch_oversize(doc_sizes, expected_batches):
|
def test_util_minibatch_oversize(doc_sizes, expected_batches):
|
||||||
""" Test that oversized documents are returned in their own batch"""
|
""" Test that oversized documents are returned in their own batch"""
|
||||||
docs = [get_random_doc(doc_size) for doc_size in doc_sizes]
|
docs = [get_random_doc(doc_size) for doc_size in doc_sizes]
|
||||||
examples = [Example(doc=doc) for doc in docs]
|
|
||||||
tol = 0.2
|
tol = 0.2
|
||||||
batch_size = 1000
|
batch_size = 1000
|
||||||
batches = list(minibatch_by_words(examples=examples, size=batch_size, tolerance=tol, discard_oversize=False))
|
batches = list(minibatch_by_words(docs, size=batch_size, tolerance=tol, discard_oversize=False))
|
||||||
assert [len(batch) for batch in batches] == expected_batches
|
assert [len(batch) for batch in batches] == expected_batches
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -471,14 +471,6 @@ def get_async(stream, numpy_array):
|
||||||
return array
|
return array
|
||||||
|
|
||||||
|
|
||||||
def eg2doc(example):
|
|
||||||
"""Get a Doc object from an Example (or if it's a Doc, use it directly)"""
|
|
||||||
# Put the import here to avoid circular import problems
|
|
||||||
from .tokens.doc import Doc
|
|
||||||
|
|
||||||
return example if isinstance(example, Doc) else example.doc
|
|
||||||
|
|
||||||
|
|
||||||
def env_opt(name, default=None):
|
def env_opt(name, default=None):
|
||||||
if type(default) is float:
|
if type(default) is float:
|
||||||
type_convert = float
|
type_convert = float
|
||||||
|
@ -697,10 +689,13 @@ def decaying(start, stop, decay):
|
||||||
curr -= decay
|
curr -= decay
|
||||||
|
|
||||||
|
|
||||||
def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_oversize=False):
|
def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
|
||||||
"""Create minibatches of roughly a given number of words. If any examples
|
"""Create minibatches of roughly a given number of words. If any examples
|
||||||
are longer than the specified batch length, they will appear in a batch by
|
are longer than the specified batch length, they will appear in a batch by
|
||||||
themselves, or be discarded if discard_oversize=True."""
|
themselves, or be discarded if discard_oversize=True.
|
||||||
|
The argument 'docs' can be a list of strings, Doc's or Example's. """
|
||||||
|
from .gold import Example
|
||||||
|
|
||||||
if isinstance(size, int):
|
if isinstance(size, int):
|
||||||
size_ = itertools.repeat(size)
|
size_ = itertools.repeat(size)
|
||||||
elif isinstance(size, List):
|
elif isinstance(size, List):
|
||||||
|
@ -715,22 +710,27 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o
|
||||||
batch_size = 0
|
batch_size = 0
|
||||||
overflow_size = 0
|
overflow_size = 0
|
||||||
|
|
||||||
for example in examples:
|
for doc in docs:
|
||||||
n_words = count_words(example.doc)
|
if isinstance(doc, Example):
|
||||||
|
n_words = len(doc.reference)
|
||||||
|
elif isinstance(doc, str):
|
||||||
|
n_words = len(doc.split())
|
||||||
|
else:
|
||||||
|
n_words = len(doc)
|
||||||
# if the current example exceeds the maximum batch size, it is returned separately
|
# if the current example exceeds the maximum batch size, it is returned separately
|
||||||
# but only if discard_oversize=False.
|
# but only if discard_oversize=False.
|
||||||
if n_words > target_size + tol_size:
|
if n_words > target_size + tol_size:
|
||||||
if not discard_oversize:
|
if not discard_oversize:
|
||||||
yield [example]
|
yield [doc]
|
||||||
|
|
||||||
# add the example to the current batch if there's no overflow yet and it still fits
|
# add the example to the current batch if there's no overflow yet and it still fits
|
||||||
elif overflow_size == 0 and (batch_size + n_words) <= target_size:
|
elif overflow_size == 0 and (batch_size + n_words) <= target_size:
|
||||||
batch.append(example)
|
batch.append(doc)
|
||||||
batch_size += n_words
|
batch_size += n_words
|
||||||
|
|
||||||
# add the example to the overflow buffer if it fits in the tolerance margin
|
# add the example to the overflow buffer if it fits in the tolerance margin
|
||||||
elif (batch_size + overflow_size + n_words) <= (target_size + tol_size):
|
elif (batch_size + overflow_size + n_words) <= (target_size + tol_size):
|
||||||
overflow.append(example)
|
overflow.append(doc)
|
||||||
overflow_size += n_words
|
overflow_size += n_words
|
||||||
|
|
||||||
# yield the previous batch and start a new one. The new one gets the overflow examples.
|
# yield the previous batch and start a new one. The new one gets the overflow examples.
|
||||||
|
@ -745,12 +745,12 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o
|
||||||
|
|
||||||
# this example still fits
|
# this example still fits
|
||||||
if (batch_size + n_words) <= target_size:
|
if (batch_size + n_words) <= target_size:
|
||||||
batch.append(example)
|
batch.append(doc)
|
||||||
batch_size += n_words
|
batch_size += n_words
|
||||||
|
|
||||||
# this example fits in overflow
|
# this example fits in overflow
|
||||||
elif (batch_size + n_words) <= (target_size + tol_size):
|
elif (batch_size + n_words) <= (target_size + tol_size):
|
||||||
overflow.append(example)
|
overflow.append(doc)
|
||||||
overflow_size += n_words
|
overflow_size += n_words
|
||||||
|
|
||||||
# this example does not fit with the previous overflow: start another new batch
|
# this example does not fit with the previous overflow: start another new batch
|
||||||
|
@ -758,7 +758,7 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o
|
||||||
yield batch
|
yield batch
|
||||||
target_size = next(size_)
|
target_size = next(size_)
|
||||||
tol_size = target_size * tolerance
|
tol_size = target_size * tolerance
|
||||||
batch = [example]
|
batch = [doc]
|
||||||
batch_size = n_words
|
batch_size = n_words
|
||||||
|
|
||||||
# yield the final batch
|
# yield the final batch
|
||||||
|
|
Loading…
Reference in New Issue
Block a user