From 4fc40340f94d6dc47398dfa264804723b7e52b65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A1d=C3=A1r=20=C3=81kos?= Date: Mon, 28 Mar 2022 11:28:21 +0200 Subject: [PATCH] handle empty head_ids --- spacy/ml/models/coref.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 71082e7ac..7972f9160 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -133,6 +133,7 @@ def convert_coref_scorer_outputs( indices_xp = torch2xp(indices) return (scores_xp, indices_xp), convert_for_torch_backward + def convert_span_predictor_inputs( model: Model, X: Tuple[Ints1d, Floats2d, Ints1d], @@ -141,13 +142,17 @@ def convert_span_predictor_inputs( tok2vec, (sent_ids, head_ids) = X # Normally we shoudl use the input is_train, but for these two it's not relevant sent_ids = xp2torch(sent_ids[0], requires_grad=False) - head_ids = xp2torch(head_ids[0], requires_grad=False) + if not head_ids[0].size: + head_ids = torch.empty(size=(0,)) + else: + head_ids = xp2torch(head_ids[0], requires_grad=False) word_features = xp2torch(tok2vec[0], requires_grad=is_train) argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={}) # TODO actually support backprop return argskwargs, lambda dX: [[]] + # TODO This probably belongs in the component, not the model. def predict_span_clusters(span_predictor: Model, sent_ids: Ints1d, @@ -543,6 +548,9 @@ class SpanPredictor(torch.nn.Module): Returns: torch.Tensor: span start/end scores, [n_heads, n_words, 2] """ + # If we don't receive heads, return empty + if heads_ids.nelement() == 0: + return torch.empty(size=(0,)) # Obtain distance embedding indices, [n_heads, n_words] relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0]).unsqueeze(0)) # make all valid distances positive @@ -550,7 +558,6 @@ class SpanPredictor(torch.nn.Module): # "too_far" emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127 # Obtain "same sentence" boolean mask, [n_heads, n_words] - sent_id = torch.tensor(sent_id) heads_ids = heads_ids.long() same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0)) # To save memory, only pass candidates from one sentence for each head