Fix handling of score_weights

This commit is contained in:
Ines Montani 2020-09-24 10:27:33 +02:00
parent e2ffe51fb5
commit ae51f580c1
28 changed files with 95 additions and 92 deletions

View File

@ -317,21 +317,3 @@ start = 100
stop = 1000 stop = 1000
compound = 1.001 compound = 1.001
{% endif %} {% endif %}
[training.score_weights]
{%- if "tagger" in components %}
tag_acc = {{ (1.0 / components|length)|round(2) }}
{%- endif -%}
{%- if "parser" in components %}
dep_uas = 0.0
dep_las = {{ (1.0 / components|length)|round(2) }}
sents_f = 0.0
{%- endif %}
{%- if "ner" in components %}
ents_f = {{ (1.0 / components|length)|round(2) }}
ents_p = 0.0
ents_r = 0.0
{%- endif %}
{%- if "textcat" in components %}
cats_score = {{ (1.0 / components|length)|round(2) }}
{%- endif -%}

View File

@ -209,6 +209,8 @@ def create_train_batches(iterator, batcher, max_epochs: int):
def create_evaluation_callback( def create_evaluation_callback(
nlp: Language, dev_corpus: Callable, weights: Dict[str, float] nlp: Language, dev_corpus: Callable, weights: Dict[str, float]
) -> Callable[[], Tuple[float, Dict[str, float]]]: ) -> Callable[[], Tuple[float, Dict[str, float]]]:
weights = {key: value for key, value in weights.items() if value is not None}
def evaluate() -> Tuple[float, Dict[str, float]]: def evaluate() -> Tuple[float, Dict[str, float]]:
dev_examples = list(dev_corpus(nlp)) dev_examples = list(dev_corpus(nlp))
scores = nlp.evaluate(dev_examples) scores = nlp.evaluate(dev_examples)
@ -368,7 +370,8 @@ def update_meta(
) -> None: ) -> None:
nlp.meta["performance"] = {} nlp.meta["performance"] = {}
for metric in training["score_weights"]: for metric in training["score_weights"]:
nlp.meta["performance"][metric] = info["other_scores"].get(metric, 0.0) if metric is not None:
nlp.meta["performance"][metric] = info["other_scores"].get(metric, 0.0)
for pipe_name in nlp.pipe_names: for pipe_name in nlp.pipe_names:
nlp.meta["performance"][f"{pipe_name}_loss"] = info["losses"][pipe_name] nlp.meta["performance"][f"{pipe_name}_loss"] = info["losses"][pipe_name]

View File

