Update textcat ensemble model

This commit is contained in:
Matthew Honnibal 2021-01-19 02:53:02 +11:00
parent f50502dad7
commit c2a18e4fa3

View File

@ -72,23 +72,20 @@ def build_text_classifier_v2(
attention_layer = ParametricAttention(
width
) # TODO: benchmark performance difference of this layer
maxout_layer = Maxout(nO=width, nI=width)
linear_layer = Linear(nO=nO, nI=width)
maxout_layer = Maxout(nO=width, nI=width, dropout=0.0, normalize=True)
cnn_model = (
tok2vec
>> list2ragged()
>> attention_layer
>> reduce_sum()
>> residual(maxout_layer)
>> linear_layer
>> Dropout(0.0)
)
nO_double = nO * 2 if nO else None
if exclusive_classes:
output_layer = Softmax(nO=nO, nI=nO_double)
else:
output_layer = Linear(nO=nO, nI=nO_double) >> Dropout(0.0) >> Logistic()
output_layer = Linear(nO=nO, nI=nO_double) >> Logistic()
model = (linear_model | cnn_model) >> output_layer
model.set_ref("tok2vec", tok2vec)
if model.has_dim("nO") is not False:
@ -96,7 +93,6 @@ def build_text_classifier_v2(
model.set_ref("output_layer", linear_model.get_ref("output_layer"))
model.set_ref("attention_layer", attention_layer)
model.set_ref("maxout_layer", maxout_layer)
model.set_ref("linear_layer", linear_layer)
model.attrs["multi_label"] = not exclusive_classes
model.init = init_ensemble_textcat
@ -108,7 +104,6 @@ def init_ensemble_textcat(model, X, Y) -> Model:
model.get_ref("attention_layer").set_dim("nO", tok2vec_width)
model.get_ref("maxout_layer").set_dim("nO", tok2vec_width)
model.get_ref("maxout_layer").set_dim("nI", tok2vec_width)
model.get_ref("linear_layer").set_dim("nI", tok2vec_width)
init_chain(model, X, Y)
return model