diff --git a/spacy/pipeline/edit_tree_lemmatizer.py b/spacy/pipeline/edit_tree_lemmatizer.py index d8279da31..b19511790 100644 --- a/spacy/pipeline/edit_tree_lemmatizer.py +++ b/spacy/pipeline/edit_tree_lemmatizer.py @@ -137,15 +137,15 @@ class EditTreeLemmatizer(TrainablePipe): ): if gold_lemma is None: label = -1 - eg_lowercasing_truths.append([0]) + eg_lowercasing_truths.append([0.0]) else: if self.lowercasing and _should_lowercased( predicted.text, gold_lemma ): - eg_lowercasing_truths.append([1]) + eg_lowercasing_truths.append([1.0]) text = predicted.lower_ else: - eg_lowercasing_truths.append([0]) + eg_lowercasing_truths.append([0.0]) text = predicted.text tree_id = self.trees.add(text, gold_lemma) label = self.tree2label.get(tree_id, 0) @@ -163,7 +163,8 @@ class EditTreeLemmatizer(TrainablePipe): for i, doc_d_tree_scores in enumerate(d_tree_scores): eg_lowercasing_flags = lowercasing_flags[i] eg_d_lowercasing_flags, eg_lowercasing_loss = lowercasing_loss_func( - eg_lowercasing_flags, self.model.ops.asarray2i(lowercasing_truths[i]) + eg_lowercasing_flags, + self.model.ops.asarray2f(lowercasing_truths[i]), # type: ignore[arg-type] ) doc_d_scores = self.model.ops.xp.hstack( [