mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-29 11:26:28 +03:00
TextCat.predict: return dict
This commit is contained in:
parent
c792019350
commit
1cfbb934ed
|
@ -207,12 +207,14 @@ class TextCategorizer(TrainablePipe):
|
|||
tensors = [doc.tensor for doc in docs]
|
||||
xp = self.model.ops.xp
|
||||
scores = xp.zeros((len(list(docs)), len(self.labels)))
|
||||
return scores
|
||||
return {"probs": scores}
|
||||
scores = self.model.predict(docs)
|
||||
scores = self.model.ops.asarray(scores)
|
||||
return scores
|
||||
return {"probs": scores}
|
||||
|
||||
def set_annotations(self, docs: Iterable[Doc], scores: Floats2d) -> None:
|
||||
def set_annotations(
|
||||
self, docs: Iterable[Doc], activations: Dict[str, Floats2d]
|
||||
) -> None:
|
||||
"""Modify a batch of Doc objects, using pre-computed scores.
|
||||
|
||||
docs (Iterable[Doc]): The documents to modify.
|
||||
|
@ -220,12 +222,13 @@ class TextCategorizer(TrainablePipe):
|
|||
|
||||
DOCS: https://spacy.io/api/textcategorizer#set_annotations
|
||||
"""
|
||||
probs = activations["probs"]
|
||||
for i, doc in enumerate(docs):
|
||||
doc.activations[self.name] = {}
|
||||
if "probs" in self.store_activations:
|
||||
doc.activations[self.name]["probs"] = scores[i]
|
||||
doc.activations[self.name]["probs"] = probs[i]
|
||||
for j, label in enumerate(self.labels):
|
||||
doc.cats[label] = float(scores[i, j])
|
||||
doc.cats[label] = float(probs[i, j])
|
||||
|
||||
def update(
|
||||
self,
|
||||
|
|
|
@ -286,7 +286,7 @@ def test_issue9904():
|
|||
nlp.initialize(get_examples)
|
||||
|
||||
examples = get_examples()
|
||||
scores = textcat.predict([eg.predicted for eg in examples])
|
||||
scores = textcat.predict([eg.predicted for eg in examples])["probs"]
|
||||
|
||||
loss = textcat.get_loss(examples, scores)[0]
|
||||
loss_double_bs = textcat.get_loss(examples * 2, scores.repeat(2, axis=0))[0]
|
||||
|
|
Loading…
Reference in New Issue
Block a user