From 449559cc2d658ad348345aea3436ddc664866c98 Mon Sep 17 00:00:00 2001 From: richardpaulhudson Date: Fri, 9 Dec 2022 17:03:48 +0100 Subject: [PATCH] Fix mypy issues --- spacy/pipeline/edit_tree_lemmatizer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/spacy/pipeline/edit_tree_lemmatizer.py b/spacy/pipeline/edit_tree_lemmatizer.py index d8279da31..b19511790 100644 --- a/spacy/pipeline/edit_tree_lemmatizer.py +++ b/spacy/pipeline/edit_tree_lemmatizer.py @@ -137,15 +137,15 @@ class EditTreeLemmatizer(TrainablePipe): ): if gold_lemma is None: label = -1 - eg_lowercasing_truths.append([0]) + eg_lowercasing_truths.append([0.0]) else: if self.lowercasing and _should_lowercased( predicted.text, gold_lemma ): - eg_lowercasing_truths.append([1]) + eg_lowercasing_truths.append([1.0]) text = predicted.lower_ else: - eg_lowercasing_truths.append([0]) + eg_lowercasing_truths.append([0.0]) text = predicted.text tree_id = self.trees.add(text, gold_lemma) label = self.tree2label.get(tree_id, 0) @@ -163,7 +163,8 @@ class EditTreeLemmatizer(TrainablePipe): for i, doc_d_tree_scores in enumerate(d_tree_scores): eg_lowercasing_flags = lowercasing_flags[i] eg_d_lowercasing_flags, eg_lowercasing_loss = lowercasing_loss_func( - eg_lowercasing_flags, self.model.ops.asarray2i(lowercasing_truths[i]) + eg_lowercasing_flags, + self.model.ops.asarray2f(lowercasing_truths[i]), # type: ignore[arg-type] ) doc_d_scores = self.model.ops.xp.hstack( [