mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
* add failing test for issue 9904 * remove division by batch size and summation before applying the mean Co-authored-by: jonas <jsnfly@gmx.de>
This commit is contained in:
parent
d8a3012539
commit
176a90edee
|
@ -320,9 +320,9 @@ class TextCategorizer(TrainablePipe):
|
||||||
self._validate_categories(examples)
|
self._validate_categories(examples)
|
||||||
truths, not_missing = self._examples_to_truth(examples)
|
truths, not_missing = self._examples_to_truth(examples)
|
||||||
not_missing = self.model.ops.asarray(not_missing) # type: ignore
|
not_missing = self.model.ops.asarray(not_missing) # type: ignore
|
||||||
d_scores = (scores - truths) / scores.shape[0]
|
d_scores = (scores - truths)
|
||||||
d_scores *= not_missing
|
d_scores *= not_missing
|
||||||
mean_square_error = (d_scores ** 2).sum(axis=1).mean()
|
mean_square_error = (d_scores ** 2).mean()
|
||||||
return float(mean_square_error), d_scores
|
return float(mean_square_error), d_scores
|
||||||
|
|
||||||
def add_label(self, label: str) -> int:
|
def add_label(self, label: str) -> int:
|
||||||
|
|
|
@ -277,6 +277,21 @@ def test_issue7019():
|
||||||
print_prf_per_type(msg, scores, name="foo", type="bar")
|
print_prf_per_type(msg, scores, name="foo", type="bar")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.issue(9904)
|
||||||
|
def test_issue9904():
|
||||||
|
nlp = Language()
|
||||||
|
textcat = nlp.add_pipe("textcat")
|
||||||
|
get_examples = make_get_examples_single_label(nlp)
|
||||||
|
nlp.initialize(get_examples)
|
||||||
|
|
||||||
|
examples = get_examples()
|
||||||
|
scores = textcat.predict([eg.predicted for eg in examples])
|
||||||
|
|
||||||
|
loss = textcat.get_loss(examples, scores)[0]
|
||||||
|
loss_double_bs = textcat.get_loss(examples * 2, scores.repeat(2, axis=0))[0]
|
||||||
|
assert loss == pytest.approx(loss_double_bs)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Test is flakey when run with others")
|
@pytest.mark.skip(reason="Test is flakey when run with others")
|
||||||
def test_simple_train():
|
def test_simple_train():
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user