update with eg.predited as other components

This commit is contained in:
Kádár Ákos 2022-04-07 13:20:12 +02:00
parent ef141ad399
commit 3ba913109d

View File

@ -457,7 +457,6 @@ class SpanPredictor(TrainablePipe):
for doc in docs:
# TODO check shape here
span_scores = self.model.predict([doc])
print(span_scores)
if span_scores.size:
# the information about clustering has to come from the input docs
# first let's convert the scores to a list of span idxs
@ -513,7 +512,7 @@ class SpanPredictor(TrainablePipe):
for eg in examples:
# For update we use the gold coref_head_clusters
# in the reference.
span_scores, backprop = self.model.begin_update([eg.reference])
span_scores, backprop = self.model.begin_update([eg.predicted])
loss, d_scores = self.get_loss([eg], span_scores)
total_loss += loss
# TODO check shape here
@ -622,10 +621,9 @@ class SpanPredictor(TrainablePipe):
ref = eg.reference
pred = eg.predicted
for key, gold_sg in ref.spans.items():
if key.startswith(self.input_prefix):
if key.startswith(self.output_prefix):
cluster_id = key.split('_')[-1]
# FIXME THIS DOESN'T WORK BECAUSE pred.spans are empty?
pred_sg = pred.spans[f"{self.output_prefix}_{cluster_id}"]
pred_sg = pred.spans[key]
for gold_mention, pred_mention in zip(gold_sg, pred_sg):
starts.append(gold_mention.start)
ends.append(gold_mention.end)