make sure predicted and reference keeps aligned

This commit is contained in:
Kádár Ákos 2022-03-25 18:29:33 +01:00
parent 83ac0477c8
commit 7304604edd

View File

@ -130,7 +130,6 @@ class CoreferenceResolver(TrainablePipe):
DOCS: https://spacy.io/api/coref#predict (TODO) DOCS: https://spacy.io/api/coref#predict (TODO)
""" """
#print("DOCS", docs)
out = [] out = []
for doc in docs: for doc in docs:
scores, idxs = self.model.predict([doc]) scores, idxs = self.model.predict([doc])
@ -212,7 +211,6 @@ class CoreferenceResolver(TrainablePipe):
# TODO check this causes no issues (in practice it runs) # TODO check this causes no issues (in practice it runs)
preds, backprop = self.model.begin_update([eg.predicted]) preds, backprop = self.model.begin_update([eg.predicted])
score_matrix, mention_idx = preds score_matrix, mention_idx = preds
loss, d_scores = self.get_loss([eg], score_matrix, mention_idx) loss, d_scores = self.get_loss([eg], score_matrix, mention_idx)
total_loss += loss total_loss += loss
# TODO check shape here # TODO check shape here
@ -518,7 +516,8 @@ class SpanPredictor(TrainablePipe):
# takes 'coref_head_clusters' from the reference. # takes 'coref_head_clusters' from the reference.
for key, sg in eg.reference.spans.items(): for key, sg in eg.reference.spans.items():
if key.startswith(self.input_prefix): if key.startswith(self.input_prefix):
doc.spans[key] = [doc[span.start:span.end] for span in sg] aligned_spans = eg.get_aligned_spans_x2y(sg)
doc.spans[key] = [doc[span.start:span.end] for span in aligned_spans]
span_scores, backprop = self.model.begin_update([doc]) span_scores, backprop = self.model.begin_update([doc])
loss, d_scores = self.get_loss([eg], span_scores) loss, d_scores = self.get_loss([eg], span_scores)
total_loss += loss total_loss += loss
@ -600,7 +599,7 @@ class SpanPredictor(TrainablePipe):
*, *,
nlp: Optional[Language] = None, nlp: Optional[Language] = None,
) -> None: ) -> None:
validate_get_examples(get_examples, "CoreferenceResolver.initialize") validate_get_examples(get_examples, "SpanPredictor.initialize")
X = [] X = []
Y = [] Y = []