mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			212 lines
		
	
	
		
			6.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			212 lines
		
	
	
		
			6.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import pytest
 | 
						|
import spacy
 | 
						|
 | 
						|
from typing import List
 | 
						|
from spacy.training import Example
 | 
						|
 | 
						|
 | 
						|
TRAIN_DATA = [
 | 
						|
    (
 | 
						|
        "Who is Kofi Annan?",
 | 
						|
        {
 | 
						|
            "entities": [(7, 18, "PERSON")],
 | 
						|
            "tags": ["PRON", "AUX", "PROPN", "PRON", "PUNCT"],
 | 
						|
            "heads": [1, 1, 3, 1, 1],
 | 
						|
            "deps": ["attr", "ROOT", "compound", "nsubj", "punct"],
 | 
						|
            "morphs": [
 | 
						|
                "",
 | 
						|
                "Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin",
 | 
						|
                "Number=Sing",
 | 
						|
                "Number=Sing",
 | 
						|
                "PunctType=Peri",
 | 
						|
            ],
 | 
						|
            "cats": {"question": 1.0},
 | 
						|
        },
 | 
						|
    ),
 | 
						|
    (
 | 
						|
        "Who is Steve Jobs?",
 | 
						|
        {
 | 
						|
            "entities": [(7, 17, "PERSON")],
 | 
						|
            "tags": ["PRON", "AUX", "PROPN", "PRON", "PUNCT"],
 | 
						|
            "heads": [1, 1, 3, 1, 1],
 | 
						|
            "deps": ["attr", "ROOT", "compound", "nsubj", "punct"],
 | 
						|
            "morphs": [
 | 
						|
                "",
 | 
						|
                "Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin",
 | 
						|
                "Number=Sing",
 | 
						|
                "Number=Sing",
 | 
						|
                "PunctType=Peri",
 | 
						|
            ],
 | 
						|
            "cats": {"question": 1.0},
 | 
						|
        },
 | 
						|
    ),
 | 
						|
    (
 | 
						|
        "Bob is a nice person.",
 | 
						|
        {
 | 
						|
            "entities": [(0, 3, "PERSON")],
 | 
						|
            "tags": ["PROPN", "AUX", "DET", "ADJ", "NOUN", "PUNCT"],
 | 
						|
            "heads": [1, 1, 4, 4, 1, 1],
 | 
						|
            "deps": ["nsubj", "ROOT", "det", "amod", "attr", "punct"],
 | 
						|
            "morphs": [
 | 
						|
                "Number=Sing",
 | 
						|
                "Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin",
 | 
						|
                "Definite=Ind|PronType=Art",
 | 
						|
                "Degree=Pos",
 | 
						|
                "Number=Sing",
 | 
						|
                "PunctType=Peri",
 | 
						|
            ],
 | 
						|
            "cats": {"statement": 1.0},
 | 
						|
        },
 | 
						|
    ),
 | 
						|
    (
 | 
						|
        "Hi Anil, how are you?",
 | 
						|
        {
 | 
						|
            "entities": [(3, 7, "PERSON")],
 | 
						|
            "tags": ["INTJ", "PROPN", "PUNCT", "ADV", "AUX", "PRON", "PUNCT"],
 | 
						|
            "deps": ["intj", "npadvmod", "punct", "advmod", "ROOT", "nsubj", "punct"],
 | 
						|
            "heads": [4, 0, 4, 4, 4, 4, 4],
 | 
						|
            "morphs": [
 | 
						|
                "",
 | 
						|
                "Number=Sing",
 | 
						|
                "PunctType=Comm",
 | 
						|
                "",
 | 
						|
                "Mood=Ind|Tense=Pres|VerbForm=Fin",
 | 
						|
                "Case=Nom|Person=2|PronType=Prs",
 | 
						|
                "PunctType=Peri",
 | 
						|
            ],
 | 
						|
            "cats": {"greeting": 1.0, "question": 1.0},
 | 
						|
        },
 | 
						|
    ),
 | 
						|
    (
 | 
						|
        "I like London and Berlin.",
 | 
						|
        {
 | 
						|
            "entities": [(7, 13, "LOC"), (18, 24, "LOC")],
 | 
						|
            "tags": ["PROPN", "VERB", "PROPN", "CCONJ", "PROPN", "PUNCT"],
 | 
						|
            "deps": ["nsubj", "ROOT", "dobj", "cc", "conj", "punct"],
 | 
						|
            "heads": [1, 1, 1, 2, 2, 1],
 | 
						|
            "morphs": [
 | 
						|
                "Case=Nom|Number=Sing|Person=1|PronType=Prs",
 | 
						|
                "Tense=Pres|VerbForm=Fin",
 | 
						|
                "Number=Sing",
 | 
						|
                "ConjType=Cmp",
 | 
						|
                "Number=Sing",
 | 
						|
                "PunctType=Peri",
 | 
						|
            ],
 | 
						|
            "cats": {"statement": 1.0},
 | 
						|
        },
 | 
						|
    ),
 | 
						|
]
 | 
						|
 | 
						|
