mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Add last reduction (use_reduce_last)
				
					
				
			This commit is contained in:
		
							parent
							
								
									e310ddfd4a
								
							
						
					
					
						commit
						7669a450ec
					
				| 
						 | 
					@ -282,6 +282,7 @@ no_output_layer = false
 | 
				
			||||||
@architectures = "spacy.TextCatReduce.v1"
 | 
					@architectures = "spacy.TextCatReduce.v1"
 | 
				
			||||||
exclusive_classes = true
 | 
					exclusive_classes = true
 | 
				
			||||||
use_reduce_first = false
 | 
					use_reduce_first = false
 | 
				
			||||||
 | 
					use_reduce_last = false
 | 
				
			||||||
use_reduce_max = false
 | 
					use_reduce_max = false
 | 
				
			||||||
use_reduce_mean = true
 | 
					use_reduce_mean = true
 | 
				
			||||||
nO = null
 | 
					nO = null
 | 
				
			||||||
| 
						 | 
					@ -323,6 +324,7 @@ no_output_layer = false
 | 
				
			||||||
@architectures = "spacy.TextCatReduce.v1"
 | 
					@architectures = "spacy.TextCatReduce.v1"
 | 
				
			||||||
exclusive_classes = false
 | 
					exclusive_classes = false
 | 
				
			||||||
use_reduce_first = false
 | 
					use_reduce_first = false
 | 
				
			||||||
 | 
					use_reduce_last = false
 | 
				
			||||||
use_reduce_max = false
 | 
					use_reduce_max = false
 | 
				
			||||||
use_reduce_mean = true
 | 
					use_reduce_mean = true
 | 
				
			||||||
