mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
make sure predicted and reference keeps aligned
This commit is contained in:
parent
83ac0477c8
commit
7304604edd
|
@ -130,7 +130,6 @@ class CoreferenceResolver(TrainablePipe):
|
|||
|
||||
DOCS: https://spacy.io/api/coref#predict (TODO)
|
||||
"""
|
||||
#print("DOCS", docs)
|
||||
out = []
|
||||
for doc in docs:
|
||||
scores, idxs = self.model.predict([doc])
|
||||
|
@ -212,7 +211,6 @@ class CoreferenceResolver(TrainablePipe):
|
|||
# TODO check this causes no issues (in practice it runs)
|
||||
preds, backprop = self.model.begin_update([eg.predicted])
|
||||
score_matrix, mention_idx = preds
|
||||
|
||||
loss, d_scores = self.get_loss([eg], score_matrix, mention_idx)
|
||||
total_loss += loss
|
||||
# TODO check shape here
|
||||
|
@ -518,7 +516,8 @@ class SpanPredictor(TrainablePipe):
|
|||
# takes 'coref_head_clusters' from the reference.
|
||||
for key, sg in eg.reference.spans.items():
|
||||
if key.startswith(self.input_prefix):
|
||||
doc.spans[key] = [doc[span.start:span.end] for span in sg]
|
||||
aligned_spans = eg.get_aligned_spans_x2y(sg)
|
||||
doc.spans[key] = [doc[span.start:span.end] for span in aligned_spans]
|
||||
span_scores, backprop = self.model.begin_update([doc])
|
||||
loss, d_scores = self.get_loss([eg], span_scores)
|
||||
total_loss += loss
|
||||
|
@ -600,7 +599,7 @@ class SpanPredictor(TrainablePipe):
|
|||
*,
|
||||
nlp: Optional[Language] = None,
|
||||
) -> None:
|
||||
validate_get_examples(get_examples, "CoreferenceResolver.initialize")
|
||||
validate_get_examples(get_examples, "SpanPredictor.initialize")
|
||||
|
||||
X = []
|
||||
Y = []
|
||||
|
|
Loading…
Reference in New Issue
Block a user