mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Clean up sgd
This commit is contained in:
parent
b3b6868639
commit
f2d1b7feb5
|
@ -1298,7 +1298,8 @@ class Language:
|
|||
|
||||
def create_optimizer(self):
|
||||
"""Create an optimizer, usually using the [training.optimizer] config."""
|
||||
return registry.resolve(self.config["training"]["optimizer"])
|
||||
subconfig = {"optimizer": self.config["training"]["optimizer"]}
|
||||
return registry.resolve(subconfig)["optimizer"]
|
||||
|
||||
@contextmanager
|
||||
def use_params(self, params: Optional[dict]):
|
||||
|
|
|
@ -91,9 +91,6 @@ class MultitaskObjective(Tagger):
|
|||
if label is not None and label not in self.labels:
|
||||
self.labels[label] = len(self.labels)
|
||||
self.model.initialize() # TODO: fix initialization by defining X and Y
|
||||
if sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
return sgd
|
||||
|
||||
def predict(self, docs):
|
||||
tokvecs = self.model.get_ref("tok2vec")(docs)
|
||||
|
@ -181,9 +178,6 @@ class ClozeMultitask(Pipe):
|
|||
self.model.initialize() # TODO: fix initialization by defining X and Y
|
||||
X = self.model.ops.alloc((5, self.model.get_ref("tok2vec").get_dim("nO")))
|
||||
self.model.output_layer.initialize(X)
|
||||
if sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
return sgd
|
||||
|
||||
def predict(self, docs):
|
||||
tokvecs = self.model.get_ref("tok2vec")(docs)
|
||||
|
|
|
@ -149,9 +149,6 @@ class SentenceRecognizer(Tagger):
|
|||
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
||||
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
|
||||
self.model.initialize(X=doc_sample, Y=label_sample)
|
||||
if sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
return sgd
|
||||
|
||||
def add_label(self, label, values=None):
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -348,8 +348,6 @@ class TextCategorizer(Pipe):
|
|||
pipeline (List[Tuple[str, Callable]]): Optional list of pipeline
|
||||
components that this component is part of. Corresponds to
|
||||
nlp.pipeline.
|
||||
sgd (thinc.api.Optimizer): Optional optimizer. Will be created with
|
||||
create_optimizer if it doesn't exist.
|
||||
RETURNS (thinc.api.Optimizer): The optimizer.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/textcategorizer#initialize
|
||||
|
@ -367,9 +365,6 @@ class TextCategorizer(Pipe):
|
|||
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
||||
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
|
||||
self.model.initialize(X=doc_sample, Y=label_sample)
|
||||
if sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
return sgd
|
||||
|
||||
def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
||||
"""Score a batch of examples.
|
||||
|
|
|
@ -440,7 +440,6 @@ cdef class Parser(Pipe):
|
|||
self.model.initialize(doc_sample)
|
||||
if pipeline is not None:
|
||||
self.init_multitask_objectives(get_examples, pipeline)
|
||||
return sgd
|
||||
|
||||
def to_disk(self, path, exclude=tuple()):
|
||||
serializers = {
|
||||
|
|
Loading…
Reference in New Issue
Block a user