Support loading labels in morphologizer

This commit is contained in:
Matthew Honnibal 2020-10-03 19:13:42 +02:00
parent d6c967401f
commit 8ea8b7d940

View File

@ -134,7 +134,7 @@ class Morphologizer(Tagger):
self.cfg["labels_pos"][norm_label] = POS_IDS[pos] self.cfg["labels_pos"][norm_label] = POS_IDS[pos]
return 1 return 1
def initialize(self, get_examples, *, nlp=None): def initialize(self, get_examples, *, nlp=None, labels=None):
"""Initialize the pipe for training, using a representative set """Initialize the pipe for training, using a representative set
of data examples. of data examples.
@ -145,20 +145,24 @@ class Morphologizer(Tagger):
DOCS: https://nightly.spacy.io/api/morphologizer#initialize DOCS: https://nightly.spacy.io/api/morphologizer#initialize
""" """
self._ensure_examples(get_examples) self._ensure_examples(get_examples)
# First, fetch all labels from the data if labels is not None:
for example in get_examples(): self.cfg["labels_morph"] = labels["labels_morph"]
for i, token in enumerate(example.reference): self.cfg["labels_pos"] = labels["labels_pos"]
pos = token.pos_ else:
morph = str(token.morph) # First, fetch all labels from the data
# create and add the combined morph+POS label for example in get_examples():
morph_dict = Morphology.feats_to_dict(morph) for i, token in enumerate(example.reference):
if pos: pos = token.pos_
morph_dict[self.POS_FEAT] = pos morph = str(token.morph)
norm_label = self.vocab.strings[self.vocab.morphology.add(morph_dict)] # create and add the combined morph+POS label
# add label->morph and label->POS mappings morph_dict = Morphology.feats_to_dict(morph)
if norm_label not in self.cfg["labels_morph"]: if pos:
self.cfg["labels_morph"][norm_label] = morph morph_dict[self.POS_FEAT] = pos
self.cfg["labels_pos"][norm_label] = POS_IDS[pos] norm_label = self.vocab.strings[self.vocab.morphology.add(morph_dict)]
# add label->morph and label->POS mappings
if norm_label not in self.cfg["labels_morph"]:
self.cfg["labels_morph"][norm_label] = morph
self.cfg["labels_pos"][norm_label] = POS_IDS[pos]
if len(self.labels) <= 1: if len(self.labels) <= 1:
raise ValueError(Errors.E143.format(name=self.name)) raise ValueError(Errors.E143.format(name=self.name))
doc_sample = [] doc_sample = []