diff --git a/spacy/ml/tb_framework.py b/spacy/ml/tb_framework.py index 1e14d239e..ddc283216 100644 --- a/spacy/ml/tb_framework.py +++ b/spacy/ml/tb_framework.py @@ -1,5 +1,5 @@ from typing import List, Tuple, Any, Optional -from thinc.api import Ops, Model, normal_init +from thinc.api import Ops, Model, normal_init, chain, list2array, Linear from thinc.types import Floats1d, Floats2d, Floats3d, Ints2d, Floats4d from ..tokens.doc import Doc @@ -20,11 +20,15 @@ def TransitionModel( """Set up a transition-based parsing model, using a maxout hidden layer and a linear output layer. """ + t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None + tok2vec_projected = chain(tok2vec, list2array(), Linear(hidden_width, t2v_width)) + tok2vec_projected.set_dim("nO", hidden_width) + return Model( name="parser_model", forward=forward, init=init, - layers=[tok2vec], + layers=[tok2vec_projected], refs={"tok2vec": tok2vec}, params={ "lower_W": None, # Floats2d W for the hidden layer