mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Expose batch size and length caps on CLI for pretrain (#3417)
Add and document CLI options for batch size, max doc length, min doc length for `spacy pretrain`. Also improve CLI output. Closes #3216 ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information.
This commit is contained in:
		
							parent
							
								
									58d562d9b0
								
							
						
					
					
						commit
						62afa64a8d
					
				| 
						 | 
				
			
			@ -29,6 +29,9 @@ from .. import util
 | 
			
		|||
    embed_rows=("Embedding rows", "option", "er", int),
 | 
			
		||||
    use_vectors=("Whether to use the static vectors as input features", "flag", "uv"),
 | 
			
		||||
    dropout=("Dropout", "option", "d", float),
 | 
			
		||||
    batch_size=("Number of words per training batch", "option", "bs", int),
 | 
			
		||||
    max_length=("Max words per example.", "option", "xw", int),
 | 
			
		||||
    min_length=("Min words per example.", "option", "nw", int),
 | 
			
		||||
    seed=("Seed for random number generators", "option", "s", float),
 | 
			
		||||
    nr_iter=("Number of iterations to pretrain", "option", "i", int),
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			@ -42,6 +45,9 @@ def pretrain(
 | 
			
		|||
    use_vectors=False,
 | 
			
		||||
    dropout=0.2,
 | 
			
		||||
    nr_iter=1000,
 | 
			
		||||
    batch_size=3000,
 | 
			
		||||
    max_length=500,
 | 
			
		||||
    min_length=5,
 | 
			
		||||
    seed=0,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
| 
						 | 
				
			
			@ -109,9 +115,14 @@ def pretrain(
 | 
			
		|||
    msg.row(("#", "# Words", "Total Loss", "Loss", "w/s"), **row_settings)
 | 
			
		||||
    for epoch in range(nr_iter):
 | 
			
		||||
        for batch in util.minibatch_by_words(
 | 
			
		||||
            ((text, None) for text in texts), size=3000
 | 
			
		||||
            ((text, None) for text in texts), size=batch_size
 | 
			
		||||
        ):
 | 
			
		||||
            docs = make_docs(nlp, [text for (text, _) in batch])
 | 
			
		||||
            docs = make_docs(
 | 
			
		||||
                nlp,
 | 
			
		||||
                [text for (text, _) in batch],
 | 
			
		||||
                max_length=max_length,
 | 
			
		||||
                min_length=min_length,
 | 
			
		||||
            )
 | 
			
		||||
            loss = make_update(model, docs, optimizer, drop=dropout)
 | 
			
		||||
            progress = tracker.update(epoch, loss, docs)
 | 
			
		||||
            if progress:
 | 
			
		||||
| 
						 | 
				
			
			@ -152,7 +163,7 @@ def make_update(model, docs, optimizer, drop=0.0, objective="L2"):
 | 
			
		|||
    return float(loss)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_docs(nlp, batch, min_length=1, max_length=500):
 | 
			
		||||
def make_docs(nlp, batch, min_length, max_length):
 | 
			
		||||
    docs = []
 | 
			
		||||
    for record in batch:
 | 
			
		||||
        text = record["text"]
 | 
			
		||||
| 
						 | 
				
			
			@ -241,11 +252,23 @@ class ProgressTracker(object):
 | 
			
		|||
            status = (
 | 
			
		||||
                epoch,
 | 
			
		||||
                self.nr_word,
 | 
			
		||||
                "%.8f" % self.loss,
 | 
			
		||||
                "%.8f" % loss_per_word,
 | 
			
		||||
                _smart_round(self.loss, width=10),
 | 
			
		||||
                _smart_round(loss_per_word, width=6),
 | 
			
		||||
                int(wps),
 | 
			
		||||
            )
 | 
			
		||||
            self.prev_loss = float(self.loss)
 | 
			
		||||
            return status
 | 
			
		||||
        else:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _smart_round(figure, width=10, max_decimal=4):
 | 
			
		||||
    """Round large numbers as integers, smaller numbers as decimals."""
 | 
			
		||||
    n_digits = len(str(int(figure)))
 | 
			
		||||
    n_decimal = width - (n_digits + 1)
 | 
			
		||||
    if n_decimal <= 1:
 | 
			
		||||
        return str(int(figure))
 | 
			
		||||
    else:
 | 
			
		||||
        n_decimal = min(n_decimal, max_decimal)
 | 
			
		||||
        format_str = "%." + str(n_decimal) + "f"
 | 
			
		||||
        return format_str % figure
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -296,6 +296,9 @@ $ python -m spacy pretrain [texts_loc] [vectors_model] [output_dir] [--width]
 | 
			
		|||
| `--depth`, `-cd`       | option     | Depth of CNN layers.                                                                                                              |
 | 
			
		||||
| `--embed-rows`, `-er`  | option     | Number of embedding rows.                                                                                                         |
 | 
			
		||||
| `--dropout`, `-d`      | option     | Dropout rate.                                                                                                                     |
 | 
			
		||||
| `--batch-size`, `-bs`  | option     | Number of words per training batch.                                                                                               |
 | 
			
		||||
| `--max-length`, `-xw`  | option     | Maximum words per example. Longer examples are discarded.                                                                         |
 | 
			
		||||
| `--min-length`, `-nw`  | option     | Minimum words per example. Shorter examples are discarded.                                                                        |
 | 
			
		||||
| `--seed`, `-s`         | option     | Seed for random number generators.                                                                                                |
 | 
			
		||||
| `--n-iter`, `-i`       | option     | Number of iterations to pretrain.                                                                                                 |
 | 
			
		||||
| `--use-vectors`, `-uv` | flag       | Whether to use the static vectors as input features.                                                                              |
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user