nO = null
 | 
					nO = null
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -985,7 +985,8 @@ class Errors(metaclass=ErrorsWithCodes):
 | 
				
			||||||
             "but only callbacks with one or three parameters are supported")
 | 
					             "but only callbacks with one or three parameters are supported")
 | 
				
			||||||
    E1056 = ("The `TextCatBOW` architecture expects a length of at least 1, was {length}.")
 | 
					    E1056 = ("The `TextCatBOW` architecture expects a length of at least 1, was {length}.")
 | 
				
			||||||
    E1057 = ("The `TextCatReduce` architecture must be used with at least one reduction. "
 | 
					    E1057 = ("The `TextCatReduce` architecture must be used with at least one reduction. "
 | 
				
			||||||
             "Please enable one of `use_reduce_first`, `use_reduce_max` or `use_reduce_mean`.")
 | 
					             "Please enable one of `use_reduce_first`, `reduce_last`, `use_reduce_max` or "
 | 
				
			||||||
 | 
					             "`use_reduce_mean`.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Deprecated model shortcuts, only used in errors and warnings
 | 
					# Deprecated model shortcuts, only used in errors and warnings
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -18,6 +18,7 @@ from thinc.api import (
 | 
				
			||||||
    concatenate,
 | 
					    concatenate,
 | 
				
			||||||
    list2ragged,
 | 
					    list2ragged,
 | 
				
			||||||
    reduce_first,
 | 
					    reduce_first,
 | 
				
			||||||
 | 
					    reduce_last,
 | 
				
			||||||
    reduce_max,
 | 
					    reduce_max,
 | 
				
			||||||
    reduce_mean,
 | 
					    reduce_mean,
 | 
				
			||||||
    reduce_sum,
 | 
					    reduce_sum,
 | 
				
			||||||
| 
						 | 
					@ -55,6 +56,7 @@ def build_simple_cnn_text_classifier(
 | 
				
			||||||
        tok2vec=tok2vec,
 | 
					        tok2vec=tok2vec,
 | 
				
			||||||
        exclusive_classes=exclusive_classes,
 | 
					        exclusive_classes=exclusive_classes,
 | 
				
			||||||
        use_reduce_first=False,
 | 
					        use_reduce_first=False,
 | 
				
			||||||
 | 
					        use_reduce_last=False,
 | 
				
			||||||
        use_reduce_max=False,
 | 
					        use_reduce_max=False,
 | 
				
			||||||
        use_reduce_mean=True,
 | 
					        use_reduce_mean=True,
 | 
				
			||||||
        nO=nO,
 | 
					        nO=nO,
 | 
				
			||||||
| 
						 | 
					@ -214,6 +216,7 @@ def build_reduce_text_classifier(
 | 
				
			||||||
    tok2vec: Model,
 | 
					    tok2vec: Model,
 | 
				
			||||||
    exclusive_classes: bool,
 | 
					    exclusive_classes: bool,
 | 
				
			||||||
    use_reduce_first: bool,
 | 
					    use_reduce_first: bool,
 | 
				
			||||||
 | 
					    use_reduce_last: bool,
 | 
				
			||||||
    use_reduce_max: bool,
 | 
					    use_reduce_max: bool,
 | 
				
			||||||
    use_reduce_mean: bool,
 | 
					    use_reduce_mean: bool,
 | 
				
			||||||
    nO: Optional[int] = None,
 | 
					    nO: Optional[int] = None,
 | 
				
			||||||
| 
						 | 
					@ -227,6 +230,8 @@ def build_reduce_text_classifier(
 | 
				
			||||||
    exclusive_classes (bool): Whether or not classes are mutually exclusive.
 | 
					    exclusive_classes (bool): Whether or not classes are mutually exclusive.
 | 
				
			||||||
    use_reduce_first (bool): Pool by using the hidden representation of the
 | 
					    use_reduce_first (bool): Pool by using the hidden representation of the
 | 
				
			||||||
        first token of a `Doc`.
 | 
					        first token of a `Doc`.
 | 
				
			||||||
 | 
					    use_reduce_first (bool): Pool by using the hidden representation of the
 | 
				
			||||||
 | 
					        last token of a `Doc`.
 | 
				
			||||||
    use_reduce_max (bool): Pool by taking the maximum values of the hidden
 | 
					    use_reduce_max (bool): Pool by taking the maximum values of the hidden
 | 
				
			||||||
        representations of a `Doc`.
 | 
					        representations of a `Doc`.
 | 
				
			||||||
    use_reduce_mean (bool): Pool by taking the mean of all hidden
 | 
					    use_reduce_mean (bool): Pool by taking the mean of all hidden
 | 
				
			||||||
| 
						 | 
					@ -238,6 +243,8 @@ def build_reduce_text_classifier(
 | 
				
			||||||
    reductions = []
 | 
					    reductions = []
 | 
				
			||||||
    if use_reduce_first:
 | 
					    if use_reduce_first:
 | 
				
			||||||
        reductions.append(reduce_first())
 | 
					        reductions.append(reduce_first())
 | 
				
			||||||
 | 
					    if use_reduce_last:
 | 
				
			||||||
 | 
					        reductions.append(reduce_last())
 | 
				
			||||||
    if use_reduce_max:
 | 
					    if use_reduce_max:
 | 
				
			||||||
        reductions.append(reduce_max())
 | 
					        reductions.append(reduce_max())
 | 
				
			||||||
    if use_reduce_mean:
 | 
					    if use_reduce_mean:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -58,6 +58,7 @@ single_label_cnn_config = """
 | 
				
			||||||
@architectures = "spacy.TextCatReduce.v1"
 | 
					@architectures = "spacy.TextCatReduce.v1"
 | 
				
			||||||
exclusive_classes = true
 | 
					exclusive_classes = true
 | 
				
			||||||
use_reduce_first = false
 | 
					use_reduce_first = false
 | 
				
			||||||
 | 
					use_reduce_last = false
 | 
				
			||||||
use_reduce_max = false
 | 
					use_reduce_max = false
 | 
				
			||||||
use_reduce_mean = true
 | 
					use_reduce_mean = true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -56,6 +56,7 @@ multi_label_cnn_config = """
 | 
				
			||||||
@architectures = "spacy.TextCatReduce.v1"
 | 
					@architectures = "spacy.TextCatReduce.v1"
 | 
				
			||||||
exclusive_classes = false
 | 
					exclusive_classes = false
 | 
				
			||||||
use_reduce_first = false
 | 
					use_reduce_first = false
 | 
				
			||||||
 | 
					use_reduce_last = false
 | 
				
			||||||
use_reduce_max = false
 | 
					use_reduce_max = false
 | 
				
			||||||
use_reduce_mean = true
 | 
					use_reduce_mean = true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -457,8 +457,8 @@ def test_no_resize(name, textcat_config):
 | 
				
			||||||
        ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
 | 
					        ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
 | 
				
			||||||
        ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
 | 
					        ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
 | 
				
			||||||
        # CNN
 | 
					        # CNN
 | 
				
			||||||
        ("textcat", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}),
 | 
					        ("textcat", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}),
 | 
				
			||||||
        ("textcat_multilabel", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}),
 | 
					        ("textcat_multilabel", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}),
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
# fmt: on
 | 
					# fmt: on
 | 
				
			||||||
| 
						 | 
					@ -486,8 +486,8 @@ def test_resize(name, textcat_config):
 | 
				
			||||||
        ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
 | 
					        ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
 | 
				
			||||||
        ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
 | 
					        ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v3", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
 | 
				
			||||||
        # REDUCE
 | 
					        # REDUCE
 | 
				
			||||||
        ("textcat", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}),
 | 
					        ("textcat", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}),
 | 
				
			||||||
        ("textcat_multilabel", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}),
 | 
					        ("textcat_multilabel", {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}),
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
# fmt: on
 | 
					# fmt: on
 | 
				
			||||||
