TextCat.predict: return dict

This commit is contained in:
Daniël de Kok 2022-08-05 14:40:20 +02:00
parent c792019350
commit 1cfbb934ed
2 changed files with 9 additions and 6 deletions

View File

@ -207,12 +207,14 @@ class TextCategorizer(TrainablePipe):
tensors = [doc.tensor for doc in docs] tensors = [doc.tensor for doc in docs]
xp = self.model.ops.xp xp = self.model.ops.xp
scores = xp.zeros((len(list(docs)), len(self.labels))) scores = xp.zeros((len(list(docs)), len(self.labels)))
return scores return {"probs": scores}
scores = self.model.predict(docs) scores = self.model.predict(docs)
scores = self.model.ops.asarray(scores) scores = self.model.ops.asarray(scores)
return scores return {"probs": scores}
def set_annotations(self, docs: Iterable[Doc], scores: Floats2d) -> None: def set_annotations(
self, docs: Iterable[Doc], activations: Dict[str, Floats2d]
) -> None:
"""Modify a batch of Doc objects, using pre-computed scores. """Modify a batch of Doc objects, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify. docs (Iterable[Doc]): The documents to modify.
@ -220,12 +222,13 @@ class TextCategorizer(TrainablePipe):
DOCS: https://spacy.io/api/textcategorizer#set_annotations DOCS: https://spacy.io/api/textcategorizer#set_annotations
""" """
probs = activations["probs"]
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
doc.activations[self.name] = {} doc.activations[self.name] = {}
if "probs" in self.store_activations: if "probs" in self.store_activations:
doc.activations[self.name]["probs"] = scores[i] doc.activations[self.name]["probs"] = probs[i]
for j, label in enumerate(self.labels): for j, label in enumerate(self.labels):
doc.cats[label] = float(scores[i, j]) doc.cats[label] = float(probs[i, j])
def update( def update(
self, self,

View File

@ -286,7 +286,7 @@ def test_issue9904():
nlp.initialize(get_examples) nlp.initialize(get_examples)
examples = get_examples() examples = get_examples()
scores = textcat.predict([eg.predicted for eg in examples]) scores = textcat.predict([eg.predicted for eg in examples])["probs"]
loss = textcat.get_loss(examples, scores)[0] loss = textcat.get_loss(examples, scores)[0]
loss_double_bs = textcat.get_loss(examples * 2, scores.repeat(2, axis=0))[0] loss_double_bs = textcat.get_loss(examples * 2, scores.repeat(2, axis=0))[0]