mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	minibatch utiltiy can deal with strings, docs or examples
This commit is contained in:
		
							parent
							
								
									8b66c11ff2
								
							
						
					
					
						commit
						4ed399c848
					
				| 
						 | 
					@ -166,8 +166,7 @@ def pretrain(
 | 
				
			||||||
    skip_counter = 0
 | 
					    skip_counter = 0
 | 
				
			||||||
    loss_func = pretrain_config["loss_func"]
 | 
					    loss_func = pretrain_config["loss_func"]
 | 
				
			||||||
    for epoch in range(epoch_resume, pretrain_config["max_epochs"]):
 | 
					    for epoch in range(epoch_resume, pretrain_config["max_epochs"]):
 | 
				
			||||||
        examples = [Example(doc=text) for text in texts]
 | 
					        batches = util.minibatch_by_words(texts, size=pretrain_config["batch_size"])
 | 
				
			||||||
        batches = util.minibatch_by_words(examples, size=pretrain_config["batch_size"])
 | 
					 | 
				
			||||||
        for batch_id, batch in enumerate(batches):
 | 
					        for batch_id, batch in enumerate(batches):
 | 
				
			||||||
            docs, count = make_docs(
 | 
					            docs, count = make_docs(
 | 
				
			||||||
                nlp,
 | 
					                nlp,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -275,7 +275,7 @@ def _fix_legacy_dict_data(example_dict):
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _parse_ner_tags(biluo_or_offsets, vocab, words, spaces=None):
 | 
					def _parse_ner_tags(biluo_or_offsets, vocab, words, spaces):
 | 
				
			||||||
    if isinstance(biluo_or_offsets[0], (list, tuple)):
 | 
					    if isinstance(biluo_or_offsets[0], (list, tuple)):
 | 
				
			||||||
        # Convert to biluo if necessary
 | 
					        # Convert to biluo if necessary
 | 
				
			||||||
        # This is annoying but to convert the offsets we need a Doc
 | 
					        # This is annoying but to convert the offsets we need a Doc
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -677,7 +677,7 @@ class Language(object):
 | 
				
			||||||
        # Populate vocab
 | 
					        # Populate vocab
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            for example in get_examples():
 | 
					            for example in get_examples():
 | 
				
			||||||
                for word in example.token_annotation.words:
 | 
					                for word in [t.text for t in example.reference]:
 | 
				
			||||||
                    _ = self.vocab[word]  # noqa: F841
 | 
					                    _ = self.vocab[word]  # noqa: F841
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if cfg.get("device", -1) >= 0:
 | 
					        if cfg.get("device", -1) >= 0:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -5,7 +5,7 @@ from ..gold import Example
 | 
				
			||||||
from ..tokens import Doc
 | 
					from ..tokens import Doc
 | 
				
			||||||
from ..vocab import Vocab
 | 
					from ..vocab import Vocab
 | 
				
			||||||
from ..language import component
 | 
					from ..language import component
 | 
				
			||||||
from ..util import link_vectors_to_models, minibatch, eg2doc
 | 
					from ..util import link_vectors_to_models, minibatch
 | 
				
			||||||
from .defaults import default_tok2vec
 | 
					from .defaults import default_tok2vec
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -51,19 +51,15 @@ class Tok2Vec(Pipe):
 | 
				
			||||||
        self.set_annotations([doc], tokvecses)
 | 
					        self.set_annotations([doc], tokvecses)
 | 
				
			||||||
        return doc
 | 
					        return doc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def pipe(self, stream, batch_size=128, n_threads=-1, as_example=False):
 | 
					    def pipe(self, stream, batch_size=128, n_threads=-1):
 | 
				
			||||||
        """Process `Doc` objects as a stream.
 | 
					        """Process `Doc` objects as a stream.
 | 
				
			||||||
        stream (iterator): A sequence of `Doc` objects to process.
 | 
					        stream (iterator): A sequence of `Doc` objects to process.
 | 
				
			||||||
        batch_size (int): Number of `Doc` objects to group.
 | 
					        batch_size (int): Number of `Doc` objects to group.
 | 
				
			||||||
        n_threads (int): Number of threads.
 | 
					        n_threads (int): Number of threads.
 | 
				
			||||||
        YIELDS (iterator): A sequence of `Doc` objects, in order of input.
 | 
					        YIELDS (iterator): A sequence of `Doc` objects, in order of input.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        for batch in minibatch(stream, batch_size):
 | 
					        for docs in minibatch(stream, batch_size):
 | 
				
			||||||
            batch = list(batch)
 | 
					            batch = list(batch)
 | 
				
			||||||
            if as_example:
 | 
					 | 
				
			||||||
                docs = [eg2doc(doc) for doc in batch]
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                docs = batch
 | 
					 | 
				
			||||||
            tokvecses = self.predict(docs)
 | 
					            tokvecses = self.predict(docs)
 | 
				
			||||||
            self.set_annotations(docs, tokvecses)
 | 
					            self.set_annotations(docs, tokvecses)
 | 
				
			||||||
            yield from batch
 | 
					            yield from batch
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -430,7 +430,6 @@ def test_tuple_format_implicit():
 | 
				
			||||||
    _train(train_data)
 | 
					    _train(train_data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.xfail # TODO
 | 
					 | 
				
			||||||
def test_tuple_format_implicit_invalid():
 | 
					def test_tuple_format_implicit_invalid():
 | 
				
			||||||
    """Test that an error is thrown for an implicit invalid field"""
 | 
					    """Test that an error is thrown for an implicit invalid field"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -443,7 +442,7 @@ def test_tuple_format_implicit_invalid():
 | 
				
			||||||
        ("Google rebrands its business apps", {"entities": [(0, 6, "ORG")]}),
 | 
					        ("Google rebrands its business apps", {"entities": [(0, 6, "ORG")]}),
 | 
				
			||||||
    ]
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    with pytest.raises(TypeError):
 | 
					    with pytest.raises(KeyError):
 | 
				
			||||||
        _train(train_data)
 | 
					        _train(train_data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -25,15 +25,14 @@ from spacy.util import minibatch_by_words
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
def test_util_minibatch(doc_sizes, expected_batches):
 | 
					def test_util_minibatch(doc_sizes, expected_batches):
 | 
				
			||||||
    docs = [get_random_doc(doc_size) for doc_size in doc_sizes]
 | 
					    docs = [get_random_doc(doc_size) for doc_size in doc_sizes]
 | 
				
			||||||
    examples = [Example(doc=doc) for doc in docs]
 | 
					 | 
				
			||||||
    tol = 0.2
 | 
					    tol = 0.2
 | 
				
			||||||
    batch_size = 1000
 | 
					    batch_size = 1000
 | 
				
			||||||
    batches = list(minibatch_by_words(examples=examples, size=batch_size, tolerance=tol, discard_oversize=True))
 | 
					    batches = list(minibatch_by_words(docs, size=batch_size, tolerance=tol, discard_oversize=True))
 | 
				
			||||||
    assert [len(batch) for batch in batches] == expected_batches
 | 
					    assert [len(batch) for batch in batches] == expected_batches
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    max_size = batch_size + batch_size * tol
 | 
					    max_size = batch_size + batch_size * tol
 | 
				
			||||||
    for batch in batches:
 | 
					    for batch in batches:
 | 
				
			||||||
        assert sum([len(example.doc) for example in batch]) < max_size
 | 
					        assert sum([len(doc) for doc in batch]) < max_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.parametrize(
 | 
					@pytest.mark.parametrize(
 | 
				
			||||||
| 
						 | 
					@ -50,10 +49,9 @@ def test_util_minibatch(doc_sizes, expected_batches):
 | 
				
			||||||
def test_util_minibatch_oversize(doc_sizes, expected_batches):
 | 
					def test_util_minibatch_oversize(doc_sizes, expected_batches):
 | 
				
			||||||
    """ Test that oversized documents are returned in their own batch"""
 | 
					    """ Test that oversized documents are returned in their own batch"""
 | 
				
			||||||
    docs = [get_random_doc(doc_size) for doc_size in doc_sizes]
 | 
					    docs = [get_random_doc(doc_size) for doc_size in doc_sizes]
 | 
				
			||||||
    examples = [Example(doc=doc) for doc in docs]
 | 
					 | 
				
			||||||
    tol = 0.2
 | 
					    tol = 0.2
 | 
				
			||||||
    batch_size = 1000
 | 
					    batch_size = 1000
 | 
				
			||||||
    batches = list(minibatch_by_words(examples=examples, size=batch_size, tolerance=tol, discard_oversize=False))
 | 
					    batches = list(minibatch_by_words(docs, size=batch_size, tolerance=tol, discard_oversize=False))
 | 
				
			||||||
    assert [len(batch) for batch in batches] == expected_batches
 | 
					    assert [len(batch) for batch in batches] == expected_batches
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -471,14 +471,6 @@ def get_async(stream, numpy_array):
 | 
				
			||||||
        return array
 | 
					        return array
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def eg2doc(example):
 | 
					 | 
				
			||||||
    """Get a Doc object from an Example (or if it's a Doc, use it directly)"""
 | 
					 | 
				
			||||||
    # Put the import here to avoid circular import problems
 | 
					 | 
				
			||||||
    from .tokens.doc import Doc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return example if isinstance(example, Doc) else example.doc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def env_opt(name, default=None):
 | 
					def env_opt(name, default=None):
 | 
				
			||||||
    if type(default) is float:
 | 
					    if type(default) is float:
 | 
				
			||||||
        type_convert = float
 | 
					        type_convert = float
 | 
				
			||||||
| 
						 | 
					@ -697,10 +689,13 @@ def decaying(start, stop, decay):
 | 
				
			||||||
        curr -= decay
 | 
					        curr -= decay
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_oversize=False):
 | 
					def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
 | 
				
			||||||
    """Create minibatches of roughly a given number of words. If any examples
 | 
					    """Create minibatches of roughly a given number of words. If any examples
 | 
				
			||||||
    are longer than the specified batch length, they will appear in a batch by
 | 
					    are longer than the specified batch length, they will appear in a batch by
 | 
				
			||||||
    themselves, or be discarded if discard_oversize=True."""
 | 
					    themselves, or be discarded if discard_oversize=True.
 | 
				
			||||||
 | 
					    The argument 'docs' can be a list of strings, Doc's or Example's. """
 | 
				
			||||||
 | 
					    from .gold import Example
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if isinstance(size, int):
 | 
					    if isinstance(size, int):
 | 
				
			||||||
        size_ = itertools.repeat(size)
 | 
					        size_ = itertools.repeat(size)
 | 
				
			||||||
    elif isinstance(size, List):
 | 
					    elif isinstance(size, List):
 | 
				
			||||||
| 
						 | 
					@ -715,22 +710,27 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o
 | 
				
			||||||
    batch_size = 0
 | 
					    batch_size = 0
 | 
				
			||||||
    overflow_size = 0
 | 
					    overflow_size = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for example in examples:
 | 
					    for doc in docs:
 | 
				
			||||||
        n_words = count_words(example.doc)
 | 
					        if isinstance(doc, Example):
 | 
				
			||||||
 | 
					            n_words = len(doc.reference)
 | 
				
			||||||
 | 
					        elif isinstance(doc, str):
 | 
				
			||||||
 | 
					            n_words = len(doc.split())
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            n_words = len(doc)
 | 
				
			||||||
        # if the current example exceeds the maximum batch size, it is returned separately
 | 
					        # if the current example exceeds the maximum batch size, it is returned separately
 | 
				
			||||||
        # but only if discard_oversize=False.
 | 
					        # but only if discard_oversize=False.
 | 
				
			||||||
        if n_words > target_size + tol_size:
 | 
					        if n_words > target_size + tol_size:
 | 
				
			||||||
            if not discard_oversize:
 | 
					            if not discard_oversize:
 | 
				
			||||||
                yield [example]
 | 
					                yield [doc]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # add the example to the current batch if there's no overflow yet and it still fits
 | 
					        # add the example to the current batch if there's no overflow yet and it still fits
 | 
				
			||||||
        elif overflow_size == 0 and (batch_size + n_words) <= target_size:
 | 
					        elif overflow_size == 0 and (batch_size + n_words) <= target_size:
 | 
				
			||||||
            batch.append(example)
 | 
					            batch.append(doc)
 | 
				
			||||||
            batch_size += n_words
 | 
					            batch_size += n_words
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # add the example to the overflow buffer if it fits in the tolerance margin
 | 
					        # add the example to the overflow buffer if it fits in the tolerance margin
 | 
				
			||||||
        elif (batch_size + overflow_size + n_words) <= (target_size + tol_size):
 | 
					        elif (batch_size + overflow_size + n_words) <= (target_size + tol_size):
 | 
				
			||||||
            overflow.append(example)
 | 
					            overflow.append(doc)
 | 
				
			||||||
            overflow_size += n_words
 | 
					            overflow_size += n_words
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # yield the previous batch and start a new one. The new one gets the overflow examples.
 | 
					        # yield the previous batch and start a new one. The new one gets the overflow examples.
 | 
				
			||||||
| 
						 | 
					@ -745,12 +745,12 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # this example still fits
 | 
					            # this example still fits
 | 
				
			||||||
            if (batch_size + n_words) <= target_size:
 | 
					            if (batch_size + n_words) <= target_size:
 | 
				
			||||||
                batch.append(example)
 | 
					                batch.append(doc)
 | 
				
			||||||
                batch_size += n_words
 | 
					                batch_size += n_words
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # this example fits in overflow
 | 
					            # this example fits in overflow
 | 
				
			||||||
            elif (batch_size + n_words) <= (target_size + tol_size):
 | 
					            elif (batch_size + n_words) <= (target_size + tol_size):
 | 
				
			||||||
                overflow.append(example)
 | 
					                overflow.append(doc)
 | 
				
			||||||
                overflow_size += n_words
 | 
					                overflow_size += n_words
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # this example does not fit with the previous overflow: start another new batch
 | 
					            # this example does not fit with the previous overflow: start another new batch
 | 
				
			||||||
| 
						 | 
					@ -758,7 +758,7 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o
 | 
				
			||||||
                yield batch
 | 
					                yield batch
 | 
				
			||||||
                target_size = next(size_)
 | 
					                target_size = next(size_)
 | 
				
			||||||
                tol_size = target_size * tolerance
 | 
					                tol_size = target_size * tolerance
 | 
				
			||||||
                batch = [example]
 | 
					                batch = [doc]
 | 
				
			||||||
                batch_size = n_words
 | 
					                batch_size = n_words
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # yield the final batch
 | 
					    # yield the final batch
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user