mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
Suggestions from code review, cleanup, typing
This commit is contained in:
parent
6999436270
commit
6087da9675
|
@ -2,7 +2,7 @@ from typing import Iterable, Tuple, Optional, Dict, Callable, Any, List
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from thinc.types import Floats2d, Floats3d, Ints2d
|
from thinc.types import Floats2d, Floats3d, Ints2d
|
||||||
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
|
from thinc.api import Model, Config, Optimizer
|
||||||
from thinc.api import set_dropout_rate, to_categorical
|
from thinc.api import set_dropout_rate, to_categorical
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from statistics import mean
|
from statistics import mean
|
||||||
|
@ -11,7 +11,6 @@ from .trainable_pipe import TrainablePipe
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
from ..training import Example, validate_examples, validate_get_examples
|
from ..training import Example, validate_examples, validate_get_examples
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
from ..scorer import Scorer
|
|
||||||
from ..tokens import Doc
|
from ..tokens import Doc
|
||||||
from ..vocab import Vocab
|
from ..vocab import Vocab
|
||||||
|
|
||||||
|
@ -118,7 +117,7 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
self.span_cluster_prefix = span_cluster_prefix
|
self.span_cluster_prefix = span_cluster_prefix
|
||||||
self._rehearsal_model = None
|
self._rehearsal_model = None
|
||||||
|
|
||||||
self.cfg = {}
|
self.cfg: Dict[str, Any] = {}
|
||||||
|
|
||||||
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
|
def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
|
||||||
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||||
|
@ -154,6 +153,7 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/coref#set_annotations (TODO)
|
DOCS: https://spacy.io/api/coref#set_annotations (TODO)
|
||||||
"""
|
"""
|
||||||
|
docs = list(docs)
|
||||||
if len(docs) != len(clusters_by_doc):
|
if len(docs) != len(clusters_by_doc):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Found coref clusters incompatible with the "
|
"Found coref clusters incompatible with the "
|
||||||
|
@ -219,49 +219,8 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
losses[self.name] += total_loss
|
losses[self.name] += total_loss
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def rehearse(
|
def rehearse(self, examples, *, sgd=None, losses=None, **config):
|
||||||
self,
|
raise NotImplementedError
|
||||||
examples: Iterable[Example],
|
|
||||||
*,
|
|
||||||
drop: float = 0.0,
|
|
||||||
sgd: Optional[Optimizer] = None,
|
|
||||||
losses: Optional[Dict[str, float]] = None,
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
"""Perform a "rehearsal" update from a batch of data. Rehearsal updates
|
|
||||||
teach the current model to make predictions similar to an initial model,
|
|
||||||
to try to address the "catastrophic forgetting" problem. This feature is
|
|
||||||
experimental.
|
|
||||||
|
|
||||||
examples (Iterable[Example]): A batch of Example objects.
|
|
||||||
drop (float): The dropout rate.
|
|
||||||
sgd (thinc.api.Optimizer): The optimizer.
|
|
||||||
losses (Dict[str, float]): Optional record of the loss during training.
|
|
||||||
Updated using the component name as the key.
|
|
||||||
RETURNS (Dict[str, float]): The updated losses dictionary.
|
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/coref#rehearse (TODO)
|
|
||||||
"""
|
|
||||||
if losses is not None:
|
|
||||||
losses.setdefault(self.name, 0.0)
|
|
||||||
if self._rehearsal_model is None:
|
|
||||||
return losses
|
|
||||||
validate_examples(examples, "CoreferenceResolver.rehearse")
|
|
||||||
# TODO test this whole function
|
|
||||||
docs = [eg.predicted for eg in examples]
|
|
||||||
if not any(len(doc) for doc in docs):
|
|
||||||
# Handle cases where there are no tokens in any docs.
|
|
||||||
return losses
|
|
||||||
set_dropout_rate(self.model, drop)
|
|
||||||
scores, bp_scores = self.model.begin_update(docs)
|
|
||||||
# TODO below
|
|
||||||
target = self._rehearsal_model(examples)
|
|
||||||
gradient = scores - target
|
|
||||||
bp_scores(gradient)
|
|
||||||
if sgd is not None:
|
|
||||||
self.finish_update(sgd)
|
|
||||||
if losses is not None:
|
|
||||||
losses[self.name] += (gradient**2).sum()
|
|
||||||
return losses
|
|
||||||
|
|
||||||
def add_label(self, label: str) -> int:
|
def add_label(self, label: str) -> int:
|
||||||
"""Technically this method should be implemented from TrainablePipe,
|
"""Technically this method should be implemented from TrainablePipe,
|
||||||
|
@ -276,7 +235,7 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
def get_loss(
|
def get_loss(
|
||||||
self,
|
self,
|
||||||
examples: Iterable[Example],
|
examples: Iterable[Example],
|
||||||
score_matrix: List[Tuple[Floats2d, Ints2d]],
|
score_matrix: Floats2d,
|
||||||
mention_idx: Ints2d,
|
mention_idx: Ints2d,
|
||||||
):
|
):
|
||||||
"""Find the loss and gradient of loss for the batch of documents and
|
"""Find the loss and gradient of loss for the batch of documents and
|
||||||
|
@ -293,13 +252,13 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
|
|
||||||
# TODO if there is more than one example, give an error
|
# TODO if there is more than one example, give an error
|
||||||
# (or actually rework this to take multiple things)
|
# (or actually rework this to take multiple things)
|
||||||
example = examples[0]
|
example = list(examples)[0]
|
||||||
cscores = score_matrix
|
|
||||||
cidx = mention_idx
|
cidx = mention_idx
|
||||||
|
|
||||||
clusters = get_clusters_from_doc(example.reference)
|
clusters = get_clusters_from_doc(example.reference)
|
||||||
span_idxs = create_head_span_idxs(ops, len(example.predicted))
|
span_idxs = create_head_span_idxs(ops, len(example.predicted))
|
||||||
gscores = create_gold_scores(span_idxs, clusters)
|
gscores = create_gold_scores(span_idxs, clusters)
|
||||||
|
# TODO fix type here. This is bools but asarray2f wants ints.
|
||||||
gscores = ops.asarray2f(gscores)
|
gscores = ops.asarray2f(gscores)
|
||||||
# top_gscores = xp.take_along_axis(gscores, cidx, axis=1)
|
# top_gscores = xp.take_along_axis(gscores, cidx, axis=1)
|
||||||
top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1)
|
top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1)
|
||||||
|
@ -313,8 +272,8 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||||
log_marg = ops.softmax(cscores + ops.xp.log(top_gscores), axis=1)
|
log_marg = ops.softmax(score_matrix + ops.xp.log(top_gscores), axis=1)
|
||||||
log_norm = ops.softmax(cscores, axis=1)
|
log_norm = ops.softmax(score_matrix, axis=1)
|
||||||
grad = log_norm - log_marg
|
grad = log_norm - log_marg
|
||||||
# gradients.append((grad, cidx))
|
# gradients.append((grad, cidx))
|
||||||
loss = float((grad**2).sum())
|
loss = float((grad**2).sum())
|
||||||
|
|
Loading…
Reference in New Issue
Block a user