mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Fix combined scores and update test
This commit is contained in:
parent
ae51f580c1
commit
4bbe41f017
|
@ -251,11 +251,8 @@ class Language:
|
||||||
# We're merging the existing score weights back into the combined
|
# We're merging the existing score weights back into the combined
|
||||||
# weights to make sure we're preserving custom settings in the config
|
# weights to make sure we're preserving custom settings in the config
|
||||||
# but also reflect updates (e.g. new components added)
|
# but also reflect updates (e.g. new components added)
|
||||||
prev_score_weights = self._config["training"].get("score_weights", {})
|
prev_weights = self._config["training"].get("score_weights", {})
|
||||||
combined_score_weights = combine_score_weights(score_weights)
|
combined_score_weights = combine_score_weights(score_weights, prev_weights)
|
||||||
combined_score_weights.update(prev_score_weights)
|
|
||||||
# Combine the scores a second time to normalize them
|
|
||||||
combined_score_weights = combine_score_weights([combined_score_weights])
|
|
||||||
self._config["training"]["score_weights"] = combined_score_weights
|
self._config["training"]["score_weights"] = combined_score_weights
|
||||||
if not srsly.is_json_serializable(self._config):
|
if not srsly.is_json_serializable(self._config):
|
||||||
raise ValueError(Errors.E961.format(config=self._config))
|
raise ValueError(Errors.E961.format(config=self._config))
|
||||||
|
|
|
@ -378,14 +378,14 @@ def test_language_factories_scores():
|
||||||
config["training"]["score_weights"]["b3"] = 1.0
|
config["training"]["score_weights"]["b3"] = 1.0
|
||||||
nlp = English.from_config(config)
|
nlp = English.from_config(config)
|
||||||
score_weights = nlp.config["training"]["score_weights"]
|
score_weights = nlp.config["training"]["score_weights"]
|
||||||
expected = {"a1": 0.0, "a2": 0.15, "b1": 0.06, "b2": 0.21, "b3": 0.59}
|
expected = {"a1": 0.0, "a2": 0.5, "b1": 0.03, "b2": 0.12, "b3": 0.34}
|
||||||
assert score_weights == expected
|
assert score_weights == expected
|
||||||
# Test with null values
|
# Test with null values
|
||||||
config = nlp.config.copy()
|
config = nlp.config.copy()
|
||||||
config["training"]["score_weights"]["a1"] = None
|
config["training"]["score_weights"]["a1"] = None
|
||||||
nlp = English.from_config(config)
|
nlp = English.from_config(config)
|
||||||
score_weights = nlp.config["training"]["score_weights"]
|
score_weights = nlp.config["training"]["score_weights"]
|
||||||
expected = {"a1": None, "a2": 0.15, "b1": 0.06, "b2": 0.21, "b3": 0.58} # rounding :(
|
expected = {"a1": None, "a2": 0.5, "b1": 0.03, "b2": 0.12, "b3": 0.35}
|
||||||
assert score_weights == expected
|
assert score_weights == expected
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1202,11 +1202,16 @@ def get_arg_names(func: Callable) -> List[str]:
|
||||||
return list(set([*argspec.args, *argspec.kwonlyargs]))
|
return list(set([*argspec.args, *argspec.kwonlyargs]))
|
||||||
|
|
||||||
|
|
||||||
def combine_score_weights(weights: List[Dict[str, float]]) -> Dict[str, float]:
|
def combine_score_weights(
|
||||||
|
weights: List[Dict[str, float]],
|
||||||
|
overrides: Dict[str, Optional[Union[float, int]]] = SimpleFrozenDict(),
|
||||||
|
) -> Dict[str, float]:
|
||||||
"""Combine and normalize score weights defined by components, e.g.
|
"""Combine and normalize score weights defined by components, e.g.
|
||||||
{"ents_r": 0.2, "ents_p": 0.3, "ents_f": 0.5} and {"some_other_score": 1.0}.
|
{"ents_r": 0.2, "ents_p": 0.3, "ents_f": 0.5} and {"some_other_score": 1.0}.
|
||||||
|
|
||||||
weights (List[dict]): The weights defined by the components.
|
weights (List[dict]): The weights defined by the components.
|
||||||
|
overrides (Dict[str, Optional[Union[float, int]]]): Existing scores that
|
||||||
|
should be preserved.
|
||||||
RETURNS (Dict[str, float]): The combined and normalized weights.
|
RETURNS (Dict[str, float]): The combined and normalized weights.
|
||||||
"""
|
"""
|
||||||
# We first need to extract all None/null values for score weights that
|
# We first need to extract all None/null values for score weights that
|
||||||
|
@ -1216,6 +1221,7 @@ def combine_score_weights(weights: List[Dict[str, float]]) -> Dict[str, float]:
|
||||||
for w_dict in weights:
|
for w_dict in weights:
|
||||||
filtered_weights = {}
|
filtered_weights = {}
|
||||||
for key, value in w_dict.items():
|
for key, value in w_dict.items():
|
||||||
|
value = overrides.get(key, value)
|
||||||
if value is None:
|
if value is None:
|
||||||
result[key] = None
|
result[key] = None
|
||||||
else:
|
else:
|
||||||
|
@ -1227,7 +1233,7 @@ def combine_score_weights(weights: List[Dict[str, float]]) -> Dict[str, float]:
|
||||||
# components.
|
# components.
|
||||||
total = sum(w_dict.values())
|
total = sum(w_dict.values())
|
||||||
for key, value in w_dict.items():
|
for key, value in w_dict.items():
|
||||||
weight = round(value / total / len(weights), 2)
|
weight = round(value / total / len(all_weights), 2)
|
||||||
result[key] = result.get(key, 0.0) + weight
|
result[key] = result.get(key, 0.0) + weight
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user