more friendly textcat errors (#3946)

* more friendly textcat errors with require_model and require_labels

* update thinc version with recent bugfix
This commit is contained in:
Sofie Van Landeghem 2019-07-10 19:39:38 +02:00 committed by Matthew Honnibal
parent b94c5443d9
commit c4c21cb428
3 changed files with 9 additions and 1 deletions

View File

@ -1,7 +1,7 @@
# Our libraries
cymem>=2.0.2,<2.1.0
preshed>=2.0.1,<2.1.0
thinc>=7.0.2,<7.1.0
thinc>=7.0.5,<7.1.0
blis>=0.2.2,<0.3.0
murmurhash>=0.28.0,<1.1.0
wasabi>=0.2.0,<1.1.0

View File

@ -403,6 +403,7 @@ class Errors(object):
E140 = ("The list of entities, prior probabilities and entity vectors should be of equal length.")
E141 = ("Entity vectors should be of length {required} instead of the provided {found}.")
E142 = ("Unsupported loss_function '{loss_func}'. Use either 'L2' or 'cosine'")
E143 = ("Labels for component '{name}' not initialized. Did you forget to call add_label()?")
@add_codes

View File

@ -902,6 +902,11 @@ class TextCategorizer(Pipe):
def labels(self):
return tuple(self.cfg.setdefault("labels", []))
def require_labels(self):
"""Raise an error if the component's model has no labels defined."""
if not self.labels:
raise ValueError(Errors.E143.format(name=self.name))
@labels.setter
def labels(self, value):
self.cfg["labels"] = tuple(value)
@ -931,6 +936,7 @@ class TextCategorizer(Pipe):
doc.cats[label] = float(scores[i, j])
def update(self, docs, golds, state=None, drop=0., sgd=None, losses=None):
self.require_model()
scores, bp_scores = self.model.begin_update(docs, drop=drop)
loss, d_scores = self.get_loss(docs, golds, scores)
bp_scores(d_scores, sgd=sgd)
@ -985,6 +991,7 @@ class TextCategorizer(Pipe):
def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None, **kwargs):
if self.model is True:
self.cfg["pretrained_vectors"] = kwargs.get("pretrained_vectors")
self.require_labels()
self.model = self.Model(len(self.labels), **self.cfg)
link_vectors_to_models(self.vocab)
if sgd is None: