mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-01 04:46:38 +03:00
TextCat.predict: return dict
This commit is contained in:
parent
c792019350
commit
1cfbb934ed
|
@ -207,12 +207,14 @@ class TextCategorizer(TrainablePipe):
|
||||||
tensors = [doc.tensor for doc in docs]
|
tensors = [doc.tensor for doc in docs]
|
||||||
xp = self.model.ops.xp
|
xp = self.model.ops.xp
|
||||||
scores = xp.zeros((len(list(docs)), len(self.labels)))
|
scores = xp.zeros((len(list(docs)), len(self.labels)))
|
||||||
return scores
|
return {"probs": scores}
|
||||||
scores = self.model.predict(docs)
|
scores = self.model.predict(docs)
|
||||||
scores = self.model.ops.asarray(scores)
|
scores = self.model.ops.asarray(scores)
|
||||||
return scores
|
return {"probs": scores}
|
||||||
|
|
||||||
def set_annotations(self, docs: Iterable[Doc], scores: Floats2d) -> None:
|
def set_annotations(
|
||||||
|
self, docs: Iterable[Doc], activations: Dict[str, Floats2d]
|
||||||
|
) -> None:
|
||||||
"""Modify a batch of Doc objects, using pre-computed scores.
|
"""Modify a batch of Doc objects, using pre-computed scores.
|
||||||
|
|
||||||
docs (Iterable[Doc]): The documents to modify.
|
docs (Iterable[Doc]): The documents to modify.
|
||||||
|
@ -220,12 +222,13 @@ class TextCategorizer(TrainablePipe):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/textcategorizer#set_annotations
|
DOCS: https://spacy.io/api/textcategorizer#set_annotations
|
||||||
"""
|
"""
|
||||||
|
probs = activations["probs"]
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
doc.activations[self.name] = {}
|
doc.activations[self.name] = {}
|
||||||
if "probs" in self.store_activations:
|
if "probs" in self.store_activations:
|
||||||
doc.activations[self.name]["probs"] = scores[i]
|
doc.activations[self.name]["probs"] = probs[i]
|
||||||
for j, label in enumerate(self.labels):
|
for j, label in enumerate(self.labels):
|
||||||
doc.cats[label] = float(scores[i, j])
|
doc.cats[label] = float(probs[i, j])
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -286,7 +286,7 @@ def test_issue9904():
|
||||||
nlp.initialize(get_examples)
|
nlp.initialize(get_examples)
|
||||||
|
|
||||||
examples = get_examples()
|
examples = get_examples()
|
||||||
scores = textcat.predict([eg.predicted for eg in examples])
|
scores = textcat.predict([eg.predicted for eg in examples])["probs"]
|
||||||
|
|
||||||
loss = textcat.get_loss(examples, scores)[0]
|
loss = textcat.get_loss(examples, scores)[0]
|
||||||
loss_double_bs = textcat.get_loss(examples * 2, scores.repeat(2, axis=0))[0]
|
loss_double_bs = textcat.get_loss(examples * 2, scores.repeat(2, axis=0))[0]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user