mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Use Literal type for nr_feature_tokens
This commit is contained in:
parent
e4e7f5b00d
commit
76bbed3466
|
@ -20,6 +20,7 @@ pytokenizations
|
||||||
setuptools
|
setuptools
|
||||||
packaging
|
packaging
|
||||||
importlib_metadata>=0.20; python_version < "3.8"
|
importlib_metadata>=0.20; python_version < "3.8"
|
||||||
|
typing_extensions>=3.7.4; python_version < "3.8"
|
||||||
# Development dependencies
|
# Development dependencies
|
||||||
cython>=0.25
|
cython>=0.25
|
||||||
pytest>=4.6.5
|
pytest>=4.6.5
|
||||||
|
|
|
@ -57,6 +57,7 @@ install_requires =
|
||||||
setuptools
|
setuptools
|
||||||
packaging
|
packaging
|
||||||
importlib_metadata>=0.20; python_version < "3.8"
|
importlib_metadata>=0.20; python_version < "3.8"
|
||||||
|
typing_extensions>=3.7.4; python_version < "3.8"
|
||||||
|
|
||||||
[options.entry_points]
|
[options.entry_points]
|
||||||
console_scripts =
|
console_scripts =
|
||||||
|
|
|
@ -22,6 +22,11 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
cupy = None
|
cupy = None
|
||||||
|
|
||||||
|
try: # Python 3.8+
|
||||||
|
from typing import Literal
|
||||||
|
except ImportError:
|
||||||
|
from typing_extensions import Literal # noqa: F401
|
||||||
|
|
||||||
from thinc.api import Optimizer # noqa: F401
|
from thinc.api import Optimizer # noqa: F401
|
||||||
|
|
||||||
pickle = pickle
|
pickle = pickle
|
||||||
|
|
|
@ -2,6 +2,7 @@ from typing import Optional, List
|
||||||
from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops
|
from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops
|
||||||
from thinc.types import Floats2d
|
from thinc.types import Floats2d
|
||||||
|
|
||||||
|
from ...compat import Literal
|
||||||
from ...util import registry
|
from ...util import registry
|
||||||
from .._precomputable_affine import PrecomputableAffine
|
from .._precomputable_affine import PrecomputableAffine
|
||||||
from ..tb_framework import TransitionModel
|
from ..tb_framework import TransitionModel
|
||||||
|
@ -11,7 +12,7 @@ from ...tokens import Doc
|
||||||
@registry.architectures.register("spacy.TransitionBasedParser.v1")
|
@registry.architectures.register("spacy.TransitionBasedParser.v1")
|
||||||
def build_tb_parser_model(
|
def build_tb_parser_model(
|
||||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||||
nr_feature_tokens: int,
|
nr_feature_tokens: Literal[3, 6, 8, 13],
|
||||||
hidden_width: int,
|
hidden_width: int,
|
||||||
maxout_pieces: int,
|
maxout_pieces: int,
|
||||||
use_upper: bool = True,
|
use_upper: bool = True,
|
||||||
|
|
|
@ -67,7 +67,7 @@ width = ${components.tok2vec.model.width}
|
||||||
parser_config_string = """
|
parser_config_string = """
|
||||||
[model]
|
[model]
|
||||||
@architectures = "spacy.TransitionBasedParser.v1"
|
@architectures = "spacy.TransitionBasedParser.v1"
|
||||||
nr_feature_tokens = 99
|
nr_feature_tokens = 3
|
||||||
hidden_width = 66
|
hidden_width = 66
|
||||||
maxout_pieces = 2
|
maxout_pieces = 2
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ def my_parser():
|
||||||
MaxoutWindowEncoder(width=321, window_size=3, maxout_pieces=4, depth=2),
|
MaxoutWindowEncoder(width=321, window_size=3, maxout_pieces=4, depth=2),
|
||||||
)
|
)
|
||||||
parser = build_tb_parser_model(
|
parser = build_tb_parser_model(
|
||||||
tok2vec=tok2vec, nr_feature_tokens=7, hidden_width=65, maxout_pieces=5
|
tok2vec=tok2vec, nr_feature_tokens=8, hidden_width=65, maxout_pieces=5
|
||||||
)
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
@ -340,3 +340,13 @@ def test_config_auto_fill_extra_fields():
|
||||||
assert "extra" not in nlp.config["training"]
|
assert "extra" not in nlp.config["training"]
|
||||||
# Make sure the config generated is valid
|
# Make sure the config generated is valid
|
||||||
load_model_from_config(nlp.config)
|
load_model_from_config(nlp.config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_validate_literal():
|
||||||
|
nlp = English()
|
||||||
|
config = Config().from_str(parser_config_string)
|
||||||
|
config["model"]["nr_feature_tokens"] = 666
|
||||||
|
with pytest.raises(ConfigValidationError):
|
||||||
|
nlp.add_pipe("parser", config=config)
|
||||||
|
config["model"]["nr_feature_tokens"] = 13
|
||||||
|
nlp.add_pipe("parser", config=config)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user