mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Merge branch 'develop' of https://github.com/explosion/spaCy into develop
This commit is contained in:
		
						commit
						803e41bc66
					
				| 
						 | 
					@ -32,6 +32,7 @@ numpy.random.seed(0)
 | 
				
			||||||
    n_sents=("number of sentences", "option", "ns", int),
 | 
					    n_sents=("number of sentences", "option", "ns", int),
 | 
				
			||||||
    use_gpu=("Use GPU", "option", "g", int),
 | 
					    use_gpu=("Use GPU", "option", "g", int),
 | 
				
			||||||
    vectors=("Model to load vectors from", "option", "v"),
 | 
					    vectors=("Model to load vectors from", "option", "v"),
 | 
				
			||||||
 | 
					    vectors_limit=("Truncate to N vectors (requires -v)", "option", None, int),
 | 
				
			||||||
    no_tagger=("Don't train tagger", "flag", "T", bool),
 | 
					    no_tagger=("Don't train tagger", "flag", "T", bool),
 | 
				
			||||||
    no_parser=("Don't train parser", "flag", "P", bool),
 | 
					    no_parser=("Don't train parser", "flag", "P", bool),
 | 
				
			||||||
    no_entities=("Don't train NER", "flag", "N", bool),
 | 
					    no_entities=("Don't train NER", "flag", "N", bool),
 | 
				
			||||||
| 
						 | 
					@ -40,9 +41,9 @@ numpy.random.seed(0)
 | 
				
			||||||
    meta_path=("Optional path to meta.json. All relevant properties will be "
 | 
					    meta_path=("Optional path to meta.json. All relevant properties will be "
 | 
				
			||||||
               "overwritten.", "option", "m", Path))
 | 
					               "overwritten.", "option", "m", Path))
 | 
				
			||||||
def train(cmd, lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
 | 
					def train(cmd, lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
 | 
				
			||||||
          use_gpu=-1, vectors=None, no_tagger=False, no_parser=False,
 | 
					          use_gpu=-1, vectors=None, vectors_limit=None, no_tagger=False,
 | 
				
			||||||
          no_entities=False, gold_preproc=False, version="0.0.0",
 | 
					          no_parser=False, no_entities=False, gold_preproc=False,
 | 
				
			||||||
          meta_path=None):
 | 
					          version="0.0.0", meta_path=None):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Train a model. Expects data in spaCy's JSON format.
 | 
					    Train a model. Expects data in spaCy's JSON format.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
| 
						 | 
					@ -94,6 +95,8 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
 | 
				
			||||||
    nlp.meta.update(meta)
 | 
					    nlp.meta.update(meta)
 | 
				
			||||||
    if vectors:
 | 
					    if vectors:
 | 
				
			||||||
        util.load_model(vectors, vocab=nlp.vocab)
 | 
					        util.load_model(vectors, vocab=nlp.vocab)
 | 
				
			||||||
 | 
					        if vectors_limit is not None:
 | 
				
			||||||
 | 
					            nlp.vocab.prune_vectors(vectors_limit)
 | 
				
			||||||
    for name in pipeline:
 | 
					    for name in pipeline:
 | 
				
			||||||
        nlp.add_pipe(nlp.create_pipe(name), name=name)
 | 
					        nlp.add_pipe(nlp.create_pipe(name), name=name)
 | 
				
			||||||
    optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu)
 | 
					    optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -5,6 +5,7 @@ import numpy
 | 
				
			||||||
import dill
 | 
					import dill
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from collections import OrderedDict
 | 
					from collections import OrderedDict
 | 
				
			||||||
 | 
					from thinc.neural.util import get_array_module
 | 
				
			||||||
from .lexeme cimport EMPTY_LEXEME
 | 
					from .lexeme cimport EMPTY_LEXEME
 | 
				
			||||||
from .lexeme cimport Lexeme
 | 
					from .lexeme cimport Lexeme
 | 
				
			||||||
