Add error for non-string labels (#4690)

Add error when attempting to add non-string labels to `Tagger` or
`TextCategorizer`.
This commit is contained in:
adrianeboyd 2019-11-21 16:24:10 +01:00 committed by Ines Montani
parent 8d06386e1e
commit 054df5d90a
4 changed files with 28 additions and 0 deletions

View File

@ -529,6 +529,7 @@ class Errors(object):
E185 = ("Received invalid attribute in component attribute declaration: "
"{obj}.{attr}\nAttribute '{attr}' does not exist on {obj}.")
E186 = ("'{tok_a}' and '{tok_b}' are different texts.")
E187 = ("Only unicode strings are supported as labels.")
@add_codes

View File

@ -13,6 +13,7 @@ from thinc.misc import LayerNorm
from thinc.neural.util import to_categorical
from thinc.neural.util import get_array_module
from ..compat import basestring_
from ..tokens.doc cimport Doc
from ..syntax.nn_parser cimport Parser
from ..syntax.ner cimport BiluoPushDown
@ -547,6 +548,8 @@ class Tagger(Pipe):
return build_tagger_model(n_tags, **cfg)
def add_label(self, label, values=None):
if not isinstance(label, basestring_):
raise ValueError(Errors.E187)
if label in self.labels:
return 0
if self.model not in (True, False, None):
@ -1016,6 +1019,8 @@ class TextCategorizer(Pipe):
return float(mean_square_error), d_scores
def add_label(self, label):
if not isinstance(label, basestring_):
raise ValueError(Errors.E187)
if label in self.labels:
return 0
if self.model not in (None, True, False):

View File

@ -0,0 +1,14 @@
# coding: utf8
from __future__ import unicode_literals
import pytest
from spacy.language import Language
from spacy.pipeline import Tagger
def test_label_types():
nlp = Language()
nlp.add_pipe(nlp.create_pipe("tagger"))
nlp.get_pipe("tagger").add_label("A")
with pytest.raises(ValueError):
nlp.get_pipe("tagger").add_label(9)

View File

@ -62,3 +62,11 @@ def test_textcat_learns_multilabel():
assert score < 0.5
else:
assert score > 0.5
def test_label_types():
nlp = Language()
nlp.add_pipe(nlp.create_pipe("textcat"))
nlp.get_pipe("textcat").add_label("answer")
with pytest.raises(ValueError):
nlp.get_pipe("textcat").add_label(9)