make sure predicted and reference keeps aligned

This commit is contained in:
Kádár Ákos 2022-03-25 18:29:33 +01:00
parent 83ac0477c8
commit 7304604edd

View File

@ -130,7 +130,6 @@ class CoreferenceResolver(TrainablePipe):
DOCS: https://spacy.io/api/coref#predict (TODO)
"""
#print("DOCS", docs)
out = []
for doc in docs:
scores, idxs = self.model.predict([doc])
@ -212,7 +211,6 @@ class CoreferenceResolver(TrainablePipe):
# TODO check this causes no issues (in practice it runs)
preds, backprop = self.model.begin_update([eg.predicted])
score_matrix, mention_idx = preds
loss, d_scores = self.get_loss([eg], score_matrix, mention_idx)
total_loss += loss
# TODO check shape here
@ -518,7 +516,8 @@ class SpanPredictor(TrainablePipe):
# takes 'coref_head_clusters' from the reference.
for key, sg in eg.reference.spans.items():
if key.startswith(self.input_prefix):
doc.spans[key] = [doc[span.start:span.end] for span in sg]
aligned_spans = eg.get_aligned_spans_x2y(sg)
doc.spans[key] = [doc[span.start:span.end] for span in aligned_spans]
span_scores, backprop = self.model.begin_update([doc])
loss, d_scores = self.get_loss([eg], span_scores)
total_loss += loss
@ -600,7 +599,7 @@ class SpanPredictor(TrainablePipe):
*,
nlp: Optional[Language] = None,
) -> None:
validate_get_examples(get_examples, "CoreferenceResolver.initialize")
validate_get_examples(get_examples, "SpanPredictor.initialize")
X = []
Y = []