Merge pull request #6133 from explosion/fix/score_weights

This commit is contained in:
Ines Montani 2020-09-24 12:00:57 +02:00 committed by GitHub
commit c6c67b606e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 122 additions and 100 deletions

View File

@ -317,21 +317,3 @@ start = 100
stop = 1000
compound = 1.001
{% endif %}
[training.score_weights]
{%- if "tagger" in components %}
tag_acc = {{ (1.0 / components|length)|round(2) }}
{%- endif -%}
{%- if "parser" in components %}
dep_uas = 0.0
dep_las = {{ (1.0 / components|length)|round(2) }}
sents_f = 0.0
{%- endif %}
{%- if "ner" in components %}
ents_f = {{ (1.0 / components|length)|round(2) }}
ents_p = 0.0
ents_r = 0.0
{%- endif %}
{%- if "textcat" in components %}
cats_score = {{ (1.0 / components|length)|round(2) }}
{%- endif -%}

View File

@ -209,10 +209,17 @@ def create_train_batches(iterator, batcher, max_epochs: int):
def create_evaluation_callback(
nlp: Language, dev_corpus: Callable, weights: Dict[str, float]
) -> Callable[[], Tuple[float, Dict[str, float]]]:
weights = {key: value for key, value in weights.items() if value is not None}
def evaluate() -> Tuple[float, Dict[str, float]]:
dev_examples = list(dev_corpus(nlp))
scores = nlp.evaluate(dev_examples)
# Calculate a weighted sum based on score_weights for the main score
# Calculate a weighted sum based on score_weights for the main score.
# We can only consider scores that are ints/floats, not dicts like
# entity scores per type etc.
for key, value in scores.items():
if key in weights and not isinstance(value, (int, float)):
raise ValueError(Errors.E915.format(name=key, score_type=type(value)))
try:
weighted_score = sum(
scores.get(s, 0.0) * weights.get(s, 0.0) for s in weights
@ -368,7 +375,8 @@ def update_meta(
) -> None:
nlp.meta["performance"] = {}
for metric in training["score_weights"]:
nlp.meta["performance"][metric] = info["other_scores"].get(metric, 0.0)
if metric is not None:
nlp.meta["performance"][metric] = info["other_scores"].get(metric, 0.0)
for pipe_name in nlp.pipe_names:
nlp.meta["performance"][f"{pipe_name}_loss"] = info["losses"][pipe_name]

View File

@ -480,6 +480,11 @@ class Errors:
E201 = ("Span index out of range.")
# TODO: fix numbering after merging develop into master
E915 = ("Can't use score '{name}' to calculate final weighted score. Expected "
"float or int but got: {score_type}. To exclude the score from the "
"final score, set its weight to null in the [training.score_weights] "
"section of your training config.")
E916 = ("Can't log score for '{name}' in table: not a valid score ({score_type})")
E917 = ("Received invalid value {value} for 'state_type' in "
"TransitionBasedParser: only 'parser' or 'ner' are valid options.")
E918 = ("Received invalid value for vocab: {vocab} ({vocab_type}). Valid "

View File

@ -25,7 +25,6 @@ class Bengali(Language):
"lemmatizer",
assigns=["token.lemma"],
default_config={"model": None, "mode": "rule", "lookups": None},
scores=["lemma_acc"],
default_score_weights={"lemma_acc": 1.0},
)
def make_lemmatizer(

View File

@ -30,7 +30,6 @@ class Greek(Language):
"lemmatizer",
assigns=["token.lemma"],
default_config={"model": None, "mode": "rule", "lookups": None},
scores=["lemma_acc"],
default_score_weights={"lemma_acc": 1.0},
)
def make_lemmatizer(

View File

@ -29,7 +29,6 @@ class English(Language):
"lemmatizer",
assigns=["token.lemma"],
default_config={"model": None, "mode": "rule", "lookups": None},
scores=["lemma_acc"],
default_score_weights={"lemma_acc": 1.0},
)
def make_lemmatizer(

View File

@ -28,7 +28,6 @@ class Persian(Language):
"lemmatizer",
assigns=["token.lemma"],
default_config={"model": None, "mode": "rule", "lookups": None},
scores=["lemma_acc"],
default_score_weights={"lemma_acc": 1.0},
)
def make_lemmatizer(

View File

@ -33,7 +33,6 @@ class French(Language):
"lemmatizer",
assigns=["token.lemma"],
default_config={"model": None, "mode": "rule", "lookups": None},
scores=["lemma_acc"],
default_score_weights={"lemma_acc": 1.0},
)
def make_lemmatizer(

View File

@ -28,7 +28,6 @@ class Norwegian(Language):
"lemmatizer",
assigns=["token.lemma"],
default_config={"model": None, "mode": "rule", "lookups": None},
scores=["lemma_acc"],
default_score_weights={"lemma_acc": 1.0},
)
def make_lemmatizer(

View File

@ -30,7 +30,6 @@ class Dutch(Language):
"lemmatizer",
assigns=["token.lemma"],
default_config={"model": None, "mode": "rule", "lookups": None},
scores=["lemma_acc"],
default_score_weights={"lemma_acc": 1.0},
)
def make_lemmatizer(

View File

@ -35,7 +35,6 @@ class Polish(Language):
"lemmatizer",
assigns=["token.lemma"],
default_config={"model": None, "mode": "pos_lookup", "lookups": None},
scores=["lemma_acc"],
default_score_weights={"lemma_acc": 1.0},
)
def make_lemmatizer(

View File

@ -25,7 +25,6 @@ class Russian(Language):
"lemmatizer",
assigns=["token.lemma"],
default_config={"model": None, "mode": "pymorphy2", "lookups": None},
scores=["lemma_acc"],
default_score_weights={"lemma_acc": 1.0},
)
def make_lemmatizer(

View File

@ -31,7 +31,6 @@ class Swedish(Language):
"lemmatizer",
assigns=["token.lemma"],
default_config={"model": None, "mode": "rule", "lookups": None},
scores=["lemma_acc"],
default_score_weights={"lemma_acc": 1.0},
)
def make_lemmatizer(

View File

@ -25,7 +25,6 @@ class Ukrainian(Language):
"lemmatizer",
assigns=["token.lemma"],
default_config={"model": None, "mode": "pymorphy2", "lookups": None},
scores=["lemma_acc"],
default_score_weights={"lemma_acc": 1.0},
)
def make_lemmatizer(

View File

@ -248,9 +248,12 @@ class Language:
self._config["nlp"]["pipeline"] = list(self.component_names)
self._config["nlp"]["disabled"] = list(self.disabled)
self._config["components"] = pipeline
if not self._config["training"].get("score_weights"):
combined_score_weights = combine_score_weights(score_weights)
self._config["training"]["score_weights"] = combined_score_weights
# We're merging the existing score weights back into the combined
# weights to make sure we're preserving custom settings in the config
# but also reflect updates (e.g. new components added)
prev_weights = self._config["training"].get("score_weights", {})
combined_score_weights = combine_score_weights(score_weights, prev_weights)
self._config["training"]["score_weights"] = combined_score_weights
if not srsly.is_json_serializable(self._config):
raise ValueError(Errors.E961.format(config=self._config))
return self._config
@ -412,7 +415,6 @@ class Language:
assigns: Iterable[str] = SimpleFrozenList(),
requires: Iterable[str] = SimpleFrozenList(),
retokenizes: bool = False,
scores: Iterable[str] = SimpleFrozenList(),
default_score_weights: Dict[str, float] = SimpleFrozenDict(),
func: Optional[Callable] = None,
) -> Callable:
@ -430,12 +432,11 @@ class Language:
e.g. "token.ent_id". Used for pipeline analyis.
retokenizes (bool): Whether the component changes the tokenization.
Used for pipeline analysis.
scores (Iterable[str]): All scores set by the component if it's trainable,
e.g. ["ents_f", "ents_r", "ents_p"].
default_score_weights (Dict[str, float]): The scores to report during
training, and their default weight towards the final score used to
select the best model. Weights should sum to 1.0 per component and
will be combined and normalized for the whole pipeline.
will be combined and normalized for the whole pipeline. If None,
the score won't be shown in the logs or be weighted.
func (Optional[Callable]): Factory function if not used as a decorator.
DOCS: https://nightly.spacy.io/api/language#factory
@ -475,7 +476,7 @@ class Language:
default_config=default_config,
assigns=validate_attrs(assigns),
requires=validate_attrs(requires),
scores=scores,
scores=list(default_score_weights.keys()),
default_score_weights=default_score_weights,
retokenizes=retokenizes,
)

