mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-11 17:56:30 +03:00
Update textcat ensemble model
This commit is contained in:
parent
f50502dad7
commit
c2a18e4fa3
|
@ -72,23 +72,20 @@ def build_text_classifier_v2(
|
||||||
attention_layer = ParametricAttention(
|
attention_layer = ParametricAttention(
|
||||||
width
|
width
|
||||||
) # TODO: benchmark performance difference of this layer
|
) # TODO: benchmark performance difference of this layer
|
||||||
maxout_layer = Maxout(nO=width, nI=width)
|
maxout_layer = Maxout(nO=width, nI=width, dropout=0.0, normalize=True)
|
||||||
linear_layer = Linear(nO=nO, nI=width)
|
|
||||||
cnn_model = (
|
cnn_model = (
|
||||||
tok2vec
|
tok2vec
|
||||||
>> list2ragged()
|
>> list2ragged()
|
||||||
>> attention_layer
|
>> attention_layer
|
||||||
>> reduce_sum()
|
>> reduce_sum()
|
||||||
>> residual(maxout_layer)
|
>> residual(maxout_layer)
|
||||||
>> linear_layer
|
|
||||||
>> Dropout(0.0)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
nO_double = nO * 2 if nO else None
|
nO_double = nO * 2 if nO else None
|
||||||
if exclusive_classes:
|
if exclusive_classes:
|
||||||
output_layer = Softmax(nO=nO, nI=nO_double)
|
output_layer = Softmax(nO=nO, nI=nO_double)
|
||||||
else:
|
else:
|
||||||
output_layer = Linear(nO=nO, nI=nO_double) >> Dropout(0.0) >> Logistic()
|
output_layer = Linear(nO=nO, nI=nO_double) >> Logistic()
|
||||||
model = (linear_model | cnn_model) >> output_layer
|
model = (linear_model | cnn_model) >> output_layer
|
||||||
model.set_ref("tok2vec", tok2vec)
|
model.set_ref("tok2vec", tok2vec)
|
||||||
if model.has_dim("nO") is not False:
|
if model.has_dim("nO") is not False:
|
||||||
|
@ -96,7 +93,6 @@ def build_text_classifier_v2(
|
||||||
model.set_ref("output_layer", linear_model.get_ref("output_layer"))
|
model.set_ref("output_layer", linear_model.get_ref("output_layer"))
|
||||||
model.set_ref("attention_layer", attention_layer)
|
model.set_ref("attention_layer", attention_layer)
|
||||||
model.set_ref("maxout_layer", maxout_layer)
|
model.set_ref("maxout_layer", maxout_layer)
|
||||||
model.set_ref("linear_layer", linear_layer)
|
|
||||||
model.attrs["multi_label"] = not exclusive_classes
|
model.attrs["multi_label"] = not exclusive_classes
|
||||||
|
|
||||||
model.init = init_ensemble_textcat
|
model.init = init_ensemble_textcat
|
||||||
|
@ -108,7 +104,6 @@ def init_ensemble_textcat(model, X, Y) -> Model:
|
||||||
model.get_ref("attention_layer").set_dim("nO", tok2vec_width)
|
model.get_ref("attention_layer").set_dim("nO", tok2vec_width)
|
||||||
model.get_ref("maxout_layer").set_dim("nO", tok2vec_width)
|
model.get_ref("maxout_layer").set_dim("nO", tok2vec_width)
|
||||||
model.get_ref("maxout_layer").set_dim("nI", tok2vec_width)
|
model.get_ref("maxout_layer").set_dim("nI", tok2vec_width)
|
||||||
model.get_ref("linear_layer").set_dim("nI", tok2vec_width)
|
|
||||||
init_chain(model, X, Y)
|
init_chain(model, X, Y)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user