Fix multitasks

This commit is contained in:
Matthew Honnibal 2020-09-29 18:33:16 +02:00
parent 0b5c72fce2
commit a4da3120b4
2 changed files with 4 additions and 4 deletions

View File

@ -126,13 +126,13 @@ cdef class DependencyParser(Parser):
def add_multitask_objective(self, mt_component): def add_multitask_objective(self, mt_component):
self._multitasks.append(mt_component) self._multitasks.append(mt_component)
def init_multitask_objectives(self, get_examples, pipeline, sgd=None, **cfg): def init_multitask_objectives(self, get_examples, nlp=None, **cfg):
# TODO: transfer self.model.get_ref("tok2vec") to the multitask's model ? # TODO: transfer self.model.get_ref("tok2vec") to the multitask's model ?
for labeller in self._multitasks: for labeller in self._multitasks:
labeller.model.set_dim("nO", len(self.labels)) labeller.model.set_dim("nO", len(self.labels))
if labeller.model.has_ref("output_layer"): if labeller.model.has_ref("output_layer"):
labeller.model.get_ref("output_layer").set_dim("nO", len(self.labels)) labeller.model.get_ref("output_layer").set_dim("nO", len(self.labels))
labeller.initialize(get_examples, pipeline=pipeline) labeller.initialize(get_examples, nlp=nlp)
@property @property
def labels(self): def labels(self):

View File

@ -96,14 +96,14 @@ cdef class EntityRecognizer(Parser):
"""Register another component as a multi-task objective. Experimental.""" """Register another component as a multi-task objective. Experimental."""
self._multitasks.append(mt_component) self._multitasks.append(mt_component)
def init_multitask_objectives(self, get_examples, pipeline, sgd=None, **cfg): def init_multitask_objectives(self, get_examples, nlp=None, **cfg):
"""Setup multi-task objective components. Experimental and internal.""" """Setup multi-task objective components. Experimental and internal."""
# TODO: transfer self.model.get_ref("tok2vec") to the multitask's model ? # TODO: transfer self.model.get_ref("tok2vec") to the multitask's model ?
for labeller in self._multitasks: for labeller in self._multitasks:
labeller.model.set_dim("nO", len(self.labels)) labeller.model.set_dim("nO", len(self.labels))
if labeller.model.has_ref("output_layer"): if labeller.model.has_ref("output_layer"):
labeller.model.get_ref("output_layer").set_dim("nO", len(self.labels)) labeller.model.get_ref("output_layer").set_dim("nO", len(self.labels))
labeller.initialize(get_examples, pipeline=pipeline) labeller.initialize(get_examples, nlp=nlp)
@property @property
def labels(self): def labels(self):