Clean up TextCategorizer slightly

This commit is contained in:
Matthew Honnibal 2019-02-23 12:28:06 +01:00
parent d13b9373bf
commit 6b0008afc6

View File

@ -946,7 +946,7 @@ class TextCategorizer(Pipe):
not_missing = self.model.ops.asarray(not_missing) not_missing = self.model.ops.asarray(not_missing)
d_scores = (scores-truths) / scores.shape[0] d_scores = (scores-truths) / scores.shape[0]
d_scores *= not_missing d_scores *= not_missing
mean_square_error = ((scores-truths)**2).sum(axis=1).mean() mean_square_error = (d_scores**2).sum(axis=1).mean()
return float(mean_square_error), d_scores return float(mean_square_error), d_scores
def add_label(self, label): def add_label(self, label):
@ -968,11 +968,6 @@ class TextCategorizer(Pipe):
def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None, def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None,
**kwargs): **kwargs):
if pipeline and getattr(pipeline[0], 'name', None) == 'tensorizer':
token_vector_width = pipeline[0].model.nO
else:
token_vector_width = 64
if self.model is True: if self.model is True:
self.cfg['pretrained_vectors'] = kwargs.get('pretrained_vectors') self.cfg['pretrained_vectors'] = kwargs.get('pretrained_vectors')
self.model = self.Model(len(self.labels), **self.cfg) self.model = self.Model(len(self.labels), **self.cfg)