mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Tests for CLI app - init config generates train-able config (#12173)
				
					
				
			* remove migration support form * initial test commit * add fixture * add combo test * pull out parameter example data * fix formatting on examples * remove unused import * remove unncessary fmt:off instructions * only set logger level if verbose flag is explicitly set --------- Co-authored-by: svlandeg <svlandeg@github.com>
This commit is contained in:
		
							parent
							
								
									186889ec9c
								
							
						
					
					
						commit
						a0a195688f
					
				| 
						 | 
					@ -40,7 +40,8 @@ def assemble_cli(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    DOCS: https://spacy.io/api/cli#assemble
 | 
					    DOCS: https://spacy.io/api/cli#assemble
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
 | 
					    if verbose:
 | 
				
			||||||
 | 
					        util.logger.setLevel(logging.DEBUG)
 | 
				
			||||||
    # Make sure all files and paths exists if they are needed
 | 
					    # Make sure all files and paths exists if they are needed
 | 
				
			||||||
    if not config_path or (str(config_path) != "-" and not config_path.exists()):
 | 
					    if not config_path or (str(config_path) != "-" and not config_path.exists()):
 | 
				
			||||||
        msg.fail("Config file not found", config_path, exits=1)
 | 
					        msg.fail("Config file not found", config_path, exits=1)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -52,8 +52,8 @@ def find_threshold_cli(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    DOCS: https://spacy.io/api/cli#find-threshold
 | 
					    DOCS: https://spacy.io/api/cli#find-threshold
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					    if verbose:
 | 
				
			||||||
    util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
 | 
					        util.logger.setLevel(logging.DEBUG)
 | 
				
			||||||
    import_code(code_path)
 | 
					    import_code(code_path)
 | 
				
			||||||
    find_threshold(
 | 
					    find_threshold(
 | 
				
			||||||
        model=model,
 | 
					        model=model,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -39,7 +39,8 @@ def init_vectors_cli(
 | 
				
			||||||
    you can use in the [initialize] block of your config to initialize
 | 
					    you can use in the [initialize] block of your config to initialize
 | 
				
			||||||
    a model with vectors.
 | 
					    a model with vectors.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
 | 
					    if verbose:
 | 
				
			||||||
 | 
					        util.logger.setLevel(logging.DEBUG)
 | 
				
			||||||
    msg.info(f"Creating blank nlp object for language '{lang}'")
 | 
					    msg.info(f"Creating blank nlp object for language '{lang}'")
 | 
				
			||||||
    nlp = util.get_lang_class(lang)()
 | 
					    nlp = util.get_lang_class(lang)()
 | 
				
			||||||
    if jsonl_loc is not None:
 | 
					    if jsonl_loc is not None:
 | 
				
			||||||
| 
						 | 
					@ -87,7 +88,8 @@ def init_pipeline_cli(
 | 
				
			||||||
    use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU")
 | 
					    use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU")
 | 
				
			||||||
    # fmt: on
 | 
					    # fmt: on
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
 | 
					    if verbose:
 | 
				
			||||||
 | 
					        util.logger.setLevel(logging.DEBUG)
 | 
				
			||||||
    overrides = parse_config_overrides(ctx.args)
 | 
					    overrides = parse_config_overrides(ctx.args)
 | 
				
			||||||
    import_code(code_path)
 | 
					    import_code(code_path)
 | 
				
			||||||
    setup_gpu(use_gpu)
 | 
					    setup_gpu(use_gpu)
 | 
				
			||||||
| 
						 | 
					@ -116,7 +118,8 @@ def init_labels_cli(
 | 
				
			||||||
    """Generate JSON files for the labels in the data. This helps speed up the
 | 
					    """Generate JSON files for the labels in the data. This helps speed up the
 | 
				
			||||||
    training process, since spaCy won't have to preprocess the data to
 | 
					    training process, since spaCy won't have to preprocess the data to
 | 
				
			||||||
    extract the labels."""
 | 
					    extract the labels."""
 | 
				
			||||||
    util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
 | 
					    if verbose:
 | 
				
			||||||
 | 
					        util.logger.setLevel(logging.DEBUG)
 | 
				
			||||||
    if not output_path.exists():
 | 
					    if not output_path.exists():
 | 
				
			||||||
        output_path.mkdir(parents=True)
 | 
					        output_path.mkdir(parents=True)
 | 
				
			||||||
    overrides = parse_config_overrides(ctx.args)
 | 
					    overrides = parse_config_overrides(ctx.args)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -47,7 +47,8 @@ def train_cli(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    DOCS: https://spacy.io/api/cli#train
 | 
					    DOCS: https://spacy.io/api/cli#train
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
 | 
					    if verbose:
 | 
				
			||||||
 | 
					        util.logger.setLevel(logging.DEBUG)
 | 
				
			||||||
    overrides = parse_config_overrides(ctx.args)
 | 
					    overrides = parse_config_overrides(ctx.args)
 | 
				
			||||||
    import_code(code_path)
 | 
					    import_code(code_path)
 | 
				
			||||||
    train(config_path, output_path, use_gpu=use_gpu, overrides=overrides)
 | 
					    train(config_path, output_path, use_gpu=use_gpu, overrides=overrides)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -6,7 +6,7 @@ import srsly
 | 
				
			||||||
from typer.testing import CliRunner
 | 
					from typer.testing import CliRunner
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from spacy.cli._util import app, get_git_version
 | 
					from spacy.cli._util import app, get_git_version
 | 
				
			||||||
from spacy.tokens import Doc, DocBin
 | 
					from spacy.tokens import Doc, DocBin, Span
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .util import make_tempdir, normalize_whitespace
 | 
					from .util import make_tempdir, normalize_whitespace
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -267,3 +267,162 @@ def test_find_function_invalid():
 | 
				
			||||||
    function = "spacy.TextCatBOW.v666"
 | 
					    function = "spacy.TextCatBOW.v666"
 | 
				
			||||||
    result = CliRunner().invoke(app, ["find-function", function])
 | 
					    result = CliRunner().invoke(app, ["find-function", function])
 | 
				
			||||||
    assert f"Couldn't find registered function: '{function}'" in result.stdout
 | 
					    assert f"Couldn't find registered function: '{function}'" in result.stdout
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					example_words_1 = ["I", "like", "cats"]
 | 
				
			||||||
 | 
					example_words_2 = ["I", "like", "dogs"]
 | 
				
			||||||
 | 
					example_lemmas_1 = ["I", "like", "cat"]
 | 
				
			||||||
 | 
					example_lemmas_2 = ["I", "like", "dog"]
 | 
				
			||||||
 | 
					example_tags = ["PRP", "VBP", "NNS"]
 | 
				
			||||||
 | 
					example_morphs = [
 | 
				
			||||||
 | 
					    "Case=Nom|Number=Sing|Person=1|PronType=Prs",
 | 
				
			||||||
 | 
					    "Tense=Pres|VerbForm=Fin",
 | 
				
			||||||
 | 
					    "Number=Plur",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					example_deps = ["nsubj", "ROOT", "dobj"]
 | 
				
			||||||
 | 
					example_pos = ["PRON", "VERB", "NOUN"]
 | 
				
			||||||
 | 
					example_ents = ["O", "O", "I-ANIMAL"]
 | 
				
			||||||
 | 
					example_spans = [(2, 3, "ANIMAL")]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TRAIN_EXAMPLE_1 = dict(
 | 
				
			||||||
 | 
					    words=example_words_1,
 | 
				
			||||||
 | 
					    lemmas=example_lemmas_1,
 | 
				
			||||||
 | 
					    tags=example_tags,
 | 
				
			||||||
 | 
					    morphs=example_morphs,
 | 
				
			||||||
 | 
					    deps=example_deps,
 | 
				
			||||||
 | 
					    heads=[1, 1, 1],
 | 
				
			||||||
 | 
					    pos=example_pos,
 | 
				
			||||||
 | 
					    ents=example_ents,
 | 
				
			||||||
 | 
					    spans=example_spans,
 | 
				
			||||||
 | 
					    cats={"CAT": 1.0, "DOG": 0.0},
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					TRAIN_EXAMPLE_2 = dict(
 | 
				
			||||||
 | 
					    words=example_words_2,
 | 
				
			||||||
 | 
					    lemmas=example_lemmas_2,
 | 
				
			||||||
 | 
					    tags=example_tags,
 | 
				
			||||||
 | 
					    morphs=example_morphs,
 | 
				
			||||||
 | 
					    deps=example_deps,
 | 
				
			||||||
 | 
					    heads=[1, 1, 1],
 | 
				
			||||||
 | 
					    pos=example_pos,
 | 
				
			||||||
 | 
					    ents=example_ents,
 | 
				
			||||||
 | 
					    spans=example_spans,
 | 
				
			||||||
 | 
					    cats={"CAT": 0.0, "DOG": 1.0},
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.slow
 | 
				
			||||||
 | 
					@pytest.mark.parametrize(
 | 
				
			||||||
 | 
					    "component,examples",
 | 
				
			||||||
 | 
					    [
 | 
				
			||||||
 | 
					        ("tagger", [TRAIN_EXAMPLE_1, TRAIN_EXAMPLE_2]),
 | 
				
			||||||
 | 
					        ("morphologizer", [TRAIN_EXAMPLE_1, TRAIN_EXAMPLE_2]),
 | 
				
			||||||
 | 
					        ("trainable_lemmatizer", [TRAIN_EXAMPLE_1, TRAIN_EXAMPLE_2]),
 | 
				
			||||||
 | 
					        ("parser", [TRAIN_EXAMPLE_1] * 30),
 | 
				
			||||||
 | 
					        ("ner", [TRAIN_EXAMPLE_1, TRAIN_EXAMPLE_2]),
 | 
				
			||||||
 | 
					        ("spancat", [TRAIN_EXAMPLE_1, TRAIN_EXAMPLE_2]),
 | 
				
			||||||
 | 
					        ("textcat", [TRAIN_EXAMPLE_1, TRAIN_EXAMPLE_2]),
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					def test_init_config_trainable(component, examples, en_vocab):
 | 
				
			||||||
 | 
					    if component == "textcat":
 | 
				
			||||||
 | 
					        train_docs = []
 | 
				
			||||||
 | 
					        for example in examples:
 | 
				
			||||||
 | 
					            doc = Doc(en_vocab, words=example["words"])
 | 
				
			||||||
 | 
					            doc.cats = example["cats"]
 | 
				
			||||||
 | 
					            train_docs.append(doc)
 | 
				
			||||||
 | 
					    elif component == "spancat":
 | 
				
			||||||
 | 
					        train_docs = []
 | 
				
			||||||
 | 
					        for example in examples:
 | 
				
			||||||
 | 
					            doc = Doc(en_vocab, words=example["words"])
 | 
				
			||||||
 | 
					            doc.spans["sc"] = [
 | 
				
			||||||
 | 
					                Span(doc, start, end, label) for start, end, label in example["spans"]
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					            train_docs.append(doc)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        train_docs = []
 | 
				
			||||||
 | 
					        for example in examples:
 | 
				
			||||||
 | 
					            # cats, spans are not valid kwargs for instantiating a Doc
 | 
				
			||||||
 | 
					            example = {k: v for k, v in example.items() if k not in ("cats", "spans")}
 | 
				
			||||||
 | 
					            doc = Doc(en_vocab, **example)
 | 
				
			||||||
 | 
					            train_docs.append(doc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with make_tempdir() as d_in:
 | 
				
			||||||
 | 
					        train_bin = DocBin(docs=train_docs)
 | 
				
			||||||
 | 
					        train_bin.to_disk(d_in / "train.spacy")
 | 
				
			||||||
 | 
					        dev_bin = DocBin(docs=train_docs)
 | 
				
			||||||
 | 
					        dev_bin.to_disk(d_in / "dev.spacy")
 | 
				
			||||||
 | 
					        init_config_result = CliRunner().invoke(
 | 
				
			||||||
 | 
					            app,
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                "init",
 | 
				
			||||||
 | 
					                "config",
 | 
				
			||||||
 | 
					                f"{d_in}/config.cfg",
 | 
				
			||||||
 | 
					                "--lang",
 | 
				
			||||||
 | 
					                "en",
 | 
				
			||||||
 | 
					                "--pipeline",
 | 
				
			||||||
 | 
					                component,
 | 
				
			||||||
 | 
					            ],
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        assert init_config_result.exit_code == 0
 | 
				
			||||||
 | 
					        train_result = CliRunner().invoke(
 | 
				
			||||||
 | 
					            app,
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                "train",
 | 
				
			||||||
 | 
					                f"{d_in}/config.cfg",
 | 
				
			||||||
 | 
					                "--paths.train",
 | 
				
			||||||
 | 
					                f"{d_in}/train.spacy",
 | 
				
			||||||
 | 
					                "--paths.dev",
 | 
				
			||||||
 | 
					                f"{d_in}/dev.spacy",
 | 
				
			||||||
 | 
					                "--output",
 | 
				
			||||||
 | 
					                f"{d_in}/model",
 | 
				
			||||||
 | 
					            ],
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        assert train_result.exit_code == 0
 | 
				
			||||||
 | 
					        assert Path(d_in / "model" / "model-last").exists()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.slow
 | 
				
			||||||
 | 
					@pytest.mark.parametrize(
 | 
				
			||||||
 | 
					    "component,examples",
 | 
				
			||||||
 | 
					    [("tagger,parser,morphologizer", [TRAIN_EXAMPLE_1, TRAIN_EXAMPLE_2] * 15)],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					def test_init_config_trainable_multiple(component, examples, en_vocab):
 | 
				
			||||||
 | 
					    train_docs = []
 | 
				
			||||||
 | 
					    for example in examples:
 | 
				
			||||||
 | 
					        example = {k: v for k, v in example.items() if k not in ("cats", "spans")}
 | 
				
			||||||
 | 
					        doc = Doc(en_vocab, **example)
 | 
				
			||||||
 | 
					        train_docs.append(doc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with make_tempdir() as d_in:
 | 
				
			||||||
 | 
					        train_bin = DocBin(docs=train_docs)
 | 
				
			||||||
 | 
					        train_bin.to_disk(d_in / "train.spacy")
 | 
				
			||||||
 | 
					        dev_bin = DocBin(docs=train_docs)
 | 
				
			||||||
 | 
					        dev_bin.to_disk(d_in / "dev.spacy")
 | 
				
			||||||
 | 
					        init_config_result = CliRunner().invoke(
 | 
				
			||||||
 | 
					            app,
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                "init",
 | 
				
			||||||
 | 
					                "config",
 | 
				
			||||||
 | 
					                f"{d_in}/config.cfg",
 | 
				
			||||||
 | 
					                "--lang",
 | 
				
			||||||
 | 
					                "en",
 | 
				
			||||||
 | 
					                "--pipeline",
 | 
				
			||||||
 | 
					                component,
 | 
				
			||||||
 | 
					            ],
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        assert init_config_result.exit_code == 0
 | 
				
			||||||
 | 
					        train_result = CliRunner().invoke(
 | 
				
			||||||
 | 
					            app,
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                "train",
 | 
				
			||||||
 | 
					                f"{d_in}/config.cfg",
 | 
				
			||||||
 | 
					                "--paths.train",
 | 
				
			||||||
 | 
					                f"{d_in}/train.spacy",
 | 
				
			||||||
 | 
					                "--paths.dev",
 | 
				
			||||||
 | 
					                f"{d_in}/dev.spacy",
 | 
				
			||||||
 | 
					                "--output",
 | 
				
			||||||
 | 
					                f"{d_in}/model",
 | 
				
			||||||
 | 
					            ],
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        assert train_result.exit_code == 0
 | 
				
			||||||
 | 
					        assert Path(d_in / "model" / "model-last").exists()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user