mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 09:14:32 +03:00
Bugfixes and test for rehearse (#10347)
* fixing argument order for rehearse * rehearse test for ner and tagger * rehearse bugfix * added test for parser * test for multilabel textcat * rehearse fix * remove debug line * Update spacy/tests/training/test_rehearse.py Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Update spacy/tests/training/test_rehearse.py Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> Co-authored-by: Kádár Ákos <akos@onyx.uvt.nl> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
parent
b7ba7f78a2
commit
249b97184d
|
@ -1222,8 +1222,9 @@ class Language:
|
|||
component_cfg = {}
|
||||
grads = {}
|
||||
|
||||
def get_grads(W, dW, key=None):
|
||||
def get_grads(key, W, dW):
|
||||
grads[key] = (W, dW)
|
||||
return W, dW
|
||||
|
||||
get_grads.learn_rate = sgd.learn_rate # type: ignore[attr-defined, union-attr]
|
||||
get_grads.b1 = sgd.b1 # type: ignore[attr-defined, union-attr]
|
||||
|
@ -1236,7 +1237,7 @@ class Language:
|
|||
examples, sgd=get_grads, losses=losses, **component_cfg.get(name, {})
|
||||
)
|
||||
for key, (W, dW) in grads.items():
|
||||
sgd(W, dW, key=key) # type: ignore[call-arg, misc]
|
||||
sgd(key, W, dW) # type: ignore[call-arg, misc]
|
||||
return losses
|
||||
|
||||
def begin_training(
|
||||
|
|
|
@ -225,6 +225,7 @@ class Tagger(TrainablePipe):
|
|||
|
||||
DOCS: https://spacy.io/api/tagger#rehearse
|
||||
"""
|
||||
loss_func = SequenceCategoricalCrossentropy()
|
||||
if losses is None:
|
||||
losses = {}
|
||||
losses.setdefault(self.name, 0.0)
|
||||
|
@ -236,12 +237,12 @@ class Tagger(TrainablePipe):
|
|||
# Handle cases where there are no tokens in any docs.
|
||||
return losses
|
||||
set_dropout_rate(self.model, drop)
|
||||
guesses, backprop = self.model.begin_update(docs)
|
||||
target = self._rehearsal_model(examples)
|
||||
gradient = guesses - target
|
||||
backprop(gradient)
|
||||
tag_scores, bp_tag_scores = self.model.begin_update(docs)
|
||||
tutor_tag_scores, _ = self._rehearsal_model.begin_update(docs)
|
||||
grads, loss = loss_func(tag_scores, tutor_tag_scores)
|
||||
bp_tag_scores(grads)
|
||||
self.finish_update(sgd)
|
||||
losses[self.name] += (gradient**2).sum()
|
||||
losses[self.name] += loss
|
||||
return losses
|
||||
|
||||
def get_loss(self, examples, scores):
|
||||
|
|
|
@ -283,7 +283,7 @@ class TextCategorizer(TrainablePipe):
|
|||
return losses
|
||||
set_dropout_rate(self.model, drop)
|
||||
scores, bp_scores = self.model.begin_update(docs)
|
||||
target = self._rehearsal_model(examples)
|
||||
target, _ = self._rehearsal_model.begin_update(docs)
|
||||
gradient = scores - target
|
||||
bp_scores(gradient)
|
||||
if sgd is not None:
|
||||
|
|
168
spacy/tests/training/test_rehearse.py
Normal file
168
spacy/tests/training/test_rehearse.py
Normal file
|
@ -0,0 +1,168 @@
|
|||
import pytest
|
||||
import spacy
|
||||
|
||||
from typing import List
|
||||
from spacy.training import Example
|
||||
|
||||
|
||||
TRAIN_DATA = [
|
||||
(
|
||||
'Who is Kofi Annan?',
|
||||
{
|
||||
'entities': [(7, 18, 'PERSON')],
|
||||
'tags': ['PRON', 'AUX', 'PROPN', 'PRON', 'PUNCT'],
|
||||
'heads': [1, 1, 3, 1, 1],
|
||||
'deps': ['attr', 'ROOT', 'compound', 'nsubj', 'punct'],
|
||||
'morphs': ['', 'Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin', 'Number=Sing', 'Number=Sing', 'PunctType=Peri'],
|
||||
'cats': {'question': 1.0}
|
||||
}
|
||||
),
|
||||
(
|
||||
'Who is Steve Jobs?',
|
||||
{
|
||||
'entities': [(7, 17, 'PERSON')],
|
||||
'tags': ['PRON', 'AUX', 'PROPN', 'PRON', 'PUNCT'],
|
||||
'heads': [1, 1, 3, 1, 1],
|
||||
'deps': ['attr', 'ROOT', 'compound', 'nsubj', 'punct'],
|
||||
'morphs': ['', 'Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin', 'Number=Sing', 'Number=Sing', 'PunctType=Peri'],
|
||||
'cats': {'question': 1.0}
|
||||
}
|
||||
),
|
||||
(
|
||||
'Bob is a nice person.',
|
||||
{
|
||||
'entities': [(0, 3, 'PERSON')],
|
||||
'tags': ['PROPN', 'AUX', 'DET', 'ADJ', 'NOUN', 'PUNCT'],
|
||||
'heads': [1, 1, 4, 4, 1, 1],
|
||||
'deps': ['nsubj', 'ROOT', 'det', 'amod', 'attr', 'punct'],
|
||||
'morphs': ['Number=Sing', 'Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin', 'Definite=Ind|PronType=Art', 'Degree=Pos', 'Number=Sing', 'PunctType=Peri'],
|
||||
'cats': {'statement': 1.0}
|
||||
},
|
||||
),
|
||||
(
|
||||
'Hi Anil, how are you?',
|
||||
{
|
||||
'entities': [(3, 7, 'PERSON')],
|
||||
'tags': ['INTJ', 'PROPN', 'PUNCT', 'ADV', 'AUX', 'PRON', 'PUNCT'],
|
||||
'deps': ['intj', 'npadvmod', 'punct', 'advmod', 'ROOT', 'nsubj', 'punct'],
|
||||
'heads': [4, 0, 4, 4, 4, 4, 4],
|
||||
'morphs': ['', 'Number=Sing', 'PunctType=Comm', '', 'Mood=Ind|Tense=Pres|VerbForm=Fin', 'Case=Nom|Person=2|PronType=Prs', 'PunctType=Peri'],
|
||||
'cats': {'greeting': 1.0, 'question': 1.0}
|
||||
}
|
||||
),
|
||||
(
|
||||
'I like London and Berlin.',
|
||||
{
|
||||
'entities': [(7, 13, 'LOC'), (18, 24, 'LOC')],
|
||||
'tags': ['PROPN', 'VERB', 'PROPN', 'CCONJ', 'PROPN', 'PUNCT'],
|
||||
'deps': ['nsubj', 'ROOT', 'dobj', 'cc', 'conj', 'punct'],
|
||||
'heads': [1, 1, 1, 2, 2, 1],
|
||||
'morphs': ['Case=Nom|Number=Sing|Person=1|PronType=Prs', 'Tense=Pres|VerbForm=Fin', 'Number=Sing', 'ConjType=Cmp', 'Number=Sing', 'PunctType=Peri'],
|
||||
'cats': {'statement': 1.0}
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
REHEARSE_DATA = [
|
||||
(
|
||||
'Hi Anil',
|
||||
{
|
||||
'entities': [(3, 7, 'PERSON')],
|
||||
'tags': ['INTJ', 'PROPN'],
|
||||
'deps': ['ROOT', 'npadvmod'],
|
||||
'heads': [0, 0],
|
||||
'morphs': ['', 'Number=Sing'],
|
||||
'cats': {'greeting': 1.0}
|
||||
}
|
||||
),
|
||||
(
|
||||
'Hi Ravish, how you doing?',
|
||||
{
|
||||
'entities': [(3, 9, 'PERSON')],
|
||||
'tags': ['INTJ', 'PROPN', 'PUNCT', 'ADV', 'AUX', 'PRON', 'PUNCT'],
|
||||
'deps': ['intj', 'ROOT', 'punct', 'advmod', 'nsubj', 'advcl', 'punct'],
|
||||
'heads': [1, 1, 1, 5, 5, 1, 1],
|
||||
'morphs': ['', 'VerbForm=Inf', 'PunctType=Comm', '', 'Case=Nom|Person=2|PronType=Prs', 'Aspect=Prog|Tense=Pres|VerbForm=Part', 'PunctType=Peri'],
|
||||
'cats': {'greeting': 1.0, 'question': 1.0}
|
||||
}
|
||||
),
|
||||
# UTENSIL new label
|
||||
(
|
||||
'Natasha bought new forks.',
|
||||
{
|
||||
'entities': [(0, 7, 'PERSON'), (19, 24, 'UTENSIL')],
|
||||
'tags': ['PROPN', 'VERB', 'ADJ', 'NOUN', 'PUNCT'],
|
||||
'deps': ['nsubj', 'ROOT', 'amod', 'dobj', 'punct'],
|
||||
'heads': [1, 1, 3, 1, 1],
|
||||
'morphs': ['Number=Sing', 'Tense=Past|VerbForm=Fin', 'Degree=Pos', 'Number=Plur', 'PunctType=Peri'],
|
||||
'cats': {'statement': 1.0}
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def _add_ner_label(ner, data):
|
||||
for _, annotations in data:
|
||||
for ent in annotations['entities']:
|
||||
ner.add_label(ent[2])
|
||||
|
||||
|
||||
def _add_tagger_label(tagger, data):
|
||||
for _, annotations in data:
|
||||
for tag in annotations['tags']:
|
||||
tagger.add_label(tag)
|
||||
|
||||
|
||||
def _add_parser_label(parser, data):
|
||||
for _, annotations in data:
|
||||
for dep in annotations['deps']:
|
||||
parser.add_label(dep)
|
||||
|
||||
|
||||
def _add_textcat_label(textcat, data):
|
||||
for _, annotations in data:
|
||||
for cat in annotations['cats']:
|
||||
textcat.add_label(cat)
|
||||
|
||||
|
||||
def _optimize(
|
||||
nlp,
|
||||
component: str,
|
||||
data: List,
|
||||
rehearse: bool
|
||||
):
|
||||
"""Run either train or rehearse."""
|
||||
pipe = nlp.get_pipe(component)
|
||||
if component == 'ner':
|
||||
_add_ner_label(pipe, data)
|
||||
elif component == 'tagger':
|
||||
_add_tagger_label(pipe, data)
|
||||
elif component == 'parser':
|
||||
_add_tagger_label(pipe, data)
|
||||
elif component == 'textcat_multilabel':
|
||||
_add_textcat_label(pipe, data)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if rehearse:
|
||||
optimizer = nlp.resume_training()
|
||||
else:
|
||||
optimizer = nlp.initialize()
|
||||
|
||||
for _ in range(5):
|
||||
for text, annotation in data:
|
||||
doc = nlp.make_doc(text)
|
||||
example = Example.from_dict(doc, annotation)
|
||||
if rehearse:
|
||||
nlp.rehearse([example], sgd=optimizer)
|
||||
else:
|
||||
nlp.update([example], sgd=optimizer)
|
||||
return nlp
|
||||
|
||||
|
||||
@pytest.mark.parametrize("component", ['ner', 'tagger', 'parser', 'textcat_multilabel'])
|
||||
def test_rehearse(component):
|
||||
nlp = spacy.blank("en")
|
||||
nlp.add_pipe(component)
|
||||
nlp = _optimize(nlp, component, TRAIN_DATA, False)
|
||||
_optimize(nlp, component, REHEARSE_DATA, True)
|
Loading…
Reference in New Issue
Block a user