mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-11 04:08:09 +03:00
Merge pull request #1497 from explosion/feature/improve-optimizer-handling
💫 Improve optimizer handling
This commit is contained in:
commit
6fdffd7246
15
spacy/_ml.py
15
spacy/_ml.py
|
@ -15,12 +15,12 @@ from thinc.linear.linear import LinearModel
|
|||
from thinc.neural.ops import NumpyOps, CupyOps
|
||||
from thinc.neural.util import get_array_module, copy_array
|
||||
from thinc.neural._lsuv import svd_orthonormal
|
||||
from thinc.neural.optimizers import Adam
|
||||
|
||||
from thinc import describe
|
||||
from thinc.describe import Dimension, Synapses, Biases, Gradient
|
||||
from thinc.neural._classes.affine import _set_dimensions_if_needed
|
||||
import thinc.extra.load_nlp
|
||||
from thinc.neural._lsuv import svd_orthonormal
|
||||
|
||||
from .attrs import ID, ORTH, LOWER, NORM, PREFIX, SUFFIX, SHAPE
|
||||
from . import util
|
||||
|
@ -39,6 +39,19 @@ def cosine(vec1, vec2):
|
|||
return vec1.dot(vec2) / (norm1 * norm2)
|
||||
|
||||
|
||||
def create_default_optimizer(ops, **cfg):
|
||||
learn_rate = util.env_opt('learn_rate', 0.001)
|
||||
beta1 = util.env_opt('optimizer_B1', 0.9)
|
||||
beta2 = util.env_opt('optimizer_B2', 0.999)
|
||||
eps = util.env_opt('optimizer_eps', 1e-08)
|
||||
L2 = util.env_opt('L2_penalty', 1e-6)
|
||||
max_grad_norm = util.env_opt('grad_norm_clip', 1.)
|
||||
optimizer = Adam(ops, learn_rate, L2=L2, beta1=beta1,
|
||||
beta2=beta2, eps=eps)
|
||||
optimizer.max_grad_norm = max_grad_norm
|
||||
optimizer.device = ops.device
|
||||
return optimizer
|
||||
|
||||
@layerize
|
||||
def _flatten_add_lengths(seqs, pad=0, drop=0.):
|
||||
ops = Model.ops
|
||||
|
|
|
@ -19,7 +19,7 @@ from .pipeline import DependencyParser, Tensorizer, Tagger, EntityRecognizer
|
|||
from .pipeline import SimilarityHook, TextCategorizer, SentenceSegmenter
|
||||
from .compat import json_dumps, izip
|
||||
from .scorer import Scorer
|
||||
from ._ml import link_vectors_to_models
|
||||
from ._ml import link_vectors_to_models, create_default_optimizer
|
||||
from .attrs import IS_STOP
|
||||
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES
|
||||
from .lang.punctuation import TOKENIZER_INFIXES
|
||||
|
@ -407,27 +407,7 @@ class Language(object):
|
|||
for doc, gold in docs_golds:
|
||||
yield doc, gold
|
||||
|
||||
def resume_training(self, **cfg):
|
||||
if cfg.get('device', -1) >= 0:
|
||||
device = util.use_gpu(cfg['device'])
|
||||
if self.vocab.vectors.data.shape[1] >= 1:
|
||||
self.vocab.vectors.data = Model.ops.asarray(
|
||||
self.vocab.vectors.data)
|
||||
else:
|
||||
device = None
|
||||
learn_rate = util.env_opt('learn_rate', 0.001)
|
||||
beta1 = util.env_opt('optimizer_B1', 0.9)
|
||||
beta2 = util.env_opt('optimizer_B2', 0.999)
|
||||
eps = util.env_opt('optimizer_eps', 1e-08)
|
||||
L2 = util.env_opt('L2_penalty', 1e-6)
|
||||
max_grad_norm = util.env_opt('grad_norm_clip', 1.)
|
||||
self._optimizer = Adam(Model.ops, learn_rate, L2=L2, beta1=beta1,
|
||||
beta2=beta2, eps=eps)
|
||||
self._optimizer.max_grad_norm = max_grad_norm
|
||||
self._optimizer.device = device
|
||||
return self._optimizer
|
||||
|
||||
def begin_training(self, get_gold_tuples=None, **cfg):
|
||||
def begin_training(self, get_gold_tuples=None, sgd=None, **cfg):
|
||||
"""Allocate models, pre-process training data and acquire a trainer and
|
||||
optimizer. Used as a contextmanager.
|
||||
|
||||
|
@ -452,21 +432,14 @@ class Language(object):
|
|||
else:
|
||||
device = None
|
||||
link_vectors_to_models(self.vocab)
|
||||
if sgd is None:
|
||||
sgd = create_default_optimizer(Model.ops)
|
||||
self._optimizer = sgd
|
||||
for name, proc in self.pipeline:
|
||||
if hasattr(proc, 'begin_training'):
|
||||
context = proc.begin_training(get_gold_tuples(),
|
||||
pipeline=self.pipeline)
|
||||
contexts.append(context)
|
||||
learn_rate = util.env_opt('learn_rate', 0.001)
|
||||
beta1 = util.env_opt('optimizer_B1', 0.9)
|
||||
beta2 = util.env_opt('optimizer_B2', 0.999)
|
||||
eps = util.env_opt('optimizer_eps', 1e-08)
|
||||
L2 = util.env_opt('L2_penalty', 1e-6)
|
||||
max_grad_norm = util.env_opt('grad_norm_clip', 1.)
|
||||
self._optimizer = Adam(Model.ops, learn_rate, L2=L2, beta1=beta1,
|
||||
beta2=beta2, eps=eps)
|
||||
self._optimizer.max_grad_norm = max_grad_norm
|
||||
self._optimizer.device = device
|
||||
proc.begin_training(get_gold_tuples(),
|
||||
pipeline=self.pipeline,
|
||||
sgd=self._optimizer)
|
||||
return self._optimizer
|
||||
|
||||
def evaluate(self, docs_golds, verbose=False):
|
||||
|
|
|
@ -30,6 +30,7 @@ from .attrs import POS
|
|||
from .parts_of_speech import X
|
||||
from ._ml import Tok2Vec, build_text_classifier, build_tagger_model
|
||||
from ._ml import link_vectors_to_models, zero_init, flatten
|
||||
from ._ml import create_default_optimizer
|
||||
from . import util
|
||||
|
||||
|
||||
|
@ -138,13 +139,20 @@ class Pipe(object):
|
|||
problem.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def create_optimizer(self):
|
||||
return create_default_optimizer(self.model.ops,
|
||||
**self.cfg.get('optimizer', {}))
|
||||
|
||||
def begin_training(self, gold_tuples=tuple(), pipeline=None):
|
||||
def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None):
|
||||
"""Initialize the pipe for training, using data exampes if available.
|
||||
If no model has been initialized yet, the model is added."""
|
||||
if self.model is True:
|
||||
self.model = self.Model(**self.cfg)
|
||||
link_vectors_to_models(self.vocab)
|
||||
if sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
return sgd
|
||||
|
||||
def use_params(self, params):
|
||||
"""Modify the pipe's model, to use the given parameter values."""
|
||||
|
@ -336,8 +344,8 @@ class Tensorizer(Pipe):
|
|||
loss = (d_scores**2).sum()
|
||||
return loss, d_scores
|
||||
|
||||
def begin_training(self, gold_tuples=tuple(), pipeline=None):
|
||||
"""Allocate models, pre-process training data and acquire a trainer and
|
||||
def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None):
|
||||
"""Allocate models, pre-process training data and acquire an
|
||||
optimizer.
|
||||
|
||||
gold_tuples (iterable): Gold-standard training data.
|
||||
|
@ -349,9 +357,11 @@ class Tensorizer(Pipe):
|
|||
if self.model is True:
|
||||
self.cfg['input_size'] = 384
|
||||
self.cfg['output_size'] = 300
|
||||
#self.cfg['pretrained_dims'] = self.vocab.vectors_length
|
||||
self.model = self.Model(**self.cfg)
|
||||
link_vectors_to_models(self.vocab)
|
||||
if sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
return sgd
|
||||
|
||||
|
||||
class Tagger(Pipe):
|
||||
|
@ -457,7 +467,7 @@ class Tagger(Pipe):
|
|||
d_scores = self.model.ops.unflatten(d_scores, [len(d) for d in docs])
|
||||
return float(loss), d_scores
|
||||
|
||||
def begin_training(self, gold_tuples=tuple(), pipeline=None):
|
||||
def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None):
|
||||
orig_tag_map = dict(self.vocab.morphology.tag_map)
|
||||
new_tag_map = {}
|
||||
for raw_text, annots_brackets in gold_tuples:
|
||||
|
@ -477,6 +487,9 @@ class Tagger(Pipe):
|
|||
self.cfg['pretrained_dims'] = self.vocab.vectors.data.shape[1]
|
||||
self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg)
|
||||
link_vectors_to_models(self.vocab)
|
||||
if sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
return sgd
|
||||
|
||||
@classmethod
|
||||
def Model(cls, n_tags, **cfg):
|
||||
|
@ -627,7 +640,8 @@ class MultitaskObjective(Tagger):
|
|||
def set_annotations(self, docs, dep_ids, tensors=None):
|
||||
pass
|
||||
|
||||
def begin_training(self, gold_tuples=tuple(), pipeline=None, tok2vec=None):
|
||||
def begin_training(self, gold_tuples=tuple(), pipeline=None, tok2vec=None,
|
||||
sgd=None):
|
||||
gold_tuples = nonproj.preprocess_training_data(gold_tuples)
|
||||
for raw_text, annots_brackets in gold_tuples:
|
||||
for annots, brackets in annots_brackets:
|
||||
|
@ -643,6 +657,9 @@ class MultitaskObjective(Tagger):
|
|||
Softmax(len(self.labels), token_vector_width)
|
||||
)
|
||||
link_vectors_to_models(self.vocab)
|
||||
if sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
return sgd
|
||||
|
||||
@classmethod
|
||||
def Model(cls, n_tags, tok2vec=None, **cfg):
|
||||
|
@ -739,7 +756,7 @@ class SimilarityHook(Pipe):
|
|||
def update(self, doc1_doc2, golds, sgd=None, drop=0.):
|
||||
sims, bp_sims = self.model.begin_update(doc1_doc2, drop=drop)
|
||||
|
||||
def begin_training(self, _=tuple(), pipeline=None):
|
||||
def begin_training(self, _=tuple(), pipeline=None, sgd=None):
|
||||
"""Allocate model, using width from tensorizer in pipeline.
|
||||
|
||||
gold_tuples (iterable): Gold-standard training data.
|
||||
|
@ -748,6 +765,9 @@ class SimilarityHook(Pipe):
|
|||
if self.model is True:
|
||||
self.model = self.Model(pipeline[0].model.nO)
|
||||
link_vectors_to_models(self.vocab)
|
||||
if sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
return sgd
|
||||
|
||||
|
||||
class TextCategorizer(Pipe):
|
||||
|
@ -831,7 +851,7 @@ class TextCategorizer(Pipe):
|
|||
self.labels.append(label)
|
||||
return 1
|
||||
|
||||
def begin_training(self, gold_tuples=tuple(), pipeline=None):
|
||||
def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None):
|
||||
if pipeline and getattr(pipeline[0], 'name', None) == 'tensorizer':
|
||||
token_vector_width = pipeline[0].model.nO
|
||||
else:
|
||||
|
@ -841,6 +861,9 @@ class TextCategorizer(Pipe):
|
|||
self.model = self.Model(len(self.labels), token_vector_width,
|
||||
**self.cfg)
|
||||
link_vectors_to_models(self.vocab)
|
||||
if sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
return sgd
|
||||
|
||||
|
||||
cdef class DependencyParser(Parser):
|
||||
|
@ -851,12 +874,12 @@ cdef class DependencyParser(Parser):
|
|||
def postprocesses(self):
|
||||
return [nonproj.deprojectivize]
|
||||
|
||||
def init_multitask_objectives(self, gold_tuples, pipeline, **cfg):
|
||||
def init_multitask_objectives(self, gold_tuples, pipeline, sgd=None, **cfg):
|
||||
for target in []:
|
||||
labeller = MultitaskObjective(self.vocab, target=target)
|
||||
tok2vec = self.model[0]
|
||||
labeller.begin_training(gold_tuples, pipeline=pipeline,
|
||||
tok2vec=tok2vec)
|
||||
tok2vec=tok2vec, sgd=sgd)
|
||||
pipeline.append(labeller)
|
||||
self._multitasks.append(labeller)
|
||||
|
||||
|
@ -871,7 +894,7 @@ cdef class EntityRecognizer(Parser):
|
|||
|
||||
nr_feature = 6
|
||||
|
||||
def init_multitask_objectives(self, gold_tuples, pipeline, **cfg):
|
||||
def init_multitask_objectives(self, gold_tuples, pipeline, sgd=None, **cfg):
|
||||
for target in []:
|
||||
labeller = MultitaskObjective(self.vocab, target=target)
|
||||
tok2vec = self.model[0]
|
||||
|
|
|
@ -30,7 +30,7 @@ from thinc.neural.util import get_array_module
|
|||
from thinc.linalg cimport Vec, VecVec
|
||||
|
||||
from .._ml import zero_init, PrecomputableAffine, Tok2Vec, flatten
|
||||
from .._ml import link_vectors_to_models
|
||||
from .._ml import link_vectors_to_models, create_default_optimizer
|
||||
from ..compat import json_dumps, copy_array
|
||||
from ..tokens.doc cimport Doc
|
||||
from ..gold cimport GoldParse
|
||||
|
@ -273,6 +273,10 @@ cdef class Parser:
|
|||
}
|
||||
return (tok2vec, lower, upper), cfg
|
||||
|
||||
def create_optimizer(self):
|
||||
return create_default_optimizer(self.model[0].ops,
|
||||
**self.cfg.get('optimizer', {}))
|
||||
|
||||
def __init__(self, Vocab vocab, moves=True, model=True, **cfg):
|
||||
"""Create a Parser.
|
||||
|
||||
|
@ -793,7 +797,7 @@ cdef class Parser:
|
|||
copy_array(larger.b[:smaller.nO], smaller.b)
|
||||
self.model[-1]._layers[-1] = larger
|
||||
|
||||
def begin_training(self, gold_tuples, pipeline=None, **cfg):
|
||||
def begin_training(self, gold_tuples, pipeline=None, sgd=None, **cfg):
|
||||
if 'model' in cfg:
|
||||
self.model = cfg['model']
|
||||
gold_tuples = nonproj.preprocess_training_data(gold_tuples,
|
||||
|
@ -805,9 +809,14 @@ cdef class Parser:
|
|||
if self.model is True:
|
||||
cfg['pretrained_dims'] = self.vocab.vectors_length
|
||||
self.model, cfg = self.Model(self.moves.n_moves, **cfg)
|
||||
self.init_multitask_objectives(gold_tuples, pipeline, **cfg)
|
||||
if sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
self.init_multitask_objectives(gold_tuples, pipeline, sgd=sgd, **cfg)
|
||||
link_vectors_to_models(self.vocab)
|
||||
self.cfg.update(cfg)
|
||||
elif sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
return sgd
|
||||
|
||||
def init_multitask_objectives(self, gold_tuples, pipeline, **cfg):
|
||||
'''Setup models for secondary objectives, to benefit from multi-task
|
||||
|
|
|
@ -200,8 +200,8 @@ p
|
|||
+cell Config parameters.
|
||||
|
||||
+row("foot")
|
||||
+cell yields
|
||||
+cell tuple
|
||||
+cell returns
|
||||
+cell callable
|
||||
+cell An optimizer.
|
||||
|
||||
+h(2, "use_params") Language.use_params
|
||||
|
|
|
@ -262,13 +262,13 @@ p
|
|||
+tag method
|
||||
|
||||
p
|
||||
| Initialize the pipe for training, using data exampes if available. If no
|
||||
| model has been initialized yet, the model is added.
|
||||
| Initialise the pipe for training, using data exampes if available. If no
|
||||
| model has been initialised yet, the model is added.
|
||||
|
||||
+aside-code("Example").
|
||||
#{VARNAME} = #{CLASSNAME}(nlp.vocab)
|
||||
nlp.pipeline.append(#{VARNAME})
|
||||
#{VARNAME}.begin_training(pipeline=nlp.pipeline)
|
||||
optimizer = #{VARNAME}.begin_training(pipeline=nlp.pipeline)
|
||||
|
||||
+table(["Name", "Type", "Description"])
|
||||
+row
|
||||
|
@ -285,6 +285,36 @@ p
|
|||
| Optional list of #[+api("pipe") #[code Pipe]] components that
|
||||
| this component is part of.
|
||||
|
||||
+row
|
||||
+cell #[code sgd]
|
||||
+cell callable
|
||||
+cell
|
||||
| An optional optimizer. Should take two arguments #[code weights]
|
||||
| and #[code gradient], and an optional ID. Will be created via
|
||||
| #[+api(CLASSNAME.toLowerCase() + "#create_optimizer") #[code create_optimizer]]
|
||||
| if not set.
|
||||
|
||||
+row("foot")
|
||||
+cell returns
|
||||
+cell callable
|
||||
+cell An optimizer.
|
||||
|
||||
+h(2, "create_optimizer") #{CLASSNAME}.create_optimizer
|
||||
+tag method
|
||||
|
||||
p
|
||||
| Create an optmizer for the pipeline component.
|
||||
|
||||
+aside-code("Example").
|
||||
#{VARNAME} = #{CLASSNAME}(nlp.vocab)
|
||||
optimizer = #{VARNAME}.create_optimizer()
|
||||
|
||||
+table(["Name", "Type", "Description"])
|
||||
+row("foot")
|
||||
+cell returns
|
||||
+cell callable
|
||||
+cell The optimizer.
|
||||
|
||||
+h(2, "use_params") #{CLASSNAME}.use_params
|
||||
+tag method
|
||||
+tag contextmanager
|
||||
|
@ -309,9 +339,14 @@ p Modify the pipe's model, to use the given parameter values.
|
|||
|
||||
p Add a new label to the pipe.
|
||||
|
||||
+aside-code("Example").
|
||||
#{VARNAME} = #{CLASSNAME}(nlp.vocab)
|
||||
#{VARNAME}.add_label('MY_LABEL')
|
||||
if CLASSNAME == "Tagger"
|
||||
+aside-code("Example").
|
||||
#{VARNAME} = #{CLASSNAME}(nlp.vocab)
|
||||
#{VARNAME}.add_label('MY_LABEL', {POS: 'NOUN'})
|
||||
else
|
||||
+aside-code("Example").
|
||||
#{VARNAME} = #{CLASSNAME}(nlp.vocab)
|
||||
#{VARNAME}.add_label('MY_LABEL')
|
||||
|
||||
+table(["Name", "Type", "Description"])
|
||||
+row
|
||||
|
@ -319,6 +354,14 @@ p Add a new label to the pipe.
|
|||
+cell unicode
|
||||
+cell The label to add.
|
||||
|
||||
if CLASSNAME == "Tagger"
|
||||
+row
|
||||
+cell #[code values]
|
||||
+cell dict
|
||||
+cell
|
||||
| Optional values to map to the label, e.g. a tag map
|
||||
| dictionary.
|
||||
|
||||
+h(2, "to_disk") #{CLASSNAME}.to_disk
|
||||
+tag method
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user