mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 16:22:29 +03:00
Add last reduction (use_reduce_last
)
This commit is contained in:
parent
e310ddfd4a
commit
7669a450ec
|
@ -282,6 +282,7 @@ no_output_layer = false
|
||||||
@architectures = "spacy.TextCatReduce.v1"
|
@architectures = "spacy.TextCatReduce.v1"
|
||||||
exclusive_classes = true
|
exclusive_classes = true
|
||||||
use_reduce_first = false
|
use_reduce_first = false
|
||||||
|
use_reduce_last = false
|
||||||
use_reduce_max = false
|
use_reduce_max = false
|
||||||
use_reduce_mean = true
|
use_reduce_mean = true
|
||||||
nO = null
|
nO = null
|
||||||
|
@ -323,6 +324,7 @@ no_output_layer = false
|
||||||
@architectures = "spacy.TextCatReduce.v1"
|
@architectures = "spacy.TextCatReduce.v1"
|
||||||
exclusive_classes = false
|
exclusive_classes = false
|
||||||
use_reduce_first = false
|
use_reduce_first = false
|
||||||
|
use_reduce_last = false
|
||||||
use_reduce_max = false
|
use_reduce_max = false
|
||||||
use_reduce_mean = true
|
use_reduce_mean = true
|
||||||
nO = null
|
nO = null
|
||||||
|
|
|
@ -985,7 +985,8 @@ class Errors(metaclass=ErrorsWithCodes):
|
||||||
"but only callbacks with one or three parameters are supported")
|
"but only callbacks with one or three parameters are supported")
|
||||||
E1056 = ("The `TextCatBOW` architecture expects a length of at least 1, was {length}.")
|
E1056 = ("The `TextCatBOW` architecture expects a length of at least 1, was {length}.")
|
||||||
E1057 = ("The `TextCatReduce` architecture must be used with at least one reduction. "
|
E1057 = ("The `TextCatReduce` architecture must be used with at least one reduction. "
|
||||||
"Please enable one of `use_reduce_first`, `use_reduce_max` or `use_reduce_mean`.")
|
"Please enable one of `use_reduce_first`, `reduce_last`, `use_reduce_max` or "
|
||||||
|
"`use_reduce_mean`.")
|
||||||
|
|
||||||
|
|
||||||
# Deprecated model shortcuts, only used in errors and warnings
|
# Deprecated model shortcuts, only used in errors and warnings
|
||||||
|
|
|
@ -18,6 +18,7 @@ from thinc.api import (
|
||||||
concatenate,
|
concatenate,
|
||||||
list2ragged,
|
list2ragged,
|
||||||
reduce_first,
|
reduce_first,
|
||||||
|
reduce_last,
|
||||||
reduce_max,
|
reduce_max,
|
||||||
reduce_mean,
|
reduce_mean,
|
||||||
reduce_sum,
|
reduce_sum,
|
||||||
|
@ -55,6 +56,7 @@ def build_simple_cnn_text_classifier(
|
||||||
tok2vec=tok2vec,
|
tok2vec=tok2vec,
|
||||||
exclusive_classes=exclusive_classes,
|
exclusive_classes=exclusive_classes,
|
||||||
use_reduce_first=False,
|
use_reduce_first=False,
|
||||||
|
use_reduce_last=False,
|
||||||
use_reduce_max=False,
|
use_reduce_max=False,
|
||||||
use_reduce_mean=True,
|
use_reduce_mean=True,
|
||||||
nO=nO,
|
nO=nO,
|
||||||
|
@ -214,6 +216,7 @@ def build_reduce_text_classifier(
|
||||||
tok2vec: Model,
|
tok2vec: Model,
|
||||||
exclusive_classes: bool,
|
exclusive_classes: bool,
|
||||||
use_reduce_first: bool,
|
use_reduce_first: bool,
|
||||||
|
use_reduce_last: bool,
|
||||||
use_reduce_max: bool,
|
use_reduce_max: bool,
|
||||||
use_reduce_mean: bool,
|
use_reduce_mean: bool,
|
||||||
nO: Optional[int] = None,
|
nO: Optional[int] = None,
|
||||||
|
@ -227,6 +230,8 @@ def build_reduce_text_classifier(
|
||||||
exclusive_classes (bool): Whether or not classes are mutually exclusive.
|
exclusive_classes (bool): Whether or not classes are mutually exclusive.
|
||||||
use_reduce_first (bool): Pool by using the hidden representation of the
|
use_reduce_first (bool): Pool by using the hidden representation of the
|
||||||
first token of a `Doc`.
|
first token of a `Doc`.
|
||||||
|
use_reduce_first (bool): Pool by using the hidden representation of the
|
||||||
|
last token of a `Doc`.
|
||||||
use_reduce_max (bool): Pool by taking the maximum values of the hidden
|
use_reduce_max (bool): Pool by taking the maximum values of the hidden
|
||||||
representations of a `Doc`.
|
representations of a `Doc`.
|
||||||
use_reduce_mean (bool): Pool by taking the mean of all hidden
|
use_reduce_mean (bool): Pool by taking the mean of all hidden
|
||||||
|
@ -238,6 +243,8 @@ def build_reduce_text_classifier(
|
||||||
reductions = []
|
reductions = []
|
||||||
if use_reduce_first:
|
if use_reduce_first:
|
||||||
reductions.append(reduce_first())
|
reductions.append(reduce_first())
|
||||||
|
if use_reduce_last:
|
||||||
|
reductions.append(reduce_last())
|
||||||
if use_reduce_max:
|
if use_reduce_max:
|
||||||
reductions.append(reduce_max())
|
reductions.append(reduce_max())
|
||||||
if use_reduce_mean:
|
if use_reduce_mean:
|
||||||
|
|
|
@ -58,6 +58,7 @@ single_label_cnn_config = """
|
||||||
@architectures = "spacy.TextCatReduce.v1"
|
@architectures = "spacy.TextCatReduce.v1"
|
||||||
exclusive_classes = true
|
exclusive_classes = true
|
||||||
use_reduce_first = false
|
use_reduce_first = false
|
||||||
|
use_reduce_last = false
|
||||||
use_reduce_max = false
|
use_reduce_max = false
|
||||||
use_reduce_mean = true
|
use_reduce_mean = true
|
||||||
|
|
||||||
|
|
|
@ -56,6 +56,7 @@ multi_label_cnn_config = """
|
||||||
@architectures = "spacy.TextCatReduce.v1"
|
@architectures = "spacy.TextCatReduce.v1"
|
||||||
exclusive_classes = false
|
exclusive_classes = false
|
||||||
use_reduce_first = false
|
use_reduce_first = false
|
||||||
|
use_reduce_last = false
|
||||||
use_reduce_max = false
|
use_reduce_max = false
|
||||||
use_reduce_mean = true
|
use_reduce_mean = true
|
||||||
|
|
||||||
|
|
|
@ -457,8 +457,8 @@ def test_no_resize(name, textcat_config):
|
||||||
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
|
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
|
||||||
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
|
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
|
||||||
# CNN
|
# CNN
|
||||||
("textcat", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}),
|
("textcat", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}),
|
||||||
("textcat_multilabel", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}),
|
("textcat_multilabel", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
@ -486,8 +486,8 @@ def test_resize(name, textcat_config):
|
||||||
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
|
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
|
||||||
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
|
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
|
||||||
# REDUCE
|
# REDUCE
|
||||||
("textcat", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}),
|
("textcat", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}),
|
||||||
("textcat_multilabel", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}),
|
("textcat_multilabel", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
@ -705,8 +705,8 @@ def test_overfitting_IO_multi():
|
||||||
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
|
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
|
||||||
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
|
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
|
||||||
# REDUCE V1
|
# REDUCE V1
|
||||||
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}),
|
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}),
|
||||||
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}),
|
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
|
@ -296,6 +296,7 @@ def test_textcat_reduce_invalid_args():
|
||||||
tok2vec=tok2vec,
|
tok2vec=tok2vec,
|
||||||
exclusive_classes=False,
|
exclusive_classes=False,
|
||||||
use_reduce_first=False,
|
use_reduce_first=False,
|
||||||
|
use_reduce_last=False,
|
||||||
use_reduce_max=False,
|
use_reduce_max=False,
|
||||||
use_reduce_mean=False,
|
use_reduce_mean=False,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1065,6 +1065,7 @@ the others, but may not be as accurate, especially if texts are short.
|
||||||
> @architectures = "spacy.TextCatReduce.v1"
|
> @architectures = "spacy.TextCatReduce.v1"
|
||||||
> exclusive_classes = false
|
> exclusive_classes = false
|
||||||
> use_reduce_first = false
|
> use_reduce_first = false
|
||||||
|
> use_reduce_last = false
|
||||||
> use_reduce_max = false
|
> use_reduce_max = false
|
||||||
> use_reduce_mean = true
|
> use_reduce_mean = true
|
||||||
> nO = null
|
> nO = null
|
||||||
|
@ -1097,6 +1098,7 @@ reduction, whereas `TextCatReduce` also supports first/max reductions.
|
||||||
| `exclusive_classes` | Whether or not categories are mutually exclusive. ~~bool~~ |
|
| `exclusive_classes` | Whether or not categories are mutually exclusive. ~~bool~~ |
|
||||||
| `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ |
|
| `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ |
|
||||||
| `use_reduce_first` | Pool by using the hidden representation of the first token of a `Doc`. ~~bool~~ |
|
| `use_reduce_first` | Pool by using the hidden representation of the first token of a `Doc`. ~~bool~~ |
|
||||||
|
| `use_reduce_last` | Pool by using the hidden representation of the last token of a `Doc`. ~~bool~~ |
|
||||||
| `use_reduce_max` | Pool by taking the maximum values of the hidden representations of a `Doc`. ~~bool~~ |
|
| `use_reduce_max` | Pool by taking the maximum values of the hidden representations of a `Doc`. ~~bool~~ |
|
||||||
| `use_reduce_mean` | Pool by taking the mean of all hidden representations of a `Doc`. ~~bool~~ |
|
| `use_reduce_mean` | Pool by taking the mean of all hidden representations of a `Doc`. ~~bool~~ |
|
||||||
| `nO` | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `initialize` is called. ~~Optional[int]~~ |
|
| `nO` | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `initialize` is called. ~~Optional[int]~~ |
|
||||||
|
|
|
@ -165,7 +165,7 @@ that you can add labels to a previously trained textcat. `TextCatCNN` v1 did not
|
||||||
yet support that. `TextCatCNN` has been replaced by the more general
|
yet support that. `TextCatCNN` has been replaced by the more general
|
||||||
[`TextCatReduce`](/api/architectures#TextCatReduce) layer. `TextCatCNN` is
|
[`TextCatReduce`](/api/architectures#TextCatReduce) layer. `TextCatCNN` is
|
||||||
identical to `TextCatReduce` with `use_reduce_mean=true`,
|
identical to `TextCatReduce` with `use_reduce_mean=true`,
|
||||||
`use_reduce_first=false` and `use_reduce_max=false`.
|
`use_reduce_first=false`, `reduce_last=false` and `use_reduce_max=false`.
|
||||||
|
|
||||||
> #### Example Config
|
> #### Example Config
|
||||||
>
|
>
|
||||||
|
@ -225,7 +225,7 @@ architecture is usually less accurate than the ensemble, but runs faster.
|
||||||
`TextCatCNN` has been replaced by the more general
|
`TextCatCNN` has been replaced by the more general
|
||||||
[`TextCatReduce`](/api/architectures#TextCatReduce) layer. `TextCatCNN` is
|
[`TextCatReduce`](/api/architectures#TextCatReduce) layer. `TextCatCNN` is
|
||||||
identical to `TextCatReduce` with `use_reduce_mean=true`,
|
identical to `TextCatReduce` with `use_reduce_mean=true`,
|
||||||
`use_reduce_first=false` and `use_reduce_max=false`.
|
`use_reduce_first=false`, `reduce_last=false` and `use_reduce_max=false`.
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| ------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| ------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
|
Loading…
Reference in New Issue
Block a user