from .strings cimport hash_string
 | 
					from .strings cimport hash_string
 | 
				
			||||||
| 
						 | 
					@ -247,6 +248,44 @@ cdef class Vocab:
 | 
				
			||||||
            width = self.vectors.data.shape[1]
 | 
					            width = self.vectors.data.shape[1]
 | 
				
			||||||
        self.vectors = Vectors(self.strings, width=width)
 | 
					        self.vectors = Vectors(self.strings, width=width)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def prune_vectors(self, nr_row, batch_size=1024):
 | 
				
			||||||
 | 
					        """Reduce the current vector table to `nr_row` unique entries. Words
 | 
				
			||||||
 | 
					        mapped to the discarded vectors will be remapped to the closest vector
 | 
				
			||||||
 | 
					        among those remaining.
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        For example, suppose the original table had vectors for the words:
 | 
				
			||||||
 | 
					        ['sat', 'cat', 'feline', 'reclined']. If we prune the vector table to,
 | 
				
			||||||
 | 
					        two rows, we would discard the vectors for 'feline' and 'reclined'.
 | 
				
			||||||
 | 
					        These words would then be remapped to the closest remaining vector
 | 
				
			||||||
 | 
					        -- so "feline" would have the same vector as "cat", and "reclined"
 | 
				
			||||||
 | 
					        would have the same vector as "sat".
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        The similarities are judged by cosine. The original vectors may
 | 
				
			||||||
 | 
					        be large, so the cosines are calculated in minibatches, to reduce
 | 
				
			||||||
 | 
					        memory usage.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        xp = get_array_module(self.vectors.data)
 | 
				
			||||||
 | 
					        # Work in batches, to avoid memory problems.
 | 
				
			||||||
 | 
					        keep = self.vectors.data[:nr_row]
 | 
				
			||||||
 | 
					        toss = self.vectors.data[nr_row:]
 | 
				
			||||||
 | 
					        # Normalize the vectors, so cosine similarity is just dot product.
 | 
				
			||||||
 | 
					        # Note we can't modify the ones we're keeping in-place...
 | 
				
			||||||
 | 
					        keep = keep / (xp.linalg.norm(keep)+1e-8)
 | 
				
			||||||
 | 
					        keep = xp.ascontiguousarray(keep.T)
 | 
				
			||||||
 | 
					        neighbours = xp.zeros((toss.shape[0],), dtype='i')
 | 
				
			||||||
 | 
					        for i in range(0, toss.shape[0], batch_size):
 | 
				
			||||||
 | 
					            batch = toss[i : i+batch_size]
 | 
				
			||||||
 | 
					            batch /= xp.linalg.norm(batch)+1e-8
 | 
				
			||||||
 | 
					            neighbours[i:i+batch_size] = xp.dot(batch, keep).argmax(axis=1)
 | 
				
			||||||
 | 
					        for lex in self:
 | 
				
			||||||
 | 
					            # If we're losing the vector for this word, map it to the nearest
 | 
				
			||||||
 | 
					            # vector we're keeping.
 | 
				
			||||||
 | 
					            if lex.rank >= nr_row:
 | 
				
			||||||
 | 
					                lex.rank = neighbours[lex.rank-nr_row]
 | 
				
			||||||
 | 
					                self.vectors.add(lex.orth, row=lex.rank)
 | 
				
			||||||
 | 
					        # Make copy, to encourage the original table to be garbage collected.
 | 
				
			||||||
 | 
					        self.vectors.data = xp.ascontiguousarray(self.vectors.data[:nr_row])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_vector(self, orth):
 | 
					    def get_vector(self, orth):
 | 
				
			||||||
        """Retrieve a vector for a word in the vocabulary. Words can be looked
 | 
					        """Retrieve a vector for a word in the vocabulary. Words can be looked
 | 
				
			||||||
        up by string or int ID. If no vectors data is loaded, ValueError is
 | 
					        up by string or int ID. If no vectors data is loaded, ValueError is
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user