mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-05 12:50:20 +03:00
Add tok2vec rehearse
This commit is contained in:
parent
f9e020dd67
commit
77cf0ee75d
|
@ -1337,9 +1337,18 @@ class Language:
|
||||||
ops = get_current_ops()
|
ops = get_current_ops()
|
||||||
if self.vocab.vectors.shape[1] >= 1:
|
if self.vocab.vectors.shape[1] >= 1:
|
||||||
self.vocab.vectors.to_ops(ops)
|
self.vocab.vectors.to_ops(ops)
|
||||||
|
|
||||||
|
# Create rehearsal models
|
||||||
for name, proc in self.pipeline:
|
for name, proc in self.pipeline:
|
||||||
if hasattr(proc, "_rehearsal_model"):
|
if hasattr(proc, "_rehearsal_model"):
|
||||||
proc._rehearsal_model = deepcopy(proc.model) # type: ignore[attr-defined]
|
proc._rehearsal_model = deepcopy(proc.model) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
# Link listeners from rehearsal models to Tok2Vec components
|
||||||
|
for i, (name1, proc1) in enumerate(self.pipeline):
|
||||||
|
if isinstance(proc1, ty.ListenedToComponent):
|
||||||
|
for name2, proc2 in self.pipeline[i + 1 :]:
|
||||||
|
proc1.find_listeners(proc2)
|
||||||
|
|
||||||
if sgd is not None:
|
if sgd is not None:
|
||||||
self._optimizer = sgd
|
self._optimizer = sgd
|
||||||
elif self._optimizer is None:
|
elif self._optimizer is None:
|
||||||
|
|
|
@ -59,6 +59,7 @@ class Tok2Vec(TrainablePipe):
|
||||||
"""
|
"""
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self._rehearsal_model = None
|
||||||
self.name = name
|
self.name = name
|
||||||
self.listener_map: Dict[str, List["Tok2VecListener"]] = {}
|
self.listener_map: Dict[str, List["Tok2VecListener"]] = {}
|
||||||
self.cfg: Dict[str, Any] = {}
|
self.cfg: Dict[str, Any] = {}
|
||||||
|
@ -108,6 +109,11 @@ class Tok2Vec(TrainablePipe):
|
||||||
for node in component.model.walk():
|
for node in component.model.walk():
|
||||||
if isinstance(node, Tok2VecListener) and node.upstream_name in names:
|
if isinstance(node, Tok2VecListener) and node.upstream_name in names:
|
||||||
self.add_listener(node, component.name)
|
self.add_listener(node, component.name)
|
||||||
|
# Make sure to link to Tok2VecListeners from rehearsal models
|
||||||
|
if isinstance(getattr(component, "_rehearsal_model", None), Model):
|
||||||
|
for node in component._rehearsal_model.walk():
|
||||||
|
if isinstance(node, Tok2VecListener) and node.upstream_name in names:
|
||||||
|
self.add_listener(node, component.name + "_rehearsal_model")
|
||||||
|
|
||||||
def predict(self, docs: Iterable[Doc]):
|
def predict(self, docs: Iterable[Doc]):
|
||||||
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||||
|
@ -191,6 +197,57 @@ class Tok2Vec(TrainablePipe):
|
||||||
self.listeners[-1].receive(batch_id, tokvecs, backprop)
|
self.listeners[-1].receive(batch_id, tokvecs, backprop)
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
|
def rehearse(
|
||||||
|
self,
|
||||||
|
examples: Iterable[Example],
|
||||||
|
*,
|
||||||
|
drop: float = 0.0,
|
||||||
|
sgd: Optional[Optimizer] = None,
|
||||||
|
losses: Optional[Dict[str, float]] = None,
|
||||||
|
):
|
||||||
|
"""Perform a "rehearsal" update from a batch of data. Rehearsal updates
|
||||||
|
teach the current model to make predictions similar to an initial model,
|
||||||
|
to try to address the "catastrophic forgetting" problem. This feature is
|
||||||
|
experimental.
|
||||||
|
|
||||||
|
examples (Iterable[Example]): A batch of Example objects.
|
||||||
|
drop (float): The dropout rate.
|
||||||
|
sgd (thinc.api.Optimizer): The optimizer.
|
||||||
|
losses (Dict[str, float]): Optional record of the loss during training.
|
||||||
|
Updated using the component name as the key.
|
||||||
|
RETURNS (Dict[str, float]): The updated losses dictionary.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/tok2vec#rehearse
|
||||||
|
"""
|
||||||
|
if losses is None:
|
||||||
|
losses = {}
|
||||||
|
if self._rehearsal_model is None:
|
||||||
|
return losses
|
||||||
|
validate_examples(examples, "Tok2Vec.rehearse")
|
||||||
|
docs = [eg.predicted for eg in examples]
|
||||||
|
set_dropout_rate(self.model, drop)
|
||||||
|
tokvecs, bp_tokvecs = self.model.begin_update(docs)
|
||||||
|
target, _ = self._rehearsal_model.begin_update(docs)
|
||||||
|
d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
|
||||||
|
losses.setdefault(self.name, 0.0)
|
||||||
|
|
||||||
|
for i in range(len(target)):
|
||||||
|
d_tokvecs[i] += target[i]
|
||||||
|
losses[self.name] += float((target[i] ** 2).sum())
|
||||||
|
|
||||||
|
def empty_backprop(_):
|
||||||
|
return [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
|
||||||
|
|
||||||
|
batch_id = Tok2VecListener.get_batch_id(docs)
|
||||||
|
for listener in self.listeners:
|
||||||
|
listener.receive(batch_id, tokvecs, empty_backprop)
|
||||||
|
|
||||||
|
bp_tokvecs(d_tokvecs)
|
||||||
|
if sgd is not None:
|
||||||
|
self.finish_update(sgd)
|
||||||
|
|
||||||
|
return losses
|
||||||
|
|
||||||
def get_loss(self, examples, scores) -> None:
|
def get_loss(self, examples, scores) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
import spacy
|
import spacy
|
||||||
|
|
||||||
|
from thinc.api import Config
|
||||||
from typing import List
|
from typing import List
|
||||||
from spacy.training import Example
|
from spacy.training import Example
|
||||||
|
|
||||||
|
@ -148,6 +149,86 @@ REHEARSE_DATA = [
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
TEXTCAT_MULTILABEL_LISTENER_CONFIG = """
|
||||||
|
[nlp]
|
||||||
|
lang = "en"
|
||||||
|
pipeline = ["tok2vec","textcat_multilabel"]
|
||||||
|
disabled = []
|
||||||
|
before_creation = null
|
||||||
|
after_creation = null
|
||||||
|
after_pipeline_creation = null
|
||||||
|
batch_size = 1000
|
||||||
|
tokenizer = {"@tokenizers":"spacy.Tokenizer.v1"}
|
||||||
|
[components]
|
||||||
|
[components.textcat_multilabel]
|
||||||
|
factory = "textcat_multilabel"
|
||||||
|
threshold = 0.5
|
||||||
|
[components.textcat_multilabel.model]
|
||||||
|
@architectures = "spacy.TextCatEnsemble.v2"
|
||||||
|
nO = null
|
||||||
|
[components.textcat_multilabel.model.linear_model]
|
||||||
|
@architectures = "spacy.TextCatBOW.v2"
|
||||||
|
exclusive_classes = false
|
||||||
|
ngram_size = 1
|
||||||
|
no_output_layer = false
|
||||||
|
[components.textcat_multilabel.model.tok2vec]
|
||||||
|
@architectures = "spacy.Tok2VecListener.v1"
|
||||||
|
width = 64
|
||||||
|
upstream = "*"
|
||||||
|
[components.tok2vec]
|
||||||
|
factory = "tok2vec"
|
||||||
|
[components.tok2vec.model]
|
||||||
|
@architectures = "spacy.Tok2Vec.v2"
|
||||||
|
[components.tok2vec.model.embed]
|
||||||
|
@architectures = "spacy.MultiHashEmbed.v2"
|
||||||
|
width = 64
|
||||||
|
attrs = ["ORTH", "SHAPE"]
|
||||||
|
rows = [5000, 2500]
|
||||||
|
include_static_vectors = true
|
||||||
|
[components.tok2vec.model.encode]
|
||||||
|
@architectures = "spacy.MishWindowEncoder.v2"
|
||||||
|
width = 64
|
||||||
|
depth = 4
|
||||||
|
window_size = 1
|
||||||
|
"""
|
||||||
|
|
||||||
|
NER_LISTENER_CONFIG = """
|
||||||
|
[nlp]
|
||||||
|
lang = "en"
|
||||||
|
pipeline = ["tok2vec","ner"]
|
||||||
|
batch_size = 1000
|
||||||
|
[components]
|
||||||
|
[components.tok2vec]
|
||||||
|
factory = "tok2vec"
|
||||||
|
[components.tok2vec.model]
|
||||||
|
@architectures = "spacy.Tok2Vec.v2"
|
||||||
|
[components.tok2vec.model.embed]
|
||||||
|
@architectures = "spacy.MultiHashEmbed.v2"
|
||||||
|
width = ${components.tok2vec.model.encode.width}
|
||||||
|
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
|
||||||
|
rows = [5000, 1000, 2500, 2500]
|
||||||
|
include_static_vectors = false
|
||||||
|
[components.tok2vec.model.encode]
|
||||||
|
@architectures = "spacy.MaxoutWindowEncoder.v2"
|
||||||
|
width = 96
|
||||||
|
depth = 4
|
||||||
|
window_size = 1
|
||||||
|
maxout_pieces = 3
|
||||||
|
[components.ner]
|
||||||
|
factory = "ner"
|
||||||
|
[components.ner.model]
|
||||||
|
@architectures = "spacy.TransitionBasedParser.v2"
|
||||||
|
state_type = "ner"
|
||||||
|
extra_state_tokens = false
|
||||||
|
hidden_width = 64
|
||||||
|
maxout_pieces = 2
|
||||||
|
use_upper = true
|
||||||
|
nO = null
|
||||||
|
[components.ner.model.tok2vec]
|
||||||
|
@architectures = "spacy.Tok2VecListener.v1"
|
||||||
|
width = ${components.tok2vec.model.encode.width}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _add_ner_label(ner, data):
|
def _add_ner_label(ner, data):
|
||||||
for _, annotations in data:
|
for _, annotations in data:
|
||||||
|
@ -197,7 +278,11 @@ def _optimize(nlp, component: str, data: List, rehearse: bool):
|
||||||
doc = nlp.make_doc(text)
|
doc = nlp.make_doc(text)
|
||||||
example = Example.from_dict(doc, annotation)
|
example = Example.from_dict(doc, annotation)
|
||||||
if rehearse:
|
if rehearse:
|
||||||
nlp.rehearse([example], sgd=optimizer)
|
nlp.update([example], sgd=None)
|
||||||
|
nlp.rehearse([example], sgd=None)
|
||||||
|
for name, proc in nlp.pipeline:
|
||||||
|
if proc.is_trainable and proc.model not in (True, False, None):
|
||||||
|
proc.finish_update(optimizer)
|
||||||
else:
|
else:
|
||||||
nlp.update([example], sgd=optimizer)
|
nlp.update([example], sgd=optimizer)
|
||||||
return nlp
|
return nlp
|
||||||
|
@ -209,3 +294,21 @@ def test_rehearse(component):
|
||||||
nlp.add_pipe(component)
|
nlp.add_pipe(component)
|
||||||
nlp = _optimize(nlp, component, TRAIN_DATA, False)
|
nlp = _optimize(nlp, component, TRAIN_DATA, False)
|
||||||
_optimize(nlp, component, REHEARSE_DATA, True)
|
_optimize(nlp, component, REHEARSE_DATA, True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.issue(12044)
|
||||||
|
def test_rehearse_textcat_multilabel_listener():
|
||||||
|
"""Test nlp.rehearse on a textcat_multilabel pipeline with a tok2vec listener"""
|
||||||
|
config = Config().from_str(TEXTCAT_MULTILABEL_LISTENER_CONFIG)
|
||||||
|
nlp = spacy.blank("en", config=config)
|
||||||
|
nlp = _optimize(nlp, "textcat_multilabel", TRAIN_DATA, False)
|
||||||
|
_optimize(nlp, "textcat_multilabel", REHEARSE_DATA, True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.issue(12044)
|
||||||
|
def test_rehearse_ner_listener():
|
||||||
|
"""Test nlp.rehearse on a ner pipeline with a tok2vec listener"""
|
||||||
|
config = Config().from_str(NER_LISTENER_CONFIG)
|
||||||
|
nlp = spacy.blank("en", config=config)
|
||||||
|
nlp = _optimize(nlp, "ner", TRAIN_DATA, False)
|
||||||
|
_optimize(nlp, "ner", REHEARSE_DATA, True)
|
||||||
|
|
|
@ -205,6 +205,31 @@ Delegates to [`predict`](/api/tok2vec#predict).
|
||||||
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
|
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
|
||||||
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
|
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
|
||||||
|
|
||||||
|
## Tok2Vec.rehearse {id="rehearse",tag="method,experimental",version="3.5.1"}
|
||||||
|
|
||||||
|
Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the
|
||||||
|
current model to make predictions similar to an initial model, to try to address
|
||||||
|
the "catastrophic forgetting" problem. Please note that `Tok2Vec.rehearse` needs to be used together with `Tok2Vec.update`. This feature is experimental.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> tok2vec = nlp.add_pipe("tok2vec")
|
||||||
|
> optimizer = nlp.resume_training()
|
||||||
|
> update_losses = tok2vec.update(examples, sgd=None)
|
||||||
|
> rehearse_losses = tok2vec.rehearse(examples, sgd=None)
|
||||||
|
> tok2vec.finish_update(optimizer)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| -------------- | ------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ |
|
||||||
|
| _keyword-only_ | |
|
||||||
|
| `drop` | The dropout rate. ~~float~~ |
|
||||||
|
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
|
||||||
|
| `losses` | Dictionary to update with the loss, keyed by pipeline component. ~~Optional[Dict[str, float]]~~ |
|
||||||
|
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
|
||||||
|
|
||||||
## Tok2Vec.create_optimizer {id="create_optimizer",tag="method"}
|
## Tok2Vec.create_optimizer {id="create_optimizer",tag="method"}
|
||||||
|
|
||||||
Create an optimizer for the pipeline component.
|
Create an optimizer for the pipeline component.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user