Fix shape inference

This commit is contained in:
Matthw Honnibal 2020-05-21 20:46:10 +02:00
parent df87c32a40
commit d507ac28d8
2 changed files with 4 additions and 2 deletions

View File

@ -15,9 +15,10 @@ def build_tb_parser_model(
use_upper=True, use_upper=True,
nO=None, nO=None,
): ):
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
tok2vec = chain( tok2vec = chain(
tok2vec, tok2vec,
with_array(Linear(hidden_width)), with_array(Linear(hidden_width, t2v_width)),
list2array(), list2array(),
) )
tok2vec.set_dim("nO", hidden_width) tok2vec.set_dim("nO", hidden_width)

View File

@ -7,7 +7,8 @@ from ...util import registry
@registry.architectures.register("spacy.Tagger.v1") @registry.architectures.register("spacy.Tagger.v1")
def build_tagger_model(tok2vec, nO=None) -> Model: def build_tagger_model(tok2vec, nO=None) -> Model:
# TODO: glorot_uniform_init seems to work a bit better than zero_init here?! # TODO: glorot_uniform_init seems to work a bit better than zero_init here?!
output_layer = Softmax(nO, init_W=zero_init) t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
output_layer = Softmax(nO, t2v_width, init_W=zero_init)
softmax = with_array(output_layer) softmax = with_array(output_layer)
model = chain(tok2vec, softmax) model = chain(tok2vec, softmax)
model.set_ref("tok2vec", tok2vec) model.set_ref("tok2vec", tok2vec)