diff --git a/spacy/language.py b/spacy/language.py index a5b78b178..5b1f50ee2 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -19,7 +19,7 @@ from .vocab import Vocab, create_vocab from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis from .training import Example, validate_examples from .scorer import Scorer -from .util import create_default_optimizer, registry, SimpleFrozenList +from .util import registry, SimpleFrozenList from .util import SimpleFrozenDict, combine_score_weights, CONFIG_SECTION_ORDER from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES @@ -1065,7 +1065,7 @@ class Language: validate_examples(examples, "Language.update") if sgd is None: if self._optimizer is None: - self._optimizer = create_default_optimizer() + self._optimizer = self.create_optimizer() sgd = self._optimizer if component_cfg is None: component_cfg = {} @@ -1123,7 +1123,7 @@ class Language: validate_examples(examples, "Language.rehearse") if sgd is None: if self._optimizer is None: - self._optimizer = create_default_optimizer() + self._optimizer = self.create_optimizer() sgd = self._optimizer pipes = list(self.pipeline) random.shuffle(pipes) @@ -1161,16 +1161,14 @@ class Language: def initialize( self, get_examples: Optional[Callable[[], Iterable[Example]]] = None, - *, - sgd: Optional[Optimizer] = None, - device: int = -1, - ) -> Optimizer: + sgd: Optional[Optimizer]=None + ) -> None: """Initialize the pipe for training, using data examples if available. get_examples (Callable[[], Iterable[Example]]): Optional function that returns gold-standard Example objects. - sgd (thinc.api.Optimizer): Optional optimizer. Will be created with - create_optimizer if it doesn't exist. + sgd (Optional[Optimizer]): An optimizer to use for updates. If not + provided, will be created using the .create_optimizer() method. RETURNS (thinc.api.Optimizer): The optimizer. DOCS: https://nightly.spacy.io/api/language#initialize @@ -1199,25 +1197,22 @@ class Language: if not valid_examples: err = Errors.E930.format(name="Language", obj="empty list") raise ValueError(err) - if device >= 0: # TODO: do we need this here? - require_gpu(device) - if self.vocab.vectors.data.shape[1] >= 1: - ops = get_current_ops() - self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data) - if sgd is None: - sgd = create_default_optimizer() - self._optimizer = sgd + if self.vocab.vectors.data.shape[1] >= 1: + ops = get_current_ops() + self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data) for name, proc in self.pipeline: if hasattr(proc, "initialize"): proc.initialize( - get_examples, pipeline=self.pipeline, sgd=self._optimizer + get_examples, pipeline=self.pipeline ) self._link_components() + if sgd is not None: + self._optimizer = sgd + elif self._optimizer is None: + self._optimizer = self.create_optimizer() return self._optimizer - def resume_training( - self, *, sgd: Optional[Optimizer] = None, device: int = -1 - ) -> Optimizer: + def resume_training(self, *, sgd: Optional[Optimizer] = None) -> Optimizer: """Continue training a pretrained model. Create and return an optimizer, and initialize "rehearsal" for any pipeline @@ -1226,22 +1221,20 @@ class Language: rehearsal, collect samples of text you want the models to retain performance on, and call nlp.rehearse() with a batch of Example objects. - sgd (Optional[Optimizer]): An optimizer. RETURNS (Optimizer): The optimizer. DOCS: https://nightly.spacy.io/api/language#resume_training """ - if device >= 0: # TODO: do we need this here? - require_gpu(device) - ops = get_current_ops() - if self.vocab.vectors.data.shape[1] >= 1: - self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data) - if sgd is None: - sgd = create_default_optimizer() - self._optimizer = sgd + ops = get_current_ops() + if self.vocab.vectors.data.shape[1] >= 1: + self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data) for name, proc in self.pipeline: if hasattr(proc, "_rehearsal_model"): proc._rehearsal_model = deepcopy(proc.model) + if sgd is not None: + self._optimizer = sgd + elif self._optimizer is None: + self._optimizer = self.create_optimizer() return self._optimizer def evaluate( @@ -1302,6 +1295,10 @@ class Language: n_words = sum(len(doc) for doc in docs) results["speed"] = n_words / (end_time - start_time) return results + + def create_optimizer(self): + """Create an optimizer, usually using the [training.optimizer] config.""" + return registry.resolve(self.config["training"]["optimizer"]) @contextmanager def use_params(self, params: Optional[dict]):