mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 09:26:27 +03:00
Prevent division by zero in score weights
This commit is contained in:
parent
74ee456374
commit
d0ef4a4cf5
|
@ -345,12 +345,13 @@ def test_language_factories_invalid():
|
|||
[{"a": 100, "b": 400}, {"c": 0.5, "d": 0.5}],
|
||||
{"a": 0.1, "b": 0.4, "c": 0.25, "d": 0.25},
|
||||
),
|
||||
([{"a": 0.5, "b": 0.5}, {"b": 1.0}], {"a": 0.25, "b": 0.75},),
|
||||
([{"a": 0.5, "b": 0.5}, {"b": 1.0}], {"a": 0.25, "b": 0.75}),
|
||||
([{"a": 0.0, "b": 0.0}, {"c": 0.0}], {"a": 0.0, "b": 0.0, "c": 0.0}),
|
||||
],
|
||||
)
|
||||
def test_language_factories_combine_score_weights(weights, expected):
|
||||
result = combine_score_weights(weights)
|
||||
assert sum(result.values()) in (0.99, 1.0)
|
||||
assert sum(result.values()) in (0.99, 1.0, 0.0)
|
||||
assert result == expected
|
||||
|
||||
|
||||
|
|
|
@ -1233,7 +1233,10 @@ def combine_score_weights(
|
|||
# components.
|
||||
total = sum(w_dict.values())
|
||||
for key, value in w_dict.items():
|
||||
weight = round(value / total / len(all_weights), 2)
|
||||
if total == 0:
|
||||
weight = 0.0
|
||||
else:
|
||||
weight = round(value / total / len(all_weights), 2)
|
||||
result[key] = result.get(key, 0.0) + weight
|
||||
return result
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user