Fix texcat loss scaling (#9904) (#10002)

* add failing test for issue 9904

* remove division by batch size and summation before applying the mean

Co-authored-by: jonas <jsnfly@gmx.de>
This commit is contained in:
jsnfly 2022-01-13 09:03:23 +01:00 committed by GitHub
parent d8a3012539
commit 176a90edee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 2 deletions

View File

@ -320,9 +320,9 @@ class TextCategorizer(TrainablePipe):
self._validate_categories(examples) self._validate_categories(examples)
truths, not_missing = self._examples_to_truth(examples) truths, not_missing = self._examples_to_truth(examples)
not_missing = self.model.ops.asarray(not_missing) # type: ignore not_missing = self.model.ops.asarray(not_missing) # type: ignore
d_scores = (scores - truths) / scores.shape[0] d_scores = (scores - truths)
d_scores *= not_missing d_scores *= not_missing
mean_square_error = (d_scores ** 2).sum(axis=1).mean() mean_square_error = (d_scores ** 2).mean()
return float(mean_square_error), d_scores return float(mean_square_error), d_scores
def add_label(self, label: str) -> int: def add_label(self, label: str) -> int:

View File

@ -277,6 +277,21 @@ def test_issue7019():
print_prf_per_type(msg, scores, name="foo", type="bar") print_prf_per_type(msg, scores, name="foo", type="bar")
@pytest.mark.issue(9904)
def test_issue9904():
nlp = Language()
textcat = nlp.add_pipe("textcat")
get_examples = make_get_examples_single_label(nlp)
nlp.initialize(get_examples)
examples = get_examples()
scores = textcat.predict([eg.predicted for eg in examples])
loss = textcat.get_loss(examples, scores)[0]
loss_double_bs = textcat.get_loss(examples * 2, scores.repeat(2, axis=0))[0]
assert loss == pytest.approx(loss_double_bs)
@pytest.mark.skip(reason="Test is flakey when run with others") @pytest.mark.skip(reason="Test is flakey when run with others")
def test_simple_train(): def test_simple_train():
nlp = Language() nlp = Language()