View File

@ -43,8 +43,14 @@ DEFAULT_PARSER_MODEL = Config().from_str(default_model_config)["model"]
"min_action_freq": 30,
"model": DEFAULT_PARSER_MODEL,
},
scores=["dep_uas", "dep_las", "dep_las_per_type", "sents_p", "sents_r", "sents_f"],
default_score_weights={"dep_uas": 0.5, "dep_las": 0.5, "sents_f": 0.0},
default_score_weights={
"dep_uas": 0.5,
"dep_las": 0.5,
"dep_las_per_type": None,
"sents_p": None,
"sents_r": None,
"sents_f": 0.0,
},
)
def make_parser(
nlp: Language,

View File

@ -25,8 +25,12 @@ PatternType = Dict[str, Union[str, List[Dict[str, Any]]]]
"overwrite_ents": False,
"ent_id_sep": DEFAULT_ENT_ID_SEP,
},
scores=["ents_p", "ents_r", "ents_f", "ents_per_type"],
default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0},
default_score_weights={
"ents_f": 1.0,
"ents_p": 0.0,
"ents_r": 0.0,
"ents_per_type": None,
},
)
def make_entity_ruler(
nlp: Language,

View File

@ -21,7 +21,6 @@ from .. import util
"lookups": None,
"overwrite": False,
},
scores=["lemma_acc"],
default_score_weights={"lemma_acc": 1.0},
)
def make_lemmatizer(

View File

@ -49,8 +49,7 @@ DEFAULT_MORPH_MODEL = Config().from_str(default_model_config)["model"]
"morphologizer",
assigns=["token.morph", "token.pos"],
default_config={"model": DEFAULT_MORPH_MODEL},
scores=["pos_acc", "morph_acc", "morph_per_feat"],
default_score_weights={"pos_acc": 0.5, "morph_acc": 0.5},
default_score_weights={"pos_acc": 0.5, "morph_acc": 0.5, "morph_per_feat": None},
)
def make_morphologizer(
nlp: Language,

View File

@ -39,8 +39,7 @@ DEFAULT_NER_MODEL = Config().from_str(default_model_config)["model"]
"update_with_oracle_cut_size": 100,
"model": DEFAULT_NER_MODEL,
},
scores=["ents_p", "ents_r", "ents_f", "ents_per_type"],
default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0},
default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0, "ents_per_type": None},
)
def make_ner(

View File

@ -15,7 +15,6 @@ from .. import util
"sentencizer",
assigns=["token.is_sent_start", "doc.sents"],
default_config={"punct_chars": None},
scores=["sents_p", "sents_r", "sents_f"],
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
)
def make_sentencizer(

View File

@ -36,7 +36,6 @@ DEFAULT_SENTER_MODEL = Config().from_str(default_model_config)["model"]
"senter",
assigns=["token.is_sent_start"],
default_config={"model": DEFAULT_SENTER_MODEL},
scores=["sents_p", "sents_r", "sents_f"],
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
)
def make_senter(nlp: Language, name: str, model: Model):

