mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
more friendly textcat errors (#3946)
* more friendly textcat errors with require_model and require_labels * update thinc version with recent bugfix
This commit is contained in:
parent
b94c5443d9
commit
c4c21cb428
|
@ -1,7 +1,7 @@
|
|||
# Our libraries
|
||||
cymem>=2.0.2,<2.1.0
|
||||
preshed>=2.0.1,<2.1.0
|
||||
thinc>=7.0.2,<7.1.0
|
||||
thinc>=7.0.5,<7.1.0
|
||||
blis>=0.2.2,<0.3.0
|
||||
murmurhash>=0.28.0,<1.1.0
|
||||
wasabi>=0.2.0,<1.1.0
|
||||
|
|
|
@ -403,6 +403,7 @@ class Errors(object):
|
|||
E140 = ("The list of entities, prior probabilities and entity vectors should be of equal length.")
|
||||
E141 = ("Entity vectors should be of length {required} instead of the provided {found}.")
|
||||
E142 = ("Unsupported loss_function '{loss_func}'. Use either 'L2' or 'cosine'")
|
||||
E143 = ("Labels for component '{name}' not initialized. Did you forget to call add_label()?")
|
||||
|
||||
|
||||
@add_codes
|
||||
|
|
|
@ -902,6 +902,11 @@ class TextCategorizer(Pipe):
|
|||
def labels(self):
|
||||
return tuple(self.cfg.setdefault("labels", []))
|
||||
|
||||
def require_labels(self):
|
||||
"""Raise an error if the component's model has no labels defined."""
|
||||
if not self.labels:
|
||||
raise ValueError(Errors.E143.format(name=self.name))
|
||||
|
||||
@labels.setter
|
||||
def labels(self, value):
|
||||
self.cfg["labels"] = tuple(value)
|
||||
|
@ -931,6 +936,7 @@ class TextCategorizer(Pipe):
|
|||
doc.cats[label] = float(scores[i, j])
|
||||
|
||||
def update(self, docs, golds, state=None, drop=0., sgd=None, losses=None):
|
||||
self.require_model()
|
||||
scores, bp_scores = self.model.begin_update(docs, drop=drop)
|
||||
loss, d_scores = self.get_loss(docs, golds, scores)
|
||||
bp_scores(d_scores, sgd=sgd)
|
||||
|
@ -985,6 +991,7 @@ class TextCategorizer(Pipe):
|
|||
def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None, **kwargs):
|
||||
if self.model is True:
|
||||
self.cfg["pretrained_vectors"] = kwargs.get("pretrained_vectors")
|
||||
self.require_labels()
|
||||
self.model = self.Model(len(self.labels), **self.cfg)
|
||||
link_vectors_to_models(self.vocab)
|
||||
if sgd is None:
|
||||
|
|
Loading…
Reference in New Issue
Block a user