Clean up sgd

This commit is contained in:
Matthew Honnibal 2020-09-29 12:00:08 +02:00
parent b3b6868639
commit f2d1b7feb5
5 changed files with 2 additions and 16 deletions

View File

@ -1298,7 +1298,8 @@ class Language:
def create_optimizer(self):
"""Create an optimizer, usually using the [training.optimizer] config."""
return registry.resolve(self.config["training"]["optimizer"])
subconfig = {"optimizer": self.config["training"]["optimizer"]}
return registry.resolve(subconfig)["optimizer"]
@contextmanager
def use_params(self, params: Optional[dict]):

View File

@ -91,9 +91,6 @@ class MultitaskObjective(Tagger):
if label is not None and label not in self.labels:
self.labels[label] = len(self.labels)
self.model.initialize() # TODO: fix initialization by defining X and Y
if sgd is None:
sgd = self.create_optimizer()
return sgd
def predict(self, docs):
tokvecs = self.model.get_ref("tok2vec")(docs)
@ -181,9 +178,6 @@ class ClozeMultitask(Pipe):
self.model.initialize() # TODO: fix initialization by defining X and Y
X = self.model.ops.alloc((5, self.model.get_ref("tok2vec").get_dim("nO")))
self.model.output_layer.initialize(X)
if sgd is None:
sgd = self.create_optimizer()
return sgd
def predict(self, docs):
tokvecs = self.model.get_ref("tok2vec")(docs)

View File

@ -149,9 +149,6 @@ class SentenceRecognizer(Tagger):
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
self.model.initialize(X=doc_sample, Y=label_sample)
if sgd is None:
sgd = self.create_optimizer()
return sgd
def add_label(self, label, values=None):
raise NotImplementedError

View File

@ -348,8 +348,6 @@ class TextCategorizer(Pipe):
pipeline (List[Tuple[str, Callable]]): Optional list of pipeline
components that this component is part of. Corresponds to
nlp.pipeline.
sgd (thinc.api.Optimizer): Optional optimizer. Will be created with
create_optimizer if it doesn't exist.
RETURNS (thinc.api.Optimizer): The optimizer.
DOCS: https://nightly.spacy.io/api/textcategorizer#initialize
@ -367,9 +365,6 @@ class TextCategorizer(Pipe):
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
self.model.initialize(X=doc_sample, Y=label_sample)
if sgd is None:
sgd = self.create_optimizer()
return sgd
def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
"""Score a batch of examples.

View File

@ -440,7 +440,6 @@ cdef class Parser(Pipe):
self.model.initialize(doc_sample)
if pipeline is not None:
self.init_multitask_objectives(get_examples, pipeline)
return sgd
def to_disk(self, path, exclude=tuple()):
serializers = {