From 5ffee4863fbf7254de828e04a72855c0cbf4ad50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Tue, 27 Sep 2022 12:12:24 +0200 Subject: [PATCH] Add back spacy.TransitionBasedParser.v2 --- spacy/errors.py | 2 ++ spacy/ml/models/parser.py | 26 +++++++++++++++++++++++++- spacy/tests/parser/test_parse.py | 2 ++ 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/spacy/errors.py b/spacy/errors.py index ff5ddacf0..3268cb437 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -213,6 +213,8 @@ class Warnings(metaclass=ErrorsWithCodes): W122 = ("Couldn't trace method '{method}' in pipe '{pipe}'. This can happen if the pipe class " "is a Cython extension type.") + W400 = ("`use_upper=False` is ignored, the upper layer is always enabled") + class Errors(metaclass=ErrorsWithCodes): E001 = ("No component '{name}' found in pipeline. Available names: {opts}") diff --git a/spacy/ml/models/parser.py b/spacy/ml/models/parser.py index bbc5bf957..e2ee87d82 100644 --- a/spacy/ml/models/parser.py +++ b/spacy/ml/models/parser.py @@ -1,8 +1,9 @@ from typing import Optional, List, Tuple, Any from thinc.types import Floats2d from thinc.api import Model +import warnings -from ...errors import Errors +from ...errors import Errors, Warnings from ...compat import Literal from ...util import registry from ..tb_framework import TransitionModel @@ -12,6 +13,29 @@ TransitionSystem = Any # TODO State = Any # TODO +@registry.architectures.register("spacy.TransitionBasedParser.v2") +def transition_parser_v2( + tok2vec: Model[List[Doc], List[Floats2d]], + state_type: Literal["parser", "ner"], + extra_state_tokens: bool, + hidden_width: int, + maxout_pieces: int, + use_upper: bool, + nO: Optional[int] = None, +) -> Model: + if not use_upper: + warnings.warn(Warnings.W400) + + return build_tb_parser_model( + tok2vec, + state_type, + extra_state_tokens, + hidden_width, + maxout_pieces, + nO=nO, + ) + + @registry.architectures.register("spacy.TransitionBasedParser.v3") def transition_parser_v3( tok2vec: Model[List[Doc], List[Floats2d]], diff --git a/spacy/tests/parser/test_parse.py b/spacy/tests/parser/test_parse.py index c7427fac6..af33dcf5f 100644 --- a/spacy/tests/parser/test_parse.py +++ b/spacy/tests/parser/test_parse.py @@ -459,6 +459,8 @@ def test_overfitting_IO(pipe_name): @pytest.mark.parametrize( "parser_config", [ + ({"@architectures": "spacy.TransitionBasedParser.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2, "use_upper": True}), + ({"@architectures": "spacy.TransitionBasedParser.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2, "use_upper": False}), ({"@architectures": "spacy.TransitionBasedParser.v3", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2}), ], )