Training works now

This commit is contained in:
Paul O'Leary McCann 2022-03-14 19:27:23 +09:00
parent 8eadf3781b
commit dfec6993d6

View File

@ -25,6 +25,8 @@ from ..ml.models.coref_util import (
doc2clusters, doc2clusters,
) )
from ..ml.models.coref_util_wl import make_head_only_clusters
from ..coref_scorer import Evaluator, get_cluster_info, b_cubed, muc, ceafe from ..coref_scorer import Evaluator, get_cluster_info, b_cubed, muc, ceafe
# TODO remove this - kept for reference for now # TODO remove this - kept for reference for now
@ -235,6 +237,8 @@ class CoreferenceResolver(TrainablePipe):
return losses return losses
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
make_head_only_clusters(examples)
inputs = [example.predicted for example in examples] inputs = [example.predicted for example in examples]
preds, backprop = self.model.begin_update(inputs) preds, backprop = self.model.begin_update(inputs)
score_matrix, mention_idx = preds score_matrix, mention_idx = preds
@ -275,6 +279,7 @@ class CoreferenceResolver(TrainablePipe):
if self._rehearsal_model is None: if self._rehearsal_model is None:
return losses return losses
validate_examples(examples, "CoreferenceResolver.rehearse") validate_examples(examples, "CoreferenceResolver.rehearse")
#TODO test this whole function
docs = [eg.predicted for eg in examples] docs = [eg.predicted for eg in examples]
if not any(len(doc) for doc in docs): if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs. # Handle cases where there are no tokens in any docs.
@ -394,6 +399,7 @@ class CoreferenceResolver(TrainablePipe):
def score(self, examples, **kwargs): def score(self, examples, **kwargs):
"""Score a batch of examples.""" """Score a batch of examples."""
make_head_only_clusters(examples)
# NOTE traditionally coref uses the average of b_cubed, muc, and ceaf. # NOTE traditionally coref uses the average of b_cubed, muc, and ceaf.
# we need to handle the average ourselves. # we need to handle the average ourselves.
scores = [] scores = []