diff --git a/spacy/tests/pipeline/test_pipe_factories.py b/spacy/tests/pipeline/test_pipe_factories.py index 4c197005e..07648024c 100644 --- a/spacy/tests/pipeline/test_pipe_factories.py +++ b/spacy/tests/pipeline/test_pipe_factories.py @@ -345,12 +345,13 @@ def test_language_factories_invalid(): [{"a": 100, "b": 400}, {"c": 0.5, "d": 0.5}], {"a": 0.1, "b": 0.4, "c": 0.25, "d": 0.25}, ), - ([{"a": 0.5, "b": 0.5}, {"b": 1.0}], {"a": 0.25, "b": 0.75},), + ([{"a": 0.5, "b": 0.5}, {"b": 1.0}], {"a": 0.25, "b": 0.75}), + ([{"a": 0.0, "b": 0.0}, {"c": 0.0}], {"a": 0.0, "b": 0.0, "c": 0.0}), ], ) def test_language_factories_combine_score_weights(weights, expected): result = combine_score_weights(weights) - assert sum(result.values()) in (0.99, 1.0) + assert sum(result.values()) in (0.99, 1.0, 0.0) assert result == expected diff --git a/spacy/util.py b/spacy/util.py index 709da8d29..ad3298651 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -1233,7 +1233,10 @@ def combine_score_weights( # components. total = sum(w_dict.values()) for key, value in w_dict.items(): - weight = round(value / total / len(all_weights), 2) + if total == 0: + weight = 0.0 + else: + weight = round(value / total / len(all_weights), 2) result[key] = result.get(key, 0.0) + weight return result