@ -25,7 +25,6 @@ class Bengali(Language):
"lemmatizer", "lemmatizer",
assigns=["token.lemma"], assigns=["token.lemma"],
default_config={"model": None, "mode": "rule", "lookups": None}, default_config={"model": None, "mode": "rule", "lookups": None},
scores=["lemma_acc"],
default_score_weights={"lemma_acc": 1.0}, default_score_weights={"lemma_acc": 1.0},
) )
def make_lemmatizer( def make_lemmatizer(

View File

@ -30,7 +30,6 @@ class Greek(Language):
"lemmatizer", "lemmatizer",
assigns=["token.lemma"], assigns=["token.lemma"],
default_config={"model": None, "mode": "rule", "lookups": None}, default_config={"model": None, "mode": "rule", "lookups": None},
scores=["lemma_acc"],
default_score_weights={"lemma_acc": 1.0}, default_score_weights={"lemma_acc": 1.0},
) )
def make_lemmatizer( def make_lemmatizer(

View File

@ -29,7 +29,6 @@ class English(Language):
"lemmatizer", "lemmatizer",
assigns=["token.lemma"], assigns=["token.lemma"],
default_config={"model": None, "mode": "rule", "lookups": None}, default_config={"model": None, "mode": "rule", "lookups": None},
scores=["lemma_acc"],
default_score_weights={"lemma_acc": 1.0}, default_score_weights={"lemma_acc": 1.0},
) )
def make_lemmatizer( def make_lemmatizer(

View File

@ -28,7 +28,6 @@ class Persian(Language):
"lemmatizer", "lemmatizer",
assigns=["token.lemma"], assigns=["token.lemma"],
default_config={"model": None, "mode": "rule", "lookups": None}, default_config={"model": None, "mode": "rule", "lookups": None},
scores=["lemma_acc"],
default_score_weights={"lemma_acc": 1.0}, default_score_weights={"lemma_acc": 1.0},
) )
def make_lemmatizer( def make_lemmatizer(

View File

@ -33,7 +33,6 @@ class French(Language):
"lemmatizer", "lemmatizer",
assigns=["token.lemma"], assigns=["token.lemma"],
default_config={"model": None, "mode": "rule", "lookups": None}, default_config={"model": None, "mode": "rule", "lookups": None},
scores=["lemma_acc"],
default_score_weights={"lemma_acc": 1.0}, default_score_weights={"lemma_acc": 1.0},
) )
def make_lemmatizer( def make_lemmatizer(

View File

@ -28,7 +28,6 @@ class Norwegian(Language):
"lemmatizer", "lemmatizer",
assigns=["token.lemma"], assigns=["token.lemma"],
default_config={"model": None, "mode": "rule", "lookups": None}, default_config={"model": None, "mode": "rule", "lookups": None},
scores=["lemma_acc"],
default_score_weights={"lemma_acc": 1.0}, default_score_weights={"lemma_acc": 1.0},
) )
def make_lemmatizer( def make_lemmatizer(

View File

@ -30,7 +30,6 @@ class Dutch(Language):
"lemmatizer", "lemmatizer",
assigns=["token.lemma"], assigns=["token.lemma"],
default_config={"model": None, "mode": "rule", "lookups": None}, default_config={"model": None, "mode": "rule", "lookups": None},
scores=["lemma_acc"],
default_score_weights={"lemma_acc": 1.0}, default_score_weights={"lemma_acc": 1.0},
) )
def make_lemmatizer( def make_lemmatizer(

View File

@ -35,7 +35,6 @@ class Polish(Language):
"lemmatizer", "lemmatizer",
assigns=["token.lemma"], assigns=["token.lemma"],
default_config={"model": None, "mode": "pos_lookup", "lookups": None}, default_config={"model": None, "mode": "pos_lookup", "lookups": None},
scores=["lemma_acc"],
default_score_weights={"lemma_acc": 1.0}, default_score_weights={"lemma_acc": 1.0},
) )
def make_lemmatizer( def make_lemmatizer(

View File

@ -25,7 +25,6 @@ class Russian(Language):
"lemmatizer", "lemmatizer",
assigns=["token.lemma"], assigns=["token.lemma"],
default_config={"model": None, "mode": "pymorphy2", "lookups": None}, default_config={"model": None, "mode": "pymorphy2", "lookups": None},
scores=["lemma_acc"],
default_score_weights={"lemma_acc": 1.0}, default_score_weights={"lemma_acc": 1.0},
) )
def make_lemmatizer( def make_lemmatizer(

View File

@ -31,7 +31,6 @@ class Swedish(Language):
"lemmatizer", "lemmatizer",
assigns=["token.lemma"], assigns=["token.lemma"],
default_config={"model": None, "mode": "rule", "lookups": None}, default_config={"model": None, "mode": "rule", "lookups": None},
scores=["lemma_acc"],
default_score_weights={"lemma_acc": 1.0}, default_score_weights={"lemma_acc": 1.0},
) )
def make_lemmatizer( def make_lemmatizer(

View File

@ -25,7 +25,6 @@ class Ukrainian(Language):
"lemmatizer", "lemmatizer",
assigns=["token.lemma"], assigns=["token.lemma"],
default_config={"model": None, "mode": "pymorphy2", "lookups": None}, default_config={"model": None, "mode": "pymorphy2", "lookups": None},
scores=["lemma_acc"],
default_score_weights={"lemma_acc": 1.0}, default_score_weights={"lemma_acc": 1.0},
) )
def make_lemmatizer( def make_lemmatizer(

View File

@ -248,9 +248,15 @@ class Language:
self._config["nlp"]["pipeline"] = list(self.component_names) self._config["nlp"]["pipeline"] = list(self.component_names)
self._config["nlp"]["disabled"] = list(self.disabled) self._config["nlp"]["disabled"] = list(self.disabled)
self._config["components"] = pipeline self._config["components"] = pipeline
if not self._config["training"].get("score_weights"): # We're merging the existing score weights back into the combined
combined_score_weights = combine_score_weights(score_weights) # weights to make sure we're preserving custom settings in the config
self._config["training"]["score_weights"] = combined_score_weights # but also reflect updates (e.g. new components added)
prev_score_weights = self._config["training"].get("score_weights", {})
combined_score_weights = combine_score_weights(score_weights)
combined_score_weights.update(prev_score_weights)
# Combine the scores a second time to normalize them
combined_score_weights = combine_score_weights([combined_score_weights])
self._config["training"]["score_weights"] = combined_score_weights
if not srsly.is_json_serializable(self._config): if not srsly.is_json_serializable(self._config):
raise ValueError(Errors.E961.format(config=self._config)) raise ValueError(Errors.E961.format(config=self._config))
return self._config return self._config
@ -412,7 +418,6 @@ class Language:
assigns: Iterable[str] = SimpleFrozenList(), assigns: Iterable[str] = SimpleFrozenList(),
requires: Iterable[str] = SimpleFrozenList(), requires: Iterable[str] = SimpleFrozenList(),
retokenizes: bool = False, retokenizes: bool = False,
scores: Iterable[str] = SimpleFrozenList(),
default_score_weights: Dict[str, float] = SimpleFrozenDict(), default_score_weights: Dict[str, float] = SimpleFrozenDict(),
func: Optional[Callable] = None, func: Optional[Callable] = None,
) -> Callable: ) -> Callable:
@ -430,12 +435,11 @@ class Language:
e.g. "token.ent_id". Used for pipeline analyis. e.g. "token.ent_id". Used for pipeline analyis.
retokenizes (bool): Whether the component changes the tokenization. retokenizes (bool): Whether the component changes the tokenization.
Used for pipeline analysis. Used for pipeline analysis.
scores (Iterable[str]): All scores set by the component if it's trainable,
e.g. ["ents_f", "ents_r", "ents_p"].
default_score_weights (Dict[str, float]): The scores to report during default_score_weights (Dict[str, float]): The scores to report during
training, and their default weight towards the final score used to training, and their default weight towards the final score used to
select the best model. Weights should sum to 1.0 per component and select the best model. Weights should sum to 1.0 per component and
will be combined and normalized for the whole pipeline. will be combined and normalized for the whole pipeline. If None,
the score won't be shown in the logs or be weighted.
func (Optional[Callable]): Factory function if not used as a decorator. func (Optional[Callable]): Factory function if not used as a decorator.
DOCS: https://nightly.spacy.io/api/language#factory DOCS: https://nightly.spacy.io/api/language#factory
@ -475,7 +479,7 @@ class Language:
default_config=default_config, default_config=default_config,
assigns=validate_attrs(assigns), assigns=validate_attrs(assigns),
requires=validate_attrs(requires), requires=validate_attrs(requires),
scores=scores, scores=list(default_score_weights.keys()),
default_score_weights=default_score_weights, default_score_weights=default_score_weights,
retokenizes=retokenizes, retokenizes=retokenizes,
) )

