mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Add TextCatReduce.v1
This is a textcat classifier that pools the vectors generated by a tok2vec implementation and then applies a classifier to the pooled representation. Three reductions are supported for pooling: first, max, and mean. When multiple reductions are enabled, the reductions are concatenated before providing them to the classification layer. This model is a generalization of the TextCatCNN model, which only supports mean reductions and is a bit of a misnomer, because it can also be used with transformers. This change also reimplements TextCatCNN.v2 using the new TextCatReduce.v1 layer.
This commit is contained in:
		
							parent
							
								
									e467573550
								
							
						
					
					
						commit
						61471038f5
					
				| 
						 | 
					@ -984,6 +984,8 @@ class Errors(metaclass=ErrorsWithCodes):
 | 
				
			||||||
    E1055 = ("The 'replace_listener' callback expects {num_params} parameters, "
 | 
					    E1055 = ("The 'replace_listener' callback expects {num_params} parameters, "
 | 
				
			||||||
             "but only callbacks with one or three parameters are supported")
 | 
					             "but only callbacks with one or three parameters are supported")
 | 
				
			||||||
    E1056 = ("The `TextCatBOW` architecture expects a length of at least 1, was {length}.")
 | 
					    E1056 = ("The `TextCatBOW` architecture expects a length of at least 1, was {length}.")
 | 
				
			||||||
 | 
					    E1057 = ("The `TextCatReduce` architecture must be used with at least one reduction. "
 | 
				
			||||||
 | 
					             "Please enable one of `use_reduce_first`, `use_reduce_max` or `use_reduce_mean`.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Deprecated model shortcuts, only used in errors and warnings
 | 
					# Deprecated model shortcuts, only used in errors and warnings
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -17,6 +17,8 @@ from thinc.api import (
 | 
				
			||||||
    clone,
 | 
					    clone,
 | 
				
			||||||
    concatenate,
 | 
					    concatenate,
 | 
				
			||||||
    list2ragged,
 | 
					    list2ragged,
 | 
				
			||||||
 | 
					    reduce_first,
 | 
				
			||||||
 | 
					    reduce_max,
 | 
				
			||||||
    reduce_mean,
 | 
					    reduce_mean,
 | 
				
			||||||
    reduce_sum,
 | 
					    reduce_sum,
 | 
				
			||||||
    residual,
 | 
					    residual,
 | 
				
			||||||
| 
						 | 
					@ -49,39 +51,14 @@ def build_simple_cnn_text_classifier(
 | 
				
			||||||
    outputs sum to 1. If exclusive_classes=False, a logistic non-linearity
 | 
					    outputs sum to 1. If exclusive_classes=False, a logistic non-linearity
 | 
				
			||||||
    is applied instead, so that outputs are in the range [0, 1].
 | 
					    is applied instead, so that outputs are in the range [0, 1].
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    fill_defaults = {"b": 0, "W": 0}
 | 
					    return build_reduce_text_classifier(
 | 
				
			||||||
    with Model.define_operators({">>": chain}):
 | 
					        tok2vec=tok2vec,
 | 
				
			||||||
        cnn = tok2vec >> list2ragged() >> reduce_mean()
 | 
					        exclusive_classes=exclusive_classes,
 | 
				
			||||||
        nI = tok2vec.maybe_get_dim("nO")
 | 
					        use_reduce_first=False,
 | 
				
			||||||
        if exclusive_classes:
 | 
					        use_reduce_max=False,
 | 
				
			||||||
            output_layer = Softmax(nO=nO, nI=nI)
 | 
					        use_reduce_mean=True,
 | 
				
			||||||
            fill_defaults["b"] = NEG_VALUE
 | 
					        nO=nO,
 | 
				
			||||||
            resizable_layer: Model = resizable(
 | 
					    )
 | 
				
			||||||
                output_layer,
 | 
					 | 
				
			||||||
                resize_layer=partial(
 | 
					 | 
				
			||||||
                    resize_linear_weighted, fill_defaults=fill_defaults
 | 
					 | 
				
			||||||
                ),
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            model = cnn >> resizable_layer
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            output_layer = Linear(nO=nO, nI=nI)
 | 
					 | 
				
			||||||
            resizable_layer = resizable(
 | 
					 | 
				
			||||||
                output_layer,
 | 
					 | 
				
			||||||
                resize_layer=partial(
 | 
					 | 
				
			||||||
                    resize_linear_weighted, fill_defaults=fill_defaults
 | 
					 | 
				
			||||||
                ),
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            model = cnn >> resizable_layer >> Logistic()
 | 
					 | 
				
			||||||
        model.set_ref("output_layer", output_layer)
 | 
					 | 
				
			||||||
        model.attrs["resize_output"] = partial(
 | 
					 | 
				
			||||||
            resize_and_set_ref,
 | 
					 | 
				
			||||||
            resizable_layer=resizable_layer,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
    model.set_ref("tok2vec", tok2vec)
 | 
					 | 
				
			||||||
    if nO is not None:
 | 
					 | 
				
			||||||
        model.set_dim("nO", cast(int, nO))
 | 
					 | 
				
			||||||
    model.attrs["multi_label"] = not exclusive_classes
 | 
					 | 
				
			||||||
    return model
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def resize_and_set_ref(model, new_nO, resizable_layer):
 | 
					def resize_and_set_ref(model, new_nO, resizable_layer):
 | 
				
			||||||
| 
						 | 
					@ -230,3 +207,75 @@ def build_text_classifier_lowdata(
 | 
				
			||||||
            model = model >> Dropout(dropout)
 | 
					            model = model >> Dropout(dropout)
 | 
				
			||||||
        model = model >> Logistic()
 | 
					        model = model >> Logistic()
 | 
				
			||||||
    return model
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@registry.architectures("spacy.TextCatReduce.v1")
 | 
				
			||||||
 | 
					def build_reduce_text_classifier(
 | 
				
			||||||
 | 
					    tok2vec: Model,
 | 
				
			||||||
 | 
					    exclusive_classes: bool,
 | 
				
			||||||
 | 
					    use_reduce_first: bool,
 | 
				
			||||||
 | 
					    use_reduce_max: bool,
 | 
				
			||||||
 | 
					    use_reduce_mean: bool,
 | 
				
			||||||
 | 
					    nO: Optional[int] = None,
 | 
				
			||||||
 | 
					) -> Model[List[Doc], Floats2d]:
 | 
				
			||||||
 | 
					    """Build a model that classifies pooled `Doc` representations.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Pooling is performed using reductions. Reductions are concatenated when
 | 
				
			||||||
 | 
					    multible reductions are used.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    tok2vec (Model): the tok2vec layer to pool over.
 | 
				
			||||||
 | 
					    exclusive_classes (bool): Whether or not classes are mutually exclusive.
 | 
				
			||||||
 | 
					    use_reduce_first (bool): Pool by using the hidden representation of the
 | 
				
			||||||
 | 
					        first token of a `Doc`
 | 
				
			||||||
 | 
					    use_reduce_max (bool): Pool by taking the maximum values of the hidden
 | 
				
			||||||
 | 
					        representations of a `Doc`.
 | 
				
			||||||
 | 
					    use_reduce_mean (bool): Pool by taking the mean of all hidden
 | 
				
			||||||
 | 
					        representations of a `Doc`.
 | 
				
			||||||
 | 
					    nO (Optional[int]): Number of classes.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    fill_defaults = {"b": 0, "W": 0}
 | 
				
			||||||
 | 
					    reductions = []
 | 
				
			||||||
 | 
					    if use_reduce_first:
 | 
				
			||||||
 | 
					        reductions.append(reduce_first())
 | 
				
			||||||
 | 
					    if use_reduce_max:
 | 
				
			||||||
 | 
					        reductions.append(reduce_max())
 | 
				
			||||||
 | 
					    if use_reduce_mean:
 | 
				
			||||||
 | 
					        reductions.append(reduce_mean())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not len(reductions):
 | 
				
			||||||
 | 
					        raise ValueError(Errors.E1057)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with Model.define_operators({">>": chain}):
 | 
				
			||||||
 | 
					        cnn = tok2vec >> list2ragged() >> concatenate(*reductions)
 | 
				
			||||||
 | 
					        nO_tok2vec = tok2vec.maybe_get_dim("nO")
 | 
				
			||||||
 | 
					        nI = nO_tok2vec * len(reductions) if nO_tok2vec is not None else None
 | 
				
			||||||
 | 
					        if exclusive_classes:
 | 
				
			||||||
 | 
					            output_layer = Softmax(nO=nO, nI=nI)
 | 
				
			||||||
 | 
					            fill_defaults["b"] = NEG_VALUE
 | 
				
			||||||
 | 
					            resizable_layer: Model = resizable(
 | 
				
			||||||
 | 
					                output_layer,
 | 
				
			||||||
 | 
					                resize_layer=partial(
 | 
				
			||||||
 | 
					                    resize_linear_weighted, fill_defaults=fill_defaults
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            model = cnn >> resizable_layer
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            output_layer = Linear(nO=nO, nI=nI)
 | 
				
			||||||
 | 
					            resizable_layer = resizable(
 | 
				
			||||||
 | 
					                output_layer,
 | 
				
			||||||
 | 
					                resize_layer=partial(
 | 
				
			||||||
 | 
					                    resize_linear_weighted, fill_defaults=fill_defaults
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            model = cnn >> resizable_layer >> Logistic()
 | 
				
			||||||
 | 
					        model.set_ref("output_layer", output_layer)
 | 
				
			||||||
 | 
					        model.attrs["resize_output"] = partial(
 | 
				
			||||||
 | 
					            resize_and_set_ref,
 | 
				
			||||||
 | 
					            resizable_layer=resizable_layer,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    model.set_ref("tok2vec", tok2vec)
 | 
				
			||||||
 | 
					    if nO is not None:
 | 
				
			||||||
 | 
					        model.set_dim("nO", cast(int, nO))
 | 
				
			||||||
 | 
					    model.attrs["multi_label"] = not exclusive_classes
 | 
				
			||||||
 | 
					    return model
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -457,8 +457,8 @@ def test_no_resize(name, textcat_config):
 | 
				
			||||||
        ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
 | 
					        ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
 | 
				
			||||||
        ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
 | 
					        ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
 | 
				
			||||||
        # CNN
 | 
					        # CNN
 | 
				
			||||||
        ("textcat", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
 | 
					        ("textcat", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}),
 | 
				
			||||||
        ("textcat_multilabel", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
 | 
					        ("textcat_multilabel", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}),
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
# fmt: on
 | 
					# fmt: on
 | 
				
			||||||
| 
						 | 
					@ -485,9 +485,9 @@ def test_resize(name, textcat_config):
 | 
				
			||||||
        ("textcat", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": True, "no_output_layer": True, "ngram_size": 3}),
 | 
					        ("textcat", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": True, "no_output_layer": True, "ngram_size": 3}),
 | 
				
			||||||
        ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
 | 
					        ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
 | 
				
			||||||
        ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
 | 
					        ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
 | 
				
			||||||
        # CNN
 | 
					        # REDUCE
 | 
				
			||||||
        ("textcat", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
 | 
					        ("textcat", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}),
 | 
				
			||||||
        ("textcat_multilabel", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
 | 
					        ("textcat_multilabel", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}),
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
# fmt: on
 | 
					# fmt: on
 | 
				
			||||||
| 
						 | 
					@ -701,9 +701,9 @@ def test_overfitting_IO_multi():
 | 
				
			||||||
        # ENSEMBLE V2
 | 
					        # ENSEMBLE V2
 | 
				
			||||||
        ("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "ngram_size": 1, "no_output_layer": False}}),
 | 
					        ("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "ngram_size": 1, "no_output_layer": False}}),
 | 
				
			||||||
        ("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": True, "ngram_size": 5, "no_output_layer": False}}),
 | 
					        ("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": True, "ngram_size": 5, "no_output_layer": False}}),
 | 
				
			||||||
        # CNN V2
 | 
					        # REDUCE V1
 | 
				
			||||||
        ("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
 | 
					        ("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}),
 | 
				
			||||||
        ("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
 | 
					        ("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}),
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
# fmt: on
 | 
					# fmt: on
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -26,6 +26,8 @@ from spacy.ml.models import (
 | 
				
			||||||
    build_Tok2Vec_model,
 | 
					    build_Tok2Vec_model,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from spacy.ml.staticvectors import StaticVectors
 | 
					from spacy.ml.staticvectors import StaticVectors
 | 
				
			||||||
 | 
					from spacy.pipeline import tok2vec
 | 
				
			||||||
 | 
					from spacy.util import registry
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_textcat_bow_kwargs():
 | 
					def get_textcat_bow_kwargs():
 | 
				
			||||||
| 
						 | 
					@ -284,3 +286,16 @@ def test_spancat_model_forward_backward(nO=5):
 | 
				
			||||||
    Y, backprop = model((docs, spans), is_train=True)
 | 
					    Y, backprop = model((docs, spans), is_train=True)
 | 
				
			||||||
    assert Y.shape == (spans.dataXd.shape[0], nO)
 | 
					    assert Y.shape == (spans.dataXd.shape[0], nO)
 | 
				
			||||||
    backprop(Y)
 | 
					    backprop(Y)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_textcat_reduce_invalid_args():
 | 
				
			||||||
 | 
					    textcat_reduce = registry.architectures.get("spacy.TextCatReduce.v1")
 | 
				
			||||||
 | 
					    tok2vec = make_test_tok2vec()
 | 
				
			||||||
 | 
					    with pytest.raises(ValueError, match=r"must be used with at least one reduction"):
 | 
				
			||||||
 | 
					        textcat_reduce(
 | 
				
			||||||
 | 
					            tok2vec=tok2vec,
 | 
				
			||||||
 | 
					            exclusive_classes=False,
 | 
				
			||||||
 | 
					            use_reduce_first=False,
 | 
				
			||||||
 | 
					            use_reduce_max=False,
 | 
				
			||||||
 | 
					            use_reduce_mean=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1043,6 +1043,9 @@ A neural network model where token vectors are calculated using a CNN. The
 | 
				
			||||||
vectors are mean pooled and used as features in a feed-forward network. This
 | 
					vectors are mean pooled and used as features in a feed-forward network. This
 | 
				
			||||||
architecture is usually less accurate than the ensemble, but runs faster.
 | 
					architecture is usually less accurate than the ensemble, but runs faster.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					This model is identical to [TexCatReduce.v1](#TextCatReduce) with
 | 
				
			||||||
 | 
					`use_reduce_mean=true`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| Name                | Description                                                                                                                                                                                    |
 | 
					| Name                | Description                                                                                                                                                                                    |
 | 
				
			||||||
| ------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
 | 
					| ------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
 | 
				
			||||||
| `exclusive_classes` | Whether or not categories are mutually exclusive. ~~bool~~                                                                                                                                     |
 | 
					| `exclusive_classes` | Whether or not categories are mutually exclusive. ~~bool~~                                                                                                                                     |
 | 
				
			||||||
| 
						 | 
					@ -1096,6 +1099,44 @@ the others, but may not be as accurate, especially if texts are short.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
</Accordion>
 | 
					</Accordion>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### spacy.TextCatReduce.v1 {id="TextCatReduce"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					> #### Example Config
 | 
				
			||||||
 | 
					>
 | 
				
			||||||
 | 
					> ```ini
 | 
				
			||||||
 | 
					> [model]
 | 
				
			||||||
 | 
					> @architectures = "spacy.TextCatReduce.v1"
 | 
				
			||||||
 | 
					> exclusive_classes = false
 | 
				
			||||||
 | 
					> use_reduce_first = false
 | 
				
			||||||
 | 
					> use_reduce_max = false
 | 
				
			||||||
 | 
					> use_reduce_mean = true
 | 
				
			||||||
 | 
					> nO = null
 | 
				
			||||||
 | 
					>
 | 
				
			||||||
 | 
					> [model.tok2vec]
 | 
				
			||||||
 | 
					> @architectures = "spacy.HashEmbedCNN.v2"
 | 
				
			||||||
 | 
					> pretrained_vectors = null
 | 
				
			||||||
 | 
					> width = 96
 | 
				
			||||||
 | 
					> depth = 4
 | 
				
			||||||
 | 
					> embed_size = 2000
 | 
				
			||||||
 | 
					> window_size = 1
 | 
				
			||||||
 | 
					> maxout_pieces = 3
 | 
				
			||||||
 | 
					> subword_features = true
 | 
				
			||||||
 | 
					> ```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					A classifier that pools token hidden representations of each `Doc` using first,
 | 
				
			||||||
 | 
					max or mean reduction and then applies a classification layer. Reductions are
 | 
				
			||||||
 | 
					concatenated when multiple reductions are used.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					| Name                | Description                                                                                                                                                                                    |
 | 
				
			||||||
 | 
					| ------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
 | 
				
			||||||
 | 
					| `exclusive_classes` | Whether or not categories are mutually exclusive. ~~bool~~                                                                                                                                     |
 | 
				
			||||||
 | 
					| `tok2vec`           | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~                                                                                                                                        |
 | 
				
			||||||
 | 
					| `use_reduce_first`  | Pool by using the hidden representation of the first token of a `Doc`. ~~bool~~                                                                                                                |
 | 
				
			||||||
 | 
					| `use_reduce_max`    | Pool by taking the maximum values of the hidden representations of a `Doc`. ~~bool~~                                                                                                           |
 | 
				
			||||||
 | 
					| `use_reduce_mean`   | Pool by taking the mean of all hidden representations of a `Doc`. ~~bool~~                                                                                                                     |
 | 
				
			||||||
 | 
					| `nO`                | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `initialize` is called. ~~Optional[int]~~ |
 | 
				
			||||||
 | 
					| **CREATES**         | The model using the architecture. ~~Model[List[Doc], Floats2d]~~                                                                                                                               |
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Span classification architectures {id="spancat",source="spacy/ml/models/spancat.py"}
 | 
					## Span classification architectures {id="spancat",source="spacy/ml/models/spancat.py"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### spacy.SpanCategorizer.v1 {id="SpanCategorizer"}
 | 
					### spacy.SpanCategorizer.v1 {id="SpanCategorizer"}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user