mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Prevent division by zero in score weights
This commit is contained in:
parent
74ee456374
commit
d0ef4a4cf5
|
@ -345,12 +345,13 @@ def test_language_factories_invalid():
|
||||||
[{"a": 100, "b": 400}, {"c": 0.5, "d": 0.5}],
|
[{"a": 100, "b": 400}, {"c": 0.5, "d": 0.5}],
|
||||||
{"a": 0.1, "b": 0.4, "c": 0.25, "d": 0.25},
|
{"a": 0.1, "b": 0.4, "c": 0.25, "d": 0.25},
|
||||||
),
|
),
|
||||||
([{"a": 0.5, "b": 0.5}, {"b": 1.0}], {"a": 0.25, "b": 0.75},),
|
([{"a": 0.5, "b": 0.5}, {"b": 1.0}], {"a": 0.25, "b": 0.75}),
|
||||||
|
([{"a": 0.0, "b": 0.0}, {"c": 0.0}], {"a": 0.0, "b": 0.0, "c": 0.0}),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_language_factories_combine_score_weights(weights, expected):
|
def test_language_factories_combine_score_weights(weights, expected):
|
||||||
result = combine_score_weights(weights)
|
result = combine_score_weights(weights)
|
||||||
assert sum(result.values()) in (0.99, 1.0)
|
assert sum(result.values()) in (0.99, 1.0, 0.0)
|
||||||
assert result == expected
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1233,6 +1233,9 @@ def combine_score_weights(
|
||||||
# components.
|
# components.
|
||||||
total = sum(w_dict.values())
|
total = sum(w_dict.values())
|
||||||
for key, value in w_dict.items():
|
for key, value in w_dict.items():
|
||||||
|
if total == 0:
|
||||||
|
weight = 0.0
|
||||||
|
else:
|
||||||
weight = round(value / total / len(all_weights), 2)
|
weight = round(value / total / len(all_weights), 2)
|
||||||
result[key] = result.get(key, 0.0) + weight
|
result[key] = result.get(key, 0.0) + weight
|
||||||
return result
|
return result
|
||||||
|
|
Loading…
Reference in New Issue
Block a user