mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Remove 'device' argument from Language, clean up 'sgd' arg
This commit is contained in:
parent
ff9a63bfbd
commit
5276db6f3f
|
@ -19,7 +19,7 @@ from .vocab import Vocab, create_vocab
|
|||
from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis
|
||||
from .training import Example, validate_examples
|
||||
from .scorer import Scorer
|
||||
from .util import create_default_optimizer, registry, SimpleFrozenList
|
||||
from .util import registry, SimpleFrozenList
|
||||
from .util import SimpleFrozenDict, combine_score_weights, CONFIG_SECTION_ORDER
|
||||
from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS
|
||||
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES
|
||||
|
@ -1065,7 +1065,7 @@ class Language:
|
|||
validate_examples(examples, "Language.update")
|
||||
if sgd is None:
|
||||
if self._optimizer is None:
|
||||
self._optimizer = create_default_optimizer()
|
||||
self._optimizer = self.create_optimizer()
|
||||
sgd = self._optimizer
|
||||
if component_cfg is None:
|
||||
component_cfg = {}
|
||||
|
@ -1123,7 +1123,7 @@ class Language:
|
|||
validate_examples(examples, "Language.rehearse")
|
||||
if sgd is None:
|
||||
if self._optimizer is None:
|
||||
self._optimizer = create_default_optimizer()
|
||||
self._optimizer = self.create_optimizer()
|
||||
sgd = self._optimizer
|
||||
pipes = list(self.pipeline)
|
||||
random.shuffle(pipes)
|
||||
|
@ -1161,16 +1161,14 @@ class Language:
|
|||
def initialize(
|
||||
self,
|
||||
get_examples: Optional[Callable[[], Iterable[Example]]] = None,
|
||||
*,
|
||||
sgd: Optional[Optimizer] = None,
|
||||
device: int = -1,
|
||||
) -> Optimizer:
|
||||
sgd: Optional[Optimizer]=None
|
||||
) -> None:
|
||||
"""Initialize the pipe for training, using data examples if available.
|
||||
|
||||
get_examples (Callable[[], Iterable[Example]]): Optional function that
|
||||
returns gold-standard Example objects.
|
||||
sgd (thinc.api.Optimizer): Optional optimizer. Will be created with
|
||||
create_optimizer if it doesn't exist.
|
||||
sgd (Optional[Optimizer]): An optimizer to use for updates. If not
|
||||
provided, will be created using the .create_optimizer() method.
|
||||
RETURNS (thinc.api.Optimizer): The optimizer.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/language#initialize
|
||||
|
@ -1199,25 +1197,22 @@ class Language:
|
|||
if not valid_examples:
|
||||
err = Errors.E930.format(name="Language", obj="empty list")
|
||||
raise ValueError(err)
|
||||
if device >= 0: # TODO: do we need this here?
|
||||
require_gpu(device)
|
||||
if self.vocab.vectors.data.shape[1] >= 1:
|
||||
ops = get_current_ops()
|
||||
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
|
||||
if sgd is None:
|
||||
sgd = create_default_optimizer()
|
||||
self._optimizer = sgd
|
||||
if self.vocab.vectors.data.shape[1] >= 1:
|
||||
ops = get_current_ops()
|
||||
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
|
||||
for name, proc in self.pipeline:
|
||||
if hasattr(proc, "initialize"):
|
||||
proc.initialize(
|
||||
get_examples, pipeline=self.pipeline, sgd=self._optimizer
|
||||
get_examples, pipeline=self.pipeline
|
||||
)
|
||||
self._link_components()
|
||||
if sgd is not None:
|
||||
self._optimizer = sgd
|
||||
elif self._optimizer is None:
|
||||
self._optimizer = self.create_optimizer()
|
||||
return self._optimizer
|
||||
|
||||
def resume_training(
|
||||
self, *, sgd: Optional[Optimizer] = None, device: int = -1
|
||||
) -> Optimizer:
|
||||
def resume_training(self, *, sgd: Optional[Optimizer] = None) -> Optimizer:
|
||||
"""Continue training a pretrained model.
|
||||
|
||||
Create and return an optimizer, and initialize "rehearsal" for any pipeline
|
||||
|
@ -1226,22 +1221,20 @@ class Language:
|
|||
rehearsal, collect samples of text you want the models to retain performance
|
||||
on, and call nlp.rehearse() with a batch of Example objects.
|
||||
|
||||
sgd (Optional[Optimizer]): An optimizer.
|
||||
RETURNS (Optimizer): The optimizer.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/language#resume_training
|
||||
"""
|
||||
if device >= 0: # TODO: do we need this here?
|
||||
require_gpu(device)
|
||||
ops = get_current_ops()
|
||||
if self.vocab.vectors.data.shape[1] >= 1:
|
||||
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
|
||||
if sgd is None:
|
||||
sgd = create_default_optimizer()
|
||||
self._optimizer = sgd
|
||||
ops = get_current_ops()
|
||||
if self.vocab.vectors.data.shape[1] >= 1:
|
||||
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
|
||||
for name, proc in self.pipeline:
|
||||
if hasattr(proc, "_rehearsal_model"):
|
||||
proc._rehearsal_model = deepcopy(proc.model)
|
||||
if sgd is not None:
|
||||
self._optimizer = sgd
|
||||
elif self._optimizer is None:
|
||||
self._optimizer = self.create_optimizer()
|
||||
return self._optimizer
|
||||
|
||||
def evaluate(
|
||||
|
@ -1302,6 +1295,10 @@ class Language:
|
|||
n_words = sum(len(doc) for doc in docs)
|
||||
results["speed"] = n_words / (end_time - start_time)
|
||||
return results
|
||||
|
||||
def create_optimizer(self):
|
||||
"""Create an optimizer, usually using the [training.optimizer] config."""
|
||||
return registry.resolve(self.config["training"]["optimizer"])
|
||||
|
||||
@contextmanager
|
||||
def use_params(self, params: Optional[dict]):
|
||||
|
|
Loading…
Reference in New Issue
Block a user