mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Add error for non-string labels (#4690)
Add error when attempting to add non-string labels to `Tagger` or `TextCategorizer`.
This commit is contained in:
parent
8d06386e1e
commit
054df5d90a
|
@ -529,6 +529,7 @@ class Errors(object):
|
||||||
E185 = ("Received invalid attribute in component attribute declaration: "
|
E185 = ("Received invalid attribute in component attribute declaration: "
|
||||||
"{obj}.{attr}\nAttribute '{attr}' does not exist on {obj}.")
|
"{obj}.{attr}\nAttribute '{attr}' does not exist on {obj}.")
|
||||||
E186 = ("'{tok_a}' and '{tok_b}' are different texts.")
|
E186 = ("'{tok_a}' and '{tok_b}' are different texts.")
|
||||||
|
E187 = ("Only unicode strings are supported as labels.")
|
||||||
|
|
||||||
|
|
||||||
@add_codes
|
@add_codes
|
||||||
|
|
|
@ -13,6 +13,7 @@ from thinc.misc import LayerNorm
|
||||||
from thinc.neural.util import to_categorical
|
from thinc.neural.util import to_categorical
|
||||||
from thinc.neural.util import get_array_module
|
from thinc.neural.util import get_array_module
|
||||||
|
|
||||||
|
from ..compat import basestring_
|
||||||
from ..tokens.doc cimport Doc
|
from ..tokens.doc cimport Doc
|
||||||
from ..syntax.nn_parser cimport Parser
|
from ..syntax.nn_parser cimport Parser
|
||||||
from ..syntax.ner cimport BiluoPushDown
|
from ..syntax.ner cimport BiluoPushDown
|
||||||
|
@ -547,6 +548,8 @@ class Tagger(Pipe):
|
||||||
return build_tagger_model(n_tags, **cfg)
|
return build_tagger_model(n_tags, **cfg)
|
||||||
|
|
||||||
def add_label(self, label, values=None):
|
def add_label(self, label, values=None):
|
||||||
|
if not isinstance(label, basestring_):
|
||||||
|
raise ValueError(Errors.E187)
|
||||||
if label in self.labels:
|
if label in self.labels:
|
||||||
return 0
|
return 0
|
||||||
if self.model not in (True, False, None):
|
if self.model not in (True, False, None):
|
||||||
|
@ -1016,6 +1019,8 @@ class TextCategorizer(Pipe):
|
||||||
return float(mean_square_error), d_scores
|
return float(mean_square_error), d_scores
|
||||||
|
|
||||||
def add_label(self, label):
|
def add_label(self, label):
|
||||||
|
if not isinstance(label, basestring_):
|
||||||
|
raise ValueError(Errors.E187)
|
||||||
if label in self.labels:
|
if label in self.labels:
|
||||||
return 0
|
return 0
|
||||||
if self.model not in (None, True, False):
|
if self.model not in (None, True, False):
|
||||||
|
|
14
spacy/tests/pipeline/test_tagger.py
Normal file
14
spacy/tests/pipeline/test_tagger.py
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
# coding: utf8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from spacy.language import Language
|
||||||
|
from spacy.pipeline import Tagger
|
||||||
|
|
||||||
|
|
||||||
|
def test_label_types():
|
||||||
|
nlp = Language()
|
||||||
|
nlp.add_pipe(nlp.create_pipe("tagger"))
|
||||||
|
nlp.get_pipe("tagger").add_label("A")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
nlp.get_pipe("tagger").add_label(9)
|
|
@ -62,3 +62,11 @@ def test_textcat_learns_multilabel():
|
||||||
assert score < 0.5
|
assert score < 0.5
|
||||||
else:
|
else:
|
||||||
assert score > 0.5
|
assert score > 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def test_label_types():
|
||||||
|
nlp = Language()
|
||||||
|
nlp.add_pipe(nlp.create_pipe("textcat"))
|
||||||
|
nlp.get_pipe("textcat").add_label("answer")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
nlp.get_pipe("textcat").add_label(9)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user