mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
Wire up parser model
This commit is contained in:
parent
71abe2e42d
commit
45ca12f07a
|
@ -1,13 +1,15 @@
|
|||
from typing import Optional, List
|
||||
from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops
|
||||
from typing import Optional, List, Tuple, Any
|
||||
from thinc.types import Floats2d
|
||||
from thinc.api import Model
|
||||
|
||||
from ...errors import Errors
|
||||
from ...compat import Literal
|
||||
from ...util import registry
|
||||
from .._precomputable_affine import PrecomputableAffine
|
||||
from ..tb_framework import TransitionModel
|
||||
from ...tokens import Doc
|
||||
from ...tokens.doc import Doc
|
||||
|
||||
TransitionSystem = Any # TODO
|
||||
State = Any # TODO
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.TransitionBasedParser.v1")
|
||||
|
@ -19,7 +21,7 @@ def transition_parser_v1(
|
|||
maxout_pieces: int,
|
||||
use_upper: bool = True,
|
||||
nO: Optional[int] = None,
|
||||
) -> Model:
|
||||
) -> Model[Tuple[List[Doc], TransitionSystem], List[Tuple[State, List[Floats2d]]]]:
|
||||
return build_tb_parser_model(
|
||||
tok2vec,
|
||||
state_type,
|
||||
|
@ -47,8 +49,26 @@ def transition_parser_v2(
|
|||
extra_state_tokens,
|
||||
hidden_width,
|
||||
maxout_pieces,
|
||||
use_upper,
|
||||
nO,
|
||||
nO=nO,
|
||||
)
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.TransitionBasedParser.v3")
|
||||
def transition_parser_v2(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||
state_type: Literal["parser", "ner"],
|
||||
extra_state_tokens: bool,
|
||||
hidden_width: int,
|
||||
maxout_pieces: int,
|
||||
nO: Optional[int] = None,
|
||||
) -> Model:
|
||||
return build_tb_parser_model(
|
||||
tok2vec,
|
||||
state_type,
|
||||
extra_state_tokens,
|
||||
hidden_width,
|
||||
maxout_pieces,
|
||||
nO=nO,
|
||||
)
|
||||
|
||||
|
||||
|
@ -58,7 +78,6 @@ def build_tb_parser_model(
|
|||
extra_state_tokens: bool,
|
||||
hidden_width: int,
|
||||
maxout_pieces: int,
|
||||
use_upper: bool,
|
||||
nO: Optional[int] = None,
|
||||
) -> Model:
|
||||
"""
|
||||
|
@ -110,102 +129,11 @@ def build_tb_parser_model(
|
|||
nr_feature_tokens = 6 if extra_state_tokens else 3
|
||||
else:
|
||||
raise ValueError(Errors.E917.format(value=state_type))
|
||||
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
|
||||
tok2vec = chain(tok2vec, list2array(), Linear(hidden_width, t2v_width))
|
||||
tok2vec.set_dim("nO", hidden_width)
|
||||
lower = _define_lower(
|
||||
nO=hidden_width if use_upper else nO,
|
||||
nF=nr_feature_tokens,
|
||||
nI=tok2vec.get_dim("nO"),
|
||||
nP=maxout_pieces,
|
||||
return TransitionModel(
|
||||
tok2vec=tok2vec,
|
||||
state_tokens=nr_feature_tokens,
|
||||
hidden_width=hidden_width,
|
||||
maxout_pieces=maxout_pieces,
|
||||
nO=nO,
|
||||
unseen_classes=set(),
|
||||
)
|
||||
upper = None
|
||||
if use_upper:
|
||||
with use_ops("numpy"):
|
||||
# Initialize weights at zero, as it's a classification layer.
|
||||
upper = _define_upper(nO=nO, nI=None)
|
||||
return TransitionModel(tok2vec, lower, upper, resize_output)
|
||||
|
||||
|
||||
def _define_upper(nO, nI):
|
||||
return Linear(nO=nO, nI=nI, init_W=zero_init)
|
||||
|
||||
|
||||
def _define_lower(nO, nF, nI, nP):
|
||||
return PrecomputableAffine(nO=nO, nF=nF, nI=nI, nP=nP)
|
||||
|
||||
|
||||
def resize_output(model, new_nO):
|
||||
if model.attrs["has_upper"]:
|
||||
return _resize_upper(model, new_nO)
|
||||
return _resize_lower(model, new_nO)
|
||||
|
||||
|
||||
def _resize_upper(model, new_nO):
|
||||
upper = model.get_ref("upper")
|
||||
if upper.has_dim("nO") is None:
|
||||
upper.set_dim("nO", new_nO)
|
||||
return model
|
||||
elif new_nO == upper.get_dim("nO"):
|
||||
return model
|
||||
|
||||
smaller = upper
|
||||
nI = smaller.maybe_get_dim("nI")
|
||||
with use_ops("numpy"):
|
||||
larger = _define_upper(nO=new_nO, nI=nI)
|
||||
# it could be that the model is not initialized yet, then skip this bit
|
||||
if smaller.has_param("W"):
|
||||
larger_W = larger.ops.alloc2f(new_nO, nI)
|
||||
larger_b = larger.ops.alloc1f(new_nO)
|
||||
smaller_W = smaller.get_param("W")
|
||||
smaller_b = smaller.get_param("b")
|
||||
# Weights are stored in (nr_out, nr_in) format, so we're basically
|
||||
# just adding rows here.
|
||||
if smaller.has_dim("nO"):
|
||||
old_nO = smaller.get_dim("nO")
|
||||
larger_W[:old_nO] = smaller_W
|
||||
larger_b[:old_nO] = smaller_b
|
||||
for i in range(old_nO, new_nO):
|
||||
model.attrs["unseen_classes"].add(i)
|
||||
|
||||
larger.set_param("W", larger_W)
|
||||
larger.set_param("b", larger_b)
|
||||
model._layers[-1] = larger
|
||||
model.set_ref("upper", larger)
|
||||
return model
|
||||
|
||||
|
||||
def _resize_lower(model, new_nO):
|
||||
lower = model.get_ref("lower")
|
||||
if lower.has_dim("nO") is None:
|
||||
lower.set_dim("nO", new_nO)
|
||||
return model
|
||||
|
||||
smaller = lower
|
||||
nI = smaller.maybe_get_dim("nI")
|
||||
nF = smaller.maybe_get_dim("nF")
|
||||
nP = smaller.maybe_get_dim("nP")
|
||||
larger = _define_lower(nO=new_nO, nI=nI, nF=nF, nP=nP)
|
||||
# it could be that the model is not initialized yet, then skip this bit
|
||||
if smaller.has_param("W"):
|
||||
larger_W = larger.ops.alloc4f(nF, new_nO, nP, nI)
|
||||
larger_b = larger.ops.alloc2f(new_nO, nP)
|
||||
larger_pad = larger.ops.alloc4f(1, nF, new_nO, nP)
|
||||
smaller_W = smaller.get_param("W")
|
||||
smaller_b = smaller.get_param("b")
|
||||
smaller_pad = smaller.get_param("pad")
|
||||
# Copy the old weights and padding into the new layer
|
||||
if smaller.has_dim("nO"):
|
||||
old_nO = smaller.get_dim("nO")
|
||||
larger_W[:, 0:old_nO, :, :] = smaller_W
|
||||
larger_pad[:, :, 0:old_nO, :] = smaller_pad
|
||||
larger_b[0:old_nO, :] = smaller_b
|
||||
for i in range(old_nO, new_nO):
|
||||
model.attrs["unseen_classes"].add(i)
|
||||
|
||||
larger.set_param("W", larger_W)
|
||||
larger.set_param("b", larger_b)
|
||||
larger.set_param("pad", larger_pad)
|
||||
model._layers[1] = larger
|
||||
model.set_ref("lower", larger)
|
||||
return model
|
||||
|
|
Loading…
Reference in New Issue
Block a user