Add tok2vec rehearse

This commit is contained in:
thomashacker 2023-01-23 14:22:27 +01:00
parent f9e020dd67
commit 77cf0ee75d
4 changed files with 195 additions and 1 deletions

View File

@ -1337,9 +1337,18 @@ class Language:
ops = get_current_ops()
if self.vocab.vectors.shape[1] >= 1:
self.vocab.vectors.to_ops(ops)
# Create rehearsal models
for name, proc in self.pipeline:
if hasattr(proc, "_rehearsal_model"):
proc._rehearsal_model = deepcopy(proc.model) # type: ignore[attr-defined]
# Link listeners from rehearsal models to Tok2Vec components
for i, (name1, proc1) in enumerate(self.pipeline):
if isinstance(proc1, ty.ListenedToComponent):
for name2, proc2 in self.pipeline[i + 1 :]:
proc1.find_listeners(proc2)
if sgd is not None:
self._optimizer = sgd
elif self._optimizer is None:

View File

@ -59,6 +59,7 @@ class Tok2Vec(TrainablePipe):
"""
self.vocab = vocab
self.model = model
self._rehearsal_model = None
self.name = name
self.listener_map: Dict[str, List["Tok2VecListener"]] = {}
self.cfg: Dict[str, Any] = {}
@ -108,6 +109,11 @@ class Tok2Vec(TrainablePipe):
for node in component.model.walk():
if isinstance(node, Tok2VecListener) and node.upstream_name in names:
self.add_listener(node, component.name)
# Make sure to link to Tok2VecListeners from rehearsal models
if isinstance(getattr(component, "_rehearsal_model", None), Model):
for node in component._rehearsal_model.walk():
if isinstance(node, Tok2VecListener) and node.upstream_name in names:
self.add_listener(node, component.name + "_rehearsal_model")
def predict(self, docs: Iterable[Doc]):
"""Apply the pipeline's model to a batch of docs, without modifying them.
@ -191,6 +197,57 @@ class Tok2Vec(TrainablePipe):
self.listeners[-1].receive(batch_id, tokvecs, backprop)
return losses
def rehearse(
self,
examples: Iterable[Example],
*,
drop: float = 0.0,
sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None,
):
"""Perform a "rehearsal" update from a batch of data. Rehearsal updates
teach the current model to make predictions similar to an initial model,
to try to address the "catastrophic forgetting" problem. This feature is
experimental.
examples (Iterable[Example]): A batch of Example objects.
drop (float): The dropout rate.
sgd (thinc.api.Optimizer): The optimizer.
losses (Dict[str, float]): Optional record of the loss during training.
Updated using the component name as the key.
RETURNS (Dict[str, float]): The updated losses dictionary.
DOCS: https://spacy.io/api/tok2vec#rehearse
"""
if losses is None:
losses = {}
if self._rehearsal_model is None:
return losses
validate_examples(examples, "Tok2Vec.rehearse")
docs = [eg.predicted for eg in examples]
set_dropout_rate(self.model, drop)
tokvecs, bp_tokvecs = self.model.begin_update(docs)
target, _ = self._rehearsal_model.begin_update(docs)
d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
losses.setdefault(self.name, 0.0)
for i in range(len(target)):
d_tokvecs[i] += target[i]
losses[self.name] += float((target[i] ** 2).sum())
def empty_backprop(_):
return [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
batch_id = Tok2VecListener.get_batch_id(docs)
for listener in self.listeners:
listener.receive(batch_id, tokvecs, empty_backprop)
bp_tokvecs(d_tokvecs)
if sgd is not None:
self.finish_update(sgd)
return losses
def get_loss(self, examples, scores) -> None:
pass

View File

@ -1,6 +1,7 @@
import pytest
import spacy
from thinc.api import Config
from typing import List
from spacy.training import Example
@ -148,6 +149,86 @@ REHEARSE_DATA = [
),
]
TEXTCAT_MULTILABEL_LISTENER_CONFIG = """
[nlp]
lang = "en"
pipeline = ["tok2vec","textcat_multilabel"]
disabled = []
before_creation = null
after_creation = null
after_pipeline_creation = null
batch_size = 1000
tokenizer = {"@tokenizers":"spacy.Tokenizer.v1"}
[components]
[components.textcat_multilabel]
factory = "textcat_multilabel"
threshold = 0.5
[components.textcat_multilabel.model]
@architectures = "spacy.TextCatEnsemble.v2"
nO = null
[components.textcat_multilabel.model.linear_model]
@architectures = "spacy.TextCatBOW.v2"
exclusive_classes = false
ngram_size = 1
no_output_layer = false
[components.textcat_multilabel.model.tok2vec]
@architectures = "spacy.Tok2VecListener.v1"
width = 64
upstream = "*"
[components.tok2vec]
factory = "tok2vec"
[components.tok2vec.model]
@architectures = "spacy.Tok2Vec.v2"
[components.tok2vec.model.embed]
@architectures = "spacy.MultiHashEmbed.v2"
width = 64
attrs = ["ORTH", "SHAPE"]
rows = [5000, 2500]
include_static_vectors = true
[components.tok2vec.model.encode]
@architectures = "spacy.MishWindowEncoder.v2"
width = 64
depth = 4
window_size = 1
"""
NER_LISTENER_CONFIG = """
[nlp]
lang = "en"
pipeline = ["tok2vec","ner"]
batch_size = 1000
[components]
[components.tok2vec]
factory = "tok2vec"
[components.tok2vec.model]
@architectures = "spacy.Tok2Vec.v2"
[components.tok2vec.model.embed]
@architectures = "spacy.MultiHashEmbed.v2"
width = ${components.tok2vec.model.encode.width}
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
rows = [5000, 1000, 2500, 2500]
include_static_vectors = false
[components.tok2vec.model.encode]
@architectures = "spacy.MaxoutWindowEncoder.v2"
width = 96
depth = 4
window_size = 1
maxout_pieces = 3
[components.ner]
factory = "ner"
[components.ner.model]
@architectures = "spacy.TransitionBasedParser.v2"
state_type = "ner"
extra_state_tokens = false
hidden_width = 64
maxout_pieces = 2
use_upper = true
nO = null
[components.ner.model.tok2vec]
@architectures = "spacy.Tok2VecListener.v1"
width = ${components.tok2vec.model.encode.width}
"""
def _add_ner_label(ner, data):
for _, annotations in data:
@ -197,7 +278,11 @@ def _optimize(nlp, component: str, data: List, rehearse: bool):
doc = nlp.make_doc(text)
example = Example.from_dict(doc, annotation)
if rehearse:
nlp.rehearse([example], sgd=optimizer)
nlp.update([example], sgd=None)
nlp.rehearse([example], sgd=None)
for name, proc in nlp.pipeline:
if proc.is_trainable and proc.model not in (True, False, None):
proc.finish_update(optimizer)
else:
nlp.update([example], sgd=optimizer)
return nlp
@ -209,3 +294,21 @@ def test_rehearse(component):
nlp.add_pipe(component)
nlp = _optimize(nlp, component, TRAIN_DATA, False)
_optimize(nlp, component, REHEARSE_DATA, True)
@pytest.mark.issue(12044)
def test_rehearse_textcat_multilabel_listener():
"""Test nlp.rehearse on a textcat_multilabel pipeline with a tok2vec listener"""
config = Config().from_str(TEXTCAT_MULTILABEL_LISTENER_CONFIG)
nlp = spacy.blank("en", config=config)
nlp = _optimize(nlp, "textcat_multilabel", TRAIN_DATA, False)
_optimize(nlp, "textcat_multilabel", REHEARSE_DATA, True)
@pytest.mark.issue(12044)
def test_rehearse_ner_listener():
"""Test nlp.rehearse on a ner pipeline with a tok2vec listener"""
config = Config().from_str(NER_LISTENER_CONFIG)
nlp = spacy.blank("en", config=config)
nlp = _optimize(nlp, "ner", TRAIN_DATA, False)
_optimize(nlp, "ner", REHEARSE_DATA, True)

View File

@ -205,6 +205,31 @@ Delegates to [`predict`](/api/tok2vec#predict).
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
## Tok2Vec.rehearse {id="rehearse",tag="method,experimental",version="3.5.1"}
Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the
current model to make predictions similar to an initial model, to try to address
the "catastrophic forgetting" problem. Please note that `Tok2Vec.rehearse` needs to be used together with `Tok2Vec.update`. This feature is experimental.
> #### Example
>
> ```python
> tok2vec = nlp.add_pipe("tok2vec")
> optimizer = nlp.resume_training()
> update_losses = tok2vec.update(examples, sgd=None)
> rehearse_losses = tok2vec.rehearse(examples, sgd=None)
> tok2vec.finish_update(optimizer)
> ```
| Name | Description |
| -------------- | ------------------------------------------------------------------------------------------------------------- |
| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ |
| _keyword-only_ | |
| `drop` | The dropout rate. ~~float~~ |
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
| `losses` | Dictionary to update with the loss, keyed by pipeline component. ~~Optional[Dict[str, float]]~~ |
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
## Tok2Vec.create_optimizer {id="create_optimizer",tag="method"}
Create an optimizer for the pipeline component.