mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
fix get_dim calls in build_simple_cnn_text_classifier
This commit is contained in:
parent
853edace37
commit
040c7c0541
|
@ -24,11 +24,11 @@ def build_simple_cnn_text_classifier(
|
||||||
"""
|
"""
|
||||||
with Model.define_operators({">>": chain}):
|
with Model.define_operators({">>": chain}):
|
||||||
if exclusive_classes:
|
if exclusive_classes:
|
||||||
output_layer = Softmax(nO=nO, nI=tok2vec.get_dim("nO"))
|
output_layer = Softmax(nO=nO, nI=tok2vec.maybe_get_dim("nO"))
|
||||||
model = tok2vec >> list2ragged() >> reduce_mean() >> output_layer
|
model = tok2vec >> list2ragged() >> reduce_mean() >> output_layer
|
||||||
model.set_ref("output_layer", output_layer)
|
model.set_ref("output_layer", output_layer)
|
||||||
else:
|
else:
|
||||||
linear_layer = Linear(nO=nO, nI=tok2vec.get_dim("nO"))
|
linear_layer = Linear(nO=nO, nI=tok2vec.maybe_get_dim("nO"))
|
||||||
model = (
|
model = (
|
||||||
tok2vec >> list2ragged() >> reduce_mean() >> linear_layer >> Logistic()
|
tok2vec >> list2ragged() >> reduce_mean() >> linear_layer >> Logistic()
|
||||||
)
|
)
|
||||||
|
|
|
@ -622,7 +622,7 @@ def load_meta(path: Union[str, Path]) -> Dict[str, Any]:
|
||||||
if not path.parent.exists():
|
if not path.parent.exists():
|
||||||
raise IOError(Errors.E052.format(path=path.parent))
|
raise IOError(Errors.E052.format(path=path.parent))
|
||||||
if not path.exists() or not path.is_file():
|
if not path.exists() or not path.is_file():
|
||||||
raise IOError(Errors.E053.format(path=path, name="meta.json"))
|
raise IOError(Errors.E053.format(path=path.parent, name="meta.json"))
|
||||||
meta = srsly.read_json(path)
|
meta = srsly.read_json(path)
|
||||||
for setting in ["lang", "name", "version"]:
|
for setting in ["lang", "name", "version"]:
|
||||||
if setting not in meta or not meta[setting]:
|
if setting not in meta or not meta[setting]:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user