mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
Training works now
This commit is contained in:
parent
8eadf3781b
commit
dfec6993d6
|
@ -25,6 +25,8 @@ from ..ml.models.coref_util import (
|
|||
doc2clusters,
|
||||
)
|
||||
|
||||
from ..ml.models.coref_util_wl import make_head_only_clusters
|
||||
|
||||
from ..coref_scorer import Evaluator, get_cluster_info, b_cubed, muc, ceafe
|
||||
|
||||
# TODO remove this - kept for reference for now
|
||||
|
@ -235,6 +237,8 @@ class CoreferenceResolver(TrainablePipe):
|
|||
return losses
|
||||
set_dropout_rate(self.model, drop)
|
||||
|
||||
make_head_only_clusters(examples)
|
||||
|
||||
inputs = [example.predicted for example in examples]
|
||||
preds, backprop = self.model.begin_update(inputs)
|
||||
score_matrix, mention_idx = preds
|
||||
|
@ -275,6 +279,7 @@ class CoreferenceResolver(TrainablePipe):
|
|||
if self._rehearsal_model is None:
|
||||
return losses
|
||||
validate_examples(examples, "CoreferenceResolver.rehearse")
|
||||
#TODO test this whole function
|
||||
docs = [eg.predicted for eg in examples]
|
||||
if not any(len(doc) for doc in docs):
|
||||
# Handle cases where there are no tokens in any docs.
|
||||
|
@ -394,6 +399,7 @@ class CoreferenceResolver(TrainablePipe):
|
|||
def score(self, examples, **kwargs):
|
||||
"""Score a batch of examples."""
|
||||
|
||||
make_head_only_clusters(examples)
|
||||
# NOTE traditionally coref uses the average of b_cubed, muc, and ceaf.
|
||||
# we need to handle the average ourselves.
|
||||
scores = []
|
||||
|
|
Loading…
Reference in New Issue
Block a user