diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 723984768..fd36c84f7 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -373,7 +373,7 @@ def ant_scorer_forward( warnings.filterwarnings('ignore', category=RuntimeWarning) mask = xp.log( (xp.expand_dims(ant_range, 1) - xp.expand_dims(ant_range, 0)) >= 1 - ).astype(float) + ).astype('f') scores = pw_prod + pw_sum + mask