View File

@ -43,8 +43,14 @@ DEFAULT_PARSER_MODEL = Config().from_str(default_model_config)["model"]
"min_action_freq": 30, "min_action_freq": 30,
"model": DEFAULT_PARSER_MODEL, "model": DEFAULT_PARSER_MODEL,
}, },
scores=["dep_uas", "dep_las", "dep_las_per_type", "sents_p", "sents_r", "sents_f"], default_score_weights={
default_score_weights={"dep_uas": 0.5, "dep_las": 0.5, "sents_f": 0.0}, "dep_uas": 0.5,
"dep_las": 0.5,
"dep_las_per_type": None,
"sents_p": None,
"sents_r": None,
"sents_f": 0.0,
},
) )
def make_parser( def make_parser(
nlp: Language, nlp: Language,

View File

@ -25,8 +25,12 @@ PatternType = Dict[str, Union[str, List[Dict[str, Any]]]]
"overwrite_ents": False, "overwrite_ents": False,
"ent_id_sep": DEFAULT_ENT_ID_SEP, "ent_id_sep": DEFAULT_ENT_ID_SEP,
}, },
scores=["ents_p", "ents_r", "ents_f", "ents_per_type"], default_score_weights={
default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0}, "ents_f": 1.0,
"ents_p": 0.0,
"ents_r": 0.0,
"ents_per_type": None,
},
) )
def make_entity_ruler( def make_entity_ruler(
nlp: Language, nlp: Language,

View File

@ -21,7 +21,6 @@ from .. import util
"lookups": None, "lookups": None,
"overwrite": False, "overwrite": False,
}, },
scores=["lemma_acc"],
default_score_weights={"lemma_acc": 1.0}, default_score_weights={"lemma_acc": 1.0},
) )
def make_lemmatizer( def make_lemmatizer(

View File

@ -49,8 +49,7 @@ DEFAULT_MORPH_MODEL = Config().from_str(default_model_config)["model"]
"morphologizer", "morphologizer",
assigns=["token.morph", "token.pos"], assigns=["token.morph", "token.pos"],
default_config={"model": DEFAULT_MORPH_MODEL}, default_config={"model": DEFAULT_MORPH_MODEL},
scores=["pos_acc", "morph_acc", "morph_per_feat"], default_score_weights={"pos_acc": 0.5, "morph_acc": 0.5, "morph_per_feat": None},
default_score_weights={"pos_acc": 0.5, "morph_acc": 0.5},
) )
def make_morphologizer( def make_morphologizer(
nlp: Language, nlp: Language,

View File

@ -39,8 +39,7 @@ DEFAULT_NER_MODEL = Config().from_str(default_model_config)["model"]
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
"model": DEFAULT_NER_MODEL, "model": DEFAULT_NER_MODEL,
}, },
scores=["ents_p", "ents_r", "ents_f", "ents_per_type"], default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0, "ents_per_type": None},
default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0},
) )
def make_ner( def make_ner(

View File

@ -15,7 +15,6 @@ from .. import util
"sentencizer", "sentencizer",
assigns=["token.is_sent_start", "doc.sents"], assigns=["token.is_sent_start", "doc.sents"],
default_config={"punct_chars": None}, default_config={"punct_chars": None},
scores=["sents_p", "sents_r", "sents_f"],
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0}, default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
) )
def make_sentencizer( def make_sentencizer(

View File

@ -36,7 +36,6 @@ DEFAULT_SENTER_MODEL = Config().from_str(default_model_config)["model"]
"senter", "senter",
assigns=["token.is_sent_start"], assigns=["token.is_sent_start"],
default_config={"model": DEFAULT_SENTER_MODEL}, default_config={"model": DEFAULT_SENTER_MODEL},
scores=["sents_p", "sents_r", "sents_f"],
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0}, default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
) )
def make_senter(nlp: Language, name: str, model: Model): def make_senter(nlp: Language, name: str, model: Model):

View File

@ -42,7 +42,6 @@ DEFAULT_TAGGER_MODEL = Config().from_str(default_model_config)["model"]
"tagger", "tagger",
assigns=["token.tag"], assigns=["token.tag"],
default_config={"model": DEFAULT_TAGGER_MODEL}, default_config={"model": DEFAULT_TAGGER_MODEL},
scores=["tag_acc"],
default_score_weights={"tag_acc": 1.0}, default_score_weights={"tag_acc": 1.0},
) )
def make_tagger(nlp: Language, name: str, model: Model): def make_tagger(nlp: Language, name: str, model: Model):

View File

@ -62,18 +62,17 @@ subword_features = true
"positive_label": None, "positive_label": None,
"model": DEFAULT_TEXTCAT_MODEL, "model": DEFAULT_TEXTCAT_MODEL,
}, },
scores=[ default_score_weights={
"cats_score", "cats_score": 1.0,
"cats_score_desc", "cats_score_desc": None,
"cats_p", "cats_p": None,
"cats_r", "cats_r": None,
"cats_f", "cats_f": None,
"cats_macro_f", "cats_macro_f": None,
"cats_macro_auc", "cats_macro_auc": None,
"cats_f_per_type", "cats_f_per_type": None,
"cats_macro_auc_per_type", "cats_macro_auc_per_type": None,
], },
default_score_weights={"cats_score": 1.0},
) )
def make_textcat( def make_textcat(
nlp: Language, nlp: Language,

View File

@ -211,7 +211,7 @@ class ConfigSchemaTraining(BaseModel):
seed: Optional[StrictInt] = Field(..., title="Random seed") seed: Optional[StrictInt] = Field(..., title="Random seed")
gpu_allocator: Optional[StrictStr] = Field(..., title="Memory allocator when running on GPU") gpu_allocator: Optional[StrictStr] = Field(..., title="Memory allocator when running on GPU")
accumulate_gradient: StrictInt = Field(..., title="Whether to divide the batch up into substeps") accumulate_gradient: StrictInt = Field(..., title="Whether to divide the batch up into substeps")
score_weights: Dict[StrictStr, Union[StrictFloat, StrictInt]] = Field(..., title="Scores to report and their weights for selecting final model") score_weights: Dict[StrictStr, Optional[Union[StrictFloat, StrictInt]]] = Field(..., title="Scores to report and their weights for selecting final model")
init_tok2vec: Optional[StrictStr] = Field(..., title="Path to pretrained tok2vec weights") init_tok2vec: Optional[StrictStr] = Field(..., title="Path to pretrained tok2vec weights")
raw_text: Optional[StrictStr] = Field(default=None, title="Raw text") raw_text: Optional[StrictStr] = Field(default=None, title="Raw text")
optimizer: Optimizer = Field(..., title="The optimizer to use") optimizer: Optimizer = Field(..., title="The optimizer to use")

View File

@ -359,12 +359,8 @@ def test_language_factories_scores():
func = lambda nlp, name: lambda doc: doc func = lambda nlp, name: lambda doc: doc
weights1 = {"a1": 0.5, "a2": 0.5} weights1 = {"a1": 0.5, "a2": 0.5}
weights2 = {"b1": 0.2, "b2": 0.7, "b3": 0.1} weights2 = {"b1": 0.2, "b2": 0.7, "b3": 0.1}
Language.factory( Language.factory(f"{name}1", default_score_weights=weights1, func=func)
f"{name}1", scores=list(weights1), default_score_weights=weights1, func=func, Language.factory(f"{name}2", default_score_weights=weights2, func=func)
)
Language.factory(
f"{name}2", scores=list(weights2), default_score_weights=weights2, func=func,
)
meta1 = Language.get_factory_meta(f"{name}1") meta1 = Language.get_factory_meta(f"{name}1")
assert meta1.default_score_weights == weights1 assert meta1.default_score_weights == weights1
meta2 = Language.get_factory_meta(f"{name}2") meta2 = Language.get_factory_meta(f"{name}2")
@ -376,6 +372,21 @@ def test_language_factories_scores():
cfg = nlp.config["training"] cfg = nlp.config["training"]
expected_weights = {"a1": 0.25, "a2": 0.25, "b1": 0.1, "b2": 0.35, "b3": 0.05} expected_weights = {"a1": 0.25, "a2": 0.25, "b1": 0.1, "b2": 0.35, "b3": 0.05}
assert cfg["score_weights"] == expected_weights assert cfg["score_weights"] == expected_weights
# Test with custom defaults
config = nlp.config.copy()
config["training"]["score_weights"]["a1"] = 0.0
config["training"]["score_weights"]["b3"] = 1.0
nlp = English.from_config(config)
score_weights = nlp.config["training"]["score_weights"]
expected = {"a1": 0.0, "a2": 0.15, "b1": 0.06, "b2": 0.21, "b3": 0.59}
assert score_weights == expected
# Test with null values
config = nlp.config.copy()
config["training"]["score_weights"]["a1"] = None
nlp = English.from_config(config)
score_weights = nlp.config["training"]["score_weights"]
expected = {"a1": None, "a2": 0.15, "b1": 0.06, "b2": 0.21, "b3": 0.58} # rounding :(
assert score_weights == expected
def test_pipe_factories_from_source(): def test_pipe_factories_from_source():

View File

@ -1209,8 +1209,19 @@ def combine_score_weights(weights: List[Dict[str, float]]) -> Dict[str, float]:
weights (List[dict]): The weights defined by the components. weights (List[dict]): The weights defined by the components.
RETURNS (Dict[str, float]): The combined and normalized weights. RETURNS (Dict[str, float]): The combined and normalized weights.
""" """
# We first need to extract all None/null values for score weights that
# shouldn't be shown in the table *or* be weighted
result = {} result = {}
all_weights = []
for w_dict in weights: for w_dict in weights:
filtered_weights = {}
for key, value in w_dict.items():
if value is None:
result[key] = None
else:
filtered_weights[key] = value
all_weights.append(filtered_weights)
for w_dict in all_weights:
# We need to account for weights that don't sum to 1.0 and normalize # We need to account for weights that don't sum to 1.0 and normalize
# the score weights accordingly, then divide score by the number of # the score weights accordingly, then divide score by the number of
# components. # components.

View File

@ -145,17 +145,16 @@ examples, see the
> ) > )
> ``` > ```
| Name | Description | | Name | Description |
| ----------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | | ----------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `name` | The name of the component factory. ~~str~~ | | `name` | The name of the component factory. ~~str~~ |
| _keyword-only_ | | | _keyword-only_ | |
| `default_config` | The default config, describing the default values of the factory arguments. ~~Dict[str, Any]~~ | | `default_config` | The default config, describing the default values of the factory arguments. ~~Dict[str, Any]~~ |
| `assigns` | `Doc` or `Token` attributes assigned by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~ | | `assigns` | `Doc` or `Token` attributes assigned by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~ |
| `requires` | `Doc` or `Token` attributes required by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~ | | `requires` | `Doc` or `Token` attributes required by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~ |
| `retokenizes` | Whether the component changes tokenization. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~bool~~ | | `retokenizes` | Whether the component changes tokenization. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~bool~~ |
| `scores` | All scores set by the components if it's trainable, e.g. `["ents_f", "ents_r", "ents_p"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~ | | `default_score_weights` | The scores to report during training, and their default weight towards the final score used to select the best model. Weights should sum to `1.0` per component and will be combined and normalized for the whole pipeline. If a weight is set to `None`, the score will not be logged or weighted. ~~Dict[str, Optional[float]]~~ |
| `default_score_weights` | The scores to report during training, and their default weight towards the final score used to select the best model. Weights should sum to `1.0` per component and will be combined and normalized for the whole pipeline. ~~Dict[str, float]~~ | | `func` | Optional function if not used a a decorator. ~~Optional[Callable[[...], Callable[[Doc], Doc]]]~~ |
| `func` | Optional function if not used a a decorator. ~~Optional[Callable[[...], Callable[[Doc], Doc]]]~~ |
## Language.\_\_call\_\_ {#call tag="method"} ## Language.\_\_call\_\_ {#call tag="method"}
@ -1036,12 +1035,12 @@ provided by the [`@Language.component`](/api/language#component) or
component is defined and stored on the `Language` class for each component component is defined and stored on the `Language` class for each component
instance and factory instance. instance and factory instance.
| Name | Description | | Name | Description |
| ----------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | | ----------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `factory` | The name of the registered component factory. ~~str~~ | | `factory` | The name of the registered component factory. ~~str~~ |
| `default_config` | The default config, describing the default values of the factory arguments. ~~Dict[str, Any]~~ | | `default_config` | The default config, describing the default values of the factory arguments. ~~Dict[str, Any]~~ |
| `assigns` | `Doc` or `Token` attributes assigned by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~ | | `assigns` | `Doc` or `Token` attributes assigned by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~ |
| `requires` | `Doc` or `Token` attributes required by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~  | | `requires` | `Doc` or `Token` attributes required by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~  |
| `retokenizes` | Whether the component changes tokenization. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~bool~~  | | `retokenizes` | Whether the component changes tokenization. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~bool~~  |
| `scores` | All scores set by the components if it's trainable, e.g. `["ents_f", "ents_r", "ents_p"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~ | | `default_score_weights` | The scores to report during training, and their default weight towards the final score used to select the best model. Weights should sum to `1.0` per component and will be combined and normalized for the whole pipeline. If a weight is set to `None`, the score will not be logged or weighted. ~~Dict[str, Optional[float]]~~ |
| `default_score_weights` | The scores to report during training, and their default weight towards the final score used to select the best model. Weights should sum to `1.0` per component and will be combined and normalized for the whole pipeline. ~~Dict[str, float]~~ | | `scores` | All scores set by the components if it's trainable, e.g. `["ents_f", "ents_r", "ents_p"]`. Based on the `default_score_weights` and used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~ |

View File

@ -470,6 +470,7 @@ score.
```ini ```ini
[training.score_weights] [training.score_weights]
dep_las = 0.4 dep_las = 0.4
dep_uas = null
ents_f = 0.4 ents_f = 0.4
tag_acc = 0.2 tag_acc = 0.2
token_acc = 0.0 token_acc = 0.0
@ -481,9 +482,9 @@ you generate a config for a given pipeline, the score weights are generated by
combining and normalizing the default score weights of the pipeline components. combining and normalizing the default score weights of the pipeline components.
The default score weights are defined by each pipeline component via the The default score weights are defined by each pipeline component via the
`default_score_weights` setting on the `default_score_weights` setting on the
[`@Language.component`](/api/language#component) or [`@Language.factory`](/api/language#factory) decorator. By default, all pipeline
[`@Language.factory`](/api/language#factory). By default, all pipeline components are weighted equally. If a score weight is set to `null`, it will be
components are weighted equally. excluded from the logs and the score won't be weighted.
<Accordion title="Understanding the training output and score types" spaced> <Accordion title="Understanding the training output and score types" spaced>