mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
update with eg.predited as other components
This commit is contained in:
parent
ef141ad399
commit
3ba913109d
|
@ -457,7 +457,6 @@ class SpanPredictor(TrainablePipe):
|
|||
for doc in docs:
|
||||
# TODO check shape here
|
||||
span_scores = self.model.predict([doc])
|
||||
print(span_scores)
|
||||
if span_scores.size:
|
||||
# the information about clustering has to come from the input docs
|
||||
# first let's convert the scores to a list of span idxs
|
||||
|
@ -513,7 +512,7 @@ class SpanPredictor(TrainablePipe):
|
|||
for eg in examples:
|
||||
# For update we use the gold coref_head_clusters
|
||||
# in the reference.
|
||||
span_scores, backprop = self.model.begin_update([eg.reference])
|
||||
span_scores, backprop = self.model.begin_update([eg.predicted])
|
||||
loss, d_scores = self.get_loss([eg], span_scores)
|
||||
total_loss += loss
|
||||
# TODO check shape here
|
||||
|
@ -622,10 +621,9 @@ class SpanPredictor(TrainablePipe):
|
|||
ref = eg.reference
|
||||
pred = eg.predicted
|
||||
for key, gold_sg in ref.spans.items():
|
||||
if key.startswith(self.input_prefix):
|
||||
if key.startswith(self.output_prefix):
|
||||
cluster_id = key.split('_')[-1]
|
||||
# FIXME THIS DOESN'T WORK BECAUSE pred.spans are empty?
|
||||
pred_sg = pred.spans[f"{self.output_prefix}_{cluster_id}"]
|
||||
pred_sg = pred.spans[key]
|
||||
for gold_mention, pred_mention in zip(gold_sg, pred_sg):
|
||||
starts.append(gold_mention.start)
|
||||
ends.append(gold_mention.end)
|
||||
|
|
Loading…
Reference in New Issue
Block a user