mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Merge branch 'develop' of https://github.com/explosion/spaCy into develop
This commit is contained in:
commit
803e41bc66
|
@ -32,6 +32,7 @@ numpy.random.seed(0)
|
||||||
n_sents=("number of sentences", "option", "ns", int),
|
n_sents=("number of sentences", "option", "ns", int),
|
||||||
use_gpu=("Use GPU", "option", "g", int),
|
use_gpu=("Use GPU", "option", "g", int),
|
||||||
vectors=("Model to load vectors from", "option", "v"),
|
vectors=("Model to load vectors from", "option", "v"),
|
||||||
|
vectors_limit=("Truncate to N vectors (requires -v)", "option", None, int),
|
||||||
no_tagger=("Don't train tagger", "flag", "T", bool),
|
no_tagger=("Don't train tagger", "flag", "T", bool),
|
||||||
no_parser=("Don't train parser", "flag", "P", bool),
|
no_parser=("Don't train parser", "flag", "P", bool),
|
||||||
no_entities=("Don't train NER", "flag", "N", bool),
|
no_entities=("Don't train NER", "flag", "N", bool),
|
||||||
|
@ -40,9 +41,9 @@ numpy.random.seed(0)
|
||||||
meta_path=("Optional path to meta.json. All relevant properties will be "
|
meta_path=("Optional path to meta.json. All relevant properties will be "
|
||||||
"overwritten.", "option", "m", Path))
|
"overwritten.", "option", "m", Path))
|
||||||
def train(cmd, lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
|
def train(cmd, lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
|
||||||
use_gpu=-1, vectors=None, no_tagger=False, no_parser=False,
|
use_gpu=-1, vectors=None, vectors_limit=None, no_tagger=False,
|
||||||
no_entities=False, gold_preproc=False, version="0.0.0",
|
no_parser=False, no_entities=False, gold_preproc=False,
|
||||||
meta_path=None):
|
version="0.0.0", meta_path=None):
|
||||||
"""
|
"""
|
||||||
Train a model. Expects data in spaCy's JSON format.
|
Train a model. Expects data in spaCy's JSON format.
|
||||||
"""
|
"""
|
||||||
|
@ -94,6 +95,8 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0,
|
||||||
nlp.meta.update(meta)
|
nlp.meta.update(meta)
|
||||||
if vectors:
|
if vectors:
|
||||||
util.load_model(vectors, vocab=nlp.vocab)
|
util.load_model(vectors, vocab=nlp.vocab)
|
||||||
|
if vectors_limit is not None:
|
||||||
|
nlp.vocab.prune_vectors(vectors_limit)
|
||||||
for name in pipeline:
|
for name in pipeline:
|
||||||
nlp.add_pipe(nlp.create_pipe(name), name=name)
|
nlp.add_pipe(nlp.create_pipe(name), name=name)
|
||||||
optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu)
|
optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu)
|
||||||
|
|
|
@ -5,6 +5,7 @@ import numpy
|
||||||
import dill
|
import dill
|
||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from thinc.neural.util import get_array_module
|
||||||
from .lexeme cimport EMPTY_LEXEME
|
from .lexeme cimport EMPTY_LEXEME
|
||||||
from .lexeme cimport Lexeme
|
from .lexeme cimport Lexeme
|
||||||
from .strings cimport hash_string
|
from .strings cimport hash_string
|
||||||
|
@ -247,6 +248,44 @@ cdef class Vocab:
|
||||||
width = self.vectors.data.shape[1]
|
width = self.vectors.data.shape[1]
|
||||||
self.vectors = Vectors(self.strings, width=width)
|
self.vectors = Vectors(self.strings, width=width)
|
||||||
|
|
||||||
|
def prune_vectors(self, nr_row, batch_size=1024):
|
||||||
|
"""Reduce the current vector table to `nr_row` unique entries. Words
|
||||||
|
mapped to the discarded vectors will be remapped to the closest vector
|
||||||
|
among those remaining.
|
||||||
|
|
||||||
|
For example, suppose the original table had vectors for the words:
|
||||||
|
['sat', 'cat', 'feline', 'reclined']. If we prune the vector table to,
|
||||||
|
two rows, we would discard the vectors for 'feline' and 'reclined'.
|
||||||
|
These words would then be remapped to the closest remaining vector
|
||||||
|
-- so "feline" would have the same vector as "cat", and "reclined"
|
||||||
|
would have the same vector as "sat".
|
||||||
|
|
||||||
|
The similarities are judged by cosine. The original vectors may
|
||||||
|
be large, so the cosines are calculated in minibatches, to reduce
|
||||||
|
memory usage.
|
||||||
|
"""
|
||||||
|
xp = get_array_module(self.vectors.data)
|
||||||
|
# Work in batches, to avoid memory problems.
|
||||||
|
keep = self.vectors.data[:nr_row]
|
||||||
|
toss = self.vectors.data[nr_row:]
|
||||||
|
# Normalize the vectors, so cosine similarity is just dot product.
|
||||||
|
# Note we can't modify the ones we're keeping in-place...
|
||||||
|
keep = keep / (xp.linalg.norm(keep)+1e-8)
|
||||||
|
keep = xp.ascontiguousarray(keep.T)
|
||||||
|
neighbours = xp.zeros((toss.shape[0],), dtype='i')
|
||||||
|
for i in range(0, toss.shape[0], batch_size):
|
||||||
|
batch = toss[i : i+batch_size]
|
||||||
|
batch /= xp.linalg.norm(batch)+1e-8
|
||||||
|
neighbours[i:i+batch_size] = xp.dot(batch, keep).argmax(axis=1)
|
||||||
|
for lex in self:
|
||||||
|
# If we're losing the vector for this word, map it to the nearest
|
||||||
|
# vector we're keeping.
|
||||||
|
if lex.rank >= nr_row:
|
||||||
|
lex.rank = neighbours[lex.rank-nr_row]
|
||||||
|
self.vectors.add(lex.orth, row=lex.rank)
|
||||||
|
# Make copy, to encourage the original table to be garbage collected.
|
||||||
|
self.vectors.data = xp.ascontiguousarray(self.vectors.data[:nr_row])
|
||||||
|
|
||||||
def get_vector(self, orth):
|
def get_vector(self, orth):
|
||||||
"""Retrieve a vector for a word in the vocabulary. Words can be looked
|
"""Retrieve a vector for a word in the vocabulary. Words can be looked
|
||||||
up by string or int ID. If no vectors data is loaded, ValueError is
|
up by string or int ID. If no vectors data is loaded, ValueError is
|
||||||
|
|
Loading…
Reference in New Issue
Block a user