diff --git a/spacy/cli/pretrain.py b/spacy/cli/pretrain.py index 4f4707b52..8a7c73e82 100644 --- a/spacy/cli/pretrain.py +++ b/spacy/cli/pretrain.py @@ -166,8 +166,7 @@ def pretrain( skip_counter = 0 loss_func = pretrain_config["loss_func"] for epoch in range(epoch_resume, pretrain_config["max_epochs"]): - examples = [Example(doc=text) for text in texts] - batches = util.minibatch_by_words(examples, size=pretrain_config["batch_size"]) + batches = util.minibatch_by_words(texts, size=pretrain_config["batch_size"]) for batch_id, batch in enumerate(batches): docs, count = make_docs( nlp, diff --git a/spacy/gold/example.pyx b/spacy/gold/example.pyx index 6119567dc..1538be923 100644 --- a/spacy/gold/example.pyx +++ b/spacy/gold/example.pyx @@ -275,7 +275,7 @@ def _fix_legacy_dict_data(example_dict): } -def _parse_ner_tags(biluo_or_offsets, vocab, words, spaces=None): +def _parse_ner_tags(biluo_or_offsets, vocab, words, spaces): if isinstance(biluo_or_offsets[0], (list, tuple)): # Convert to biluo if necessary # This is annoying but to convert the offsets we need a Doc diff --git a/spacy/language.py b/spacy/language.py index 510c64d5b..d632bdf02 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -677,7 +677,7 @@ class Language(object): # Populate vocab else: for example in get_examples(): - for word in example.token_annotation.words: + for word in [t.text for t in example.reference]: _ = self.vocab[word] # noqa: F841 if cfg.get("device", -1) >= 0: diff --git a/spacy/pipeline/tok2vec.py b/spacy/pipeline/tok2vec.py index 69582908a..afd9b554f 100644 --- a/spacy/pipeline/tok2vec.py +++ b/spacy/pipeline/tok2vec.py @@ -5,7 +5,7 @@ from ..gold import Example from ..tokens import Doc from ..vocab import Vocab from ..language import component -from ..util import link_vectors_to_models, minibatch, eg2doc +from ..util import link_vectors_to_models, minibatch from .defaults import default_tok2vec @@ -51,19 +51,15 @@ class Tok2Vec(Pipe): self.set_annotations([doc], tokvecses) return doc - def pipe(self, stream, batch_size=128, n_threads=-1, as_example=False): + def pipe(self, stream, batch_size=128, n_threads=-1): """Process `Doc` objects as a stream. stream (iterator): A sequence of `Doc` objects to process. batch_size (int): Number of `Doc` objects to group. n_threads (int): Number of threads. YIELDS (iterator): A sequence of `Doc` objects, in order of input. """ - for batch in minibatch(stream, batch_size): + for docs in minibatch(stream, batch_size): batch = list(batch) - if as_example: - docs = [eg2doc(doc) for doc in batch] - else: - docs = batch tokvecses = self.predict(docs) self.set_annotations(docs, tokvecses) yield from batch diff --git a/spacy/tests/test_gold.py b/spacy/tests/test_gold.py index 83489799c..886f995c8 100644 --- a/spacy/tests/test_gold.py +++ b/spacy/tests/test_gold.py @@ -430,7 +430,6 @@ def test_tuple_format_implicit(): _train(train_data) -@pytest.mark.xfail # TODO def test_tuple_format_implicit_invalid(): """Test that an error is thrown for an implicit invalid field""" @@ -443,7 +442,7 @@ def test_tuple_format_implicit_invalid(): ("Google rebrands its business apps", {"entities": [(0, 6, "ORG")]}), ] - with pytest.raises(TypeError): + with pytest.raises(KeyError): _train(train_data) diff --git a/spacy/tests/test_util.py b/spacy/tests/test_util.py index 1410755db..d396dc74d 100644 --- a/spacy/tests/test_util.py +++ b/spacy/tests/test_util.py @@ -25,15 +25,14 @@ from spacy.util import minibatch_by_words ) def test_util_minibatch(doc_sizes, expected_batches): docs = [get_random_doc(doc_size) for doc_size in doc_sizes] - examples = [Example(doc=doc) for doc in docs] tol = 0.2 batch_size = 1000 - batches = list(minibatch_by_words(examples=examples, size=batch_size, tolerance=tol, discard_oversize=True)) + batches = list(minibatch_by_words(docs, size=batch_size, tolerance=tol, discard_oversize=True)) assert [len(batch) for batch in batches] == expected_batches max_size = batch_size + batch_size * tol for batch in batches: - assert sum([len(example.doc) for example in batch]) < max_size + assert sum([len(doc) for doc in batch]) < max_size @pytest.mark.parametrize( @@ -50,10 +49,9 @@ def test_util_minibatch(doc_sizes, expected_batches): def test_util_minibatch_oversize(doc_sizes, expected_batches): """ Test that oversized documents are returned in their own batch""" docs = [get_random_doc(doc_size) for doc_size in doc_sizes] - examples = [Example(doc=doc) for doc in docs] tol = 0.2 batch_size = 1000 - batches = list(minibatch_by_words(examples=examples, size=batch_size, tolerance=tol, discard_oversize=False)) + batches = list(minibatch_by_words(docs, size=batch_size, tolerance=tol, discard_oversize=False)) assert [len(batch) for batch in batches] == expected_batches diff --git a/spacy/util.py b/spacy/util.py index e9a36da71..d85940a04 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -471,14 +471,6 @@ def get_async(stream, numpy_array): return array -def eg2doc(example): - """Get a Doc object from an Example (or if it's a Doc, use it directly)""" - # Put the import here to avoid circular import problems - from .tokens.doc import Doc - - return example if isinstance(example, Doc) else example.doc - - def env_opt(name, default=None): if type(default) is float: type_convert = float @@ -697,10 +689,13 @@ def decaying(start, stop, decay): curr -= decay -def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_oversize=False): +def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False): """Create minibatches of roughly a given number of words. If any examples are longer than the specified batch length, they will appear in a batch by - themselves, or be discarded if discard_oversize=True.""" + themselves, or be discarded if discard_oversize=True. + The argument 'docs' can be a list of strings, Doc's or Example's. """ + from .gold import Example + if isinstance(size, int): size_ = itertools.repeat(size) elif isinstance(size, List): @@ -715,22 +710,27 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o batch_size = 0 overflow_size = 0 - for example in examples: - n_words = count_words(example.doc) + for doc in docs: + if isinstance(doc, Example): + n_words = len(doc.reference) + elif isinstance(doc, str): + n_words = len(doc.split()) + else: + n_words = len(doc) # if the current example exceeds the maximum batch size, it is returned separately # but only if discard_oversize=False. if n_words > target_size + tol_size: if not discard_oversize: - yield [example] + yield [doc] # add the example to the current batch if there's no overflow yet and it still fits elif overflow_size == 0 and (batch_size + n_words) <= target_size: - batch.append(example) + batch.append(doc) batch_size += n_words # add the example to the overflow buffer if it fits in the tolerance margin elif (batch_size + overflow_size + n_words) <= (target_size + tol_size): - overflow.append(example) + overflow.append(doc) overflow_size += n_words # yield the previous batch and start a new one. The new one gets the overflow examples. @@ -745,12 +745,12 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o # this example still fits if (batch_size + n_words) <= target_size: - batch.append(example) + batch.append(doc) batch_size += n_words # this example fits in overflow elif (batch_size + n_words) <= (target_size + tol_size): - overflow.append(example) + overflow.append(doc) overflow_size += n_words # this example does not fit with the previous overflow: start another new batch @@ -758,7 +758,7 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o yield batch target_size = next(size_) tol_size = target_size * tolerance - batch = [example] + batch = [doc] batch_size = n_words # yield the final batch