spaCy/spacy/tests/training/test_rehearse.py
Daniël de Kok e2b70df012
Configure isort to use the Black profile, recursively isort the spacy module (#12721)
* Use isort with Black profile

* isort all the things

* Fix import cycles as a result of import sorting

* Add DOCBIN_ALL_ATTRS type definition

* Add isort to requirements

* Remove isort from build dependencies check

* Typo
2023-06-14 17:48:41 +02:00

212 lines
6.3 KiB
Python

from typing import List
import pytest
import spacy
from spacy.training import Example
TRAIN_DATA = [
(
"Who is Kofi Annan?",
{
"entities": [(7, 18, "PERSON")],
"tags": ["PRON", "AUX", "PROPN", "PRON", "PUNCT"],
"heads": [1, 1, 3, 1, 1],
"deps": ["attr", "ROOT", "compound", "nsubj", "punct"],
"morphs": [
"",
"Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin",
"Number=Sing",
"Number=Sing",
"PunctType=Peri",
],
"cats": {"question": 1.0},
},
),
(
"Who is Steve Jobs?",
{
"entities": [(7, 17, "PERSON")],
"tags": ["PRON", "AUX", "PROPN", "PRON", "PUNCT"],
"heads": [1, 1, 3, 1, 1],
"deps": ["attr", "ROOT", "compound", "nsubj", "punct"],
"morphs": [
"",
"Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin",
"Number=Sing",
"Number=Sing",
"PunctType=Peri",
],
"cats": {"question": 1.0},
},
),
(
"Bob is a nice person.",
{
"entities": [(0, 3, "PERSON")],
"tags": ["PROPN", "AUX", "DET", "ADJ", "NOUN", "PUNCT"],
"heads": [1, 1, 4, 4, 1, 1],
"deps": ["nsubj", "ROOT", "det", "amod", "attr", "punct"],
"morphs": [
"Number=Sing",
"Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin",
"Definite=Ind|PronType=Art",
"Degree=Pos",
"Number=Sing",
"PunctType=Peri",
],
"cats": {"statement": 1.0},
},
),
(
"Hi Anil, how are you?",
{
"entities": [(3, 7, "PERSON")],
"tags": ["INTJ", "PROPN", "PUNCT", "ADV", "AUX", "PRON", "PUNCT"],
"deps": ["intj", "npadvmod", "punct", "advmod", "ROOT", "nsubj", "punct"],
"heads": [4, 0, 4, 4, 4, 4, 4],
"morphs": [
"",
"Number=Sing",
"PunctType=Comm",
"",
"Mood=Ind|Tense=Pres|VerbForm=Fin",
"Case=Nom|Person=2|PronType=Prs",
"PunctType=Peri",
],
"cats": {"greeting": 1.0, "question": 1.0},
},
),
(
"I like London and Berlin.",
{
"entities": [(7, 13, "LOC"), (18, 24, "LOC")],
"tags": ["PROPN", "VERB", "PROPN", "CCONJ", "PROPN", "PUNCT"],
"deps": ["nsubj", "ROOT", "dobj", "cc", "conj", "punct"],
"heads": [1, 1, 1, 2, 2, 1],
"morphs": [
"Case=Nom|Number=Sing|Person=1|PronType=Prs",
"Tense=Pres|VerbForm=Fin",
"Number=Sing",
"ConjType=Cmp",
"Number=Sing",
"PunctType=Peri",
],
"cats": {"statement": 1.0},
},
),
]
REHEARSE_DATA = [
(
"Hi Anil",
{
"entities": [(3, 7, "PERSON")],
"tags": ["INTJ", "PROPN"],
"deps": ["ROOT", "npadvmod"],
"heads": [0, 0],
"morphs": ["", "Number=Sing"],
"cats": {"greeting": 1.0},
},
),
(
"Hi Ravish, how you doing?",
{
"entities": [(3, 9, "PERSON")],
"tags": ["INTJ", "PROPN", "PUNCT", "ADV", "AUX", "PRON", "PUNCT"],
"deps": ["intj", "ROOT", "punct", "advmod", "nsubj", "advcl", "punct"],
"heads": [1, 1, 1, 5, 5, 1, 1],
"morphs": [
"",
"VerbForm=Inf",
"PunctType=Comm",
"",
"Case=Nom|Person=2|PronType=Prs",
"Aspect=Prog|Tense=Pres|VerbForm=Part",
"PunctType=Peri",
],
"cats": {"greeting": 1.0, "question": 1.0},
},
),
# UTENSIL new label
(
"Natasha bought new forks.",
{
"entities": [(0, 7, "PERSON"), (19, 24, "UTENSIL")],
"tags": ["PROPN", "VERB", "ADJ", "NOUN", "PUNCT"],
"deps": ["nsubj", "ROOT", "amod", "dobj", "punct"],
"heads": [1, 1, 3, 1, 1],
"morphs": [
"Number=Sing",
"Tense=Past|VerbForm=Fin",
"Degree=Pos",
"Number=Plur",
"PunctType=Peri",
],
"cats": {"statement": 1.0},
},
),
]
def _add_ner_label(ner, data):
for _, annotations in data:
for ent in annotations["entities"]:
ner.add_label(ent[2])
def _add_tagger_label(tagger, data):
for _, annotations in data:
for tag in annotations["tags"]:
tagger.add_label(tag)
def _add_parser_label(parser, data):
for _, annotations in data:
for dep in annotations["deps"]:
parser.add_label(dep)
def _add_textcat_label(textcat, data):
for _, annotations in data:
for cat in annotations["cats"]:
textcat.add_label(cat)
def _optimize(nlp, component: str, data: List, rehearse: bool):
"""Run either train or rehearse."""
pipe = nlp.get_pipe(component)
if component == "ner":
_add_ner_label(pipe, data)
elif component == "tagger":
_add_tagger_label(pipe, data)
elif component == "parser":
_add_parser_label(pipe, data)
elif component == "textcat_multilabel":
_add_textcat_label(pipe, data)
else:
raise NotImplementedError
if rehearse:
optimizer = nlp.resume_training()
else:
optimizer = nlp.initialize()
for _ in range(5):
for text, annotation in data:
doc = nlp.make_doc(text)
example = Example.from_dict(doc, annotation)
if rehearse:
nlp.rehearse([example], sgd=optimizer)
else:
nlp.update([example], sgd=optimizer)
return nlp
@pytest.mark.parametrize("component", ["ner", "tagger", "parser", "textcat_multilabel"])
def test_rehearse(component):
nlp = spacy.blank("en")
nlp.add_pipe(component)
nlp = _optimize(nlp, component, TRAIN_DATA, False)
_optimize(nlp, component, REHEARSE_DATA, True)