mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
update with eg.predited as other components
This commit is contained in:
parent
ef141ad399
commit
3ba913109d
|
@ -457,7 +457,6 @@ class SpanPredictor(TrainablePipe):
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
# TODO check shape here
|
# TODO check shape here
|
||||||
span_scores = self.model.predict([doc])
|
span_scores = self.model.predict([doc])
|
||||||
print(span_scores)
|
|
||||||
if span_scores.size:
|
if span_scores.size:
|
||||||
# the information about clustering has to come from the input docs
|
# the information about clustering has to come from the input docs
|
||||||
# first let's convert the scores to a list of span idxs
|
# first let's convert the scores to a list of span idxs
|
||||||
|
@ -513,7 +512,7 @@ class SpanPredictor(TrainablePipe):
|
||||||
for eg in examples:
|
for eg in examples:
|
||||||
# For update we use the gold coref_head_clusters
|
# For update we use the gold coref_head_clusters
|
||||||
# in the reference.
|
# in the reference.
|
||||||
span_scores, backprop = self.model.begin_update([eg.reference])
|
span_scores, backprop = self.model.begin_update([eg.predicted])
|
||||||
loss, d_scores = self.get_loss([eg], span_scores)
|
loss, d_scores = self.get_loss([eg], span_scores)
|
||||||
total_loss += loss
|
total_loss += loss
|
||||||
# TODO check shape here
|
# TODO check shape here
|
||||||
|
@ -622,10 +621,9 @@ class SpanPredictor(TrainablePipe):
|
||||||
ref = eg.reference
|
ref = eg.reference
|
||||||
pred = eg.predicted
|
pred = eg.predicted
|
||||||
for key, gold_sg in ref.spans.items():
|
for key, gold_sg in ref.spans.items():
|
||||||
if key.startswith(self.input_prefix):
|
if key.startswith(self.output_prefix):
|
||||||
cluster_id = key.split('_')[-1]
|
cluster_id = key.split('_')[-1]
|
||||||
# FIXME THIS DOESN'T WORK BECAUSE pred.spans are empty?
|
pred_sg = pred.spans[key]
|
||||||
pred_sg = pred.spans[f"{self.output_prefix}_{cluster_id}"]
|
|
||||||
for gold_mention, pred_mention in zip(gold_sg, pred_sg):
|
for gold_mention, pred_mention in zip(gold_sg, pred_sg):
|
||||||
starts.append(gold_mention.start)
|
starts.append(gold_mention.start)
|
||||||
ends.append(gold_mention.end)
|
ends.append(gold_mention.end)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user