fix get_dim calls in build_simple_cnn_text_classifier

This commit is contained in:
svlandeg 2020-10-09 15:40:58 +02:00
parent 853edace37
commit 040c7c0541
2 changed files with 3 additions and 3 deletions

View File

@ -24,11 +24,11 @@ def build_simple_cnn_text_classifier(
""" """
with Model.define_operators({">>": chain}): with Model.define_operators({">>": chain}):
if exclusive_classes: if exclusive_classes:
output_layer = Softmax(nO=nO, nI=tok2vec.get_dim("nO")) output_layer = Softmax(nO=nO, nI=tok2vec.maybe_get_dim("nO"))
model = tok2vec >> list2ragged() >> reduce_mean() >> output_layer model = tok2vec >> list2ragged() >> reduce_mean() >> output_layer
model.set_ref("output_layer", output_layer) model.set_ref("output_layer", output_layer)
else: else:
linear_layer = Linear(nO=nO, nI=tok2vec.get_dim("nO")) linear_layer = Linear(nO=nO, nI=tok2vec.maybe_get_dim("nO"))
model = ( model = (
tok2vec >> list2ragged() >> reduce_mean() >> linear_layer >> Logistic() tok2vec >> list2ragged() >> reduce_mean() >> linear_layer >> Logistic()
) )

View File

@ -622,7 +622,7 @@ def load_meta(path: Union[str, Path]) -> Dict[str, Any]:
if not path.parent.exists(): if not path.parent.exists():
raise IOError(Errors.E052.format(path=path.parent)) raise IOError(Errors.E052.format(path=path.parent))
if not path.exists() or not path.is_file(): if not path.exists() or not path.is_file():
raise IOError(Errors.E053.format(path=path, name="meta.json")) raise IOError(Errors.E053.format(path=path.parent, name="meta.json"))
meta = srsly.read_json(path) meta = srsly.read_json(path)
for setting in ["lang", "name", "version"]: for setting in ["lang", "name", "version"]:
if setting not in meta or not meta[setting]: if setting not in meta or not meta[setting]: