rewrite Maxout layer as separate layers to avoid shape inference trouble (#6760)

This commit is contained in:
Sofie Van Landeghem 2021-01-19 00:37:17 +01:00 committed by GitHub
parent 26c34ab8b0
commit c8761b0e6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,7 +4,7 @@ from thinc.types import Floats2d
from thinc.api import Model, reduce_mean, Linear, list2ragged, Logistic from thinc.api import Model, reduce_mean, Linear, list2ragged, Logistic
from thinc.api import chain, concatenate, clone, Dropout, ParametricAttention from thinc.api import chain, concatenate, clone, Dropout, ParametricAttention
from thinc.api import SparseLinear, Softmax, softmax_activation, Maxout, reduce_sum from thinc.api import SparseLinear, Softmax, softmax_activation, Maxout, reduce_sum
from thinc.api import with_cpu, Relu, residual from thinc.api import with_cpu, Relu, residual, LayerNorm
from thinc.layers.chain import init as init_chain from thinc.layers.chain import init as init_chain
from ...attrs import ORTH from ...attrs import ORTH
@ -72,13 +72,14 @@ def build_text_classifier_v2(
attention_layer = ParametricAttention( attention_layer = ParametricAttention(
width width
) # TODO: benchmark performance difference of this layer ) # TODO: benchmark performance difference of this layer
maxout_layer = Maxout(nO=width, nI=width, dropout=0.0, normalize=True) maxout_layer = Maxout(nO=width, nI=width)
norm_layer = LayerNorm(nI=width)
cnn_model = ( cnn_model = (
tok2vec tok2vec
>> list2ragged() >> list2ragged()
>> attention_layer >> attention_layer
>> reduce_sum() >> reduce_sum()
>> residual(maxout_layer) >> residual(maxout_layer >> norm_layer >> Dropout(0.0))
) )
nO_double = nO * 2 if nO else None nO_double = nO * 2 if nO else None
@ -93,6 +94,7 @@ def build_text_classifier_v2(
model.set_ref("output_layer", linear_model.get_ref("output_layer")) model.set_ref("output_layer", linear_model.get_ref("output_layer"))
model.set_ref("attention_layer", attention_layer) model.set_ref("attention_layer", attention_layer)
model.set_ref("maxout_layer", maxout_layer) model.set_ref("maxout_layer", maxout_layer)
model.set_ref("norm_layer", norm_layer)
model.attrs["multi_label"] = not exclusive_classes model.attrs["multi_label"] = not exclusive_classes
model.init = init_ensemble_textcat model.init = init_ensemble_textcat
@ -104,6 +106,7 @@ def init_ensemble_textcat(model, X, Y) -> Model:
model.get_ref("attention_layer").set_dim("nO", tok2vec_width) model.get_ref("attention_layer").set_dim("nO", tok2vec_width)
model.get_ref("maxout_layer").set_dim("nO", tok2vec_width) model.get_ref("maxout_layer").set_dim("nO", tok2vec_width)
model.get_ref("maxout_layer").set_dim("nI", tok2vec_width) model.get_ref("maxout_layer").set_dim("nI", tok2vec_width)
model.get_ref("norm_layer").set_dim("nI", tok2vec_width)
init_chain(model, X, Y) init_chain(model, X, Y)
return model return model