Remove 'device' argument from Language, clean up 'sgd' arg

This commit is contained in:
Matthew Honnibal 2020-09-29 11:42:19 +02:00
parent ff9a63bfbd
commit 5276db6f3f

View File

@ -19,7 +19,7 @@ from .vocab import Vocab, create_vocab
from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis
from .training import Example, validate_examples
from .scorer import Scorer
from .util import create_default_optimizer, registry, SimpleFrozenList
from .util import registry, SimpleFrozenList
from .util import SimpleFrozenDict, combine_score_weights, CONFIG_SECTION_ORDER
from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES
@ -1065,7 +1065,7 @@ class Language:
validate_examples(examples, "Language.update")
if sgd is None:
if self._optimizer is None:
self._optimizer = create_default_optimizer()
self._optimizer = self.create_optimizer()
sgd = self._optimizer
if component_cfg is None:
component_cfg = {}
@ -1123,7 +1123,7 @@ class Language:
validate_examples(examples, "Language.rehearse")
if sgd is None:
if self._optimizer is None:
self._optimizer = create_default_optimizer()
self._optimizer = self.create_optimizer()
sgd = self._optimizer
pipes = list(self.pipeline)
random.shuffle(pipes)
@ -1161,16 +1161,14 @@ class Language:
def initialize(
self,
get_examples: Optional[Callable[[], Iterable[Example]]] = None,
*,
sgd: Optional[Optimizer] = None,
device: int = -1,
) -> Optimizer:
sgd: Optional[Optimizer]=None
) -> None:
"""Initialize the pipe for training, using data examples if available.
get_examples (Callable[[], Iterable[Example]]): Optional function that
returns gold-standard Example objects.
sgd (thinc.api.Optimizer): Optional optimizer. Will be created with
create_optimizer if it doesn't exist.
sgd (Optional[Optimizer]): An optimizer to use for updates. If not
provided, will be created using the .create_optimizer() method.
RETURNS (thinc.api.Optimizer): The optimizer.
DOCS: https://nightly.spacy.io/api/language#initialize
@ -1199,25 +1197,22 @@ class Language:
if not valid_examples:
err = Errors.E930.format(name="Language", obj="empty list")
raise ValueError(err)
if device >= 0: # TODO: do we need this here?
require_gpu(device)
if self.vocab.vectors.data.shape[1] >= 1:
ops = get_current_ops()
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
if sgd is None:
sgd = create_default_optimizer()
self._optimizer = sgd
for name, proc in self.pipeline:
if hasattr(proc, "initialize"):
proc.initialize(
get_examples, pipeline=self.pipeline, sgd=self._optimizer
get_examples, pipeline=self.pipeline
)
self._link_components()
if sgd is not None:
self._optimizer = sgd
elif self._optimizer is None:
self._optimizer = self.create_optimizer()
return self._optimizer
def resume_training(
self, *, sgd: Optional[Optimizer] = None, device: int = -1
) -> Optimizer:
def resume_training(self, *, sgd: Optional[Optimizer] = None) -> Optimizer:
"""Continue training a pretrained model.
Create and return an optimizer, and initialize "rehearsal" for any pipeline
@ -1226,22 +1221,20 @@ class Language:
rehearsal, collect samples of text you want the models to retain performance
on, and call nlp.rehearse() with a batch of Example objects.
sgd (Optional[Optimizer]): An optimizer.
RETURNS (Optimizer): The optimizer.
DOCS: https://nightly.spacy.io/api/language#resume_training
"""
if device >= 0: # TODO: do we need this here?
require_gpu(device)
ops = get_current_ops()
if self.vocab.vectors.data.shape[1] >= 1:
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
if sgd is None:
sgd = create_default_optimizer()
self._optimizer = sgd
for name, proc in self.pipeline:
if hasattr(proc, "_rehearsal_model"):
proc._rehearsal_model = deepcopy(proc.model)
if sgd is not None:
self._optimizer = sgd
elif self._optimizer is None:
self._optimizer = self.create_optimizer()
return self._optimizer
def evaluate(
@ -1303,6 +1296,10 @@ class Language:
results["speed"] = n_words / (end_time - start_time)
return results
def create_optimizer(self):
"""Create an optimizer, usually using the [training.optimizer] config."""
return registry.resolve(self.config["training"]["optimizer"])
@contextmanager
def use_params(self, params: Optional[dict]):
"""Replace weights of models in the pipeline with those provided in the