diff --git a/spacy/cli/templates/quickstart_training.jinja b/spacy/cli/templates/quickstart_training.jinja index 5d5c1770a..f2b9777cd 100644 --- a/spacy/cli/templates/quickstart_training.jinja +++ b/spacy/cli/templates/quickstart_training.jinja @@ -282,6 +282,7 @@ no_output_layer = false @architectures = "spacy.TextCatReduce.v1" exclusive_classes = true use_reduce_first = false +use_reduce_last = false use_reduce_max = false use_reduce_mean = true nO = null @@ -323,6 +324,7 @@ no_output_layer = false @architectures = "spacy.TextCatReduce.v1" exclusive_classes = false use_reduce_first = false +use_reduce_last = false use_reduce_max = false use_reduce_mean = true nO = null diff --git a/spacy/errors.py b/spacy/errors.py index 28f34e266..2455c2f86 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -985,7 +985,8 @@ class Errors(metaclass=ErrorsWithCodes): "but only callbacks with one or three parameters are supported") E1056 = ("The `TextCatBOW` architecture expects a length of at least 1, was {length}.") E1057 = ("The `TextCatReduce` architecture must be used with at least one reduction. " - "Please enable one of `use_reduce_first`, `use_reduce_max` or `use_reduce_mean`.") + "Please enable one of `use_reduce_first`, `reduce_last`, `use_reduce_max` or " + "`use_reduce_mean`.") # Deprecated model shortcuts, only used in errors and warnings diff --git a/spacy/ml/models/textcat.py b/spacy/ml/models/textcat.py index 95362d45b..e0dcb47aa 100644 --- a/spacy/ml/models/textcat.py +++ b/spacy/ml/models/textcat.py @@ -18,6 +18,7 @@ from thinc.api import ( concatenate, list2ragged, reduce_first, + reduce_last, reduce_max, reduce_mean, reduce_sum, @@ -55,6 +56,7 @@ def build_simple_cnn_text_classifier( tok2vec=tok2vec, exclusive_classes=exclusive_classes, use_reduce_first=False, + use_reduce_last=False, use_reduce_max=False, use_reduce_mean=True, nO=nO, @@ -214,6 +216,7 @@ def build_reduce_text_classifier( tok2vec: Model, exclusive_classes: bool, use_reduce_first: bool, + use_reduce_last: bool, use_reduce_max: bool, use_reduce_mean: bool, nO: Optional[int] = None, @@ -227,6 +230,8 @@ def build_reduce_text_classifier( exclusive_classes (bool): Whether or not classes are mutually exclusive. use_reduce_first (bool): Pool by using the hidden representation of the first token of a `Doc`. + use_reduce_first (bool): Pool by using the hidden representation of the + last token of a `Doc`. use_reduce_max (bool): Pool by taking the maximum values of the hidden representations of a `Doc`. use_reduce_mean (bool): Pool by taking the mean of all hidden @@ -238,6 +243,8 @@ def build_reduce_text_classifier( reductions = [] if use_reduce_first: reductions.append(reduce_first()) + if use_reduce_last: + reductions.append(reduce_last()) if use_reduce_max: reductions.append(reduce_max()) if use_reduce_mean: diff --git a/spacy/pipeline/textcat.py b/spacy/pipeline/textcat.py index 964a772c7..ae227017a 100644 --- a/spacy/pipeline/textcat.py +++ b/spacy/pipeline/textcat.py @@ -58,6 +58,7 @@ single_label_cnn_config = """ @architectures = "spacy.TextCatReduce.v1" exclusive_classes = true use_reduce_first = false +use_reduce_last = false use_reduce_max = false use_reduce_mean = true diff --git a/spacy/pipeline/textcat_multilabel.py b/spacy/pipeline/textcat_multilabel.py index 1183aaff3..2f8d5e604 100644 --- a/spacy/pipeline/textcat_multilabel.py +++ b/spacy/pipeline/textcat_multilabel.py @@ -56,6 +56,7 @@ multi_label_cnn_config = """ @architectures = "spacy.TextCatReduce.v1" exclusive_classes = false use_reduce_first = false +use_reduce_last = false use_reduce_max = false use_reduce_mean = true diff --git a/spacy/tests/pipeline/test_textcat.py b/spacy/tests/pipeline/test_textcat.py index 3eee067eb..5dff8d124 100644 --- a/spacy/tests/pipeline/test_textcat.py +++ b/spacy/tests/pipeline/test_textcat.py @@ -457,8 +457,8 @@ def test_no_resize(name, textcat_config): ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}), ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}), # CNN - ("textcat", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}), - ("textcat_multilabel", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}), + ("textcat", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}), + ("textcat_multilabel", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}), ], ) # fmt: on @@ -486,8 +486,8 @@ def test_resize(name, textcat_config): ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}), ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}), # REDUCE - ("textcat", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}), - ("textcat_multilabel", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}), + ("textcat", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}), + ("textcat_multilabel", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}), ], ) # fmt: on @@ -705,8 +705,8 @@ def test_overfitting_IO_multi(): ("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}), ("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}), # REDUCE V1 - ("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}), - ("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}), + ("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}), + ("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}), ], ) # fmt: on diff --git a/spacy/tests/test_models.py b/spacy/tests/test_models.py index ce033a274..5ce9da508 100644 --- a/spacy/tests/test_models.py +++ b/spacy/tests/test_models.py @@ -296,6 +296,7 @@ def test_textcat_reduce_invalid_args(): tok2vec=tok2vec, exclusive_classes=False, use_reduce_first=False, + use_reduce_last=False, use_reduce_max=False, use_reduce_mean=False, ) diff --git a/website/docs/api/architectures.mdx b/website/docs/api/architectures.mdx index 1e5f50512..63f723a28 100644 --- a/website/docs/api/architectures.mdx +++ b/website/docs/api/architectures.mdx @@ -1065,6 +1065,7 @@ the others, but may not be as accurate, especially if texts are short. > @architectures = "spacy.TextCatReduce.v1" > exclusive_classes = false > use_reduce_first = false +> use_reduce_last = false > use_reduce_max = false > use_reduce_mean = true > nO = null @@ -1097,6 +1098,7 @@ reduction, whereas `TextCatReduce` also supports first/max reductions. | `exclusive_classes` | Whether or not categories are mutually exclusive. ~~bool~~ | | `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ | | `use_reduce_first` | Pool by using the hidden representation of the first token of a `Doc`. ~~bool~~ | +| `use_reduce_last` | Pool by using the hidden representation of the last token of a `Doc`. ~~bool~~ | | `use_reduce_max` | Pool by taking the maximum values of the hidden representations of a `Doc`. ~~bool~~ | | `use_reduce_mean` | Pool by taking the mean of all hidden representations of a `Doc`. ~~bool~~ | | `nO` | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `initialize` is called. ~~Optional[int]~~ | diff --git a/website/docs/api/legacy.mdx b/website/docs/api/legacy.mdx index 5fdc791c2..b44df5387 100644 --- a/website/docs/api/legacy.mdx +++ b/website/docs/api/legacy.mdx @@ -165,7 +165,7 @@ that you can add labels to a previously trained textcat. `TextCatCNN` v1 did not yet support that. `TextCatCNN` has been replaced by the more general [`TextCatReduce`](/api/architectures#TextCatReduce) layer. `TextCatCNN` is identical to `TextCatReduce` with `use_reduce_mean=true`, -`use_reduce_first=false` and `use_reduce_max=false`. +`use_reduce_first=false`, `reduce_last=false` and `use_reduce_max=false`. > #### Example Config > @@ -225,7 +225,7 @@ architecture is usually less accurate than the ensemble, but runs faster. `TextCatCNN` has been replaced by the more general [`TextCatReduce`](/api/architectures#TextCatReduce) layer. `TextCatCNN` is identical to `TextCatReduce` with `use_reduce_mean=true`, -`use_reduce_first=false` and `use_reduce_max=false`. +`use_reduce_first=false`, `reduce_last=false` and `use_reduce_max=false`. | Name | Description | | ------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |