mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
Suggestions from code review, cleanup, typing
This commit is contained in:
parent
6999436270
commit
6087da9675
|
@ -2,7 +2,7 @@ from typing import Iterable, Tuple, Optional, Dict, Callable, Any, List
|
|||
import warnings
|
||||
|
||||
from thinc.types import Floats2d, Floats3d, Ints2d
|
||||
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
|
||||
from thinc.api import Model, Config, Optimizer
|
||||
from thinc.api import set_dropout_rate, to_categorical
|
||||
from itertools import islice
|
||||
from statistics import mean
|
||||
|
@ -11,7 +11,6 @@ from .trainable_pipe import TrainablePipe
|
|||
from ..language import Language
|
||||
from ..training import Example, validate_examples, validate_get_examples
|
||||
from ..errors import Errors
|
||||
from ..scorer import Scorer
|
||||
from ..tokens import Doc
|
||||
from ..vocab import Vocab
|
||||
|
||||
|
@ -118,7 +117,7 @@ class CoreferenceResolver(TrainablePipe):
|
|||
self.span_cluster_prefix = span_cluster_prefix
|
||||
self._rehearsal_model = None
|
||||
|
||||
self.cfg = {}
|
||||
self.cfg: Dict[str, Any] = {}
|
||||
|
||||
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
|
||||
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||
|
@ -154,6 +153,7 @@ class CoreferenceResolver(TrainablePipe):
|
|||
|
||||
DOCS: https://spacy.io/api/coref#set_annotations (TODO)
|
||||
"""
|
||||
docs = list(docs)
|
||||
if len(docs) != len(clusters_by_doc):
|
||||
raise ValueError(
|
||||
"Found coref clusters incompatible with the "
|
||||
|
@ -219,49 +219,8 @@ class CoreferenceResolver(TrainablePipe):
|
|||
losses[self.name] += total_loss
|
||||
return losses
|
||||
|
||||
def rehearse(
|
||||
self,
|
||||
examples: Iterable[Example],
|
||||
*,
|
||||
drop: float = 0.0,
|
||||
sgd: Optional[Optimizer] = None,
|
||||
losses: Optional[Dict[str, float]] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""Perform a "rehearsal" update from a batch of data. Rehearsal updates
|
||||
teach the current model to make predictions similar to an initial model,
|
||||
to try to address the "catastrophic forgetting" problem. This feature is
|
||||
experimental.
|
||||
|
||||
examples (Iterable[Example]): A batch of Example objects.
|
||||
drop (float): The dropout rate.
|
||||
sgd (thinc.api.Optimizer): The optimizer.
|
||||
losses (Dict[str, float]): Optional record of the loss during training.
|
||||
Updated using the component name as the key.
|
||||
RETURNS (Dict[str, float]): The updated losses dictionary.
|
||||
|
||||
DOCS: https://spacy.io/api/coref#rehearse (TODO)
|
||||
"""
|
||||
if losses is not None:
|
||||
losses.setdefault(self.name, 0.0)
|
||||
if self._rehearsal_model is None:
|
||||
return losses
|
||||
validate_examples(examples, "CoreferenceResolver.rehearse")
|
||||
# TODO test this whole function
|
||||
docs = [eg.predicted for eg in examples]
|
||||
if not any(len(doc) for doc in docs):
|
||||
# Handle cases where there are no tokens in any docs.
|
||||
return losses
|
||||
set_dropout_rate(self.model, drop)
|
||||
scores, bp_scores = self.model.begin_update(docs)
|
||||
# TODO below
|
||||
target = self._rehearsal_model(examples)
|
||||
gradient = scores - target
|
||||
bp_scores(gradient)
|
||||
if sgd is not None:
|
||||
self.finish_update(sgd)
|
||||
if losses is not None:
|
||||
losses[self.name] += (gradient**2).sum()
|
||||
return losses
|
||||
def rehearse(self, examples, *, sgd=None, losses=None, **config):
|
||||
raise NotImplementedError
|
||||
|
||||
def add_label(self, label: str) -> int:
|
||||
"""Technically this method should be implemented from TrainablePipe,
|
||||
|
@ -276,7 +235,7 @@ class CoreferenceResolver(TrainablePipe):
|
|||
def get_loss(
|
||||
self,
|
||||
examples: Iterable[Example],
|
||||
score_matrix: List[Tuple[Floats2d, Ints2d]],
|
||||
score_matrix: Floats2d,
|
||||
mention_idx: Ints2d,
|
||||
):
|
||||
"""Find the loss and gradient of loss for the batch of documents and
|
||||
|
@ -293,13 +252,13 @@ class CoreferenceResolver(TrainablePipe):
|
|||
|
||||
# TODO if there is more than one example, give an error
|
||||
# (or actually rework this to take multiple things)
|
||||
example = examples[0]
|
||||
cscores = score_matrix
|
||||
example = list(examples)[0]
|
||||
cidx = mention_idx
|
||||
|
||||
clusters = get_clusters_from_doc(example.reference)
|
||||
span_idxs = create_head_span_idxs(ops, len(example.predicted))
|
||||
gscores = create_gold_scores(span_idxs, clusters)
|
||||
# TODO fix type here. This is bools but asarray2f wants ints.
|
||||
gscores = ops.asarray2f(gscores)
|
||||
# top_gscores = xp.take_along_axis(gscores, cidx, axis=1)
|
||||
top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1)
|
||||
|
@ -313,8 +272,8 @@ class CoreferenceResolver(TrainablePipe):
|
|||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
log_marg = ops.softmax(cscores + ops.xp.log(top_gscores), axis=1)
|
||||
log_norm = ops.softmax(cscores, axis=1)
|
||||
log_marg = ops.softmax(score_matrix + ops.xp.log(top_gscores), axis=1)
|
||||
log_norm = ops.softmax(score_matrix, axis=1)
|
||||
grad = log_norm - log_marg
|
||||
# gradients.append((grad, cidx))
|
||||
loss = float((grad**2).sum())
|
||||
|
|
Loading…
Reference in New Issue
Block a user