import pytest import spacy from typing import List from spacy.training import Example TRAIN_DATA = [ ( "Who is Kofi Annan?", { "entities": [(7, 18, "PERSON")], "tags": ["PRON", "AUX", "PROPN", "PRON", "PUNCT"], "heads": [1, 1, 3, 1, 1], "deps": ["attr", "ROOT", "compound", "nsubj", "punct"], "morphs": [ "", "Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin", "Number=Sing", "Number=Sing", "PunctType=Peri", ], "cats": {"question": 1.0}, }, ), ( "Who is Steve Jobs?", { "entities": [(7, 17, "PERSON")], "tags": ["PRON", "AUX", "PROPN", "PRON", "PUNCT"], "heads": [1, 1, 3, 1, 1], "deps": ["attr", "ROOT", "compound", "nsubj", "punct"], "morphs": [ "", "Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin", "Number=Sing", "Number=Sing", "PunctType=Peri", ], "cats": {"question": 1.0}, }, ), ( "Bob is a nice person.", { "entities": [(0, 3, "PERSON")], "tags": ["PROPN", "AUX", "DET", "ADJ", "NOUN", "PUNCT"], "heads": [1, 1, 4, 4, 1, 1], "deps": ["nsubj", "ROOT", "det", "amod", "attr", "punct"], "morphs": [ "Number=Sing", "Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin", "Definite=Ind|PronType=Art", "Degree=Pos", "Number=Sing", "PunctType=Peri", ], "cats": {"statement": 1.0}, }, ), ( "Hi Anil, how are you?", { "entities": [(3, 7, "PERSON")], "tags": ["INTJ", "PROPN", "PUNCT", "ADV", "AUX", "PRON", "PUNCT"], "deps": ["intj", "npadvmod", "punct", "advmod", "ROOT", "nsubj", "punct"], "heads": [4, 0, 4, 4, 4, 4, 4], "morphs": [ "", "Number=Sing", "PunctType=Comm", "", "Mood=Ind|Tense=Pres|VerbForm=Fin", "Case=Nom|Person=2|PronType=Prs", "PunctType=Peri", ], "cats": {"greeting": 1.0, "question": 1.0}, }, ), ( "I like London and Berlin.", { "entities": [(7, 13, "LOC"), (18, 24, "LOC")], "tags": ["PROPN", "VERB", "PROPN", "CCONJ", "PROPN", "PUNCT"], "deps": ["nsubj", "ROOT", "dobj", "cc", "conj", "punct"], "heads": [1, 1, 1, 2, 2, 1], "morphs": [ "Case=Nom|Number=Sing|Person=1|PronType=Prs", "Tense=Pres|VerbForm=Fin", "Number=Sing", "ConjType=Cmp", "Number=Sing", "PunctType=Peri", ], "cats": {"statement": 1.0}, }, ), ] REHEARSE_DATA = [ ( "Hi Anil", { "entities": [(3, 7, "PERSON")], "tags": ["INTJ", "PROPN"], "deps": ["ROOT", "npadvmod"], "heads": [0, 0], "morphs": ["", "Number=Sing"], "cats": {"greeting": 1.0}, }, ), ( "Hi Ravish, how you doing?", { "entities": [(3, 9, "PERSON")], "tags": ["INTJ", "PROPN", "PUNCT", "ADV", "AUX", "PRON", "PUNCT"], "deps": ["intj", "ROOT", "punct", "advmod", "nsubj", "advcl", "punct"], "heads": [1, 1, 1, 5, 5, 1, 1], "morphs": [ "", "VerbForm=Inf", "PunctType=Comm", "", "Case=Nom|Person=2|PronType=Prs", "Aspect=Prog|Tense=Pres|VerbForm=Part", "PunctType=Peri", ], "cats": {"greeting": 1.0, "question": 1.0}, }, ), # UTENSIL new label ( "Natasha bought new forks.", { "entities": [(0, 7, "PERSON"), (19, 24, "UTENSIL")], "tags": ["PROPN", "VERB", "ADJ", "NOUN", "PUNCT"], "deps": ["nsubj", "ROOT", "amod", "dobj", "punct"], "heads": [1, 1, 3, 1, 1], "morphs": [ "Number=Sing", "Tense=Past|VerbForm=Fin", "Degree=Pos", "Number=Plur", "PunctType=Peri", ], "cats": {"statement": 1.0}, }, ), ] def _add_ner_label(ner, data): for _, annotations in data: for ent in annotations["entities"]: ner.add_label(ent[2]) def _add_tagger_label(tagger, data): for _, annotations in data: for tag in annotations["tags"]: tagger.add_label(tag) def _add_parser_label(parser, data): for _, annotations in data: for dep in annotations["deps"]: parser.add_label(dep) def _add_textcat_label(textcat, data): for _, annotations in data: for cat in annotations["cats"]: textcat.add_label(cat) def _optimize(nlp, component: str, data: List, rehearse: bool): """Run either train or rehearse.""" pipe = nlp.get_pipe(component) if component == "ner": _add_ner_label(pipe, data) elif component == "tagger": _add_tagger_label(pipe, data) elif component == "parser": _add_tagger_label(pipe, data) elif component == "textcat_multilabel": _add_textcat_label(pipe, data) else: raise NotImplementedError if rehearse: optimizer = nlp.resume_training() else: optimizer = nlp.initialize() for _ in range(5): for text, annotation in data: doc = nlp.make_doc(text) example = Example.from_dict(doc, annotation) if rehearse: nlp.rehearse([example], sgd=optimizer) else: nlp.update([example], sgd=optimizer) return nlp @pytest.mark.parametrize("component", ["ner", "tagger", "parser", "textcat_multilabel"]) def test_rehearse(component): nlp = spacy.blank("en") nlp.add_pipe(component) nlp = _optimize(nlp, component, TRAIN_DATA, False) _optimize(nlp, component, REHEARSE_DATA, True)