mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-25 03:13:41 +03:00
Add TextCatReduce.v1
This is a textcat classifier that pools the vectors generated by a tok2vec implementation and then applies a classifier to the pooled representation. Three reductions are supported for pooling: first, max, and mean. When multiple reductions are enabled, the reductions are concatenated before providing them to the classification layer. This model is a generalization of the TextCatCNN model, which only supports mean reductions and is a bit of a misnomer, because it can also be used with transformers. This change also reimplements TextCatCNN.v2 using the new TextCatReduce.v1 layer.
This commit is contained in:
parent
e467573550
commit
61471038f5
|
@ -984,6 +984,8 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
E1055 = ("The 'replace_listener' callback expects {num_params} parameters, "
|
||||
"but only callbacks with one or three parameters are supported")
|
||||
E1056 = ("The `TextCatBOW` architecture expects a length of at least 1, was {length}.")
|
||||
E1057 = ("The `TextCatReduce` architecture must be used with at least one reduction. "
|
||||
"Please enable one of `use_reduce_first`, `use_reduce_max` or `use_reduce_mean`.")
|
||||
|
||||
|
||||
# Deprecated model shortcuts, only used in errors and warnings
|
||||
|
|
|
@ -17,6 +17,8 @@ from thinc.api import (
|
|||
clone,
|
||||
concatenate,
|
||||
list2ragged,
|
||||
reduce_first,
|
||||
reduce_max,
|
||||
reduce_mean,
|
||||
reduce_sum,
|
||||
residual,
|
||||
|
@ -49,39 +51,14 @@ def build_simple_cnn_text_classifier(
|
|||
outputs sum to 1. If exclusive_classes=False, a logistic non-linearity
|
||||
is applied instead, so that outputs are in the range [0, 1].
|
||||
"""
|
||||
fill_defaults = {"b": 0, "W": 0}
|
||||
with Model.define_operators({">>": chain}):
|
||||
cnn = tok2vec >> list2ragged() >> reduce_mean()
|
||||
nI = tok2vec.maybe_get_dim("nO")
|
||||
if exclusive_classes:
|
||||
output_layer = Softmax(nO=nO, nI=nI)
|
||||
fill_defaults["b"] = NEG_VALUE
|
||||
resizable_layer: Model = resizable(
|
||||
output_layer,
|
||||
resize_layer=partial(
|
||||
resize_linear_weighted, fill_defaults=fill_defaults
|
||||
),
|
||||
return build_reduce_text_classifier(
|
||||
tok2vec=tok2vec,
|
||||
exclusive_classes=exclusive_classes,
|
||||
use_reduce_first=False,
|
||||
use_reduce_max=False,
|
||||
use_reduce_mean=True,
|
||||
nO=nO,
|
||||
)
|
||||
model = cnn >> resizable_layer
|
||||
else:
|
||||
output_layer = Linear(nO=nO, nI=nI)
|
||||
resizable_layer = resizable(
|
||||
output_layer,
|
||||
resize_layer=partial(
|
||||
resize_linear_weighted, fill_defaults=fill_defaults
|
||||
),
|
||||
)
|
||||
model = cnn >> resizable_layer >> Logistic()
|
||||
model.set_ref("output_layer", output_layer)
|
||||
model.attrs["resize_output"] = partial(
|
||||
resize_and_set_ref,
|
||||
resizable_layer=resizable_layer,
|
||||
)
|
||||
model.set_ref("tok2vec", tok2vec)
|
||||
if nO is not None:
|
||||
model.set_dim("nO", cast(int, nO))
|
||||
model.attrs["multi_label"] = not exclusive_classes
|
||||
return model
|
||||
|
||||
|
||||
def resize_and_set_ref(model, new_nO, resizable_layer):
|
||||
|
@ -230,3 +207,75 @@ def build_text_classifier_lowdata(
|
|||
model = model >> Dropout(dropout)
|
||||
model = model >> Logistic()
|
||||
return model
|
||||
|
||||
|
||||
@registry.architectures("spacy.TextCatReduce.v1")
|
||||
def build_reduce_text_classifier(
|
||||
tok2vec: Model,
|
||||
exclusive_classes: bool,
|
||||
use_reduce_first: bool,
|
||||
use_reduce_max: bool,
|
||||
use_reduce_mean: bool,
|
||||
nO: Optional[int] = None,
|
||||
) -> Model[List[Doc], Floats2d]:
|
||||
"""Build a model that classifies pooled `Doc` representations.
|
||||
|
||||
Pooling is performed using reductions. Reductions are concatenated when
|
||||
multible reductions are used.
|
||||
|
||||
tok2vec (Model): the tok2vec layer to pool over.
|
||||
exclusive_classes (bool): Whether or not classes are mutually exclusive.
|
||||
use_reduce_first (bool): Pool by using the hidden representation of the
|
||||
first token of a `Doc`
|
||||
use_reduce_max (bool): Pool by taking the maximum values of the hidden
|
||||
representations of a `Doc`.
|
||||
use_reduce_mean (bool): Pool by taking the mean of all hidden
|
||||
representations of a `Doc`.
|
||||
nO (Optional[int]): Number of classes.
|
||||
"""
|
||||
|
||||
fill_defaults = {"b": 0, "W": 0}
|
||||
reductions = []
|
||||
if use_reduce_first:
|
||||
reductions.append(reduce_first())
|
||||
if use_reduce_max:
|
||||
reductions.append(reduce_max())
|
||||
if use_reduce_mean:
|
||||
reductions.append(reduce_mean())
|
||||
|
||||
if not len(reductions):
|
||||
raise ValueError(Errors.E1057)
|
||||
|
||||
with Model.define_operators({">>": chain}):
|
||||
cnn = tok2vec >> list2ragged() >> concatenate(*reductions)
|
||||
nO_tok2vec = tok2vec.maybe_get_dim("nO")
|
||||
nI = nO_tok2vec * len(reductions) if nO_tok2vec is not None else None
|
||||
if exclusive_classes:
|
||||
output_layer = Softmax(nO=nO, nI=nI)
|
||||
fill_defaults["b"] = NEG_VALUE
|
||||
resizable_layer: Model = resizable(
|
||||
output_layer,
|
||||
resize_layer=partial(
|
||||
resize_linear_weighted, fill_defaults=fill_defaults
|
||||
),
|
||||
)
|
||||
model = cnn >> resizable_layer
|
||||
else:
|
||||
output_layer = Linear(nO=nO, nI=nI)
|
||||
resizable_layer = resizable(
|
||||
output_layer,
|
||||
resize_layer=partial(
|
||||
resize_linear_weighted, fill_defaults=fill_defaults
|
||||
),
|
||||
)
|
||||
model = cnn >> resizable_layer >> Logistic()
|
||||
model.set_ref("output_layer", output_layer)
|
||||
model.attrs["resize_output"] = partial(
|
||||
resize_and_set_ref,
|
||||
resizable_layer=resizable_layer,
|
||||
)
|
||||
model.set_ref("tok2vec", tok2vec)
|
||||
if nO is not None:
|
||||
model.set_dim("nO", cast(int, nO))
|
||||
model.attrs["multi_label"] = not exclusive_classes
|
||||
return model
|
||||
|
|
|
@ -457,8 +457,8 @@ def test_no_resize(name, textcat_config):
|
|||
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
|
||||
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
|
||||
# CNN
|
||||
("textcat", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
|
||||
("textcat_multilabel", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
|
||||
("textcat", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}),
|
||||
("textcat_multilabel", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}),
|
||||
],
|
||||
)
|
||||
# fmt: on
|
||||
|
@ -485,9 +485,9 @@ def test_resize(name, textcat_config):
|
|||
("textcat", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": True, "no_output_layer": True, "ngram_size": 3}),
|
||||
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
|
||||
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
|
||||
# CNN
|
||||
("textcat", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
|
||||
("textcat_multilabel", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
|
||||
# REDUCE
|
||||
("textcat", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}),
|
||||
("textcat_multilabel", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}),
|
||||
],
|
||||
)
|
||||
# fmt: on
|
||||
|
@ -701,9 +701,9 @@ def test_overfitting_IO_multi():
|
|||
# ENSEMBLE V2
|
||||
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "ngram_size": 1, "no_output_layer": False}}),
|
||||
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": True, "ngram_size": 5, "no_output_layer": False}}),
|
||||
# CNN V2
|
||||
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
|
||||
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
|
||||
# REDUCE V1
|
||||
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}),
|
||||
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}),
|
||||
],
|
||||
)
|
||||
# fmt: on
|
||||
|
|
|
@ -26,6 +26,8 @@ from spacy.ml.models import (
|
|||
build_Tok2Vec_model,
|
||||
)
|
||||
from spacy.ml.staticvectors import StaticVectors
|
||||
from spacy.pipeline import tok2vec
|
||||
from spacy.util import registry
|
||||
|
||||
|
||||
def get_textcat_bow_kwargs():
|
||||
|
@ -284,3 +286,16 @@ def test_spancat_model_forward_backward(nO=5):
|
|||
Y, backprop = model((docs, spans), is_train=True)
|
||||
assert Y.shape == (spans.dataXd.shape[0], nO)
|
||||
backprop(Y)
|
||||
|
||||
|
||||
def test_textcat_reduce_invalid_args():
|
||||
textcat_reduce = registry.architectures.get("spacy.TextCatReduce.v1")
|
||||
tok2vec = make_test_tok2vec()
|
||||
with pytest.raises(ValueError, match=r"must be used with at least one reduction"):
|
||||
textcat_reduce(
|
||||
tok2vec=tok2vec,
|
||||
exclusive_classes=False,
|
||||
use_reduce_first=False,
|
||||
use_reduce_max=False,
|
||||
use_reduce_mean=False,
|
||||
)
|
||||
|
|
|
@ -1043,6 +1043,9 @@ A neural network model where token vectors are calculated using a CNN. The
|
|||
vectors are mean pooled and used as features in a feed-forward network. This
|
||||
architecture is usually less accurate than the ensemble, but runs faster.
|
||||
|
||||
This model is identical to [TexCatReduce.v1](#TextCatReduce) with
|
||||
`use_reduce_mean=true`.
|
||||
|
||||
| Name | Description |
|
||||
| ------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `exclusive_classes` | Whether or not categories are mutually exclusive. ~~bool~~ |
|
||||
|
@ -1096,6 +1099,44 @@ the others, but may not be as accurate, especially if texts are short.
|
|||
|
||||
</Accordion>
|
||||
|
||||
### spacy.TextCatReduce.v1 {id="TextCatReduce"}
|
||||
|
||||
> #### Example Config
|
||||
>
|
||||
> ```ini
|
||||
> [model]
|
||||
> @architectures = "spacy.TextCatReduce.v1"
|
||||
> exclusive_classes = false
|
||||
> use_reduce_first = false
|
||||
> use_reduce_max = false
|
||||
> use_reduce_mean = true
|
||||
> nO = null
|
||||
>
|
||||
> [model.tok2vec]
|
||||
> @architectures = "spacy.HashEmbedCNN.v2"
|
||||
> pretrained_vectors = null
|
||||
> width = 96
|
||||
> depth = 4
|
||||
> embed_size = 2000
|
||||
> window_size = 1
|
||||
> maxout_pieces = 3
|
||||
> subword_features = true
|
||||
> ```
|
||||
|
||||
A classifier that pools token hidden representations of each `Doc` using first,
|
||||
max or mean reduction and then applies a classification layer. Reductions are
|
||||
concatenated when multiple reductions are used.
|
||||
|
||||
| Name | Description |
|
||||
| ------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `exclusive_classes` | Whether or not categories are mutually exclusive. ~~bool~~ |
|
||||
| `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ |
|
||||
| `use_reduce_first` | Pool by using the hidden representation of the first token of a `Doc`. ~~bool~~ |
|
||||
| `use_reduce_max` | Pool by taking the maximum values of the hidden representations of a `Doc`. ~~bool~~ |
|
||||
| `use_reduce_mean` | Pool by taking the mean of all hidden representations of a `Doc`. ~~bool~~ |
|
||||
| `nO` | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `initialize` is called. ~~Optional[int]~~ |
|
||||
| **CREATES** | The model using the architecture. ~~Model[List[Doc], Floats2d]~~ |
|
||||
|
||||
## Span classification architectures {id="spancat",source="spacy/ml/models/spancat.py"}
|
||||
|
||||
### spacy.SpanCategorizer.v1 {id="SpanCategorizer"}
|
||||
|
|
Loading…
Reference in New Issue
Block a user