Suggestions from code review, cleanup, typing

This commit is contained in:
Paul O'Leary McCann 2022-05-25 19:11:48 +09:00
parent 6999436270
commit 6087da9675

View File

@ -2,7 +2,7 @@ from typing import Iterable, Tuple, Optional, Dict, Callable, Any, List
import warnings
from thinc.types import Floats2d, Floats3d, Ints2d
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
from thinc.api import Model, Config, Optimizer
from thinc.api import set_dropout_rate, to_categorical
from itertools import islice
from statistics import mean
@ -11,7 +11,6 @@ from .trainable_pipe import TrainablePipe
from ..language import Language
from ..training import Example, validate_examples, validate_get_examples
from ..errors import Errors
from ..scorer import Scorer
from ..tokens import Doc
from ..vocab import Vocab
@ -118,7 +117,7 @@ class CoreferenceResolver(TrainablePipe):
self.span_cluster_prefix = span_cluster_prefix
self._rehearsal_model = None
self.cfg = {}
self.cfg: Dict[str, Any] = {}
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
"""Apply the pipeline's model to a batch of docs, without modifying them.
@ -154,6 +153,7 @@ class CoreferenceResolver(TrainablePipe):
DOCS: https://spacy.io/api/coref#set_annotations (TODO)
"""
docs = list(docs)
if len(docs) != len(clusters_by_doc):
raise ValueError(
"Found coref clusters incompatible with the "
@ -219,49 +219,8 @@ class CoreferenceResolver(TrainablePipe):
losses[self.name] += total_loss
return losses
def rehearse(
self,
examples: Iterable[Example],
*,
drop: float = 0.0,
sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None,
) -> Dict[str, float]:
"""Perform a "rehearsal" update from a batch of data. Rehearsal updates
teach the current model to make predictions similar to an initial model,
to try to address the "catastrophic forgetting" problem. This feature is
experimental.
examples (Iterable[Example]): A batch of Example objects.
drop (float): The dropout rate.
sgd (thinc.api.Optimizer): The optimizer.
losses (Dict[str, float]): Optional record of the loss during training.
Updated using the component name as the key.
RETURNS (Dict[str, float]): The updated losses dictionary.
DOCS: https://spacy.io/api/coref#rehearse (TODO)
"""
if losses is not None:
losses.setdefault(self.name, 0.0)
if self._rehearsal_model is None:
return losses
validate_examples(examples, "CoreferenceResolver.rehearse")
# TODO test this whole function
docs = [eg.predicted for eg in examples]
if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs.
return losses
set_dropout_rate(self.model, drop)
scores, bp_scores = self.model.begin_update(docs)
# TODO below
target = self._rehearsal_model(examples)
gradient = scores - target
bp_scores(gradient)
if sgd is not None:
self.finish_update(sgd)
if losses is not None:
losses[self.name] += (gradient**2).sum()
return losses
def rehearse(self, examples, *, sgd=None, losses=None, **config):
raise NotImplementedError
def add_label(self, label: str) -> int:
"""Technically this method should be implemented from TrainablePipe,
@ -276,7 +235,7 @@ class CoreferenceResolver(TrainablePipe):
def get_loss(
self,
examples: Iterable[Example],
score_matrix: List[Tuple[Floats2d, Ints2d]],
score_matrix: Floats2d,
mention_idx: Ints2d,
):
"""Find the loss and gradient of loss for the batch of documents and
@ -293,13 +252,13 @@ class CoreferenceResolver(TrainablePipe):
# TODO if there is more than one example, give an error
# (or actually rework this to take multiple things)
example = examples[0]
cscores = score_matrix
example = list(examples)[0]
cidx = mention_idx
clusters = get_clusters_from_doc(example.reference)
span_idxs = create_head_span_idxs(ops, len(example.predicted))
gscores = create_gold_scores(span_idxs, clusters)
# TODO fix type here. This is bools but asarray2f wants ints.
gscores = ops.asarray2f(gscores)
# top_gscores = xp.take_along_axis(gscores, cidx, axis=1)
top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1)
@ -313,8 +272,8 @@ class CoreferenceResolver(TrainablePipe):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning)
log_marg = ops.softmax(cscores + ops.xp.log(top_gscores), axis=1)
log_norm = ops.softmax(cscores, axis=1)
log_marg = ops.softmax(score_matrix + ops.xp.log(top_gscores), axis=1)
log_norm = ops.softmax(score_matrix, axis=1)
grad = log_norm - log_marg
# gradients.append((grad, cidx))
loss = float((grad**2).sum())