mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Fix scoring normalization (#7629)
* fix scoring normalization * score weights by total sum instead of per component * cleanup * more cleanup
This commit is contained in:
parent
95e3cf576b
commit
e0b29f8ef7
|
@ -334,24 +334,31 @@ def test_language_factories_invalid():
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weights,expected",
|
||||
"weights,override,expected",
|
||||
[
|
||||
([{"a": 1.0}, {"b": 1.0}, {"c": 1.0}], {"a": 0.33, "b": 0.33, "c": 0.33}),
|
||||
([{"a": 1.0}, {"b": 50}, {"c": 123}], {"a": 0.33, "b": 0.33, "c": 0.33}),
|
||||
([{"a": 1.0}, {"b": 1.0}, {"c": 1.0}], {}, {"a": 0.33, "b": 0.33, "c": 0.33}),
|
||||
([{"a": 1.0}, {"b": 50}, {"c": 100}], {}, {"a": 0.01, "b": 0.33, "c": 0.66}),
|
||||
(
|
||||
[{"a": 0.7, "b": 0.3}, {"c": 1.0}, {"d": 0.5, "e": 0.5}],
|
||||
{},
|
||||
{"a": 0.23, "b": 0.1, "c": 0.33, "d": 0.17, "e": 0.17},
|
||||
),
|
||||
(
|
||||
[{"a": 100, "b": 400}, {"c": 0.5, "d": 0.5}],
|
||||
{"a": 0.1, "b": 0.4, "c": 0.25, "d": 0.25},
|
||||
[{"a": 100, "b": 300}, {"c": 50, "d": 50}],
|
||||
{},
|
||||
{"a": 0.2, "b": 0.6, "c": 0.1, "d": 0.1},
|
||||
),
|
||||
([{"a": 0.5, "b": 0.5}, {"b": 1.0}], {"a": 0.25, "b": 0.75}),
|
||||
([{"a": 0.0, "b": 0.0}, {"c": 0.0}], {"a": 0.0, "b": 0.0, "c": 0.0}),
|
||||
([{"a": 0.5, "b": 0.5}, {"b": 1.0}], {}, {"a": 0.33, "b": 0.67}),
|
||||
([{"a": 0.5, "b": 0.0}], {}, {"a": 1.0, "b": 0.0}),
|
||||
([{"a": 0.5, "b": 0.5}, {"b": 1.0}], {"a": 0.0}, {"a": 0.0, "b": 1.0}),
|
||||
([{"a": 0.0, "b": 0.0}, {"c": 0.0}], {}, {"a": 0.0, "b": 0.0, "c": 0.0}),
|
||||
([{"a": 0.0, "b": 0.0}, {"c": 1.0}], {}, {"a": 0.0, "b": 0.0, "c": 1.0}),
|
||||
([{"a": 0.0, "b": 0.0}, {"c": 0.0}], {"c": 0.2}, {"a": 0.0, "b": 0.0, "c": 1.0}),
|
||||
([{"a": 0.5, "b": 0.5, "c": 1.0, "d": 1.0}], {"a": 0.0, "b": 0.0}, {"a": 0.0, "b": 0.0, "c": 0.5, "d": 0.5}),
|
||||
],
|
||||
)
|
||||
def test_language_factories_combine_score_weights(weights, expected):
|
||||
result = combine_score_weights(weights)
|
||||
def test_language_factories_combine_score_weights(weights, override, expected):
|
||||
result = combine_score_weights(weights, override)
|
||||
assert sum(result.values()) in (0.99, 1.0, 0.0)
|
||||
assert result == expected
|
||||
|
||||
|
@ -377,17 +384,17 @@ def test_language_factories_scores():
|
|||
# Test with custom defaults
|
||||
config = nlp.config.copy()
|
||||
config["training"]["score_weights"]["a1"] = 0.0
|
||||
config["training"]["score_weights"]["b3"] = 1.0
|
||||
config["training"]["score_weights"]["b3"] = 1.3
|
||||
nlp = English.from_config(config)
|
||||
score_weights = nlp.config["training"]["score_weights"]
|
||||
expected = {"a1": 0.0, "a2": 0.5, "b1": 0.03, "b2": 0.12, "b3": 0.34}
|
||||
expected = {"a1": 0.0, "a2": 0.12, "b1": 0.05, "b2": 0.17, "b3": 0.65}
|
||||
assert score_weights == expected
|
||||
# Test with null values
|
||||
config = nlp.config.copy()
|
||||
config["training"]["score_weights"]["a1"] = None
|
||||
nlp = English.from_config(config)
|
||||
score_weights = nlp.config["training"]["score_weights"]
|
||||
expected = {"a1": None, "a2": 0.5, "b1": 0.03, "b2": 0.12, "b3": 0.35}
|
||||
expected = {"a1": None, "a2": 0.12, "b1": 0.05, "b2": 0.17, "b3": 0.66}
|
||||
assert score_weights == expected
|
||||
|
||||
|
||||
|
|
|
@ -1369,32 +1369,14 @@ def combine_score_weights(
|
|||
should be preserved.
|
||||
RETURNS (Dict[str, float]): The combined and normalized weights.
|
||||
"""
|
||||
# We divide each weight by the total weight sum.
|
||||
# We first need to extract all None/null values for score weights that
|
||||
# shouldn't be shown in the table *or* be weighted
|
||||
result = {}
|
||||
all_weights = []
|
||||
for w_dict in weights:
|
||||
filtered_weights = {}
|
||||
for key, value in w_dict.items():
|
||||
value = overrides.get(key, value)
|
||||
if value is None:
|
||||
result[key] = None
|
||||
else:
|
||||
filtered_weights[key] = value
|
||||
all_weights.append(filtered_weights)
|
||||
for w_dict in all_weights:
|
||||
# We need to account for weights that don't sum to 1.0 and normalize
|
||||
# the score weights accordingly, then divide score by the number of
|
||||
# components.
|
||||
total = sum(w_dict.values())
|
||||
for key, value in w_dict.items():
|
||||
if total == 0:
|
||||
weight = 0.0
|
||||
else:
|
||||
weight = round(value / total / len(all_weights), 2)
|
||||
prev_weight = result.get(key, 0.0)
|
||||
prev_weight = 0.0 if prev_weight is None else prev_weight
|
||||
result[key] = prev_weight + weight
|
||||
result = {key: overrides.get(key, value) for w_dict in weights for (key, value) in w_dict.items()}
|
||||
weight_sum = sum([v if v else 0.0 for v in result.values()])
|
||||
for key, value in result.items():
|
||||
if value and weight_sum > 0:
|
||||
result[key] = round(value / weight_sum, 2)
|
||||
return result
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user