Add back spacy.TransitionBasedParser.v2

This commit is contained in:
Daniël de Kok 2022-09-27 12:12:24 +02:00
parent fbe430a00f
commit 5ffee4863f
3 changed files with 29 additions and 1 deletions

View File

@ -213,6 +213,8 @@ class Warnings(metaclass=ErrorsWithCodes):
W122 = ("Couldn't trace method '{method}' in pipe '{pipe}'. This can happen if the pipe class " W122 = ("Couldn't trace method '{method}' in pipe '{pipe}'. This can happen if the pipe class "
"is a Cython extension type.") "is a Cython extension type.")
W400 = ("`use_upper=False` is ignored, the upper layer is always enabled")
class Errors(metaclass=ErrorsWithCodes): class Errors(metaclass=ErrorsWithCodes):
E001 = ("No component '{name}' found in pipeline. Available names: {opts}") E001 = ("No component '{name}' found in pipeline. Available names: {opts}")

View File

@ -1,8 +1,9 @@
from typing import Optional, List, Tuple, Any from typing import Optional, List, Tuple, Any
from thinc.types import Floats2d from thinc.types import Floats2d
from thinc.api import Model from thinc.api import Model
import warnings
from ...errors import Errors from ...errors import Errors, Warnings
from ...compat import Literal from ...compat import Literal
from ...util import registry from ...util import registry
from ..tb_framework import TransitionModel from ..tb_framework import TransitionModel
@ -12,6 +13,29 @@ TransitionSystem = Any # TODO
State = Any # TODO State = Any # TODO
@registry.architectures.register("spacy.TransitionBasedParser.v2")
def transition_parser_v2(
tok2vec: Model[List[Doc], List[Floats2d]],
state_type: Literal["parser", "ner"],
extra_state_tokens: bool,
hidden_width: int,
maxout_pieces: int,
use_upper: bool,
nO: Optional[int] = None,
) -> Model:
if not use_upper:
warnings.warn(Warnings.W400)
return build_tb_parser_model(
tok2vec,
state_type,
extra_state_tokens,
hidden_width,
maxout_pieces,
nO=nO,
)
@registry.architectures.register("spacy.TransitionBasedParser.v3") @registry.architectures.register("spacy.TransitionBasedParser.v3")
def transition_parser_v3( def transition_parser_v3(
tok2vec: Model[List[Doc], List[Floats2d]], tok2vec: Model[List[Doc], List[Floats2d]],

View File

@ -459,6 +459,8 @@ def test_overfitting_IO(pipe_name):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"parser_config", "parser_config",
[ [
({"@architectures": "spacy.TransitionBasedParser.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2, "use_upper": True}),
({"@architectures": "spacy.TransitionBasedParser.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2, "use_upper": False}),
({"@architectures": "spacy.TransitionBasedParser.v3", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2}), ({"@architectures": "spacy.TransitionBasedParser.v3", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2}),
], ],
) )