diff --git a/spacy/cli/templates/quickstart_training.jinja b/spacy/cli/templates/quickstart_training.jinja index a0ffa8f52..9a8b9d1d7 100644 --- a/spacy/cli/templates/quickstart_training.jinja +++ b/spacy/cli/templates/quickstart_training.jinja @@ -317,21 +317,3 @@ start = 100 stop = 1000 compound = 1.001 {% endif %} - -[training.score_weights] -{%- if "tagger" in components %} -tag_acc = {{ (1.0 / components|length)|round(2) }} -{%- endif -%} -{%- if "parser" in components %} -dep_uas = 0.0 -dep_las = {{ (1.0 / components|length)|round(2) }} -sents_f = 0.0 -{%- endif %} -{%- if "ner" in components %} -ents_f = {{ (1.0 / components|length)|round(2) }} -ents_p = 0.0 -ents_r = 0.0 -{%- endif %} -{%- if "textcat" in components %} -cats_score = {{ (1.0 / components|length)|round(2) }} -{%- endif -%} diff --git a/spacy/cli/train.py b/spacy/cli/train.py index 2900ef379..3485a4ff2 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -209,6 +209,8 @@ def create_train_batches(iterator, batcher, max_epochs: int): def create_evaluation_callback( nlp: Language, dev_corpus: Callable, weights: Dict[str, float] ) -> Callable[[], Tuple[float, Dict[str, float]]]: + weights = {key: value for key, value in weights.items() if value is not None} + def evaluate() -> Tuple[float, Dict[str, float]]: dev_examples = list(dev_corpus(nlp)) scores = nlp.evaluate(dev_examples) @@ -368,7 +370,8 @@ def update_meta( ) -> None: nlp.meta["performance"] = {} for metric in training["score_weights"]: - nlp.meta["performance"][metric] = info["other_scores"].get(metric, 0.0) + if metric is not None: + nlp.meta["performance"][metric] = info["other_scores"].get(metric, 0.0) for pipe_name in nlp.pipe_names: nlp.meta["performance"][f"{pipe_name}_loss"] = info["losses"][pipe_name] diff --git a/spacy/lang/bn/__init__.py b/spacy/lang/bn/__init__.py index 270185a4b..923e29a17 100644 --- a/spacy/lang/bn/__init__.py +++ b/spacy/lang/bn/__init__.py @@ -25,7 +25,6 @@ class Bengali(Language): "lemmatizer", assigns=["token.lemma"], default_config={"model": None, "mode": "rule", "lookups": None}, - scores=["lemma_acc"], default_score_weights={"lemma_acc": 1.0}, ) def make_lemmatizer( diff --git a/spacy/lang/el/__init__.py b/spacy/lang/el/__init__.py index 0c5e0672b..1a7b19914 100644 --- a/spacy/lang/el/__init__.py +++ b/spacy/lang/el/__init__.py @@ -30,7 +30,6 @@ class Greek(Language): "lemmatizer", assigns=["token.lemma"], default_config={"model": None, "mode": "rule", "lookups": None}, - scores=["lemma_acc"], default_score_weights={"lemma_acc": 1.0}, ) def make_lemmatizer( diff --git a/spacy/lang/en/__init__.py b/spacy/lang/en/__init__.py index 1a595b6e7..bf7e9987f 100644 --- a/spacy/lang/en/__init__.py +++ b/spacy/lang/en/__init__.py @@ -29,7 +29,6 @@ class English(Language): "lemmatizer", assigns=["token.lemma"], default_config={"model": None, "mode": "rule", "lookups": None}, - scores=["lemma_acc"], default_score_weights={"lemma_acc": 1.0}, ) def make_lemmatizer( diff --git a/spacy/lang/fa/__init__.py b/spacy/lang/fa/__init__.py index 244534120..f3a6635dc 100644 --- a/spacy/lang/fa/__init__.py +++ b/spacy/lang/fa/__init__.py @@ -28,7 +28,6 @@ class Persian(Language): "lemmatizer", assigns=["token.lemma"], default_config={"model": None, "mode": "rule", "lookups": None}, - scores=["lemma_acc"], default_score_weights={"lemma_acc": 1.0}, ) def make_lemmatizer( diff --git a/spacy/lang/fr/__init__.py b/spacy/lang/fr/__init__.py index 42241cd8a..72e641d1f 100644 --- a/spacy/lang/fr/__init__.py +++ b/spacy/lang/fr/__init__.py @@ -33,7 +33,6 @@ class French(Language): "lemmatizer", assigns=["token.lemma"], default_config={"model": None, "mode": "rule", "lookups": None}, - scores=["lemma_acc"], default_score_weights={"lemma_acc": 1.0}, ) def make_lemmatizer( diff --git a/spacy/lang/nb/__init__.py b/spacy/lang/nb/__init__.py index 28a2f0bf2..9672dfd6e 100644 --- a/spacy/lang/nb/__init__.py +++ b/spacy/lang/nb/__init__.py @@ -28,7 +28,6 @@ class Norwegian(Language): "lemmatizer", assigns=["token.lemma"], default_config={"model": None, "mode": "rule", "lookups": None}, - scores=["lemma_acc"], default_score_weights={"lemma_acc": 1.0}, ) def make_lemmatizer( diff --git a/spacy/lang/nl/__init__.py b/spacy/lang/nl/__init__.py index 1526e41f5..15b6b9de2 100644 --- a/spacy/lang/nl/__init__.py +++ b/spacy/lang/nl/__init__.py @@ -30,7 +30,6 @@ class Dutch(Language): "lemmatizer", assigns=["token.lemma"], default_config={"model": None, "mode": "rule", "lookups": None}, - scores=["lemma_acc"], default_score_weights={"lemma_acc": 1.0}, ) def make_lemmatizer( diff --git a/spacy/lang/pl/__init__.py b/spacy/lang/pl/__init__.py index 7ddad9893..573dbc6f9 100644 --- a/spacy/lang/pl/__init__.py +++ b/spacy/lang/pl/__init__.py @@ -35,7 +35,6 @@ class Polish(Language): "lemmatizer", assigns=["token.lemma"], default_config={"model": None, "mode": "pos_lookup", "lookups": None}, - scores=["lemma_acc"], default_score_weights={"lemma_acc": 1.0}, ) def make_lemmatizer( diff --git a/spacy/lang/ru/__init__.py b/spacy/lang/ru/__init__.py index be770e3ec..4a296dd23 100644 --- a/spacy/lang/ru/__init__.py +++ b/spacy/lang/ru/__init__.py @@ -25,7 +25,6 @@ class Russian(Language): "lemmatizer", assigns=["token.lemma"], default_config={"model": None, "mode": "pymorphy2", "lookups": None}, - scores=["lemma_acc"], default_score_weights={"lemma_acc": 1.0}, ) def make_lemmatizer( diff --git a/spacy/lang/sv/__init__.py b/spacy/lang/sv/__init__.py index 6db74cd39..ea314f487 100644 --- a/spacy/lang/sv/__init__.py +++ b/spacy/lang/sv/__init__.py @@ -31,7 +31,6 @@ class Swedish(Language): "lemmatizer", assigns=["token.lemma"], default_config={"model": None, "mode": "rule", "lookups": None}, - scores=["lemma_acc"], default_score_weights={"lemma_acc": 1.0}, ) def make_lemmatizer( diff --git a/spacy/lang/uk/__init__.py b/spacy/lang/uk/__init__.py index e9936cf7d..006a1cf7f 100644 --- a/spacy/lang/uk/__init__.py +++ b/spacy/lang/uk/__init__.py @@ -25,7 +25,6 @@ class Ukrainian(Language): "lemmatizer", assigns=["token.lemma"], default_config={"model": None, "mode": "pymorphy2", "lookups": None}, - scores=["lemma_acc"], default_score_weights={"lemma_acc": 1.0}, ) def make_lemmatizer( diff --git a/spacy/language.py b/spacy/language.py index 4dffd9679..0b7deacad 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -248,9 +248,15 @@ class Language: self._config["nlp"]["pipeline"] = list(self.component_names) self._config["nlp"]["disabled"] = list(self.disabled) self._config["components"] = pipeline - if not self._config["training"].get("score_weights"): - combined_score_weights = combine_score_weights(score_weights) - self._config["training"]["score_weights"] = combined_score_weights + # We're merging the existing score weights back into the combined + # weights to make sure we're preserving custom settings in the config + # but also reflect updates (e.g. new components added) + prev_score_weights = self._config["training"].get("score_weights", {}) + combined_score_weights = combine_score_weights(score_weights) + combined_score_weights.update(prev_score_weights) + # Combine the scores a second time to normalize them + combined_score_weights = combine_score_weights([combined_score_weights]) + self._config["training"]["score_weights"] = combined_score_weights if not srsly.is_json_serializable(self._config): raise ValueError(Errors.E961.format(config=self._config)) return self._config @@ -412,7 +418,6 @@ class Language: assigns: Iterable[str] = SimpleFrozenList(), requires: Iterable[str] = SimpleFrozenList(), retokenizes: bool = False, - scores: Iterable[str] = SimpleFrozenList(), default_score_weights: Dict[str, float] = SimpleFrozenDict(), func: Optional[Callable] = None, ) -> Callable: @@ -430,12 +435,11 @@ class Language: e.g. "token.ent_id". Used for pipeline analyis. retokenizes (bool): Whether the component changes the tokenization. Used for pipeline analysis. - scores (Iterable[str]): All scores set by the component if it's trainable, - e.g. ["ents_f", "ents_r", "ents_p"]. default_score_weights (Dict[str, float]): The scores to report during training, and their default weight towards the final score used to select the best model. Weights should sum to 1.0 per component and - will be combined and normalized for the whole pipeline. + will be combined and normalized for the whole pipeline. If None, + the score won't be shown in the logs or be weighted. func (Optional[Callable]): Factory function if not used as a decorator. DOCS: https://nightly.spacy.io/api/language#factory @@ -475,7 +479,7 @@ class Language: default_config=default_config, assigns=validate_attrs(assigns), requires=validate_attrs(requires), - scores=scores, + scores=list(default_score_weights.keys()), default_score_weights=default_score_weights, retokenizes=retokenizes, ) diff --git a/spacy/pipeline/dep_parser.pyx b/spacy/pipeline/dep_parser.pyx index a49475c8e..a447434d2 100644 --- a/spacy/pipeline/dep_parser.pyx +++ b/spacy/pipeline/dep_parser.pyx @@ -43,8 +43,14 @@ DEFAULT_PARSER_MODEL = Config().from_str(default_model_config)["model"] "min_action_freq": 30, "model": DEFAULT_PARSER_MODEL, }, - scores=["dep_uas", "dep_las", "dep_las_per_type", "sents_p", "sents_r", "sents_f"], - default_score_weights={"dep_uas": 0.5, "dep_las": 0.5, "sents_f": 0.0}, + default_score_weights={ + "dep_uas": 0.5, + "dep_las": 0.5, + "dep_las_per_type": None, + "sents_p": None, + "sents_r": None, + "sents_f": 0.0, + }, ) def make_parser( nlp: Language, diff --git a/spacy/pipeline/entityruler.py b/spacy/pipeline/entityruler.py index 24bbb067f..9166a69b8 100644 --- a/spacy/pipeline/entityruler.py +++ b/spacy/pipeline/entityruler.py @@ -25,8 +25,12 @@ PatternType = Dict[str, Union[str, List[Dict[str, Any]]]] "overwrite_ents": False, "ent_id_sep": DEFAULT_ENT_ID_SEP, }, - scores=["ents_p", "ents_r", "ents_f", "ents_per_type"], - default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0}, + default_score_weights={ + "ents_f": 1.0, + "ents_p": 0.0, + "ents_r": 0.0, + "ents_per_type": None, + }, ) def make_entity_ruler( nlp: Language, diff --git a/spacy/pipeline/lemmatizer.py b/spacy/pipeline/lemmatizer.py index 0fd3482c4..c30d09f62 100644 --- a/spacy/pipeline/lemmatizer.py +++ b/spacy/pipeline/lemmatizer.py @@ -21,7 +21,6 @@ from .. import util "lookups": None, "overwrite": False, }, - scores=["lemma_acc"], default_score_weights={"lemma_acc": 1.0}, ) def make_lemmatizer( diff --git a/spacy/pipeline/morphologizer.pyx b/spacy/pipeline/morphologizer.pyx index 62ad9e0eb..5fee9a900 100644 --- a/spacy/pipeline/morphologizer.pyx +++ b/spacy/pipeline/morphologizer.pyx @@ -49,8 +49,7 @@ DEFAULT_MORPH_MODEL = Config().from_str(default_model_config)["model"] "morphologizer", assigns=["token.morph", "token.pos"], default_config={"model": DEFAULT_MORPH_MODEL}, - scores=["pos_acc", "morph_acc", "morph_per_feat"], - default_score_weights={"pos_acc": 0.5, "morph_acc": 0.5}, + default_score_weights={"pos_acc": 0.5, "morph_acc": 0.5, "morph_per_feat": None}, ) def make_morphologizer( nlp: Language, diff --git a/spacy/pipeline/ner.pyx b/spacy/pipeline/ner.pyx index fc4f03473..c9b0a5031 100644 --- a/spacy/pipeline/ner.pyx +++ b/spacy/pipeline/ner.pyx @@ -39,8 +39,7 @@ DEFAULT_NER_MODEL = Config().from_str(default_model_config)["model"] "update_with_oracle_cut_size": 100, "model": DEFAULT_NER_MODEL, }, - scores=["ents_p", "ents_r", "ents_f", "ents_per_type"], - default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0}, + default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0, "ents_per_type": None}, ) def make_ner( diff --git a/spacy/pipeline/sentencizer.pyx b/spacy/pipeline/sentencizer.pyx index 5700c2b98..2882f6f8b 100644 --- a/spacy/pipeline/sentencizer.pyx +++ b/spacy/pipeline/sentencizer.pyx @@ -15,7 +15,6 @@ from .. import util "sentencizer", assigns=["token.is_sent_start", "doc.sents"], default_config={"punct_chars": None}, - scores=["sents_p", "sents_r", "sents_f"], default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0}, ) def make_sentencizer( diff --git a/spacy/pipeline/senter.pyx b/spacy/pipeline/senter.pyx index a7eb721fd..da85a9cf2 100644 --- a/spacy/pipeline/senter.pyx +++ b/spacy/pipeline/senter.pyx @@ -36,7 +36,6 @@ DEFAULT_SENTER_MODEL = Config().from_str(default_model_config)["model"] "senter", assigns=["token.is_sent_start"], default_config={"model": DEFAULT_SENTER_MODEL}, - scores=["sents_p", "sents_r", "sents_f"], default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0}, ) def make_senter(nlp: Language, name: str, model: Model): diff --git a/spacy/pipeline/tagger.pyx b/spacy/pipeline/tagger.pyx index 0d78047ae..3efe29916 100644 --- a/spacy/pipeline/tagger.pyx +++ b/spacy/pipeline/tagger.pyx @@ -42,7 +42,6 @@ DEFAULT_TAGGER_MODEL = Config().from_str(default_model_config)["model"] "tagger", assigns=["token.tag"], default_config={"model": DEFAULT_TAGGER_MODEL}, - scores=["tag_acc"], default_score_weights={"tag_acc": 1.0}, ) def make_tagger(nlp: Language, name: str, model: Model): diff --git a/spacy/pipeline/textcat.py b/spacy/pipeline/textcat.py index e7cb62a0d..6b8c0ca65 100644 --- a/spacy/pipeline/textcat.py +++ b/spacy/pipeline/textcat.py @@ -62,18 +62,17 @@ subword_features = true "positive_label": None, "model": DEFAULT_TEXTCAT_MODEL, }, - scores=[ - "cats_score", - "cats_score_desc", - "cats_p", - "cats_r", - "cats_f", - "cats_macro_f", - "cats_macro_auc", - "cats_f_per_type", - "cats_macro_auc_per_type", - ], - default_score_weights={"cats_score": 1.0}, + default_score_weights={ + "cats_score": 1.0, + "cats_score_desc": None, + "cats_p": None, + "cats_r": None, + "cats_f": None, + "cats_macro_f": None, + "cats_macro_auc": None, + "cats_f_per_type": None, + "cats_macro_auc_per_type": None, + }, ) def make_textcat( nlp: Language, diff --git a/spacy/schemas.py b/spacy/schemas.py index b0f26dcd7..e34841008 100644 --- a/spacy/schemas.py +++ b/spacy/schemas.py @@ -211,7 +211,7 @@ class ConfigSchemaTraining(BaseModel): seed: Optional[StrictInt] = Field(..., title="Random seed") gpu_allocator: Optional[StrictStr] = Field(..., title="Memory allocator when running on GPU") accumulate_gradient: StrictInt = Field(..., title="Whether to divide the batch up into substeps") - score_weights: Dict[StrictStr, Union[StrictFloat, StrictInt]] = Field(..., title="Scores to report and their weights for selecting final model") + score_weights: Dict[StrictStr, Optional[Union[StrictFloat, StrictInt]]] = Field(..., title="Scores to report and their weights for selecting final model") init_tok2vec: Optional[StrictStr] = Field(..., title="Path to pretrained tok2vec weights") raw_text: Optional[StrictStr] = Field(default=None, title="Raw text") optimizer: Optimizer = Field(..., title="The optimizer to use") diff --git a/spacy/tests/pipeline/test_pipe_factories.py b/spacy/tests/pipeline/test_pipe_factories.py index 881460704..4ab1c4248 100644 --- a/spacy/tests/pipeline/test_pipe_factories.py +++ b/spacy/tests/pipeline/test_pipe_factories.py @@ -359,12 +359,8 @@ def test_language_factories_scores(): func = lambda nlp, name: lambda doc: doc weights1 = {"a1": 0.5, "a2": 0.5} weights2 = {"b1": 0.2, "b2": 0.7, "b3": 0.1} - Language.factory( - f"{name}1", scores=list(weights1), default_score_weights=weights1, func=func, - ) - Language.factory( - f"{name}2", scores=list(weights2), default_score_weights=weights2, func=func, - ) + Language.factory(f"{name}1", default_score_weights=weights1, func=func) + Language.factory(f"{name}2", default_score_weights=weights2, func=func) meta1 = Language.get_factory_meta(f"{name}1") assert meta1.default_score_weights == weights1 meta2 = Language.get_factory_meta(f"{name}2") @@ -376,6 +372,21 @@ def test_language_factories_scores(): cfg = nlp.config["training"] expected_weights = {"a1": 0.25, "a2": 0.25, "b1": 0.1, "b2": 0.35, "b3": 0.05} assert cfg["score_weights"] == expected_weights + # Test with custom defaults + config = nlp.config.copy() + config["training"]["score_weights"]["a1"] = 0.0 + config["training"]["score_weights"]["b3"] = 1.0 + nlp = English.from_config(config) + score_weights = nlp.config["training"]["score_weights"] + expected = {"a1": 0.0, "a2": 0.15, "b1": 0.06, "b2": 0.21, "b3": 0.59} + assert score_weights == expected + # Test with null values + config = nlp.config.copy() + config["training"]["score_weights"]["a1"] = None + nlp = English.from_config(config) + score_weights = nlp.config["training"]["score_weights"] + expected = {"a1": None, "a2": 0.15, "b1": 0.06, "b2": 0.21, "b3": 0.58} # rounding :( + assert score_weights == expected def test_pipe_factories_from_source(): diff --git a/spacy/util.py b/spacy/util.py index 025fe5288..f7c5cff59 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -1209,8 +1209,19 @@ def combine_score_weights(weights: List[Dict[str, float]]) -> Dict[str, float]: weights (List[dict]): The weights defined by the components. RETURNS (Dict[str, float]): The combined and normalized weights. """ + # We first need to extract all None/null values for score weights that + # shouldn't be shown in the table *or* be weighted result = {} + all_weights = [] for w_dict in weights: + filtered_weights = {} + for key, value in w_dict.items(): + if value is None: + result[key] = None + else: + filtered_weights[key] = value + all_weights.append(filtered_weights) + for w_dict in all_weights: # We need to account for weights that don't sum to 1.0 and normalize # the score weights accordingly, then divide score by the number of # components. diff --git a/website/docs/api/language.md b/website/docs/api/language.md index a7b9c0d88..dd3cc57dd 100644 --- a/website/docs/api/language.md +++ b/website/docs/api/language.md @@ -145,17 +145,16 @@ examples, see the > ) > ``` -| Name | Description | -| ----------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| `name` | The name of the component factory. ~~str~~ | -| _keyword-only_ | | -| `default_config` | The default config, describing the default values of the factory arguments. ~~Dict[str, Any]~~ | -| `assigns` | `Doc` or `Token` attributes assigned by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~ | -| `requires` | `Doc` or `Token` attributes required by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~ | -| `retokenizes` | Whether the component changes tokenization. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~bool~~ | -| `scores` | All scores set by the components if it's trainable, e.g. `["ents_f", "ents_r", "ents_p"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~ | -| `default_score_weights` | The scores to report during training, and their default weight towards the final score used to select the best model. Weights should sum to `1.0` per component and will be combined and normalized for the whole pipeline. ~~Dict[str, float]~~ | -| `func` | Optional function if not used a a decorator. ~~Optional[Callable[[...], Callable[[Doc], Doc]]]~~ | +| Name | Description | +| ----------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `name` | The name of the component factory. ~~str~~ | +| _keyword-only_ | | +| `default_config` | The default config, describing the default values of the factory arguments. ~~Dict[str, Any]~~ | +| `assigns` | `Doc` or `Token` attributes assigned by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~ | +| `requires` | `Doc` or `Token` attributes required by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~ | +| `retokenizes` | Whether the component changes tokenization. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~bool~~ | +| `default_score_weights` | The scores to report during training, and their default weight towards the final score used to select the best model. Weights should sum to `1.0` per component and will be combined and normalized for the whole pipeline. If a weight is set to `None`, the score will not be logged or weighted. ~~Dict[str, Optional[float]]~~ | +| `func` | Optional function if not used a a decorator. ~~Optional[Callable[[...], Callable[[Doc], Doc]]]~~ | ## Language.\_\_call\_\_ {#call tag="method"} @@ -1036,12 +1035,12 @@ provided by the [`@Language.component`](/api/language#component) or component is defined and stored on the `Language` class for each component instance and factory instance. -| Name | Description | -| ----------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| `factory` | The name of the registered component factory. ~~str~~ | -| `default_config` | The default config, describing the default values of the factory arguments. ~~Dict[str, Any]~~ | -| `assigns` | `Doc` or `Token` attributes assigned by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~ | -| `requires` | `Doc` or `Token` attributes required by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~  | -| `retokenizes` | Whether the component changes tokenization. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~bool~~  | -| `scores` | All scores set by the components if it's trainable, e.g. `["ents_f", "ents_r", "ents_p"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~ | -| `default_score_weights` | The scores to report during training, and their default weight towards the final score used to select the best model. Weights should sum to `1.0` per component and will be combined and normalized for the whole pipeline. ~~Dict[str, float]~~ | +| Name | Description | +| ----------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `factory` | The name of the registered component factory. ~~str~~ | +| `default_config` | The default config, describing the default values of the factory arguments. ~~Dict[str, Any]~~ | +| `assigns` | `Doc` or `Token` attributes assigned by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~ | +| `requires` | `Doc` or `Token` attributes required by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~  | +| `retokenizes` | Whether the component changes tokenization. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~bool~~  | +| `default_score_weights` | The scores to report during training, and their default weight towards the final score used to select the best model. Weights should sum to `1.0` per component and will be combined and normalized for the whole pipeline. If a weight is set to `None`, the score will not be logged or weighted. ~~Dict[str, Optional[float]]~~ | +| `scores` | All scores set by the components if it's trainable, e.g. `["ents_f", "ents_r", "ents_p"]`. Based on the `default_score_weights` and used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~ | diff --git a/website/docs/usage/training.md b/website/docs/usage/training.md index b63145636..65afd0eb4 100644 --- a/website/docs/usage/training.md +++ b/website/docs/usage/training.md @@ -470,6 +470,7 @@ score. ```ini [training.score_weights] dep_las = 0.4 +dep_uas = null ents_f = 0.4 tag_acc = 0.2 token_acc = 0.0 @@ -481,9 +482,9 @@ you generate a config for a given pipeline, the score weights are generated by combining and normalizing the default score weights of the pipeline components. The default score weights are defined by each pipeline component via the `default_score_weights` setting on the -[`@Language.component`](/api/language#component) or -[`@Language.factory`](/api/language#factory). By default, all pipeline -components are weighted equally. +[`@Language.factory`](/api/language#factory) decorator. By default, all pipeline +components are weighted equally. If a score weight is set to `null`, it will be +excluded from the logs and the score won't be weighted.