Add last reduction (use_reduce_last)

This commit is contained in:
Daniël de Kok 2023-12-11 10:13:53 +01:00
parent e310ddfd4a
commit 7669a450ec
9 changed files with 24 additions and 9 deletions

View File

@ -282,6 +282,7 @@ no_output_layer = false
@architectures = "spacy.TextCatReduce.v1" @architectures = "spacy.TextCatReduce.v1"
exclusive_classes = true exclusive_classes = true
use_reduce_first = false use_reduce_first = false
use_reduce_last = false
use_reduce_max = false use_reduce_max = false
use_reduce_mean = true use_reduce_mean = true
nO = null nO = null
@ -323,6 +324,7 @@ no_output_layer = false
@architectures = "spacy.TextCatReduce.v1" @architectures = "spacy.TextCatReduce.v1"
exclusive_classes = false exclusive_classes = false
use_reduce_first = false use_reduce_first = false
use_reduce_last = false
use_reduce_max = false use_reduce_max = false
use_reduce_mean = true use_reduce_mean = true
nO = null nO = null

View File

@ -985,7 +985,8 @@ class Errors(metaclass=ErrorsWithCodes):
"but only callbacks with one or three parameters are supported") "but only callbacks with one or three parameters are supported")
E1056 = ("The `TextCatBOW` architecture expects a length of at least 1, was {length}.") E1056 = ("The `TextCatBOW` architecture expects a length of at least 1, was {length}.")
E1057 = ("The `TextCatReduce` architecture must be used with at least one reduction. " E1057 = ("The `TextCatReduce` architecture must be used with at least one reduction. "
"Please enable one of `use_reduce_first`, `use_reduce_max` or `use_reduce_mean`.") "Please enable one of `use_reduce_first`, `reduce_last`, `use_reduce_max` or "
"`use_reduce_mean`.")
# Deprecated model shortcuts, only used in errors and warnings # Deprecated model shortcuts, only used in errors and warnings

View File

@ -18,6 +18,7 @@ from thinc.api import (
concatenate, concatenate,
list2ragged, list2ragged,
reduce_first, reduce_first,
reduce_last,
reduce_max, reduce_max,
reduce_mean, reduce_mean,
reduce_sum, reduce_sum,
@ -55,6 +56,7 @@ def build_simple_cnn_text_classifier(
tok2vec=tok2vec, tok2vec=tok2vec,
exclusive_classes=exclusive_classes, exclusive_classes=exclusive_classes,
use_reduce_first=False, use_reduce_first=False,
use_reduce_last=False,
use_reduce_max=False, use_reduce_max=False,
use_reduce_mean=True, use_reduce_mean=True,
nO=nO, nO=nO,
@ -214,6 +216,7 @@ def build_reduce_text_classifier(
tok2vec: Model, tok2vec: Model,
exclusive_classes: bool, exclusive_classes: bool,
use_reduce_first: bool, use_reduce_first: bool,
use_reduce_last: bool,
use_reduce_max: bool, use_reduce_max: bool,
use_reduce_mean: bool, use_reduce_mean: bool,
nO: Optional[int] = None, nO: Optional[int] = None,
@ -227,6 +230,8 @@ def build_reduce_text_classifier(
exclusive_classes (bool): Whether or not classes are mutually exclusive. exclusive_classes (bool): Whether or not classes are mutually exclusive.
use_reduce_first (bool): Pool by using the hidden representation of the use_reduce_first (bool): Pool by using the hidden representation of the
first token of a `Doc`. first token of a `Doc`.
use_reduce_first (bool): Pool by using the hidden representation of the
last token of a `Doc`.
use_reduce_max (bool): Pool by taking the maximum values of the hidden use_reduce_max (bool): Pool by taking the maximum values of the hidden
representations of a `Doc`. representations of a `Doc`.
use_reduce_mean (bool): Pool by taking the mean of all hidden use_reduce_mean (bool): Pool by taking the mean of all hidden
@ -238,6 +243,8 @@ def build_reduce_text_classifier(
reductions = [] reductions = []
if use_reduce_first: if use_reduce_first:
reductions.append(reduce_first()) reductions.append(reduce_first())
if use_reduce_last:
reductions.append(reduce_last())
if use_reduce_max: if use_reduce_max:
reductions.append(reduce_max()) reductions.append(reduce_max())
if use_reduce_mean: if use_reduce_mean:

View File

@ -58,6 +58,7 @@ single_label_cnn_config = """
@architectures = "spacy.TextCatReduce.v1" @architectures = "spacy.TextCatReduce.v1"
exclusive_classes = true exclusive_classes = true
use_reduce_first = false use_reduce_first = false
use_reduce_last = false
use_reduce_max = false use_reduce_max = false
use_reduce_mean = true use_reduce_mean = true

View File

@ -56,6 +56,7 @@ multi_label_cnn_config = """
@architectures = "spacy.TextCatReduce.v1" @architectures = "spacy.TextCatReduce.v1"
exclusive_classes = false exclusive_classes = false
use_reduce_first = false use_reduce_first = false
use_reduce_last = false
use_reduce_max = false use_reduce_max = false
use_reduce_mean = true use_reduce_mean = true

View File

@ -457,8 +457,8 @@ def test_no_resize(name, textcat_config):
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}), ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}), ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
# CNN # CNN
("textcat", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}), ("textcat", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}),
("textcat_multilabel", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}), ("textcat_multilabel", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}),
], ],
) )
# fmt: on # fmt: on
@ -486,8 +486,8 @@ def test_resize(name, textcat_config):
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}), ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}), ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
# REDUCE # REDUCE
("textcat", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}), ("textcat", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}),
("textcat_multilabel", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}), ("textcat_multilabel", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}),
], ],
) )
# fmt: on # fmt: on
@ -705,8 +705,8 @@ def test_overfitting_IO_multi():
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}), ("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}), ("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
# REDUCE V1 # REDUCE V1
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}), ("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}),
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}), ("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}),
], ],
) )
# fmt: on # fmt: on

