Fix scoring normalization (#7629)

* fix scoring normalization

* score weights by total sum instead of per component

* cleanup

* more cleanup
This commit is contained in:
Sofie Van Landeghem 2021-04-26 16:53:38 +02:00 committed by GitHub
parent 95e3cf576b
commit e0b29f8ef7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 36 deletions

View File

@ -334,24 +334,31 @@ def test_language_factories_invalid():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"weights,expected", "weights,override,expected",
[ [
([{"a": 1.0}, {"b": 1.0}, {"c": 1.0}], {"a": 0.33, "b": 0.33, "c": 0.33}), ([{"a": 1.0}, {"b": 1.0}, {"c": 1.0}], {}, {"a": 0.33, "b": 0.33, "c": 0.33}),
([{"a": 1.0}, {"b": 50}, {"c": 123}], {"a": 0.33, "b": 0.33, "c": 0.33}), ([{"a": 1.0}, {"b": 50}, {"c": 100}], {}, {"a": 0.01, "b": 0.33, "c": 0.66}),
( (
[{"a": 0.7, "b": 0.3}, {"c": 1.0}, {"d": 0.5, "e": 0.5}], [{"a": 0.7, "b": 0.3}, {"c": 1.0}, {"d": 0.5, "e": 0.5}],
{},
{"a": 0.23, "b": 0.1, "c": 0.33, "d": 0.17, "e": 0.17}, {"a": 0.23, "b": 0.1, "c": 0.33, "d": 0.17, "e": 0.17},
), ),
( (
[{"a": 100, "b": 400}, {"c": 0.5, "d": 0.5}], [{"a": 100, "b": 300}, {"c": 50, "d": 50}],
{"a": 0.1, "b": 0.4, "c": 0.25, "d": 0.25}, {},
{"a": 0.2, "b": 0.6, "c": 0.1, "d": 0.1},
), ),
([{"a": 0.5, "b": 0.5}, {"b": 1.0}], {"a": 0.25, "b": 0.75}), ([{"a": 0.5, "b": 0.5}, {"b": 1.0}], {}, {"a": 0.33, "b": 0.67}),
([{"a": 0.0, "b": 0.0}, {"c": 0.0}], {"a": 0.0, "b": 0.0, "c": 0.0}), ([{"a": 0.5, "b": 0.0}], {}, {"a": 1.0, "b": 0.0}),
([{"a": 0.5, "b": 0.5}, {"b": 1.0}], {"a": 0.0}, {"a": 0.0, "b": 1.0}),
([{"a": 0.0, "b": 0.0}, {"c": 0.0}], {}, {"a": 0.0, "b": 0.0, "c": 0.0}),
([{"a": 0.0, "b": 0.0}, {"c": 1.0}], {}, {"a": 0.0, "b": 0.0, "c": 1.0}),
([{"a": 0.0, "b": 0.0}, {"c": 0.0}], {"c": 0.2}, {"a": 0.0, "b": 0.0, "c": 1.0}),
([{"a": 0.5, "b": 0.5, "c": 1.0, "d": 1.0}], {"a": 0.0, "b": 0.0}, {"a": 0.0, "b": 0.0, "c": 0.5, "d": 0.5}),
], ],
) )
def test_language_factories_combine_score_weights(weights, expected): def test_language_factories_combine_score_weights(weights, override, expected):
result = combine_score_weights(weights) result = combine_score_weights(weights, override)
assert sum(result.values()) in (0.99, 1.0, 0.0) assert sum(result.values()) in (0.99, 1.0, 0.0)
assert result == expected assert result == expected
@ -377,17 +384,17 @@ def test_language_factories_scores():
# Test with custom defaults # Test with custom defaults
config = nlp.config.copy() config = nlp.config.copy()
config["training"]["score_weights"]["a1"] = 0.0 config["training"]["score_weights"]["a1"] = 0.0
config["training"]["score_weights"]["b3"] = 1.0 config["training"]["score_weights"]["b3"] = 1.3
nlp = English.from_config(config) nlp = English.from_config(config)
score_weights = nlp.config["training"]["score_weights"] score_weights = nlp.config["training"]["score_weights"]
expected = {"a1": 0.0, "a2": 0.5, "b1": 0.03, "b2": 0.12, "b3": 0.34} expected = {"a1": 0.0, "a2": 0.12, "b1": 0.05, "b2": 0.17, "b3": 0.65}
assert score_weights == expected assert score_weights == expected
# Test with null values # Test with null values
config = nlp.config.copy() config = nlp.config.copy()
config["training"]["score_weights"]["a1"] = None config["training"]["score_weights"]["a1"] = None
nlp = English.from_config(config) nlp = English.from_config(config)
score_weights = nlp.config["training"]["score_weights"] score_weights = nlp.config["training"]["score_weights"]
expected = {"a1": None, "a2": 0.5, "b1": 0.03, "b2": 0.12, "b3": 0.35} expected = {"a1": None, "a2": 0.12, "b1": 0.05, "b2": 0.17, "b3": 0.66}
assert score_weights == expected assert score_weights == expected

View File

@ -1369,32 +1369,14 @@ def combine_score_weights(
should be preserved. should be preserved.
RETURNS (Dict[str, float]): The combined and normalized weights. RETURNS (Dict[str, float]): The combined and normalized weights.
""" """
# We divide each weight by the total weight sum.
# We first need to extract all None/null values for score weights that # We first need to extract all None/null values for score weights that
# shouldn't be shown in the table *or* be weighted # shouldn't be shown in the table *or* be weighted
result = {} result = {key: overrides.get(key, value) for w_dict in weights for (key, value) in w_dict.items()}
all_weights = [] weight_sum = sum([v if v else 0.0 for v in result.values()])
for w_dict in weights: for key, value in result.items():
filtered_weights = {} if value and weight_sum > 0:
for key, value in w_dict.items(): result[key] = round(value / weight_sum, 2)
value = overrides.get(key, value)
if value is None:
result[key] = None
else:
filtered_weights[key] = value
all_weights.append(filtered_weights)
for w_dict in all_weights:
# We need to account for weights that don't sum to 1.0 and normalize
# the score weights accordingly, then divide score by the number of
# components.
total = sum(w_dict.values())
for key, value in w_dict.items():
if total == 0:
weight = 0.0
else:
weight = round(value / total / len(all_weights), 2)
prev_weight = result.get(key, 0.0)
prev_weight = 0.0 if prev_weight is None else prev_weight
result[key] = prev_weight + weight
return result return result