Fix textcat + transformer architecture (#6371)

* add pooling to textcat TransformerListener

* maybe_get_dim in case it's null
This commit is contained in:
Sofie Van Landeghem 2020-11-10 13:14:47 +01:00 committed by GitHub
parent 3ca5c7082d
commit a0c899a0ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 3 deletions

View File

@ -143,6 +143,9 @@ nO = null
@architectures = "spacy-transformers.TransformerListener.v1" @architectures = "spacy-transformers.TransformerListener.v1"
grad_factor = 1.0 grad_factor = 1.0
[components.textcat.model.tok2vec.pooling]
@layers = "reduce_mean.v1"
[components.textcat.model.linear_model] [components.textcat.model.linear_model]
@architectures = "spacy.TextCatBOW.v1" @architectures = "spacy.TextCatBOW.v1"
exclusive_classes = false exclusive_classes = false

View File

@ -61,14 +61,14 @@ def build_bow_text_classifier(
@registry.architectures.register("spacy.TextCatEnsemble.v2") @registry.architectures.register("spacy.TextCatEnsemble.v2")
def build_text_classifier( def build_text_classifier_v2(
tok2vec: Model[List[Doc], List[Floats2d]], tok2vec: Model[List[Doc], List[Floats2d]],
linear_model: Model[List[Doc], Floats2d], linear_model: Model[List[Doc], Floats2d],
nO: Optional[int] = None, nO: Optional[int] = None,
) -> Model[List[Doc], Floats2d]: ) -> Model[List[Doc], Floats2d]:
exclusive_classes = not linear_model.attrs["multi_label"] exclusive_classes = not linear_model.attrs["multi_label"]
with Model.define_operators({">>": chain, "|": concatenate}): with Model.define_operators({">>": chain, "|": concatenate}):
width = tok2vec.get_dim("nO") width = tok2vec.maybe_get_dim("nO")
cnn_model = ( cnn_model = (
tok2vec tok2vec
>> list2ragged() >> list2ragged()
@ -94,7 +94,7 @@ def build_text_classifier(
# TODO: move to legacy # TODO: move to legacy
@registry.architectures.register("spacy.TextCatEnsemble.v1") @registry.architectures.register("spacy.TextCatEnsemble.v1")
def build_text_classifier( def build_text_classifier_v1(
width: int, width: int,
embed_size: int, embed_size: int,
pretrained_vectors: Optional[bool], pretrained_vectors: Optional[bool],