mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
fix get_dim calls in build_simple_cnn_text_classifier
This commit is contained in:
parent
853edace37
commit
040c7c0541
|
@ -24,11 +24,11 @@ def build_simple_cnn_text_classifier(
|
|||
"""
|
||||
with Model.define_operators({">>": chain}):
|
||||
if exclusive_classes:
|
||||
output_layer = Softmax(nO=nO, nI=tok2vec.get_dim("nO"))
|
||||
output_layer = Softmax(nO=nO, nI=tok2vec.maybe_get_dim("nO"))
|
||||
model = tok2vec >> list2ragged() >> reduce_mean() >> output_layer
|
||||
model.set_ref("output_layer", output_layer)
|
||||
else:
|
||||
linear_layer = Linear(nO=nO, nI=tok2vec.get_dim("nO"))
|
||||
linear_layer = Linear(nO=nO, nI=tok2vec.maybe_get_dim("nO"))
|
||||
model = (
|
||||
tok2vec >> list2ragged() >> reduce_mean() >> linear_layer >> Logistic()
|
||||
)
|
||||
|
|
|
@ -622,7 +622,7 @@ def load_meta(path: Union[str, Path]) -> Dict[str, Any]:
|
|||
if not path.parent.exists():
|
||||
raise IOError(Errors.E052.format(path=path.parent))
|
||||
if not path.exists() or not path.is_file():
|
||||
raise IOError(Errors.E053.format(path=path, name="meta.json"))
|
||||
raise IOError(Errors.E053.format(path=path.parent, name="meta.json"))
|
||||
meta = srsly.read_json(path)
|
||||
for setting in ["lang", "name", "version"]:
|
||||
if setting not in meta or not meta[setting]:
|
||||
|
|
Loading…
Reference in New Issue
Block a user