mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 20:30:24 +03:00
Add back spacy.TransitionBasedParser.v2
This commit is contained in:
parent
fbe430a00f
commit
5ffee4863f
|
@ -213,6 +213,8 @@ class Warnings(metaclass=ErrorsWithCodes):
|
||||||
W122 = ("Couldn't trace method '{method}' in pipe '{pipe}'. This can happen if the pipe class "
|
W122 = ("Couldn't trace method '{method}' in pipe '{pipe}'. This can happen if the pipe class "
|
||||||
"is a Cython extension type.")
|
"is a Cython extension type.")
|
||||||
|
|
||||||
|
W400 = ("`use_upper=False` is ignored, the upper layer is always enabled")
|
||||||
|
|
||||||
|
|
||||||
class Errors(metaclass=ErrorsWithCodes):
|
class Errors(metaclass=ErrorsWithCodes):
|
||||||
E001 = ("No component '{name}' found in pipeline. Available names: {opts}")
|
E001 = ("No component '{name}' found in pipeline. Available names: {opts}")
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
from typing import Optional, List, Tuple, Any
|
from typing import Optional, List, Tuple, Any
|
||||||
from thinc.types import Floats2d
|
from thinc.types import Floats2d
|
||||||
from thinc.api import Model
|
from thinc.api import Model
|
||||||
|
import warnings
|
||||||
|
|
||||||
from ...errors import Errors
|
from ...errors import Errors, Warnings
|
||||||
from ...compat import Literal
|
from ...compat import Literal
|
||||||
from ...util import registry
|
from ...util import registry
|
||||||
from ..tb_framework import TransitionModel
|
from ..tb_framework import TransitionModel
|
||||||
|
@ -12,6 +13,29 @@ TransitionSystem = Any # TODO
|
||||||
State = Any # TODO
|
State = Any # TODO
|
||||||
|
|
||||||
|
|
||||||
|
@registry.architectures.register("spacy.TransitionBasedParser.v2")
|
||||||
|
def transition_parser_v2(
|
||||||
|
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||||
|
state_type: Literal["parser", "ner"],
|
||||||
|
extra_state_tokens: bool,
|
||||||
|
hidden_width: int,
|
||||||
|
maxout_pieces: int,
|
||||||
|
use_upper: bool,
|
||||||
|
nO: Optional[int] = None,
|
||||||
|
) -> Model:
|
||||||
|
if not use_upper:
|
||||||
|
warnings.warn(Warnings.W400)
|
||||||
|
|
||||||
|
return build_tb_parser_model(
|
||||||
|
tok2vec,
|
||||||
|
state_type,
|
||||||
|
extra_state_tokens,
|
||||||
|
hidden_width,
|
||||||
|
maxout_pieces,
|
||||||
|
nO=nO,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.TransitionBasedParser.v3")
|
@registry.architectures.register("spacy.TransitionBasedParser.v3")
|
||||||
def transition_parser_v3(
|
def transition_parser_v3(
|
||||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||||
|
|
|
@ -459,6 +459,8 @@ def test_overfitting_IO(pipe_name):
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"parser_config",
|
"parser_config",
|
||||||
[
|
[
|
||||||
|
({"@architectures": "spacy.TransitionBasedParser.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2, "use_upper": True}),
|
||||||
|
({"@architectures": "spacy.TransitionBasedParser.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2, "use_upper": False}),
|
||||||
({"@architectures": "spacy.TransitionBasedParser.v3", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2}),
|
({"@architectures": "spacy.TransitionBasedParser.v3", "tok2vec": DEFAULT_TOK2VEC_MODEL, "state_type": "parser", "extra_state_tokens": False, "hidden_width": 64, "maxout_pieces": 2}),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user