View File

@ -42,7 +42,6 @@ DEFAULT_TAGGER_MODEL = Config().from_str(default_model_config)["model"]
"tagger",
assigns=["token.tag"],
default_config={"model": DEFAULT_TAGGER_MODEL},
scores=["tag_acc"],
default_score_weights={"tag_acc": 1.0},
)
def make_tagger(nlp: Language, name: str, model: Model):

View File

@ -62,18 +62,17 @@ subword_features = true
"positive_label": None,
"model": DEFAULT_TEXTCAT_MODEL,
},
scores=[
"cats_score",
"cats_score_desc",
"cats_p",
"cats_r",
"cats_f",
"cats_macro_f",
"cats_macro_auc",
"cats_f_per_type",
"cats_macro_auc_per_type",
],
default_score_weights={"cats_score": 1.0},
default_score_weights={
"cats_score": 1.0,
"cats_score_desc": None,
"cats_p": None,
"cats_r": None,
"cats_f": None,
"cats_macro_f": None,
"cats_macro_auc": None,
"cats_f_per_type": None,
"cats_macro_auc_per_type": None,
},
)
def make_textcat(
nlp: Language,

View File

@ -211,7 +211,7 @@ class ConfigSchemaTraining(BaseModel):
seed: Optional[StrictInt] = Field(..., title="Random seed")
gpu_allocator: Optional[StrictStr] = Field(..., title="Memory allocator when running on GPU")
accumulate_gradient: StrictInt = Field(..., title="Whether to divide the batch up into substeps")
score_weights: Dict[StrictStr, Union[StrictFloat, StrictInt]] = Field(..., title="Scores to report and their weights for selecting final model")
score_weights: Dict[StrictStr, Optional[Union[StrictFloat, StrictInt]]] = Field(..., title="Scores to report and their weights for selecting final model")
init_tok2vec: Optional[StrictStr] = Field(..., title="Path to pretrained tok2vec weights")
raw_text: Optional[StrictStr] = Field(default=None, title="Raw text")
optimizer: Optimizer = Field(..., title="The optimizer to use")

View File

@ -359,12 +359,8 @@ def test_language_factories_scores():
func = lambda nlp, name: lambda doc: doc
weights1 = {"a1": 0.5, "a2": 0.5}
weights2 = {"b1": 0.2, "b2": 0.7, "b3": 0.1}
Language.factory(
f"{name}1", scores=list(weights1), default_score_weights=weights1, func=func,
)
Language.factory(
f"{name}2", scores=list(weights2), default_score_weights=weights2, func=func,
)
Language.factory(f"{name}1", default_score_weights=weights1, func=func)
Language.factory(f"{name}2", default_score_weights=weights2, func=func)
meta1 = Language.get_factory_meta(f"{name}1")
assert meta1.default_score_weights == weights1
meta2 = Language.get_factory_meta(f"{name}2")
@ -376,6 +372,21 @@ def test_language_factories_scores():
cfg = nlp.config["training"]
expected_weights = {"a1": 0.25, "a2": 0.25, "b1": 0.1, "b2": 0.35, "b3": 0.05}
assert cfg["score_weights"] == expected_weights
# Test with custom defaults
config = nlp.config.copy()
config["training"]["score_weights"]["a1"] = 0.0
config["training"]["score_weights"]["b3"] = 1.0
nlp = English.from_config(config)
score_weights = nlp.config["training"]["score_weights"]
expected = {"a1": 0.0, "a2": 0.5, "b1": 0.03, "b2": 0.12, "b3": 0.34}
assert score_weights == expected
# Test with null values
config = nlp.config.copy()
config["training"]["score_weights"]["a1"] = None
nlp = English.from_config(config)
score_weights = nlp.config["training"]["score_weights"]
expected = {"a1": None, "a2": 0.5, "b1": 0.03, "b2": 0.12, "b3": 0.35}
assert score_weights == expected
def test_pipe_factories_from_source():

View File

@ -13,7 +13,8 @@ def console_logger():
) -> Tuple[Callable[[Dict[str, Any]], None], Callable]:
# we assume here that only components are enabled that should be trained & logged
logged_pipes = nlp.pipe_names
score_cols = list(nlp.config["training"]["score_weights"])
score_weights = nlp.config["training"]["score_weights"]
score_cols = [col for col, value in score_weights.items() if value is not None]
score_widths = [max(len(col), 6) for col in score_cols]
loss_cols = [f"Loss {pipe}" for pipe in logged_pipes]
loss_widths = [max(len(col), 8) for col in loss_cols]
@ -40,10 +41,15 @@ def console_logger():
) from None
scores = []
for col in score_cols:
score = float(info["other_scores"].get(col, 0.0))
if col != "speed":
score *= 100
scores.append("{0:.2f}".format(score))
score = info["other_scores"].get(col, 0.0)
try:
score = float(score)
if col != "speed":
score *= 100
scores.append("{0:.2f}".format(score))
except TypeError:
err = Errors.E916.format(name=col, score_type=type(score))
raise ValueError(err) from None
data = (
[info["epoch"], info["step"]]
+ losses

View File

@ -1202,21 +1202,38 @@ def get_arg_names(func: Callable) -> List[str]:
return list(set([*argspec.args, *argspec.kwonlyargs]))
def combine_score_weights(weights: List[Dict[str, float]]) -> Dict[str, float]:
def combine_score_weights(
weights: List[Dict[str, float]],
overrides: Dict[str, Optional[Union[float, int]]] = SimpleFrozenDict(),
) -> Dict[str, float]:
"""Combine and normalize score weights defined by components, e.g.
{"ents_r": 0.2, "ents_p": 0.3, "ents_f": 0.5} and {"some_other_score": 1.0}.
weights (List[dict]): The weights defined by the components.
overrides (Dict[str, Optional[Union[float, int]]]): Existing scores that
should be preserved.
RETURNS (Dict[str, float]): The combined and normalized weights.
"""
# We first need to extract all None/null values for score weights that
# shouldn't be shown in the table *or* be weighted
result = {}
all_weights = []
for w_dict in weights:
filtered_weights = {}
for key, value in w_dict.items():
value = overrides.get(key, value)
if value is None:
result[key] = None
else:
filtered_weights[key] = value
all_weights.append(filtered_weights)
for w_dict in all_weights:
# We need to account for weights that don't sum to 1.0 and normalize
# the score weights accordingly, then divide score by the number of
# components.
total = sum(w_dict.values())
for key, value in w_dict.items():
weight = round(value / total / len(weights), 2)
weight = round(value / total / len(all_weights), 2)
result[key] = result.get(key, 0.0) + weight
return result

View File

@ -145,17 +145,16 @@ examples, see the
> )
> ```
| Name | Description |
| ----------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `name` | The name of the component factory. ~~str~~ |
| _keyword-only_ | |
| `default_config` | The default config, describing the default values of the factory arguments. ~~Dict[str, Any]~~ |
| `assigns` | `Doc` or `Token` attributes assigned by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~ |
| `requires` | `Doc` or `Token` attributes required by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~ |
| `retokenizes` | Whether the component changes tokenization. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~bool~~ |
| `scores` | All scores set by the components if it's trainable, e.g. `["ents_f", "ents_r", "ents_p"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~ |
| `default_score_weights` | The scores to report during training, and their default weight towards the final score used to select the best model. Weights should sum to `1.0` per component and will be combined and normalized for the whole pipeline. ~~Dict[str, float]~~ |
| `func` | Optional function if not used a a decorator. ~~Optional[Callable[[...], Callable[[Doc], Doc]]]~~ |
| Name | Description |
| ----------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `name` | The name of the component factory. ~~str~~ |
| _keyword-only_ | |
| `default_config` | The default config, describing the default values of the factory arguments. ~~Dict[str, Any]~~ |
| `assigns` | `Doc` or `Token` attributes assigned by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~ |
| `requires` | `Doc` or `Token` attributes required by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~ |
| `retokenizes` | Whether the component changes tokenization. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~bool~~ |
| `default_score_weights` | The scores to report during training, and their default weight towards the final score used to select the best model. Weights should sum to `1.0` per component and will be combined and normalized for the whole pipeline. If a weight is set to `None`, the score will not be logged or weighted. ~~Dict[str, Optional[float]]~~ |
| `func` | Optional function if not used a a decorator. ~~Optional[Callable[[...], Callable[[Doc], Doc]]]~~ |
## Language.\_\_call\_\_ {#call tag="method"}
@ -1036,12 +1035,12 @@ provided by the [`@Language.component`](/api/language#component) or
component is defined and stored on the `Language` class for each component
instance and factory instance.
| Name | Description |
| ----------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `factory` | The name of the registered component factory. ~~str~~ |
| `default_config` | The default config, describing the default values of the factory arguments. ~~Dict[str, Any]~~ |
| `assigns` | `Doc` or `Token` attributes assigned by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~ |
| `requires` | `Doc` or `Token` attributes required by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~  |
| `retokenizes` | Whether the component changes tokenization. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~bool~~  |
| `scores` | All scores set by the components if it's trainable, e.g. `["ents_f", "ents_r", "ents_p"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~ |
| `default_score_weights` | The scores to report during training, and their default weight towards the final score used to select the best model. Weights should sum to `1.0` per component and will be combined and normalized for the whole pipeline. ~~Dict[str, float]~~ |
| Name | Description |
| ----------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `factory` | The name of the registered component factory. ~~str~~ |
| `default_config` | The default config, describing the default values of the factory arguments. ~~Dict[str, Any]~~ |
| `assigns` | `Doc` or `Token` attributes assigned by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~ |
| `requires` | `Doc` or `Token` attributes required by this component, e.g. `["token.ent_id"]`. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~  |
| `retokenizes` | Whether the component changes tokenization. Used for [pipe analysis](/usage/processing-pipelines#analysis). ~~bool~~  |
| `default_score_weights` | The scores to report during training, and their default weight towards the final score used to select the best model. Weights should sum to `1.0` per component and will be combined and normalized for the whole pipeline. If a weight is set to `None`, the score will not be logged or weighted. ~~Dict[str, Optional[float]]~~ |
| `scores` | All scores set by the components if it's trainable, e.g. `["ents_f", "ents_r", "ents_p"]`. Based on the `default_score_weights` and used for [pipe analysis](/usage/processing-pipelines#analysis). ~~Iterable[str]~~ |

View File

@ -470,6 +470,7 @@ score.
```ini
[training.score_weights]
dep_las = 0.4
dep_uas = null
ents_f = 0.4
tag_acc = 0.2
token_acc = 0.0
@ -481,9 +482,9 @@ you generate a config for a given pipeline, the score weights are generated by
combining and normalizing the default score weights of the pipeline components.
The default score weights are defined by each pipeline component via the
`default_score_weights` setting on the
[`@Language.component`](/api/language#component) or
[`@Language.factory`](/api/language#factory). By default, all pipeline
components are weighted equally.
[`@Language.factory`](/api/language#factory) decorator. By default, all pipeline
components are weighted equally. If a score weight is set to `null`, it will be
excluded from the logs and the score won't be weighted.
<Accordion title="Understanding the training output and score types" spaced>