throw custom error when state_type is invalid

This commit is contained in:
svlandeg 2020-09-23 16:57:14 +02:00
parent dd2292793f
commit 25b34bba94
2 changed files with 4 additions and 1 deletions

View File

@ -480,6 +480,8 @@ class Errors:
E201 = ("Span index out of range.") E201 = ("Span index out of range.")
# TODO: fix numbering after merging develop into master # TODO: fix numbering after merging develop into master
E917 = ("Received invalid value {value} for 'state_type' in "
"TransitionBasedParser: only 'parser' or 'ner' are valid options.")
E918 = ("Received invalid value for vocab: {vocab} ({vocab_type}). Valid " E918 = ("Received invalid value for vocab: {vocab} ({vocab_type}). Valid "
"values are an instance of spacy.vocab.Vocab or True to create one" "values are an instance of spacy.vocab.Vocab or True to create one"
" (default).") " (default).")

View File

@ -2,6 +2,7 @@ from typing import Optional, List
from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops
from thinc.types import Floats2d from thinc.types import Floats2d
from ... import Errors
from ...util import registry from ...util import registry
from .._precomputable_affine import PrecomputableAffine from .._precomputable_affine import PrecomputableAffine
from ..tb_framework import TransitionModel from ..tb_framework import TransitionModel
@ -66,7 +67,7 @@ def build_tb_parser_model(
elif state_type == "ner": elif state_type == "ner":
nr_feature_tokens = 6 if extra_state_tokens else 3 nr_feature_tokens = 6 if extra_state_tokens else 3
else: else:
raise ValueError(f"unknown state type {state_type}") # TODO error raise ValueError(Errors.E917.format(value=state_type))
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
tok2vec = chain(tok2vec, list2array(), Linear(hidden_width, t2v_width)) tok2vec = chain(tok2vec, list2array(), Linear(hidden_width, t2v_width))
tok2vec.set_dim("nO", hidden_width) tok2vec.set_dim("nO", hidden_width)