mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-28 21:03:41 +03:00
Tok2Vec
: Refactor update
This commit is contained in:
parent
c96786152b
commit
c21ebb31ec
|
@ -1,5 +1,6 @@
|
||||||
from typing import Sequence, Iterable, Optional, Dict, Callable, List, Any
|
from typing import Sequence, Iterable, Optional, Dict, Callable, List, Any, Tuple
|
||||||
from thinc.api import Model, set_dropout_rate, Optimizer, Config
|
from thinc.api import Model, set_dropout_rate, Optimizer, Config
|
||||||
|
from thinc.types import Floats2d
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
|
|
||||||
from .trainable_pipe import TrainablePipe
|
from .trainable_pipe import TrainablePipe
|
||||||
|
@ -157,39 +158,9 @@ class Tok2Vec(TrainablePipe):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/tok2vec#update
|
DOCS: https://spacy.io/api/tok2vec#update
|
||||||
"""
|
"""
|
||||||
if losses is None:
|
|
||||||
losses = {}
|
|
||||||
validate_examples(examples, "Tok2Vec.update")
|
validate_examples(examples, "Tok2Vec.update")
|
||||||
docs = [eg.predicted for eg in examples]
|
docs = [eg.predicted for eg in examples]
|
||||||
set_dropout_rate(self.model, drop)
|
return self._update_with_docs(docs, drop=drop, sgd=sgd, losses=losses)
|
||||||
tokvecs, bp_tokvecs = self.model.begin_update(docs)
|
|
||||||
d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
|
|
||||||
losses.setdefault(self.name, 0.0)
|
|
||||||
|
|
||||||
def accumulate_gradient(one_d_tokvecs):
|
|
||||||
"""Accumulate tok2vec loss and gradient. This is passed as a callback
|
|
||||||
to all but the last listener. Only the last one does the backprop.
|
|
||||||
"""
|
|
||||||
nonlocal d_tokvecs
|
|
||||||
for i in range(len(one_d_tokvecs)):
|
|
||||||
d_tokvecs[i] += one_d_tokvecs[i]
|
|
||||||
losses[self.name] += float((one_d_tokvecs[i] ** 2).sum())
|
|
||||||
return [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
|
|
||||||
|
|
||||||
def backprop(one_d_tokvecs):
|
|
||||||
"""Callback to actually do the backprop. Passed to last listener."""
|
|
||||||
accumulate_gradient(one_d_tokvecs)
|
|
||||||
d_docs = bp_tokvecs(d_tokvecs)
|
|
||||||
if sgd is not None:
|
|
||||||
self.finish_update(sgd)
|
|
||||||
return d_docs
|
|
||||||
|
|
||||||
batch_id = Tok2VecListener.get_batch_id(docs)
|
|
||||||
for listener in self.listeners[:-1]:
|
|
||||||
listener.receive(batch_id, tokvecs, accumulate_gradient)
|
|
||||||
if self.listeners:
|
|
||||||
self.listeners[-1].receive(batch_id, tokvecs, backprop)
|
|
||||||
return losses
|
|
||||||
|
|
||||||
def get_loss(self, examples, scores) -> None:
|
def get_loss(self, examples, scores) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -229,9 +200,74 @@ class Tok2Vec(TrainablePipe):
|
||||||
sgd: Optional[Optimizer] = None,
|
sgd: Optional[Optimizer] = None,
|
||||||
losses: Optional[Dict[str, float]] = None,
|
losses: Optional[Dict[str, float]] = None,
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, float]:
|
||||||
teacher_pipe.set_annotations(teacher_docs, teacher_pipe.predict(teacher_docs))
|
teacher_preds = teacher_pipe.predict(teacher_docs)
|
||||||
examples = [Example(doc, doc) for doc in student_docs]
|
teacher_pipe.set_annotations(teacher_docs, teacher_preds)
|
||||||
return self.update(examples, drop=drop, sgd=sgd, losses=losses)
|
return self._update_with_docs(student_docs, drop=drop, sgd=sgd, losses=losses)
|
||||||
|
|
||||||
|
def _update_with_docs(
|
||||||
|
self,
|
||||||
|
docs: Iterable[Doc],
|
||||||
|
*,
|
||||||
|
drop: float = 0.0,
|
||||||
|
sgd: Optional[Optimizer] = None,
|
||||||
|
losses: Optional[Dict[str, float]] = None,
|
||||||
|
):
|
||||||
|
"""Learn from a batch of documents and gold-standard information,
|
||||||
|
updating the pipe's model.
|
||||||
|
|
||||||
|
examples (Iterable[Example]): A batch of Example objects.
|
||||||
|
drop (float): The dropout rate.
|
||||||
|
sgd (thinc.api.Optimizer): The optimizer.
|
||||||
|
losses (Dict[str, float]): Optional record of the loss during training.
|
||||||
|
Updated using the component name as the key.
|
||||||
|
RETURNS (Dict[str, float]): The updated losses dictionary.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/tok2vec#update
|
||||||
|
"""
|
||||||
|
if losses is None:
|
||||||
|
losses = {}
|
||||||
|
losses.setdefault(self.name, 0.0)
|
||||||
|
set_dropout_rate(self.model, drop)
|
||||||
|
|
||||||
|
tokvecs, accumulate_gradient, backprop = self._create_backprops(
|
||||||
|
docs, losses, sgd=sgd
|
||||||
|
)
|
||||||
|
batch_id = Tok2VecListener.get_batch_id(docs)
|
||||||
|
for listener in self.listeners[:-1]:
|
||||||
|
listener.receive(batch_id, tokvecs, accumulate_gradient)
|
||||||
|
if self.listeners:
|
||||||
|
self.listeners[-1].receive(batch_id, tokvecs, backprop)
|
||||||
|
return losses
|
||||||
|
|
||||||
|
def _create_backprops(
|
||||||
|
self,
|
||||||
|
docs: Iterable[Doc],
|
||||||
|
losses: Dict[str, float],
|
||||||
|
*,
|
||||||
|
sgd: Optional[Optimizer] = None,
|
||||||
|
) -> Tuple[Floats2d, Callable, Callable]:
|
||||||
|
tokvecs, bp_tokvecs = self.model.begin_update(docs)
|
||||||
|
d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
|
||||||
|
|
||||||
|
def accumulate_gradient(one_d_tokvecs):
|
||||||
|
"""Accumulate tok2vec loss and gradient. This is passed as a callback
|
||||||
|
to all but the last listener. Only the last one does the backprop.
|
||||||
|
"""
|
||||||
|
nonlocal d_tokvecs
|
||||||
|
for i in range(len(one_d_tokvecs)):
|
||||||
|
d_tokvecs[i] += one_d_tokvecs[i]
|
||||||
|
losses[self.name] += float((one_d_tokvecs[i] ** 2).sum())
|
||||||
|
return [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
|
||||||
|
|
||||||
|
def backprop(one_d_tokvecs):
|
||||||
|
"""Callback to actually do the backprop. Passed to last listener."""
|
||||||
|
accumulate_gradient(one_d_tokvecs)
|
||||||
|
d_docs = bp_tokvecs(d_tokvecs)
|
||||||
|
if sgd is not None:
|
||||||
|
self.finish_update(sgd)
|
||||||
|
return d_docs
|
||||||
|
|
||||||
|
return tokvecs, accumulate_gradient, backprop
|
||||||
|
|
||||||
|
|
||||||
class Tok2VecListener(Model):
|
class Tok2VecListener(Model):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user