mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
Fix shape inference
This commit is contained in:
parent
df87c32a40
commit
d507ac28d8
|
@ -15,9 +15,10 @@ def build_tb_parser_model(
|
||||||
use_upper=True,
|
use_upper=True,
|
||||||
nO=None,
|
nO=None,
|
||||||
):
|
):
|
||||||
|
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
|
||||||
tok2vec = chain(
|
tok2vec = chain(
|
||||||
tok2vec,
|
tok2vec,
|
||||||
with_array(Linear(hidden_width)),
|
with_array(Linear(hidden_width, t2v_width)),
|
||||||
list2array(),
|
list2array(),
|
||||||
)
|
)
|
||||||
tok2vec.set_dim("nO", hidden_width)
|
tok2vec.set_dim("nO", hidden_width)
|
||||||
|
|
|
@ -7,7 +7,8 @@ from ...util import registry
|
||||||
@registry.architectures.register("spacy.Tagger.v1")
|
@registry.architectures.register("spacy.Tagger.v1")
|
||||||
def build_tagger_model(tok2vec, nO=None) -> Model:
|
def build_tagger_model(tok2vec, nO=None) -> Model:
|
||||||
# TODO: glorot_uniform_init seems to work a bit better than zero_init here?!
|
# TODO: glorot_uniform_init seems to work a bit better than zero_init here?!
|
||||||
output_layer = Softmax(nO, init_W=zero_init)
|
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
|
||||||
|
output_layer = Softmax(nO, t2v_width, init_W=zero_init)
|
||||||
softmax = with_array(output_layer)
|
softmax = with_array(output_layer)
|
||||||
model = chain(tok2vec, softmax)
|
model = chain(tok2vec, softmax)
|
||||||
model.set_ref("tok2vec", tok2vec)
|
model.set_ref("tok2vec", tok2vec)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user