span predictor debug start

This commit is contained in:
Kádár Ákos 2022-03-23 11:24:27 +01:00
parent 2190cbc0e6
commit 1eaf8fb0cf
2 changed files with 11 additions and 12 deletions

View File

@ -91,7 +91,7 @@ def build_span_predictor(
# TODO fix device - should be automatic # TODO fix device - should be automatic
device = "cuda:0" device = "cuda:0"
span_predictor = PyTorchWrapper( span_predictor = PyTorchWrapper(
SpanPredictor(hidden_size, dist_emb_size, device), SpanPredictor(dim, dist_emb_size, device),
convert_inputs=convert_span_predictor_inputs convert_inputs=convert_span_predictor_inputs
) )
# TODO use proper parameter for prefix # TODO use proper parameter for prefix
@ -148,7 +148,6 @@ def convert_span_predictor_inputs(
# Normally we shoudl use the input is_train, but for these two it's not relevant # Normally we shoudl use the input is_train, but for these two it's not relevant
sent_ids = xp2torch(sent_ids[0], requires_grad=False) sent_ids = xp2torch(sent_ids[0], requires_grad=False)
head_ids = xp2torch(head_ids[0], requires_grad=False) head_ids = xp2torch(head_ids[0], requires_grad=False)
word_features = xp2torch(tok2vec[0], requires_grad=is_train) word_features = xp2torch(tok2vec[0], requires_grad=is_train)
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={}) argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
@ -557,7 +556,6 @@ class SpanPredictor(torch.nn.Module):
sent_id = torch.tensor(sent_id, device=words.device) sent_id = torch.tensor(sent_id, device=words.device)
heads_ids = heads_ids.long() heads_ids = heads_ids.long()
same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0)) same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0))
# To save memory, only pass candidates from one sentence for each head # To save memory, only pass candidates from one sentence for each head
# pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb # pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb
# for each candidate among the words in the same sentence as span_head # for each candidate among the words in the same sentence as span_head
@ -568,11 +566,11 @@ class SpanPredictor(torch.nn.Module):
words[cols], words[cols],
self.emb(emb_ids[rows, cols]), self.emb(emb_ids[rows, cols]),
), dim=1) ), dim=1)
input(len(heads_ids))
lengths = same_sent.sum(dim=1) lengths = same_sent.sum(dim=1)
padding_mask = torch.arange(0, lengths.max().item(), device=words.device).unsqueeze(0) padding_mask = torch.arange(0, lengths.max().item(), device=words.device).unsqueeze(0)
padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len] padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len]
input(padding_mask.shape)
# [n_heads, max_sent_len, input_size * 2 + distance_emb_size] # [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
# This is necessary to allow the convolution layer to look at several # This is necessary to allow the convolution layer to look at several
# word scores # word scores
@ -592,6 +590,7 @@ class SpanPredictor(torch.nn.Module):
valid_positions = torch.stack((valid_starts, valid_ends), dim=2) valid_positions = torch.stack((valid_starts, valid_ends), dim=2)
return scores + valid_positions return scores + valid_positions
return scores return scores
class DistancePairwiseEncoder(torch.nn.Module): class DistancePairwiseEncoder(torch.nn.Module):
def __init__(self, embedding_size, dropout_rate): def __init__(self, embedding_size, dropout_rate):

View File

@ -3,7 +3,7 @@ import warnings
from thinc.types import Floats2d, Floats3d, Ints2d from thinc.types import Floats2d, Floats3d, Ints2d
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
from thinc.api import set_dropout_rate from thinc.api import set_dropout_rate, to_categorical
from itertools import islice from itertools import islice
from statistics import mean from statistics import mean
@ -513,10 +513,8 @@ class SpanPredictor(TrainablePipe):
total_loss = 0 total_loss = 0
for eg in examples: for eg in examples:
preds, backprop = self.model.begin_update([eg.predicted]) span_scores, backprop = self.model.begin_update([eg.predicted])
score_matrix, mention_idx = preds loss, d_scores = self.get_loss([eg], span_scores)
loss, d_scores = self.get_loss([eg], score_matrix, mention_idx)
total_loss += loss total_loss += loss
# TODO check shape here # TODO check shape here
backprop((d_scores, mention_idx)) backprop((d_scores, mention_idx))
@ -573,8 +571,10 @@ class SpanPredictor(TrainablePipe):
for cluster in gold: for cluster in gold:
for mention in cluster: for mention in cluster:
starts.append(mention[0]) starts.append(mention[0])
ends.append(mention[1]) # XXX I think this was missing here
ends.append(mention[1] - 1)
starts = self.model.ops.xp.asarray(starts)
ends = self.model.ops.xp.asarray(ends)
start_scores = span_scores[:, :, 0] start_scores = span_scores[:, :, 0]
end_scores = span_scores[:, :, 1] end_scores = span_scores[:, :, 1]
n_classes = start_scores.shape[1] n_classes = start_scores.shape[1]