mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 09:14:32 +03:00
Fix combined scores and update test
This commit is contained in:
parent
ae51f580c1
commit
4bbe41f017
|
@ -251,11 +251,8 @@ class Language:
|
|||
# We're merging the existing score weights back into the combined
|
||||
# weights to make sure we're preserving custom settings in the config
|
||||
# but also reflect updates (e.g. new components added)
|
||||
prev_score_weights = self._config["training"].get("score_weights", {})
|
||||
combined_score_weights = combine_score_weights(score_weights)
|
||||
combined_score_weights.update(prev_score_weights)
|
||||
# Combine the scores a second time to normalize them
|
||||
combined_score_weights = combine_score_weights([combined_score_weights])
|
||||
prev_weights = self._config["training"].get("score_weights", {})
|
||||
combined_score_weights = combine_score_weights(score_weights, prev_weights)
|
||||
self._config["training"]["score_weights"] = combined_score_weights
|
||||
if not srsly.is_json_serializable(self._config):
|
||||
raise ValueError(Errors.E961.format(config=self._config))
|
||||
|
|
|
@ -378,14 +378,14 @@ def test_language_factories_scores():
|
|||
config["training"]["score_weights"]["b3"] = 1.0
|
||||
nlp = English.from_config(config)
|
||||
score_weights = nlp.config["training"]["score_weights"]
|
||||
expected = {"a1": 0.0, "a2": 0.15, "b1": 0.06, "b2": 0.21, "b3": 0.59}
|
||||
expected = {"a1": 0.0, "a2": 0.5, "b1": 0.03, "b2": 0.12, "b3": 0.34}
|
||||
assert score_weights == expected
|
||||
# Test with null values
|
||||
config = nlp.config.copy()
|
||||
config["training"]["score_weights"]["a1"] = None
|
||||
nlp = English.from_config(config)
|
||||
score_weights = nlp.config["training"]["score_weights"]
|
||||
expected = {"a1": None, "a2": 0.15, "b1": 0.06, "b2": 0.21, "b3": 0.58} # rounding :(
|
||||
expected = {"a1": None, "a2": 0.5, "b1": 0.03, "b2": 0.12, "b3": 0.35}
|
||||
assert score_weights == expected
|
||||
|
||||
|
||||
|
|
|
@ -1202,11 +1202,16 @@ def get_arg_names(func: Callable) -> List[str]:
|
|||
return list(set([*argspec.args, *argspec.kwonlyargs]))
|
||||
|
||||
|
||||
def combine_score_weights(weights: List[Dict[str, float]]) -> Dict[str, float]:
|
||||
def combine_score_weights(
|
||||
weights: List[Dict[str, float]],
|
||||
overrides: Dict[str, Optional[Union[float, int]]] = SimpleFrozenDict(),
|
||||
) -> Dict[str, float]:
|
||||
"""Combine and normalize score weights defined by components, e.g.
|
||||
{"ents_r": 0.2, "ents_p": 0.3, "ents_f": 0.5} and {"some_other_score": 1.0}.
|
||||
|
||||
weights (List[dict]): The weights defined by the components.
|
||||
overrides (Dict[str, Optional[Union[float, int]]]): Existing scores that
|
||||
should be preserved.
|
||||
RETURNS (Dict[str, float]): The combined and normalized weights.
|
||||
"""
|
||||
# We first need to extract all None/null values for score weights that
|
||||
|
@ -1216,6 +1221,7 @@ def combine_score_weights(weights: List[Dict[str, float]]) -> Dict[str, float]:
|
|||
for w_dict in weights:
|
||||
filtered_weights = {}
|
||||
for key, value in w_dict.items():
|
||||
value = overrides.get(key, value)
|
||||
if value is None:
|
||||
result[key] = None
|
||||
else:
|
||||
|
@ -1227,7 +1233,7 @@ def combine_score_weights(weights: List[Dict[str, float]]) -> Dict[str, float]:
|
|||
# components.
|
||||
total = sum(w_dict.values())
|
||||
for key, value in w_dict.items():
|
||||
weight = round(value / total / len(weights), 2)
|
||||
weight = round(value / total / len(all_weights), 2)
|
||||
result[key] = result.get(key, 0.0) + weight
|
||||
return result
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user