mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Clean up sgd
This commit is contained in:
parent
b3b6868639
commit
f2d1b7feb5
|
@ -1298,7 +1298,8 @@ class Language:
|
||||||
|
|
||||||
def create_optimizer(self):
|
def create_optimizer(self):
|
||||||
"""Create an optimizer, usually using the [training.optimizer] config."""
|
"""Create an optimizer, usually using the [training.optimizer] config."""
|
||||||
return registry.resolve(self.config["training"]["optimizer"])
|
subconfig = {"optimizer": self.config["training"]["optimizer"]}
|
||||||
|
return registry.resolve(subconfig)["optimizer"]
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def use_params(self, params: Optional[dict]):
|
def use_params(self, params: Optional[dict]):
|
||||||
|
|
|
@ -91,9 +91,6 @@ class MultitaskObjective(Tagger):
|
||||||
if label is not None and label not in self.labels:
|
if label is not None and label not in self.labels:
|
||||||
self.labels[label] = len(self.labels)
|
self.labels[label] = len(self.labels)
|
||||||
self.model.initialize() # TODO: fix initialization by defining X and Y
|
self.model.initialize() # TODO: fix initialization by defining X and Y
|
||||||
if sgd is None:
|
|
||||||
sgd = self.create_optimizer()
|
|
||||||
return sgd
|
|
||||||
|
|
||||||
def predict(self, docs):
|
def predict(self, docs):
|
||||||
tokvecs = self.model.get_ref("tok2vec")(docs)
|
tokvecs = self.model.get_ref("tok2vec")(docs)
|
||||||
|
@ -181,9 +178,6 @@ class ClozeMultitask(Pipe):
|
||||||
self.model.initialize() # TODO: fix initialization by defining X and Y
|
self.model.initialize() # TODO: fix initialization by defining X and Y
|
||||||
X = self.model.ops.alloc((5, self.model.get_ref("tok2vec").get_dim("nO")))
|
X = self.model.ops.alloc((5, self.model.get_ref("tok2vec").get_dim("nO")))
|
||||||
self.model.output_layer.initialize(X)
|
self.model.output_layer.initialize(X)
|
||||||
if sgd is None:
|
|
||||||
sgd = self.create_optimizer()
|
|
||||||
return sgd
|
|
||||||
|
|
||||||
def predict(self, docs):
|
def predict(self, docs):
|
||||||
tokvecs = self.model.get_ref("tok2vec")(docs)
|
tokvecs = self.model.get_ref("tok2vec")(docs)
|
||||||
|
|
|
@ -149,9 +149,6 @@ class SentenceRecognizer(Tagger):
|
||||||
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
|
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
self.model.initialize(X=doc_sample, Y=label_sample)
|
self.model.initialize(X=doc_sample, Y=label_sample)
|
||||||
if sgd is None:
|
|
||||||
sgd = self.create_optimizer()
|
|
||||||
return sgd
|
|
||||||
|
|
||||||
def add_label(self, label, values=None):
|
def add_label(self, label, values=None):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -348,8 +348,6 @@ class TextCategorizer(Pipe):
|
||||||
pipeline (List[Tuple[str, Callable]]): Optional list of pipeline
|
pipeline (List[Tuple[str, Callable]]): Optional list of pipeline
|
||||||
components that this component is part of. Corresponds to
|
components that this component is part of. Corresponds to
|
||||||
nlp.pipeline.
|
nlp.pipeline.
|
||||||
sgd (thinc.api.Optimizer): Optional optimizer. Will be created with
|
|
||||||
create_optimizer if it doesn't exist.
|
|
||||||
RETURNS (thinc.api.Optimizer): The optimizer.
|
RETURNS (thinc.api.Optimizer): The optimizer.
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/textcategorizer#initialize
|
DOCS: https://nightly.spacy.io/api/textcategorizer#initialize
|
||||||
|
@ -367,9 +365,6 @@ class TextCategorizer(Pipe):
|
||||||
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
|
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
self.model.initialize(X=doc_sample, Y=label_sample)
|
self.model.initialize(X=doc_sample, Y=label_sample)
|
||||||
if sgd is None:
|
|
||||||
sgd = self.create_optimizer()
|
|
||||||
return sgd
|
|
||||||
|
|
||||||
def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
||||||
"""Score a batch of examples.
|
"""Score a batch of examples.
|
||||||
|
|
|
@ -440,7 +440,6 @@ cdef class Parser(Pipe):
|
||||||
self.model.initialize(doc_sample)
|
self.model.initialize(doc_sample)
|
||||||
if pipeline is not None:
|
if pipeline is not None:
|
||||||
self.init_multitask_objectives(get_examples, pipeline)
|
self.init_multitask_objectives(get_examples, pipeline)
|
||||||
return sgd
|
|
||||||
|
|
||||||
def to_disk(self, path, exclude=tuple()):
|
def to_disk(self, path, exclude=tuple()):
|
||||||
serializers = {
|
serializers = {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user