mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-11 17:56:30 +03:00
rewrite Maxout layer as separate layers to avoid shape inference trouble (#6760)
This commit is contained in:
parent
26c34ab8b0
commit
c8761b0e6e
|
@ -4,7 +4,7 @@ from thinc.types import Floats2d
|
||||||
from thinc.api import Model, reduce_mean, Linear, list2ragged, Logistic
|
from thinc.api import Model, reduce_mean, Linear, list2ragged, Logistic
|
||||||
from thinc.api import chain, concatenate, clone, Dropout, ParametricAttention
|
from thinc.api import chain, concatenate, clone, Dropout, ParametricAttention
|
||||||
from thinc.api import SparseLinear, Softmax, softmax_activation, Maxout, reduce_sum
|
from thinc.api import SparseLinear, Softmax, softmax_activation, Maxout, reduce_sum
|
||||||
from thinc.api import with_cpu, Relu, residual
|
from thinc.api import with_cpu, Relu, residual, LayerNorm
|
||||||
from thinc.layers.chain import init as init_chain
|
from thinc.layers.chain import init as init_chain
|
||||||
|
|
||||||
from ...attrs import ORTH
|
from ...attrs import ORTH
|
||||||
|
@ -72,13 +72,14 @@ def build_text_classifier_v2(
|
||||||
attention_layer = ParametricAttention(
|
attention_layer = ParametricAttention(
|
||||||
width
|
width
|
||||||
) # TODO: benchmark performance difference of this layer
|
) # TODO: benchmark performance difference of this layer
|
||||||
maxout_layer = Maxout(nO=width, nI=width, dropout=0.0, normalize=True)
|
maxout_layer = Maxout(nO=width, nI=width)
|
||||||
|
norm_layer = LayerNorm(nI=width)
|
||||||
cnn_model = (
|
cnn_model = (
|
||||||
tok2vec
|
tok2vec
|
||||||
>> list2ragged()
|
>> list2ragged()
|
||||||
>> attention_layer
|
>> attention_layer
|
||||||
>> reduce_sum()
|
>> reduce_sum()
|
||||||
>> residual(maxout_layer)
|
>> residual(maxout_layer >> norm_layer >> Dropout(0.0))
|
||||||
)
|
)
|
||||||
|
|
||||||
nO_double = nO * 2 if nO else None
|
nO_double = nO * 2 if nO else None
|
||||||
|
@ -93,6 +94,7 @@ def build_text_classifier_v2(
|
||||||
model.set_ref("output_layer", linear_model.get_ref("output_layer"))
|
model.set_ref("output_layer", linear_model.get_ref("output_layer"))
|
||||||
model.set_ref("attention_layer", attention_layer)
|
model.set_ref("attention_layer", attention_layer)
|
||||||
model.set_ref("maxout_layer", maxout_layer)
|
model.set_ref("maxout_layer", maxout_layer)
|
||||||
|
model.set_ref("norm_layer", norm_layer)
|
||||||
model.attrs["multi_label"] = not exclusive_classes
|
model.attrs["multi_label"] = not exclusive_classes
|
||||||
|
|
||||||
model.init = init_ensemble_textcat
|
model.init = init_ensemble_textcat
|
||||||
|
@ -104,6 +106,7 @@ def init_ensemble_textcat(model, X, Y) -> Model:
|
||||||
model.get_ref("attention_layer").set_dim("nO", tok2vec_width)
|
model.get_ref("attention_layer").set_dim("nO", tok2vec_width)
|
||||||
model.get_ref("maxout_layer").set_dim("nO", tok2vec_width)
|
model.get_ref("maxout_layer").set_dim("nO", tok2vec_width)
|
||||||
model.get_ref("maxout_layer").set_dim("nI", tok2vec_width)
|
model.get_ref("maxout_layer").set_dim("nI", tok2vec_width)
|
||||||
|
model.get_ref("norm_layer").set_dim("nI", tok2vec_width)
|
||||||
init_chain(model, X, Y)
|
init_chain(model, X, Y)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user