mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 09:26:27 +03:00
Clean up TextCategorizer slightly
This commit is contained in:
parent
d13b9373bf
commit
6b0008afc6
|
@ -946,7 +946,7 @@ class TextCategorizer(Pipe):
|
|||
not_missing = self.model.ops.asarray(not_missing)
|
||||
d_scores = (scores-truths) / scores.shape[0]
|
||||
d_scores *= not_missing
|
||||
mean_square_error = ((scores-truths)**2).sum(axis=1).mean()
|
||||
mean_square_error = (d_scores**2).sum(axis=1).mean()
|
||||
return float(mean_square_error), d_scores
|
||||
|
||||
def add_label(self, label):
|
||||
|
@ -968,11 +968,6 @@ class TextCategorizer(Pipe):
|
|||
|
||||
def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None,
|
||||
**kwargs):
|
||||
if pipeline and getattr(pipeline[0], 'name', None) == 'tensorizer':
|
||||
token_vector_width = pipeline[0].model.nO
|
||||
else:
|
||||
token_vector_width = 64
|
||||
|
||||
if self.model is True:
|
||||
self.cfg['pretrained_vectors'] = kwargs.get('pretrained_vectors')
|
||||
self.model = self.Model(len(self.labels), **self.cfg)
|
||||
|
|
Loading…
Reference in New Issue
Block a user