mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 09:14:32 +03:00
state_type as Literal
This commit is contained in:
parent
35dbc63578
commit
5a9fdbc8ad
|
@ -2,7 +2,8 @@ from typing import Optional, List
|
||||||
from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops
|
from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops
|
||||||
from thinc.types import Floats2d
|
from thinc.types import Floats2d
|
||||||
|
|
||||||
from ... import Errors
|
from ...errors import Errors
|
||||||
|
from ...compat import Literal
|
||||||
from ...util import registry
|
from ...util import registry
|
||||||
from .._precomputable_affine import PrecomputableAffine
|
from .._precomputable_affine import PrecomputableAffine
|
||||||
from ..tb_framework import TransitionModel
|
from ..tb_framework import TransitionModel
|
||||||
|
@ -12,7 +13,7 @@ from ...tokens import Doc
|
||||||
@registry.architectures.register("spacy.TransitionBasedParser.v1")
|
@registry.architectures.register("spacy.TransitionBasedParser.v1")
|
||||||
def build_tb_parser_model(
|
def build_tb_parser_model(
|
||||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||||
state_type: str,
|
state_type: Literal["parser", "ner"],
|
||||||
extra_state_tokens: bool,
|
extra_state_tokens: bool,
|
||||||
hidden_width: int,
|
hidden_width: int,
|
||||||
maxout_pieces: int,
|
maxout_pieces: int,
|
||||||
|
|
|
@ -345,3 +345,13 @@ def test_config_auto_fill_extra_fields():
|
||||||
assert "extra" not in nlp.config["training"]
|
assert "extra" not in nlp.config["training"]
|
||||||
# Make sure the config generated is valid
|
# Make sure the config generated is valid
|
||||||
load_model_from_config(nlp.config)
|
load_model_from_config(nlp.config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_validate_literal():
|
||||||
|
nlp = English()
|
||||||
|
config = Config().from_str(parser_config_string)
|
||||||
|
config["model"]["state_type"] = "nonsense"
|
||||||
|
with pytest.raises(ConfigValidationError):
|
||||||
|
nlp.add_pipe("parser", config=config)
|
||||||
|
config["model"]["state_type"] = "ner"
|
||||||
|
nlp.add_pipe("parser", config=config)
|
Loading…
Reference in New Issue
Block a user