From 6087da9675439753c91b78301c46d1fa4453ed5f Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Wed, 25 May 2022 19:11:48 +0900 Subject: [PATCH] Suggestions from code review, cleanup, typing --- spacy/pipeline/coref.py | 61 +++++++---------------------------------- 1 file changed, 10 insertions(+), 51 deletions(-) diff --git a/spacy/pipeline/coref.py b/spacy/pipeline/coref.py index c5bf8fbbe..76e790896 100644 --- a/spacy/pipeline/coref.py +++ b/spacy/pipeline/coref.py @@ -2,7 +2,7 @@ from typing import Iterable, Tuple, Optional, Dict, Callable, Any, List import warnings from thinc.types import Floats2d, Floats3d, Ints2d -from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy +from thinc.api import Model, Config, Optimizer from thinc.api import set_dropout_rate, to_categorical from itertools import islice from statistics import mean @@ -11,7 +11,6 @@ from .trainable_pipe import TrainablePipe from ..language import Language from ..training import Example, validate_examples, validate_get_examples from ..errors import Errors -from ..scorer import Scorer from ..tokens import Doc from ..vocab import Vocab @@ -118,7 +117,7 @@ class CoreferenceResolver(TrainablePipe): self.span_cluster_prefix = span_cluster_prefix self._rehearsal_model = None - self.cfg = {} + self.cfg: Dict[str, Any] = {} def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]: """Apply the pipeline's model to a batch of docs, without modifying them. @@ -154,6 +153,7 @@ class CoreferenceResolver(TrainablePipe): DOCS: https://spacy.io/api/coref#set_annotations (TODO) """ + docs = list(docs) if len(docs) != len(clusters_by_doc): raise ValueError( "Found coref clusters incompatible with the " @@ -219,49 +219,8 @@ class CoreferenceResolver(TrainablePipe): losses[self.name] += total_loss return losses - def rehearse( - self, - examples: Iterable[Example], - *, - drop: float = 0.0, - sgd: Optional[Optimizer] = None, - losses: Optional[Dict[str, float]] = None, - ) -> Dict[str, float]: - """Perform a "rehearsal" update from a batch of data. Rehearsal updates - teach the current model to make predictions similar to an initial model, - to try to address the "catastrophic forgetting" problem. This feature is - experimental. - - examples (Iterable[Example]): A batch of Example objects. - drop (float): The dropout rate. - sgd (thinc.api.Optimizer): The optimizer. - losses (Dict[str, float]): Optional record of the loss during training. - Updated using the component name as the key. - RETURNS (Dict[str, float]): The updated losses dictionary. - - DOCS: https://spacy.io/api/coref#rehearse (TODO) - """ - if losses is not None: - losses.setdefault(self.name, 0.0) - if self._rehearsal_model is None: - return losses - validate_examples(examples, "CoreferenceResolver.rehearse") - # TODO test this whole function - docs = [eg.predicted for eg in examples] - if not any(len(doc) for doc in docs): - # Handle cases where there are no tokens in any docs. - return losses - set_dropout_rate(self.model, drop) - scores, bp_scores = self.model.begin_update(docs) - # TODO below - target = self._rehearsal_model(examples) - gradient = scores - target - bp_scores(gradient) - if sgd is not None: - self.finish_update(sgd) - if losses is not None: - losses[self.name] += (gradient**2).sum() - return losses + def rehearse(self, examples, *, sgd=None, losses=None, **config): + raise NotImplementedError def add_label(self, label: str) -> int: """Technically this method should be implemented from TrainablePipe, @@ -276,7 +235,7 @@ class CoreferenceResolver(TrainablePipe): def get_loss( self, examples: Iterable[Example], - score_matrix: List[Tuple[Floats2d, Ints2d]], + score_matrix: Floats2d, mention_idx: Ints2d, ): """Find the loss and gradient of loss for the batch of documents and @@ -293,13 +252,13 @@ class CoreferenceResolver(TrainablePipe): # TODO if there is more than one example, give an error # (or actually rework this to take multiple things) - example = examples[0] - cscores = score_matrix + example = list(examples)[0] cidx = mention_idx clusters = get_clusters_from_doc(example.reference) span_idxs = create_head_span_idxs(ops, len(example.predicted)) gscores = create_gold_scores(span_idxs, clusters) + # TODO fix type here. This is bools but asarray2f wants ints. gscores = ops.asarray2f(gscores) # top_gscores = xp.take_along_axis(gscores, cidx, axis=1) top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1) @@ -313,8 +272,8 @@ class CoreferenceResolver(TrainablePipe): with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=RuntimeWarning) - log_marg = ops.softmax(cscores + ops.xp.log(top_gscores), axis=1) - log_norm = ops.softmax(cscores, axis=1) + log_marg = ops.softmax(score_matrix + ops.xp.log(top_gscores), axis=1) + log_norm = ops.softmax(score_matrix, axis=1) grad = log_norm - log_marg # gradients.append((grad, cidx)) loss = float((grad**2).sum())