mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-11 17:56:30 +03:00
Fix Transformer width in TextCatEnsemble (#6431)
* add convenience method to determine tok2vec width in a model * fix transformer tok2vec dimensions in TextCatEnsemble architecture * init function should not be nested to avoid pickle issues
This commit is contained in:
parent
402dbc5bae
commit
3983bc6b1e
|
@ -6,6 +6,7 @@ from thinc.api import chain, concatenate, clone, Dropout, ParametricAttention
|
|||
from thinc.api import SparseLinear, Softmax, softmax_activation, Maxout, reduce_sum
|
||||
from thinc.api import HashEmbed, with_array, with_cpu, uniqued
|
||||
from thinc.api import Relu, residual, expand_window
|
||||
from thinc.layers.chain import init as init_chain
|
||||
|
||||
from ...attrs import ID, ORTH, PREFIX, SUFFIX, SHAPE, LOWER
|
||||
from ...util import registry
|
||||
|
@ -13,6 +14,7 @@ from ..extract_ngrams import extract_ngrams
|
|||
from ..staticvectors import StaticVectors
|
||||
from ..featureextractor import FeatureExtractor
|
||||
from ...tokens import Doc
|
||||
from .tok2vec import get_tok2vec_width
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.TextCatCNN.v1")
|
||||
|
@ -69,15 +71,16 @@ def build_text_classifier_v2(
|
|||
exclusive_classes = not linear_model.attrs["multi_label"]
|
||||
with Model.define_operators({">>": chain, "|": concatenate}):
|
||||
width = tok2vec.maybe_get_dim("nO")
|
||||
attention_layer = ParametricAttention(width) # TODO: benchmark performance difference of this layer
|
||||
maxout_layer = Maxout(nO=width, nI=width)
|
||||
linear_layer = Linear(nO=nO, nI=width)
|
||||
cnn_model = (
|
||||
tok2vec
|
||||
>> list2ragged()
|
||||
>> ParametricAttention(
|
||||
width
|
||||
) # TODO: benchmark performance difference of this layer
|
||||
>> attention_layer
|
||||
>> reduce_sum()
|
||||
>> residual(Maxout(nO=width, nI=width))
|
||||
>> Linear(nO=nO, nI=width)
|
||||
>> residual(maxout_layer)
|
||||
>> linear_layer
|
||||
>> Dropout(0.0)
|
||||
)
|
||||
|
||||
|
@ -91,7 +94,22 @@ def build_text_classifier_v2(
|
|||
if model.has_dim("nO") is not False:
|
||||
model.set_dim("nO", nO)
|
||||
model.set_ref("output_layer", linear_model.get_ref("output_layer"))
|
||||
model.set_ref("attention_layer", attention_layer)
|
||||
model.set_ref("maxout_layer", maxout_layer)
|
||||
model.set_ref("linear_layer", linear_layer)
|
||||
model.attrs["multi_label"] = not exclusive_classes
|
||||
|
||||
model.init = init_ensemble_textcat
|
||||
return model
|
||||
|
||||
|
||||
def init_ensemble_textcat(model, X, Y) -> Model:
|
||||
tok2vec_width = get_tok2vec_width(model)
|
||||
model.get_ref("attention_layer").set_dim("nO", tok2vec_width)
|
||||
model.get_ref("maxout_layer").set_dim("nO", tok2vec_width)
|
||||
model.get_ref("maxout_layer").set_dim("nI", tok2vec_width)
|
||||
model.get_ref("linear_layer").set_dim("nI", tok2vec_width)
|
||||
init_chain(model, X, Y)
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -20,6 +20,17 @@ def tok2vec_listener_v1(width: int, upstream: str = "*"):
|
|||
return tok2vec
|
||||
|
||||
|
||||
def get_tok2vec_width(model: Model):
|
||||
nO = None
|
||||
if model.has_ref("tok2vec"):
|
||||
tok2vec = model.get_ref("tok2vec")
|
||||
if tok2vec.has_dim("nO"):
|
||||
nO = tok2vec.get_dim("nO")
|
||||
elif tok2vec.has_ref("listener"):
|
||||
nO = tok2vec.get_ref("listener").get_dim("nO")
|
||||
return nO
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.HashEmbedCNN.v1")
|
||||
def build_hash_embed_cnn_tok2vec(
|
||||
*,
|
||||
|
|
Loading…
Reference in New Issue
Block a user