mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Add error for non-string labels (#4690)
Add error when attempting to add non-string labels to `Tagger` or `TextCategorizer`.
This commit is contained in:
parent
8d06386e1e
commit
054df5d90a
|
@ -529,6 +529,7 @@ class Errors(object):
|
|||
E185 = ("Received invalid attribute in component attribute declaration: "
|
||||
"{obj}.{attr}\nAttribute '{attr}' does not exist on {obj}.")
|
||||
E186 = ("'{tok_a}' and '{tok_b}' are different texts.")
|
||||
E187 = ("Only unicode strings are supported as labels.")
|
||||
|
||||
|
||||
@add_codes
|
||||
|
|
|
@ -13,6 +13,7 @@ from thinc.misc import LayerNorm
|
|||
from thinc.neural.util import to_categorical
|
||||
from thinc.neural.util import get_array_module
|
||||
|
||||
from ..compat import basestring_
|
||||
from ..tokens.doc cimport Doc
|
||||
from ..syntax.nn_parser cimport Parser
|
||||
from ..syntax.ner cimport BiluoPushDown
|
||||
|
@ -547,6 +548,8 @@ class Tagger(Pipe):
|
|||
return build_tagger_model(n_tags, **cfg)
|
||||
|
||||
def add_label(self, label, values=None):
|
||||
if not isinstance(label, basestring_):
|
||||
raise ValueError(Errors.E187)
|
||||
if label in self.labels:
|
||||
return 0
|
||||
if self.model not in (True, False, None):
|
||||
|
@ -1016,6 +1019,8 @@ class TextCategorizer(Pipe):
|
|||
return float(mean_square_error), d_scores
|
||||
|
||||
def add_label(self, label):
|
||||
if not isinstance(label, basestring_):
|
||||
raise ValueError(Errors.E187)
|
||||
if label in self.labels:
|
||||
return 0
|
||||
if self.model not in (None, True, False):
|
||||
|
|
14
spacy/tests/pipeline/test_tagger.py
Normal file
14
spacy/tests/pipeline/test_tagger.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
from spacy.language import Language
|
||||
from spacy.pipeline import Tagger
|
||||
|
||||
|
||||
def test_label_types():
|
||||
nlp = Language()
|
||||
nlp.add_pipe(nlp.create_pipe("tagger"))
|
||||
nlp.get_pipe("tagger").add_label("A")
|
||||
with pytest.raises(ValueError):
|
||||
nlp.get_pipe("tagger").add_label(9)
|
|
@ -62,3 +62,11 @@ def test_textcat_learns_multilabel():
|
|||
assert score < 0.5
|
||||
else:
|
||||
assert score > 0.5
|
||||
|
||||
|
||||
def test_label_types():
|
||||
nlp = Language()
|
||||
nlp.add_pipe(nlp.create_pipe("textcat"))
|
||||
nlp.get_pipe("textcat").add_label("answer")
|
||||
with pytest.raises(ValueError):
|
||||
nlp.get_pipe("textcat").add_label(9)
|
||||
|
|
Loading…
Reference in New Issue
Block a user