Fix Transformer width in TextCatEnsemble (#6431)

* add convenience method to determine tok2vec width in a model

* fix transformer tok2vec dimensions in TextCatEnsemble architecture

* init function should not be nested to avoid pickle issues
This commit is contained in:
Sofie Van Landeghem 2021-01-06 12:44:04 +01:00 committed by GitHub
parent 402dbc5bae
commit 3983bc6b1e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 9 deletions

View File

@ -6,6 +6,7 @@ from thinc.api import chain, concatenate, clone, Dropout, ParametricAttention
from thinc.api import SparseLinear, Softmax, softmax_activation, Maxout, reduce_sum from thinc.api import SparseLinear, Softmax, softmax_activation, Maxout, reduce_sum
from thinc.api import HashEmbed, with_array, with_cpu, uniqued from thinc.api import HashEmbed, with_array, with_cpu, uniqued
from thinc.api import Relu, residual, expand_window from thinc.api import Relu, residual, expand_window
from thinc.layers.chain import init as init_chain
from ...attrs import ID, ORTH, PREFIX, SUFFIX, SHAPE, LOWER from ...attrs import ID, ORTH, PREFIX, SUFFIX, SHAPE, LOWER
from ...util import registry from ...util import registry
@ -13,6 +14,7 @@ from ..extract_ngrams import extract_ngrams
from ..staticvectors import StaticVectors from ..staticvectors import StaticVectors
from ..featureextractor import FeatureExtractor from ..featureextractor import FeatureExtractor
from ...tokens import Doc from ...tokens import Doc
from .tok2vec import get_tok2vec_width
@registry.architectures.register("spacy.TextCatCNN.v1") @registry.architectures.register("spacy.TextCatCNN.v1")
@ -69,16 +71,17 @@ def build_text_classifier_v2(
exclusive_classes = not linear_model.attrs["multi_label"] exclusive_classes = not linear_model.attrs["multi_label"]
with Model.define_operators({">>": chain, "|": concatenate}): with Model.define_operators({">>": chain, "|": concatenate}):
width = tok2vec.maybe_get_dim("nO") width = tok2vec.maybe_get_dim("nO")
attention_layer = ParametricAttention(width) # TODO: benchmark performance difference of this layer
maxout_layer = Maxout(nO=width, nI=width)
linear_layer = Linear(nO=nO, nI=width)
cnn_model = ( cnn_model = (
tok2vec tok2vec
>> list2ragged() >> list2ragged()
>> ParametricAttention( >> attention_layer
width >> reduce_sum()
) # TODO: benchmark performance difference of this layer >> residual(maxout_layer)
>> reduce_sum() >> linear_layer
>> residual(Maxout(nO=width, nI=width)) >> Dropout(0.0)
>> Linear(nO=nO, nI=width)
>> Dropout(0.0)
) )
nO_double = nO * 2 if nO else None nO_double = nO * 2 if nO else None
@ -91,7 +94,22 @@ def build_text_classifier_v2(
if model.has_dim("nO") is not False: if model.has_dim("nO") is not False:
model.set_dim("nO", nO) model.set_dim("nO", nO)
model.set_ref("output_layer", linear_model.get_ref("output_layer")) model.set_ref("output_layer", linear_model.get_ref("output_layer"))
model.set_ref("attention_layer", attention_layer)
model.set_ref("maxout_layer", maxout_layer)
model.set_ref("linear_layer", linear_layer)
model.attrs["multi_label"] = not exclusive_classes model.attrs["multi_label"] = not exclusive_classes
model.init = init_ensemble_textcat
return model
def init_ensemble_textcat(model, X, Y) -> Model:
tok2vec_width = get_tok2vec_width(model)
model.get_ref("attention_layer").set_dim("nO", tok2vec_width)
model.get_ref("maxout_layer").set_dim("nO", tok2vec_width)
model.get_ref("maxout_layer").set_dim("nI", tok2vec_width)
model.get_ref("linear_layer").set_dim("nI", tok2vec_width)
init_chain(model, X, Y)
return model return model

View File

@ -20,6 +20,17 @@ def tok2vec_listener_v1(width: int, upstream: str = "*"):
return tok2vec return tok2vec
def get_tok2vec_width(model: Model):
nO = None
if model.has_ref("tok2vec"):
tok2vec = model.get_ref("tok2vec")
if tok2vec.has_dim("nO"):
nO = tok2vec.get_dim("nO")
elif tok2vec.has_ref("listener"):
nO = tok2vec.get_ref("listener").get_dim("nO")
return nO
@registry.architectures.register("spacy.HashEmbedCNN.v1") @registry.architectures.register("spacy.HashEmbedCNN.v1")
def build_hash_embed_cnn_tok2vec( def build_hash_embed_cnn_tok2vec(
*, *,