mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
Fix textcat + transformer architecture (#6371)
* add pooling to textcat TransformerListener * maybe_get_dim in case it's null
This commit is contained in:
parent
3ca5c7082d
commit
a0c899a0ff
|
@ -143,6 +143,9 @@ nO = null
|
||||||
@architectures = "spacy-transformers.TransformerListener.v1"
|
@architectures = "spacy-transformers.TransformerListener.v1"
|
||||||
grad_factor = 1.0
|
grad_factor = 1.0
|
||||||
|
|
||||||
|
[components.textcat.model.tok2vec.pooling]
|
||||||
|
@layers = "reduce_mean.v1"
|
||||||
|
|
||||||
[components.textcat.model.linear_model]
|
[components.textcat.model.linear_model]
|
||||||
@architectures = "spacy.TextCatBOW.v1"
|
@architectures = "spacy.TextCatBOW.v1"
|
||||||
exclusive_classes = false
|
exclusive_classes = false
|
||||||
|
|
|
@ -61,14 +61,14 @@ def build_bow_text_classifier(
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures.register("spacy.TextCatEnsemble.v2")
|
@registry.architectures.register("spacy.TextCatEnsemble.v2")
|
||||||
def build_text_classifier(
|
def build_text_classifier_v2(
|
||||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||||
linear_model: Model[List[Doc], Floats2d],
|
linear_model: Model[List[Doc], Floats2d],
|
||||||
nO: Optional[int] = None,
|
nO: Optional[int] = None,
|
||||||
) -> Model[List[Doc], Floats2d]:
|
) -> Model[List[Doc], Floats2d]:
|
||||||
exclusive_classes = not linear_model.attrs["multi_label"]
|
exclusive_classes = not linear_model.attrs["multi_label"]
|
||||||
with Model.define_operators({">>": chain, "|": concatenate}):
|
with Model.define_operators({">>": chain, "|": concatenate}):
|
||||||
width = tok2vec.get_dim("nO")
|
width = tok2vec.maybe_get_dim("nO")
|
||||||
cnn_model = (
|
cnn_model = (
|
||||||
tok2vec
|
tok2vec
|
||||||
>> list2ragged()
|
>> list2ragged()
|
||||||
|
@ -94,7 +94,7 @@ def build_text_classifier(
|
||||||
|
|
||||||
# TODO: move to legacy
|
# TODO: move to legacy
|
||||||
@registry.architectures.register("spacy.TextCatEnsemble.v1")
|
@registry.architectures.register("spacy.TextCatEnsemble.v1")
|
||||||
def build_text_classifier(
|
def build_text_classifier_v1(
|
||||||
width: int,
|
width: int,
|
||||||
embed_size: int,
|
embed_size: int,
|
||||||
pretrained_vectors: Optional[bool],
|
pretrained_vectors: Optional[bool],
|
||||||
|
|
Loading…
Reference in New Issue
Block a user