diff --git a/spacy/pipeline/textcat.py b/spacy/pipeline/textcat.py index e20ae87f1..dd5fdc078 100644 --- a/spacy/pipeline/textcat.py +++ b/spacy/pipeline/textcat.py @@ -320,9 +320,9 @@ class TextCategorizer(TrainablePipe): self._validate_categories(examples) truths, not_missing = self._examples_to_truth(examples) not_missing = self.model.ops.asarray(not_missing) # type: ignore - d_scores = (scores - truths) / scores.shape[0] + d_scores = (scores - truths) d_scores *= not_missing - mean_square_error = (d_scores ** 2).sum(axis=1).mean() + mean_square_error = (d_scores ** 2).mean() return float(mean_square_error), d_scores def add_label(self, label: str) -> int: diff --git a/spacy/tests/pipeline/test_textcat.py b/spacy/tests/pipeline/test_textcat.py index 52bf6ec5c..798dd165e 100644 --- a/spacy/tests/pipeline/test_textcat.py +++ b/spacy/tests/pipeline/test_textcat.py @@ -277,6 +277,21 @@ def test_issue7019(): print_prf_per_type(msg, scores, name="foo", type="bar") +@pytest.mark.issue(9904) +def test_issue9904(): + nlp = Language() + textcat = nlp.add_pipe("textcat") + get_examples = make_get_examples_single_label(nlp) + nlp.initialize(get_examples) + + examples = get_examples() + scores = textcat.predict([eg.predicted for eg in examples]) + + loss = textcat.get_loss(examples, scores)[0] + loss_double_bs = textcat.get_loss(examples * 2, scores.repeat(2, axis=0))[0] + assert loss == pytest.approx(loss_double_bs) + + @pytest.mark.skip(reason="Test is flakey when run with others") def test_simple_train(): nlp = Language()