mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 01:34:30 +03:00
Construct TextCatEnsemble.v2 using helper function
This commit is contained in:
parent
ce4ea5ffa7
commit
e722284ff4
|
@ -19,6 +19,7 @@ from thinc.api import (
|
|||
clone,
|
||||
concatenate,
|
||||
list2ragged,
|
||||
noop,
|
||||
reduce_first,
|
||||
reduce_last,
|
||||
reduce_max,
|
||||
|
@ -148,50 +149,26 @@ def build_text_classifier_v2(
|
|||
linear_model: Model[List[Doc], Floats2d],
|
||||
nO: Optional[int] = None,
|
||||
) -> Model[List[Doc], Floats2d]:
|
||||
# TODO: build the model with _build_parametric_attention_with_residual_nonlinear
|
||||
# in spaCy v4. We don't do this in spaCy v3 to preserve model
|
||||
# compatibility.
|
||||
width = tok2vec.maybe_get_dim("nO")
|
||||
exclusive_classes = not linear_model.attrs["multi_label"]
|
||||
parametric_attention = _build_parametric_attention_with_residual_nonlinear(
|
||||
tok2vec=tok2vec,
|
||||
nonlinear_layer=Maxout(nI=width, nO=width),
|
||||
key_transform=noop(),
|
||||
)
|
||||
with Model.define_operators({">>": chain, "|": concatenate}):
|
||||
width = tok2vec.maybe_get_dim("nO")
|
||||
attention_layer = ParametricAttention(width)
|
||||
maxout_layer = Maxout(nO=width, nI=width)
|
||||
norm_layer = LayerNorm(nI=width)
|
||||
cnn_model = (
|
||||
tok2vec
|
||||
>> list2ragged()
|
||||
>> attention_layer
|
||||
>> reduce_sum()
|
||||
>> residual(maxout_layer >> norm_layer >> Dropout(0.0))
|
||||
)
|
||||
|
||||
nO_double = nO * 2 if nO else None
|
||||
if exclusive_classes:
|
||||
output_layer = Softmax(nO=nO, nI=nO_double)
|
||||
else:
|
||||
output_layer = Linear(nO=nO, nI=nO_double) >> Logistic()
|
||||
model = (linear_model | cnn_model) >> output_layer
|
||||
model = (linear_model | parametric_attention) >> output_layer
|
||||
model.set_ref("tok2vec", tok2vec)
|
||||
if model.has_dim("nO") is not False and nO is not None:
|
||||
model.set_dim("nO", cast(int, nO))
|
||||
model.set_ref("output_layer", linear_model.get_ref("output_layer"))
|
||||
model.set_ref("attention_layer", attention_layer)
|
||||
model.set_ref("maxout_layer", maxout_layer)
|
||||
model.set_ref("norm_layer", norm_layer)
|
||||
model.attrs["multi_label"] = not exclusive_classes
|
||||
|
||||
model.init = init_ensemble_textcat # type: ignore[assignment]
|
||||
return model
|
||||
|
||||
|
||||
def init_ensemble_textcat(model, X, Y) -> Model:
|
||||
tok2vec_width = get_tok2vec_width(model)
|
||||
model.get_ref("attention_layer").set_dim("nO", tok2vec_width)
|
||||
model.get_ref("maxout_layer").set_dim("nO", tok2vec_width)
|
||||
model.get_ref("maxout_layer").set_dim("nI", tok2vec_width)
|
||||
model.get_ref("norm_layer").set_dim("nI", tok2vec_width)
|
||||
model.get_ref("norm_layer").set_dim("nO", tok2vec_width)
|
||||
init_chain(model, X, Y)
|
||||
return model
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user