| 
						 | 
					@ -705,8 +705,8 @@ def test_overfitting_IO_multi():
 | 
				
			||||||
        ("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
 | 
					        ("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
 | 
				
			||||||
        ("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
 | 
					        ("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
 | 
				
			||||||
        # REDUCE V1
 | 
					        # REDUCE V1
 | 
				
			||||||
        ("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}),
 | 
					        ("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}),
 | 
				
			||||||
        ("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_max": True, "use_reduce_mean": True}),
 | 
					        ("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}),
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
# fmt: on
 | 
					# fmt: on
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -296,6 +296,7 @@ def test_textcat_reduce_invalid_args():
 | 
				
			||||||
            tok2vec=tok2vec,
 | 
					            tok2vec=tok2vec,
 | 
				
			||||||
            exclusive_classes=False,
 | 
					            exclusive_classes=False,
 | 
				
			||||||
            use_reduce_first=False,
 | 
					            use_reduce_first=False,
 | 
				
			||||||
 | 
					            use_reduce_last=False,
 | 
				
			||||||
            use_reduce_max=False,
 | 
					            use_reduce_max=False,
 | 
				
			||||||
            use_reduce_mean=False,
 | 
					            use_reduce_mean=False,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1065,6 +1065,7 @@ the others, but may not be as accurate, especially if texts are short.
 | 
				
			||||||
> @architectures = "spacy.TextCatReduce.v1"
 | 
					> @architectures = "spacy.TextCatReduce.v1"
 | 
				
			||||||
> exclusive_classes = false
 | 
					> exclusive_classes = false
 | 
				
			||||||
> use_reduce_first = false
 | 
					> use_reduce_first = false
 | 
				
			||||||
 | 
					> use_reduce_last = false
 | 
				
			||||||
> use_reduce_max = false
 | 
					> use_reduce_max = false
 | 
				
			||||||
> use_reduce_mean = true
 | 
					> use_reduce_mean = true
 | 
				
			||||||
> nO = null
 | 
					> nO = null
 | 
				
			||||||
| 
						 | 
					@ -1097,6 +1098,7 @@ reduction, whereas `TextCatReduce` also supports first/max reductions.
 | 
				
			||||||
| `exclusive_classes` | Whether or not categories are mutually exclusive. ~~bool~~                                                                                                                                     |
 | 
					| `exclusive_classes` | Whether or not categories are mutually exclusive. ~~bool~~                                                                                                                                     |
 | 
				
			||||||
| `tok2vec`           | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~                                                                                                                                        |
 | 
					| `tok2vec`           | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~                                                                                                                                        |
 | 
				
			||||||
| `use_reduce_first`  | Pool by using the hidden representation of the first token of a `Doc`. ~~bool~~                                                                                                                |
 | 
					| `use_reduce_first`  | Pool by using the hidden representation of the first token of a `Doc`. ~~bool~~                                                                                                                |
 | 
				
			||||||
 | 
					| `use_reduce_last`   | Pool by using the hidden representation of the last token of a `Doc`. ~~bool~~                                                                                                                 |
 | 
				
			||||||
| `use_reduce_max`    | Pool by taking the maximum values of the hidden representations of a `Doc`. ~~bool~~                                                                                                           |
 | 
					| `use_reduce_max`    | Pool by taking the maximum values of the hidden representations of a `Doc`. ~~bool~~                                                                                                           |
 | 
				
			||||||
| `use_reduce_mean`   | Pool by taking the mean of all hidden representations of a `Doc`. ~~bool~~                                                                                                                     |
 | 
					| `use_reduce_mean`   | Pool by taking the mean of all hidden representations of a `Doc`. ~~bool~~                                                                                                                     |
 | 
				
			||||||
| `nO`                | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `initialize` is called. ~~Optional[int]~~ |
 | 
					| `nO`                | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `initialize` is called. ~~Optional[int]~~ |
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -165,7 +165,7 @@ that you can add labels to a previously trained textcat. `TextCatCNN` v1 did not
 | 
				
			||||||
yet support that. `TextCatCNN` has been replaced by the more general
 | 
					yet support that. `TextCatCNN` has been replaced by the more general
 | 
				
			||||||
[`TextCatReduce`](/api/architectures#TextCatReduce) layer. `TextCatCNN` is
 | 
					[`TextCatReduce`](/api/architectures#TextCatReduce) layer. `TextCatCNN` is
 | 
				
			||||||
identical to `TextCatReduce` with `use_reduce_mean=true`,
 | 
					identical to `TextCatReduce` with `use_reduce_mean=true`,
 | 
				
			||||||
`use_reduce_first=false` and `use_reduce_max=false`.
 | 
					`use_reduce_first=false`, `reduce_last=false` and `use_reduce_max=false`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
> #### Example Config
 | 
					> #### Example Config
 | 
				
			||||||
>
 | 
					>
 | 
				
			||||||
| 
						 | 
					@ -225,7 +225,7 @@ architecture is usually less accurate than the ensemble, but runs faster.
 | 
				
			||||||
`TextCatCNN` has been replaced by the more general
 | 
					`TextCatCNN` has been replaced by the more general
 | 
				
			||||||
[`TextCatReduce`](/api/architectures#TextCatReduce) layer. `TextCatCNN` is
 | 
					[`TextCatReduce`](/api/architectures#TextCatReduce) layer. `TextCatCNN` is
 | 
				
			||||||
identical to `TextCatReduce` with `use_reduce_mean=true`,
 | 
					identical to `TextCatReduce` with `use_reduce_mean=true`,
 | 
				
			||||||
`use_reduce_first=false` and `use_reduce_max=false`.
 | 
					`use_reduce_first=false`, `reduce_last=false` and `use_reduce_max=false`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| Name                | Description                                                                                                                                                                                    |
 | 
					| Name                | Description                                                                                                                                                                                    |
 | 
				
			||||||
| ------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
 | 
					| ------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user