span predictor debug start

This commit is contained in:
Kádár Ákos 2022-03-23 11:24:27 +01:00
parent 2190cbc0e6
commit 1eaf8fb0cf
2 changed files with 11 additions and 12 deletions

View File

@ -91,7 +91,7 @@ def build_span_predictor(
# TODO fix device - should be automatic
device = "cuda:0"
span_predictor = PyTorchWrapper(
SpanPredictor(hidden_size, dist_emb_size, device),
SpanPredictor(dim, dist_emb_size, device),
convert_inputs=convert_span_predictor_inputs
)
# TODO use proper parameter for prefix
@ -148,7 +148,6 @@ def convert_span_predictor_inputs(
# Normally we shoudl use the input is_train, but for these two it's not relevant
sent_ids = xp2torch(sent_ids[0], requires_grad=False)
head_ids = xp2torch(head_ids[0], requires_grad=False)
word_features = xp2torch(tok2vec[0], requires_grad=is_train)
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
@ -557,7 +556,6 @@ class SpanPredictor(torch.nn.Module):
sent_id = torch.tensor(sent_id, device=words.device)
heads_ids = heads_ids.long()
same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0))
# To save memory, only pass candidates from one sentence for each head
# pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb
# for each candidate among the words in the same sentence as span_head
@ -568,11 +566,11 @@ class SpanPredictor(torch.nn.Module):
words[cols],
self.emb(emb_ids[rows, cols]),
), dim=1)
input(len(heads_ids))
lengths = same_sent.sum(dim=1)
padding_mask = torch.arange(0, lengths.max().item(), device=words.device).unsqueeze(0)
padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len]
input(padding_mask.shape)
# [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
# This is necessary to allow the convolution layer to look at several
# word scores
@ -592,6 +590,7 @@ class SpanPredictor(torch.nn.Module):
valid_positions = torch.stack((valid_starts, valid_ends), dim=2)
return scores + valid_positions
return scores
class DistancePairwiseEncoder(torch.nn.Module):
def __init__(self, embedding_size, dropout_rate):

View File

@ -3,7 +3,7 @@ import warnings
from thinc.types import Floats2d, Floats3d, Ints2d
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
from thinc.api import set_dropout_rate
from thinc.api import set_dropout_rate, to_categorical
from itertools import islice
from statistics import mean
@ -513,10 +513,8 @@ class SpanPredictor(TrainablePipe):
total_loss = 0
for eg in examples:
preds, backprop = self.model.begin_update([eg.predicted])
score_matrix, mention_idx = preds
loss, d_scores = self.get_loss([eg], score_matrix, mention_idx)
span_scores, backprop = self.model.begin_update([eg.predicted])
loss, d_scores = self.get_loss([eg], span_scores)
total_loss += loss
# TODO check shape here
backprop((d_scores, mention_idx))
@ -573,8 +571,10 @@ class SpanPredictor(TrainablePipe):
for cluster in gold:
for mention in cluster:
starts.append(mention[0])
ends.append(mention[1])
# XXX I think this was missing here
ends.append(mention[1] - 1)
starts = self.model.ops.xp.asarray(starts)
ends = self.model.ops.xp.asarray(ends)
start_scores = span_scores[:, :, 0]
end_scores = span_scores[:, :, 1]
n_classes = start_scores.shape[1]