REHEARSE_DATA = [
 | 
						|
    (
 | 
						|
        "Hi Anil",
 | 
						|
        {
 | 
						|
            "entities": [(3, 7, "PERSON")],
 | 
						|
            "tags": ["INTJ", "PROPN"],
 | 
						|
            "deps": ["ROOT", "npadvmod"],
 | 
						|
            "heads": [0, 0],
 | 
						|
            "morphs": ["", "Number=Sing"],
 | 
						|
            "cats": {"greeting": 1.0},
 | 
						|
        },
 | 
						|
    ),
 | 
						|
    (
 | 
						|
        "Hi Ravish, how you doing?",
 | 
						|
        {
 | 
						|
            "entities": [(3, 9, "PERSON")],
 | 
						|
            "tags": ["INTJ", "PROPN", "PUNCT", "ADV", "AUX", "PRON", "PUNCT"],
 | 
						|
            "deps": ["intj", "ROOT", "punct", "advmod", "nsubj", "advcl", "punct"],
 | 
						|
            "heads": [1, 1, 1, 5, 5, 1, 1],
 | 
						|
            "morphs": [
 | 
						|
                "",
 | 
						|
                "VerbForm=Inf",
 | 
						|
                "PunctType=Comm",
 | 
						|
                "",
 | 
						|
                "Case=Nom|Person=2|PronType=Prs",
 | 
						|
                "Aspect=Prog|Tense=Pres|VerbForm=Part",
 | 
						|
                "PunctType=Peri",
 | 
						|
            ],
 | 
						|
            "cats": {"greeting": 1.0, "question": 1.0},
 | 
						|
        },
 | 
						|
    ),
 | 
						|
    # UTENSIL new label
 | 
						|
    (
 | 
						|
        "Natasha bought new forks.",
 | 
						|
        {
 | 
						|
            "entities": [(0, 7, "PERSON"), (19, 24, "UTENSIL")],
 | 
						|
            "tags": ["PROPN", "VERB", "ADJ", "NOUN", "PUNCT"],
 | 
						|
            "deps": ["nsubj", "ROOT", "amod", "dobj", "punct"],
 | 
						|
            "heads": [1, 1, 3, 1, 1],
 | 
						|
            "morphs": [
 | 
						|
                "Number=Sing",
 | 
						|
                "Tense=Past|VerbForm=Fin",
 | 
						|
                "Degree=Pos",
 | 
						|
                "Number=Plur",
 | 
						|
                "PunctType=Peri",
 | 
						|
            ],
 | 
						|
            "cats": {"statement": 1.0},
 | 
						|
        },
 | 
						|
    ),
 | 
						|
]
 | 
						|
 | 
						|
 | 
						|
def _add_ner_label(ner, data):
 | 
						|
    for _, annotations in data:
 | 
						|
        for ent in annotations["entities"]:
 | 
						|
            ner.add_label(ent[2])
 | 
						|
 | 
						|
 | 
						|
def _add_tagger_label(tagger, data):
 | 
						|
    for _, annotations in data:
 | 
						|
        for tag in annotations["tags"]:
 | 
						|
            tagger.add_label(tag)
 | 
						|
 | 
						|
 | 
						|
def _add_parser_label(parser, data):
 | 
						|
    for _, annotations in data:
 | 
						|
        for dep in annotations["deps"]:
 | 
						|
            parser.add_label(dep)
 | 
						|
 | 
						|
 | 
						|
def _add_textcat_label(textcat, data):
 | 
						|
    for _, annotations in data:
 | 
						|
        for cat in annotations["cats"]:
 | 
						|
            textcat.add_label(cat)
 | 
						|
 | 
						|
 | 
						|
def _optimize(nlp, component: str, data: List, rehearse: bool):
 | 
						|
    """Run either train or rehearse."""
 | 
						|
    pipe = nlp.get_pipe(component)
 | 
						|
    if component == "ner":
 | 
						|
        _add_ner_label(pipe, data)
 | 
						|
    elif component == "tagger":
 | 
						|
        _add_tagger_label(pipe, data)
 | 
						|
    elif component == "parser":
 | 
						|
        _add_parser_label(pipe, data)
 | 
						|
    elif component == "textcat_multilabel":
 | 
						|
        _add_textcat_label(pipe, data)
 | 
						|
    else:
 | 
						|
        raise NotImplementedError
 | 
						|
 | 
						|
    if rehearse:
 | 
						|
        optimizer = nlp.resume_training()
 | 
						|
    else:
 | 
						|
        optimizer = nlp.initialize()
 | 
						|
 | 
						|
    for _ in range(5):
 | 
						|
        for text, annotation in data:
 | 
						|
            doc = nlp.make_doc(text)
 | 
						|
            example = Example.from_dict(doc, annotation)
 | 
						|
            if rehearse:
 | 
						|
                nlp.rehearse([example], sgd=optimizer)
 | 
						|
            else:
 | 
						|
                nlp.update([example], sgd=optimizer)
 | 
						|
    return nlp
 | 
						|
 | 
						|
 | 
						|
@pytest.mark.parametrize("component", ["ner", "tagger", "parser", "textcat_multilabel"])
 | 
						|
def test_rehearse(component):
 | 
						|
    nlp = spacy.blank("en")
 | 
						|
    nlp.add_pipe(component)
 | 
						|
    nlp = _optimize(nlp, component, TRAIN_DATA, False)
 | 
						|
    _optimize(nlp, component, REHEARSE_DATA, True)
 |