mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-15 02:32:37 +03:00
Wire up tb_framework to new parser model
This commit is contained in:
parent
0279aa036a
commit
71abe2e42d
|
@ -1,5 +1,5 @@
|
||||||
from typing import List, Tuple, Any, Optional
|
from typing import List, Tuple, Any, Optional
|
||||||
from thinc.api import Ops, Model, normal_init
|
from thinc.api import Ops, Model, normal_init, chain, list2array, Linear
|
||||||
from thinc.types import Floats1d, Floats2d, Floats3d, Ints2d, Floats4d
|
from thinc.types import Floats1d, Floats2d, Floats3d, Ints2d, Floats4d
|
||||||
from ..tokens.doc import Doc
|
from ..tokens.doc import Doc
|
||||||
|
|
||||||
|
@ -20,11 +20,15 @@ def TransitionModel(
|
||||||
"""Set up a transition-based parsing model, using a maxout hidden
|
"""Set up a transition-based parsing model, using a maxout hidden
|
||||||
layer and a linear output layer.
|
layer and a linear output layer.
|
||||||
"""
|
"""
|
||||||
|
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
|
||||||
|
tok2vec_projected = chain(tok2vec, list2array(), Linear(hidden_width, t2v_width))
|
||||||
|
tok2vec_projected.set_dim("nO", hidden_width)
|
||||||
|
|
||||||
return Model(
|
return Model(
|
||||||
name="parser_model",
|
name="parser_model",
|
||||||
forward=forward,
|
forward=forward,
|
||||||
init=init,
|
init=init,
|
||||||
layers=[tok2vec],
|
layers=[tok2vec_projected],
|
||||||
refs={"tok2vec": tok2vec},
|
refs={"tok2vec": tok2vec},
|
||||||
params={
|
params={
|
||||||
"lower_W": None, # Floats2d W for the hidden layer
|
"lower_W": None, # Floats2d W for the hidden layer
|
||||||
|
|
Loading…
Reference in New Issue
Block a user