Add back spacy.TransitionBasedParser.v2

This commit is contained in:
Daniël de Kok 2022-09-27 12:12:24 +02:00
parent fbe430a00f
commit 5ffee4863f
3 changed files with 29 additions and 1 deletions

View File

@ -213,6 +213,8 @@ class Warnings(metaclass=ErrorsWithCodes):
W122 = ("Couldn't trace method '{method}' in pipe '{pipe}'. This can happen if the pipe class "
"is a Cython extension type.")
W400 = ("`use_upper=False` is ignored, the upper layer is always enabled")
class Errors(metaclass=ErrorsWithCodes):
E001 = ("No component '{name}' found in pipeline. Available names: {opts}")

View File

@ -1,8 +1,9 @@
from typing import Optional, List, Tuple, Any
from thinc.types import Floats2d
from thinc.api import Model
import warnings
from ...errors import Errors
from ...errors import Errors, Warnings
from ...compat import Literal
from ...util import registry
from ..tb_framework import TransitionModel
@ -12,6 +13,29 @@ TransitionSystem = Any # TODO
State = Any # TODO
@registry.architectures.register("spacy.TransitionBasedParser.v2")
def transition_parser_v2(
tok2vec: Model[List[Doc], List[Floats2d]],
state_type: Literal["parser", "ner"],
extra_state_tokens: bool,
hidden_width: int,
maxout_pieces: int,
use_upper: bool,
nO: Optional[int] = None,
) -> Model:
if not use_upper:
warnings.warn(Warnings.W400)
return build_tb_parser_model(
tok2vec,
state_type,
extra_state_tokens,
hidden_width,
maxout_pieces,
nO=nO,
)
@registry.architectures.register("spacy.TransitionBasedParser.v3")
def transition_parser_v3(
tok2vec: Model[List[Doc], List[Floats2d]],

View File

@ -459,6 +459,8 @@ def test_overfitting_IO(pipe_name):
@pytest.mark.parametrize(
"parser_config",
[
({"@architectures": "spacy.TransitionBasedParser.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2, "use_upper": True}),
({"@architectures": "spacy.TransitionBasedParser.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2, "use_upper": False}),
({"@architectures": "spacy.TransitionBasedParser.v3", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2}),
],
)