mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 00:46:28 +03:00
Update textcat ensemble model
This commit is contained in:
parent
f50502dad7
commit
c2a18e4fa3
|
@ -72,23 +72,20 @@ def build_text_classifier_v2(
|
|||
attention_layer = ParametricAttention(
|
||||
width
|
||||
) # TODO: benchmark performance difference of this layer
|
||||
maxout_layer = Maxout(nO=width, nI=width)
|
||||
linear_layer = Linear(nO=nO, nI=width)
|
||||
maxout_layer = Maxout(nO=width, nI=width, dropout=0.0, normalize=True)
|
||||
cnn_model = (
|
||||
tok2vec
|
||||
>> list2ragged()
|
||||
>> attention_layer
|
||||
>> reduce_sum()
|
||||
>> residual(maxout_layer)
|
||||
>> linear_layer
|
||||
>> Dropout(0.0)
|
||||
)
|
||||
|
||||
nO_double = nO * 2 if nO else None
|
||||
if exclusive_classes:
|
||||
output_layer = Softmax(nO=nO, nI=nO_double)
|
||||
else:
|
||||
output_layer = Linear(nO=nO, nI=nO_double) >> Dropout(0.0) >> Logistic()
|
||||
output_layer = Linear(nO=nO, nI=nO_double) >> Logistic()
|
||||
model = (linear_model | cnn_model) >> output_layer
|
||||
model.set_ref("tok2vec", tok2vec)
|
||||
if model.has_dim("nO") is not False:
|
||||
|
@ -96,7 +93,6 @@ def build_text_classifier_v2(
|
|||
model.set_ref("output_layer", linear_model.get_ref("output_layer"))
|
||||
model.set_ref("attention_layer", attention_layer)
|
||||
model.set_ref("maxout_layer", maxout_layer)
|
||||
model.set_ref("linear_layer", linear_layer)
|
||||
model.attrs["multi_label"] = not exclusive_classes
|
||||
|
||||
model.init = init_ensemble_textcat
|
||||
|
@ -108,7 +104,6 @@ def init_ensemble_textcat(model, X, Y) -> Model:
|
|||
model.get_ref("attention_layer").set_dim("nO", tok2vec_width)
|
||||
model.get_ref("maxout_layer").set_dim("nO", tok2vec_width)
|
||||
model.get_ref("maxout_layer").set_dim("nI", tok2vec_width)
|
||||
model.get_ref("linear_layer").set_dim("nI", tok2vec_width)
|
||||
init_chain(model, X, Y)
|
||||
return model
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user