mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-24 16:24:16 +03:00
Fix multitasks
This commit is contained in:
parent
0b5c72fce2
commit
a4da3120b4
|
@ -126,13 +126,13 @@ cdef class DependencyParser(Parser):
|
||||||
def add_multitask_objective(self, mt_component):
|
def add_multitask_objective(self, mt_component):
|
||||||
self._multitasks.append(mt_component)
|
self._multitasks.append(mt_component)
|
||||||
|
|
||||||
def init_multitask_objectives(self, get_examples, pipeline, sgd=None, **cfg):
|
def init_multitask_objectives(self, get_examples, nlp=None, **cfg):
|
||||||
# TODO: transfer self.model.get_ref("tok2vec") to the multitask's model ?
|
# TODO: transfer self.model.get_ref("tok2vec") to the multitask's model ?
|
||||||
for labeller in self._multitasks:
|
for labeller in self._multitasks:
|
||||||
labeller.model.set_dim("nO", len(self.labels))
|
labeller.model.set_dim("nO", len(self.labels))
|
||||||
if labeller.model.has_ref("output_layer"):
|
if labeller.model.has_ref("output_layer"):
|
||||||
labeller.model.get_ref("output_layer").set_dim("nO", len(self.labels))
|
labeller.model.get_ref("output_layer").set_dim("nO", len(self.labels))
|
||||||
labeller.initialize(get_examples, pipeline=pipeline)
|
labeller.initialize(get_examples, nlp=nlp)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def labels(self):
|
def labels(self):
|
||||||
|
|
|
@ -96,14 +96,14 @@ cdef class EntityRecognizer(Parser):
|
||||||
"""Register another component as a multi-task objective. Experimental."""
|
"""Register another component as a multi-task objective. Experimental."""
|
||||||
self._multitasks.append(mt_component)
|
self._multitasks.append(mt_component)
|
||||||
|
|
||||||
def init_multitask_objectives(self, get_examples, pipeline, sgd=None, **cfg):
|
def init_multitask_objectives(self, get_examples, nlp=None, **cfg):
|
||||||
"""Setup multi-task objective components. Experimental and internal."""
|
"""Setup multi-task objective components. Experimental and internal."""
|
||||||
# TODO: transfer self.model.get_ref("tok2vec") to the multitask's model ?
|
# TODO: transfer self.model.get_ref("tok2vec") to the multitask's model ?
|
||||||
for labeller in self._multitasks:
|
for labeller in self._multitasks:
|
||||||
labeller.model.set_dim("nO", len(self.labels))
|
labeller.model.set_dim("nO", len(self.labels))
|
||||||
if labeller.model.has_ref("output_layer"):
|
if labeller.model.has_ref("output_layer"):
|
||||||
labeller.model.get_ref("output_layer").set_dim("nO", len(self.labels))
|
labeller.model.get_ref("output_layer").set_dim("nO", len(self.labels))
|
||||||
labeller.initialize(get_examples, pipeline=pipeline)
|
labeller.initialize(get_examples, nlp=nlp)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def labels(self):
|
def labels(self):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user