Raise for bad Vocab values

This commit is contained in:
Ines Montani 2020-09-15 13:25:34 +02:00
parent 0edd695bf6
commit 253ba5ef14
3 changed files with 16 additions and 0 deletions

View File

@ -480,6 +480,9 @@ class Errors:
E201 = ("Span index out of range.") E201 = ("Span index out of range.")
# TODO: fix numbering after merging develop into master # TODO: fix numbering after merging develop into master
E918 = ("Received invalid value for vocab: {vocab} ({vocab_type}). Valid "
"values are an instance of spacy.vocab.Vocab or True to create one"
" (default).")
E919 = ("A textcat 'positive_label' '{pos_label}' was provided for training " E919 = ("A textcat 'positive_label' '{pos_label}' was provided for training "
"data that does not appear to be a binary classification problem " "data that does not appear to be a binary classification problem "
"with two labels. Labels found: {labels}") "with two labels. Labels found: {labels}")

View File

@ -144,6 +144,8 @@ class Language:
self._pipe_meta: Dict[str, "FactoryMeta"] = {} # meta by component self._pipe_meta: Dict[str, "FactoryMeta"] = {} # meta by component
self._pipe_configs: Dict[str, Config] = {} # config by component self._pipe_configs: Dict[str, Config] = {} # config by component
if not isinstance(vocab, Vocab) and vocab is not True:
raise ValueError(Errors.E918.format(vocab=vocab, vocab_type=type(Vocab)))
if vocab is True: if vocab is True:
vectors_name = meta.get("vectors", {}).get("name") vectors_name = meta.get("vectors", {}).get("name")
vocab = create_vocab( vocab = create_vocab(

View File

@ -277,3 +277,14 @@ def test_spacy_blank():
nlp = spacy.blank("en", config=config, meta=meta) nlp = spacy.blank("en", config=config, meta=meta)
assert nlp.config["training"]["dropout"] == 0.2 assert nlp.config["training"]["dropout"] == 0.2
assert nlp.meta["name"] == "my_custom_model" assert nlp.meta["name"] == "my_custom_model"
@pytest.mark.parametrize(
"value",
[False, None, ["x", "y"], Language, Vocab],
)
def test_language_init_invalid_vocab(value):
err_fragment = "invalid value"
with pytest.raises(ValueError) as e:
Language(value)
assert err_fragment in str(e)