state_type as Literal

This commit is contained in:
svlandeg 2020-09-23 17:32:14 +02:00
parent 35dbc63578
commit 5a9fdbc8ad
2 changed files with 13 additions and 2 deletions

View File

@ -2,7 +2,8 @@ from typing import Optional, List
from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops
from thinc.types import Floats2d from thinc.types import Floats2d
from ... import Errors from ...errors import Errors
from ...compat import Literal
from ...util import registry from ...util import registry
from .._precomputable_affine import PrecomputableAffine from .._precomputable_affine import PrecomputableAffine
from ..tb_framework import TransitionModel from ..tb_framework import TransitionModel
@ -12,7 +13,7 @@ from ...tokens import Doc
@registry.architectures.register("spacy.TransitionBasedParser.v1") @registry.architectures.register("spacy.TransitionBasedParser.v1")
def build_tb_parser_model( def build_tb_parser_model(
tok2vec: Model[List[Doc], List[Floats2d]], tok2vec: Model[List[Doc], List[Floats2d]],
state_type: str, state_type: Literal["parser", "ner"],
extra_state_tokens: bool, extra_state_tokens: bool,
hidden_width: int, hidden_width: int,
maxout_pieces: int, maxout_pieces: int,

View File

@ -345,3 +345,13 @@ def test_config_auto_fill_extra_fields():
assert "extra" not in nlp.config["training"] assert "extra" not in nlp.config["training"]
# Make sure the config generated is valid # Make sure the config generated is valid
load_model_from_config(nlp.config) load_model_from_config(nlp.config)
def test_config_validate_literal():
nlp = English()
config = Config().from_str(parser_config_string)
config["model"]["state_type"] = "nonsense"
with pytest.raises(ConfigValidationError):
nlp.add_pipe("parser", config=config)
config["model"]["state_type"] = "ner"
nlp.add_pipe("parser", config=config)