mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
span predictor debug start
This commit is contained in:
parent
2190cbc0e6
commit
1eaf8fb0cf
|
@ -91,7 +91,7 @@ def build_span_predictor(
|
|||
# TODO fix device - should be automatic
|
||||
device = "cuda:0"
|
||||
span_predictor = PyTorchWrapper(
|
||||
SpanPredictor(hidden_size, dist_emb_size, device),
|
||||
SpanPredictor(dim, dist_emb_size, device),
|
||||
convert_inputs=convert_span_predictor_inputs
|
||||
)
|
||||
# TODO use proper parameter for prefix
|
||||
|
@ -148,7 +148,6 @@ def convert_span_predictor_inputs(
|
|||
# Normally we shoudl use the input is_train, but for these two it's not relevant
|
||||
sent_ids = xp2torch(sent_ids[0], requires_grad=False)
|
||||
head_ids = xp2torch(head_ids[0], requires_grad=False)
|
||||
|
||||
word_features = xp2torch(tok2vec[0], requires_grad=is_train)
|
||||
|
||||
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
|
||||
|
@ -557,7 +556,6 @@ class SpanPredictor(torch.nn.Module):
|
|||
sent_id = torch.tensor(sent_id, device=words.device)
|
||||
heads_ids = heads_ids.long()
|
||||
same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0))
|
||||
|
||||
# To save memory, only pass candidates from one sentence for each head
|
||||
# pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb
|
||||
# for each candidate among the words in the same sentence as span_head
|
||||
|
@ -568,11 +566,11 @@ class SpanPredictor(torch.nn.Module):
|
|||
words[cols],
|
||||
self.emb(emb_ids[rows, cols]),
|
||||
), dim=1)
|
||||
|
||||
input(len(heads_ids))
|
||||
lengths = same_sent.sum(dim=1)
|
||||
padding_mask = torch.arange(0, lengths.max().item(), device=words.device).unsqueeze(0)
|
||||
padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len]
|
||||
|
||||
input(padding_mask.shape)
|
||||
# [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
|
||||
# This is necessary to allow the convolution layer to look at several
|
||||
# word scores
|
||||
|
@ -592,6 +590,7 @@ class SpanPredictor(torch.nn.Module):
|
|||
valid_positions = torch.stack((valid_starts, valid_ends), dim=2)
|
||||
return scores + valid_positions
|
||||
return scores
|
||||
|
||||
class DistancePairwiseEncoder(torch.nn.Module):
|
||||
|
||||
def __init__(self, embedding_size, dropout_rate):
|
||||
|
|
|
@ -3,7 +3,7 @@ import warnings
|
|||
|
||||
from thinc.types import Floats2d, Floats3d, Ints2d
|
||||
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
|
||||
from thinc.api import set_dropout_rate
|
||||
from thinc.api import set_dropout_rate, to_categorical
|
||||
from itertools import islice
|
||||
from statistics import mean
|
||||
|
||||
|
@ -513,10 +513,8 @@ class SpanPredictor(TrainablePipe):
|
|||
total_loss = 0
|
||||
|
||||
for eg in examples:
|
||||
preds, backprop = self.model.begin_update([eg.predicted])
|
||||
score_matrix, mention_idx = preds
|
||||
|
||||
loss, d_scores = self.get_loss([eg], score_matrix, mention_idx)
|
||||
span_scores, backprop = self.model.begin_update([eg.predicted])
|
||||
loss, d_scores = self.get_loss([eg], span_scores)
|
||||
total_loss += loss
|
||||
# TODO check shape here
|
||||
backprop((d_scores, mention_idx))
|
||||
|
@ -573,8 +571,10 @@ class SpanPredictor(TrainablePipe):
|
|||
for cluster in gold:
|
||||
for mention in cluster:
|
||||
starts.append(mention[0])
|
||||
ends.append(mention[1])
|
||||
|
||||
# XXX I think this was missing here
|
||||
ends.append(mention[1] - 1)
|
||||
starts = self.model.ops.xp.asarray(starts)
|
||||
ends = self.model.ops.xp.asarray(ends)
|
||||
start_scores = span_scores[:, :, 0]
|
||||
end_scores = span_scores[:, :, 1]
|
||||
n_classes = start_scores.shape[1]
|
||||
|
|
Loading…
Reference in New Issue
Block a user