Make gold_tuples arg optional in begin_training

This commit is contained in:
Matthew Honnibal 2017-07-22 20:04:43 +02:00
parent ed6c85fa3c
commit b55714d5d1

View File

@ -80,7 +80,7 @@ class BaseThincComponent(object):
def get_loss(self, docs, golds, scores):
raise NotImplementedError
def begin_training(self, gold_tuples, pipeline=None):
def begin_training(self, gold_tuples=tuple(), pipeline=None):
token_vector_width = pipeline[0].model.nO
if self.model is True:
self.model = self.Model(1, token_vector_width)
@ -223,7 +223,7 @@ class TokenVectorEncoder(BaseThincComponent):
# TODO: implement
raise NotImplementedError
def begin_training(self, gold_tuples, pipeline=None):
def begin_training(self, gold_tuples=tuple(), pipeline=None):
"""Allocate models, pre-process training data and acquire a trainer and
optimizer.
@ -311,7 +311,7 @@ class NeuralTagger(BaseThincComponent):
d_scores = self.model.ops.unflatten(d_scores, [len(d) for d in docs])
return float(loss), d_scores
def begin_training(self, gold_tuples, pipeline=None):
def begin_training(self, gold_tuples=tuple(), pipeline=None):
orig_tag_map = dict(self.vocab.morphology.tag_map)
new_tag_map = {}
for raw_text, annots_brackets in gold_tuples:
@ -420,7 +420,7 @@ class NeuralLabeller(NeuralTagger):
def set_annotations(self, docs, dep_ids):
pass
def begin_training(self, gold_tuples, pipeline=None):
def begin_training(self, gold_tuples=tuple(), pipeline=None):
gold_tuples = nonproj.preprocess_training_data(gold_tuples)
for raw_text, annots_brackets in gold_tuples:
for annots, brackets in annots_brackets:
@ -502,7 +502,7 @@ class SimilarityHook(BaseThincComponent):
return d_tensor1s, d_tensor2s
def begin_training(self, _, pipeline=None):
def begin_training(self, _=tuple(), pipeline=None):
"""
Allocate model, using width from tensorizer in pipeline.
@ -517,7 +517,7 @@ class TextCategorizer(BaseThincComponent):
name = 'textcat'
@classmethod
def Model(cls, nr_class, width=64, **cfg):
def Model(cls, nr_class=1, width=64, **cfg):
return build_text_classifier(nr_class, width, **cfg)
def __init__(self, vocab, model=True, **cfg):
@ -544,7 +544,7 @@ class TextCategorizer(BaseThincComponent):
def set_annotations(self, docs, scores):
for i, doc in enumerate(docs):
for j, label in self.labels:
for j, label in enumerate(self.labels):
doc.cats[label] = float(scores[i, j])
def update(self, docs_tensors, golds, state=None, drop=0., sgd=None, losses=None):
@ -567,8 +567,11 @@ class TextCategorizer(BaseThincComponent):
mean_square_error = ((scores-truths)**2).sum(axis=1).mean()
return mean_square_error, d_scores
def begin_training(self, gold_tuples, pipeline=None):
def begin_training(self, gold_tuples=tuple(), pipeline=None):
if pipeline:
token_vector_width = pipeline[0].model.nO
else:
token_vector_width = 64
if self.model is True:
self.model = self.Model(len(self.labels), token_vector_width)