mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Clean up TextCategorizer slightly
This commit is contained in:
parent
d13b9373bf
commit
6b0008afc6
|
@ -946,7 +946,7 @@ class TextCategorizer(Pipe):
|
||||||
not_missing = self.model.ops.asarray(not_missing)
|
not_missing = self.model.ops.asarray(not_missing)
|
||||||
d_scores = (scores-truths) / scores.shape[0]
|
d_scores = (scores-truths) / scores.shape[0]
|
||||||
d_scores *= not_missing
|
d_scores *= not_missing
|
||||||
mean_square_error = ((scores-truths)**2).sum(axis=1).mean()
|
mean_square_error = (d_scores**2).sum(axis=1).mean()
|
||||||
return float(mean_square_error), d_scores
|
return float(mean_square_error), d_scores
|
||||||
|
|
||||||
def add_label(self, label):
|
def add_label(self, label):
|
||||||
|
@ -968,11 +968,6 @@ class TextCategorizer(Pipe):
|
||||||
|
|
||||||
def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None,
|
def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
if pipeline and getattr(pipeline[0], 'name', None) == 'tensorizer':
|
|
||||||
token_vector_width = pipeline[0].model.nO
|
|
||||||
else:
|
|
||||||
token_vector_width = 64
|
|
||||||
|
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
self.cfg['pretrained_vectors'] = kwargs.get('pretrained_vectors')
|
self.cfg['pretrained_vectors'] = kwargs.get('pretrained_vectors')
|
||||||
self.model = self.Model(len(self.labels), **self.cfg)
|
self.model = self.Model(len(self.labels), **self.cfg)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user