View File

@ -296,6 +296,7 @@ def test_textcat_reduce_invalid_args():
tok2vec=tok2vec, tok2vec=tok2vec,
exclusive_classes=False, exclusive_classes=False,
use_reduce_first=False, use_reduce_first=False,
use_reduce_last=False,
use_reduce_max=False, use_reduce_max=False,
use_reduce_mean=False, use_reduce_mean=False,
) )

View File

@ -1065,6 +1065,7 @@ the others, but may not be as accurate, especially if texts are short.
> @architectures = "spacy.TextCatReduce.v1" > @architectures = "spacy.TextCatReduce.v1"
> exclusive_classes = false > exclusive_classes = false
> use_reduce_first = false > use_reduce_first = false
> use_reduce_last = false
> use_reduce_max = false > use_reduce_max = false
> use_reduce_mean = true > use_reduce_mean = true
> nO = null > nO = null
@ -1097,6 +1098,7 @@ reduction, whereas `TextCatReduce` also supports first/max reductions.
| `exclusive_classes` | Whether or not categories are mutually exclusive. ~~bool~~ | | `exclusive_classes` | Whether or not categories are mutually exclusive. ~~bool~~ |
| `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ | | `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ |
| `use_reduce_first` | Pool by using the hidden representation of the first token of a `Doc`. ~~bool~~ | | `use_reduce_first` | Pool by using the hidden representation of the first token of a `Doc`. ~~bool~~ |
| `use_reduce_last` | Pool by using the hidden representation of the last token of a `Doc`. ~~bool~~ |
| `use_reduce_max` | Pool by taking the maximum values of the hidden representations of a `Doc`. ~~bool~~ | | `use_reduce_max` | Pool by taking the maximum values of the hidden representations of a `Doc`. ~~bool~~ |
| `use_reduce_mean` | Pool by taking the mean of all hidden representations of a `Doc`. ~~bool~~ | | `use_reduce_mean` | Pool by taking the mean of all hidden representations of a `Doc`. ~~bool~~ |
| `nO` | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `initialize` is called. ~~Optional[int]~~ | | `nO` | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `initialize` is called. ~~Optional[int]~~ |

View File

@ -165,7 +165,7 @@ that you can add labels to a previously trained textcat. `TextCatCNN` v1 did not
yet support that. `TextCatCNN` has been replaced by the more general yet support that. `TextCatCNN` has been replaced by the more general
[`TextCatReduce`](/api/architectures#TextCatReduce) layer. `TextCatCNN` is [`TextCatReduce`](/api/architectures#TextCatReduce) layer. `TextCatCNN` is
identical to `TextCatReduce` with `use_reduce_mean=true`, identical to `TextCatReduce` with `use_reduce_mean=true`,
`use_reduce_first=false` and `use_reduce_max=false`. `use_reduce_first=false`, `reduce_last=false` and `use_reduce_max=false`.
> #### Example Config > #### Example Config
> >
@ -225,7 +225,7 @@ architecture is usually less accurate than the ensemble, but runs faster.
`TextCatCNN` has been replaced by the more general `TextCatCNN` has been replaced by the more general
[`TextCatReduce`](/api/architectures#TextCatReduce) layer. `TextCatCNN` is [`TextCatReduce`](/api/architectures#TextCatReduce) layer. `TextCatCNN` is
identical to `TextCatReduce` with `use_reduce_mean=true`, identical to `TextCatReduce` with `use_reduce_mean=true`,
`use_reduce_first=false` and `use_reduce_max=false`. `use_reduce_first=false`, `reduce_last=false` and `use_reduce_max=false`.
| Name | Description | | Name | Description |
| ------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |