diff --git a/spacy/pipeline/textcat.py b/spacy/pipeline/textcat.py index 888cd0178..fc883fb68 100644 --- a/spacy/pipeline/textcat.py +++ b/spacy/pipeline/textcat.py @@ -207,12 +207,14 @@ class TextCategorizer(TrainablePipe): tensors = [doc.tensor for doc in docs] xp = self.model.ops.xp scores = xp.zeros((len(list(docs)), len(self.labels))) - return scores + return {"probs": scores} scores = self.model.predict(docs) scores = self.model.ops.asarray(scores) - return scores + return {"probs": scores} - def set_annotations(self, docs: Iterable[Doc], scores: Floats2d) -> None: + def set_annotations( + self, docs: Iterable[Doc], activations: Dict[str, Floats2d] + ) -> None: """Modify a batch of Doc objects, using pre-computed scores. docs (Iterable[Doc]): The documents to modify. @@ -220,12 +222,13 @@ class TextCategorizer(TrainablePipe): DOCS: https://spacy.io/api/textcategorizer#set_annotations """ + probs = activations["probs"] for i, doc in enumerate(docs): doc.activations[self.name] = {} if "probs" in self.store_activations: - doc.activations[self.name]["probs"] = scores[i] + doc.activations[self.name]["probs"] = probs[i] for j, label in enumerate(self.labels): - doc.cats[label] = float(scores[i, j]) + doc.cats[label] = float(probs[i, j]) def update( self, diff --git a/spacy/tests/pipeline/test_textcat.py b/spacy/tests/pipeline/test_textcat.py index 5024c1fd5..be94e18c8 100644 --- a/spacy/tests/pipeline/test_textcat.py +++ b/spacy/tests/pipeline/test_textcat.py @@ -286,7 +286,7 @@ def test_issue9904(): nlp.initialize(get_examples) examples = get_examples() - scores = textcat.predict([eg.predicted for eg in examples]) + scores = textcat.predict([eg.predicted for eg in examples])["probs"] loss = textcat.get_loss(examples, scores)[0] loss_double_bs = textcat.get_loss(examples * 2, scores.repeat(2, axis=0))[0]