remove labels from constructor

This commit is contained in:
svlandeg 2020-11-11 21:34:12 +01:00
parent fcd79e0655
commit d5a920325f
2 changed files with 13 additions and 11 deletions

View File

@ -47,7 +47,7 @@ class MultitaskObjective(Tagger):
side-objective. side-objective.
""" """
def __init__(self, vocab, model, name="nn_labeller", *, labels, target): def __init__(self, vocab, model, name="nn_labeller", *, target):
self.vocab = vocab self.vocab = vocab
self.model = model self.model = model
self.name = name self.name = name
@ -67,7 +67,7 @@ class MultitaskObjective(Tagger):
self.make_label = target self.make_label = target
else: else:
raise ValueError(Errors.E016) raise ValueError(Errors.E016)
cfg = {"labels": labels or {}, "target": target} cfg = {"labels": {}, "target": target}
self.cfg = dict(cfg) self.cfg = dict(cfg)
@property @property
@ -81,15 +81,18 @@ class MultitaskObjective(Tagger):
def set_annotations(self, docs, dep_ids): def set_annotations(self, docs, dep_ids):
pass pass
def initialize(self, get_examples, nlp=None): def initialize(self, get_examples, nlp=None, labels=None):
if not hasattr(get_examples, "__call__"): if not hasattr(get_examples, "__call__"):
err = Errors.E930.format(name="MultitaskObjective", obj=type(get_examples)) err = Errors.E930.format(name="MultitaskObjective", obj=type(get_examples))
raise ValueError(err) raise ValueError(err)
for example in get_examples(): if labels is not None:
for token in example.y: self.labels = labels
label = self.make_label(token) else:
if label is not None and label not in self.labels: for example in get_examples():
self.labels[label] = len(self.labels) for token in example.y:
label = self.make_label(token)
if label is not None and label not in self.labels:
self.labels[label] = len(self.labels)
self.model.initialize() # TODO: fix initialization by defining X and Y self.model.initialize() # TODO: fix initialization by defining X and Y
def predict(self, docs): def predict(self, docs):

View File

@ -61,14 +61,13 @@ class Tagger(TrainablePipe):
DOCS: https://nightly.spacy.io/api/tagger DOCS: https://nightly.spacy.io/api/tagger
""" """
def __init__(self, vocab, model, name="tagger", *, labels=None): def __init__(self, vocab, model, name="tagger"):
"""Initialize a part-of-speech tagger. """Initialize a part-of-speech tagger.
vocab (Vocab): The shared vocabulary. vocab (Vocab): The shared vocabulary.
model (thinc.api.Model): The Thinc Model powering the pipeline component. model (thinc.api.Model): The Thinc Model powering the pipeline component.
name (str): The component instance name, used to add entries to the name (str): The component instance name, used to add entries to the
losses during training. losses during training.
labels (List): The set of labels. Defaults to None.
DOCS: https://nightly.spacy.io/api/tagger#init DOCS: https://nightly.spacy.io/api/tagger#init
""" """
@ -76,7 +75,7 @@ class Tagger(TrainablePipe):
self.model = model self.model = model
self.name = name self.name = name
self._rehearsal_model = None self._rehearsal_model = None
cfg = {"labels": labels or []} cfg = {"labels": []}
self.cfg = dict(sorted(cfg.items())) self.cfg = dict(sorted(cfg.items()))
@property @property