mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 12:20:20 +03:00
Refactor _scores2guesses
This commit is contained in:
parent
eef3d950b4
commit
278181ab59
|
@ -160,27 +160,18 @@ class EditTreeLemmatizer(TrainablePipe):
|
||||||
def _scores2guesses(self, docs, scores):
|
def _scores2guesses(self, docs, scores):
|
||||||
guesses = []
|
guesses = []
|
||||||
for doc, doc_scores in zip(docs, scores):
|
for doc, doc_scores in zip(docs, scores):
|
||||||
if self.top_k == 1:
|
|
||||||
doc_guesses = doc_scores.argmax(axis=1).reshape(-1, 1)
|
|
||||||
else:
|
|
||||||
doc_guesses = np.argsort(doc_scores)[..., : -self.top_k - 1 : -1]
|
|
||||||
|
|
||||||
if not isinstance(doc_guesses, np.ndarray):
|
|
||||||
doc_guesses = doc_guesses.get()
|
|
||||||
|
|
||||||
doc_compat_guesses = []
|
doc_compat_guesses = []
|
||||||
for token, candidates in zip(doc, doc_guesses):
|
for i, token in enumerate(doc):
|
||||||
tree_id = -1
|
for _ in range(self.top_k):
|
||||||
for candidate in candidates:
|
candidate = doc_scores[i].argmax()
|
||||||
candidate_tree_id = self.cfg["labels"][candidate]
|
candidate_tree_id = self.cfg["labels"][candidate]
|
||||||
|
|
||||||
if self.trees.apply(candidate_tree_id, token.text) is not None:
|
if self.trees.apply(candidate_tree_id, token.text) is not None:
|
||||||
tree_id = candidate_tree_id
|
doc_compat_guesses.append(candidate_tree_id)
|
||||||
break
|
break
|
||||||
doc_compat_guesses.append(tree_id)
|
doc_scores[i, candidate] = -1
|
||||||
|
else:
|
||||||
|
doc_compat_guesses.append(-1)
|
||||||
guesses.append(np.array(doc_compat_guesses))
|
guesses.append(np.array(doc_compat_guesses))
|
||||||
|
|
||||||
return guesses
|
return guesses
|
||||||
|
|
||||||
def set_annotations(self, docs: Iterable[Doc], batch_tree_ids):
|
def set_annotations(self, docs: Iterable[Doc], batch_tree_ids):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user