mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-11 17:56:30 +03:00
Fix Transformer width in TextCatEnsemble (#6431)
* add convenience method to determine tok2vec width in a model * fix transformer tok2vec dimensions in TextCatEnsemble architecture * init function should not be nested to avoid pickle issues
This commit is contained in:
parent
402dbc5bae
commit
3983bc6b1e
|
@ -6,6 +6,7 @@ from thinc.api import chain, concatenate, clone, Dropout, ParametricAttention
|
||||||
from thinc.api import SparseLinear, Softmax, softmax_activation, Maxout, reduce_sum
|
from thinc.api import SparseLinear, Softmax, softmax_activation, Maxout, reduce_sum
|
||||||
from thinc.api import HashEmbed, with_array, with_cpu, uniqued
|
from thinc.api import HashEmbed, with_array, with_cpu, uniqued
|
||||||
from thinc.api import Relu, residual, expand_window
|
from thinc.api import Relu, residual, expand_window
|
||||||
|
from thinc.layers.chain import init as init_chain
|
||||||
|
|
||||||
from ...attrs import ID, ORTH, PREFIX, SUFFIX, SHAPE, LOWER
|
from ...attrs import ID, ORTH, PREFIX, SUFFIX, SHAPE, LOWER
|
||||||
from ...util import registry
|
from ...util import registry
|
||||||
|
@ -13,6 +14,7 @@ from ..extract_ngrams import extract_ngrams
|
||||||
from ..staticvectors import StaticVectors
|
from ..staticvectors import StaticVectors
|
||||||
from ..featureextractor import FeatureExtractor
|
from ..featureextractor import FeatureExtractor
|
||||||
from ...tokens import Doc
|
from ...tokens import Doc
|
||||||
|
from .tok2vec import get_tok2vec_width
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.TextCatCNN.v1")
|
@registry.architectures.register("spacy.TextCatCNN.v1")
|
||||||
|
@ -69,16 +71,17 @@ def build_text_classifier_v2(
|
||||||
exclusive_classes = not linear_model.attrs["multi_label"]
|
exclusive_classes = not linear_model.attrs["multi_label"]
|
||||||
with Model.define_operators({">>": chain, "|": concatenate}):
|
with Model.define_operators({">>": chain, "|": concatenate}):
|
||||||
width = tok2vec.maybe_get_dim("nO")
|
width = tok2vec.maybe_get_dim("nO")
|
||||||
|
attention_layer = ParametricAttention(width) # TODO: benchmark performance difference of this layer
|
||||||
|
maxout_layer = Maxout(nO=width, nI=width)
|
||||||
|
linear_layer = Linear(nO=nO, nI=width)
|
||||||
cnn_model = (
|
cnn_model = (
|
||||||
tok2vec
|
tok2vec
|
||||||
>> list2ragged()
|
>> list2ragged()
|
||||||
>> ParametricAttention(
|
>> attention_layer
|
||||||
width
|
>> reduce_sum()
|
||||||
) # TODO: benchmark performance difference of this layer
|
>> residual(maxout_layer)
|
||||||
>> reduce_sum()
|
>> linear_layer
|
||||||
>> residual(Maxout(nO=width, nI=width))
|
>> Dropout(0.0)
|
||||||
>> Linear(nO=nO, nI=width)
|
|
||||||
>> Dropout(0.0)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
nO_double = nO * 2 if nO else None
|
nO_double = nO * 2 if nO else None
|
||||||
|
@ -91,7 +94,22 @@ def build_text_classifier_v2(
|
||||||
if model.has_dim("nO") is not False:
|
if model.has_dim("nO") is not False:
|
||||||
model.set_dim("nO", nO)
|
model.set_dim("nO", nO)
|
||||||
model.set_ref("output_layer", linear_model.get_ref("output_layer"))
|
model.set_ref("output_layer", linear_model.get_ref("output_layer"))
|
||||||
|
model.set_ref("attention_layer", attention_layer)
|
||||||
|
model.set_ref("maxout_layer", maxout_layer)
|
||||||
|
model.set_ref("linear_layer", linear_layer)
|
||||||
model.attrs["multi_label"] = not exclusive_classes
|
model.attrs["multi_label"] = not exclusive_classes
|
||||||
|
|
||||||
|
model.init = init_ensemble_textcat
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def init_ensemble_textcat(model, X, Y) -> Model:
|
||||||
|
tok2vec_width = get_tok2vec_width(model)
|
||||||
|
model.get_ref("attention_layer").set_dim("nO", tok2vec_width)
|
||||||
|
model.get_ref("maxout_layer").set_dim("nO", tok2vec_width)
|
||||||
|
model.get_ref("maxout_layer").set_dim("nI", tok2vec_width)
|
||||||
|
model.get_ref("linear_layer").set_dim("nI", tok2vec_width)
|
||||||
|
init_chain(model, X, Y)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,17 @@ def tok2vec_listener_v1(width: int, upstream: str = "*"):
|
||||||
return tok2vec
|
return tok2vec
|
||||||
|
|
||||||
|
|
||||||
|
def get_tok2vec_width(model: Model):
|
||||||
|
nO = None
|
||||||
|
if model.has_ref("tok2vec"):
|
||||||
|
tok2vec = model.get_ref("tok2vec")
|
||||||
|
if tok2vec.has_dim("nO"):
|
||||||
|
nO = tok2vec.get_dim("nO")
|
||||||
|
elif tok2vec.has_ref("listener"):
|
||||||
|
nO = tok2vec.get_ref("listener").get_dim("nO")
|
||||||
|
return nO
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.HashEmbedCNN.v1")
|
@registry.architectures.register("spacy.HashEmbedCNN.v1")
|
||||||
def build_hash_embed_cnn_tok2vec(
|
def build_hash_embed_cnn_tok2vec(
|
||||||
*,
|
*,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user