mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	Tests for CLI app - init config generates train-able config (#12173)
				
					
				
			* remove migration support form * initial test commit * add fixture * add combo test * pull out parameter example data * fix formatting on examples * remove unused import * remove unncessary fmt:off instructions * only set logger level if verbose flag is explicitly set --------- Co-authored-by: svlandeg <svlandeg@github.com>
This commit is contained in:
		
							parent
							
								
									186889ec9c
								
							
						
					
					
						commit
						a0a195688f
					
				|  | @ -40,7 +40,8 @@ def assemble_cli( | ||||||
| 
 | 
 | ||||||
|     DOCS: https://spacy.io/api/cli#assemble |     DOCS: https://spacy.io/api/cli#assemble | ||||||
|     """ |     """ | ||||||
|     util.logger.setLevel(logging.DEBUG if verbose else logging.INFO) |     if verbose: | ||||||
|  |         util.logger.setLevel(logging.DEBUG) | ||||||
|     # Make sure all files and paths exists if they are needed |     # Make sure all files and paths exists if they are needed | ||||||
|     if not config_path or (str(config_path) != "-" and not config_path.exists()): |     if not config_path or (str(config_path) != "-" and not config_path.exists()): | ||||||
|         msg.fail("Config file not found", config_path, exits=1) |         msg.fail("Config file not found", config_path, exits=1) | ||||||
|  |  | ||||||
|  | @ -52,8 +52,8 @@ def find_threshold_cli( | ||||||
| 
 | 
 | ||||||
|     DOCS: https://spacy.io/api/cli#find-threshold |     DOCS: https://spacy.io/api/cli#find-threshold | ||||||
|     """ |     """ | ||||||
| 
 |     if verbose: | ||||||
|     util.logger.setLevel(logging.DEBUG if verbose else logging.INFO) |         util.logger.setLevel(logging.DEBUG) | ||||||
|     import_code(code_path) |     import_code(code_path) | ||||||
|     find_threshold( |     find_threshold( | ||||||
|         model=model, |         model=model, | ||||||
|  |  | ||||||
|  | @ -39,7 +39,8 @@ def init_vectors_cli( | ||||||
|     you can use in the [initialize] block of your config to initialize |     you can use in the [initialize] block of your config to initialize | ||||||
|     a model with vectors. |     a model with vectors. | ||||||
|     """ |     """ | ||||||
|     util.logger.setLevel(logging.DEBUG if verbose else logging.INFO) |     if verbose: | ||||||
|  |         util.logger.setLevel(logging.DEBUG) | ||||||
|     msg.info(f"Creating blank nlp object for language '{lang}'") |     msg.info(f"Creating blank nlp object for language '{lang}'") | ||||||
|     nlp = util.get_lang_class(lang)() |     nlp = util.get_lang_class(lang)() | ||||||
|     if jsonl_loc is not None: |     if jsonl_loc is not None: | ||||||
|  | @ -87,7 +88,8 @@ def init_pipeline_cli( | ||||||
|     use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU") |     use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU") | ||||||
|     # fmt: on |     # fmt: on | ||||||
| ): | ): | ||||||
|     util.logger.setLevel(logging.DEBUG if verbose else logging.INFO) |     if verbose: | ||||||
|  |         util.logger.setLevel(logging.DEBUG) | ||||||
|     overrides = parse_config_overrides(ctx.args) |     overrides = parse_config_overrides(ctx.args) | ||||||
|     import_code(code_path) |     import_code(code_path) | ||||||
|     setup_gpu(use_gpu) |     setup_gpu(use_gpu) | ||||||
|  | @ -116,7 +118,8 @@ def init_labels_cli( | ||||||
|     """Generate JSON files for the labels in the data. This helps speed up the |     """Generate JSON files for the labels in the data. This helps speed up the | ||||||
|     training process, since spaCy won't have to preprocess the data to |     training process, since spaCy won't have to preprocess the data to | ||||||
|     extract the labels.""" |     extract the labels.""" | ||||||
|     util.logger.setLevel(logging.DEBUG if verbose else logging.INFO) |     if verbose: | ||||||
|  |         util.logger.setLevel(logging.DEBUG) | ||||||
|     if not output_path.exists(): |     if not output_path.exists(): | ||||||
|         output_path.mkdir(parents=True) |         output_path.mkdir(parents=True) | ||||||
|     overrides = parse_config_overrides(ctx.args) |     overrides = parse_config_overrides(ctx.args) | ||||||
|  |  | ||||||
|  | @ -47,7 +47,8 @@ def train_cli( | ||||||
| 
 | 
 | ||||||
|     DOCS: https://spacy.io/api/cli#train |     DOCS: https://spacy.io/api/cli#train | ||||||
|     """ |     """ | ||||||
|     util.logger.setLevel(logging.DEBUG if verbose else logging.INFO) |     if verbose: | ||||||
|  |         util.logger.setLevel(logging.DEBUG) | ||||||
|     overrides = parse_config_overrides(ctx.args) |     overrides = parse_config_overrides(ctx.args) | ||||||
|     import_code(code_path) |     import_code(code_path) | ||||||
|     train(config_path, output_path, use_gpu=use_gpu, overrides=overrides) |     train(config_path, output_path, use_gpu=use_gpu, overrides=overrides) | ||||||
|  |  | ||||||
|  | @ -6,7 +6,7 @@ import srsly | ||||||
| from typer.testing import CliRunner | from typer.testing import CliRunner | ||||||
| 
 | 
 | ||||||
| from spacy.cli._util import app, get_git_version | from spacy.cli._util import app, get_git_version | ||||||
| from spacy.tokens import Doc, DocBin | from spacy.tokens import Doc, DocBin, Span | ||||||
| 
 | 
 | ||||||
| from .util import make_tempdir, normalize_whitespace | from .util import make_tempdir, normalize_whitespace | ||||||
| 
 | 
 | ||||||
|  | @ -267,3 +267,162 @@ def test_find_function_invalid(): | ||||||
|     function = "spacy.TextCatBOW.v666" |     function = "spacy.TextCatBOW.v666" | ||||||
|     result = CliRunner().invoke(app, ["find-function", function]) |     result = CliRunner().invoke(app, ["find-function", function]) | ||||||
|     assert f"Couldn't find registered function: '{function}'" in result.stdout |     assert f"Couldn't find registered function: '{function}'" in result.stdout | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | example_words_1 = ["I", "like", "cats"] | ||||||
|  | example_words_2 = ["I", "like", "dogs"] | ||||||
|  | example_lemmas_1 = ["I", "like", "cat"] | ||||||
|  | example_lemmas_2 = ["I", "like", "dog"] | ||||||
|  | example_tags = ["PRP", "VBP", "NNS"] | ||||||
|  | example_morphs = [ | ||||||
|  |     "Case=Nom|Number=Sing|Person=1|PronType=Prs", | ||||||
|  |     "Tense=Pres|VerbForm=Fin", | ||||||
|  |     "Number=Plur", | ||||||
|  | ] | ||||||
|  | example_deps = ["nsubj", "ROOT", "dobj"] | ||||||
|  | example_pos = ["PRON", "VERB", "NOUN"] | ||||||
|  | example_ents = ["O", "O", "I-ANIMAL"] | ||||||
|  | example_spans = [(2, 3, "ANIMAL")] | ||||||
|  | 
 | ||||||
|  | TRAIN_EXAMPLE_1 = dict( | ||||||
|  |     words=example_words_1, | ||||||
|  |     lemmas=example_lemmas_1, | ||||||
|  |     tags=example_tags, | ||||||
|  |     morphs=example_morphs, | ||||||
|  |     deps=example_deps, | ||||||
|  |     heads=[1, 1, 1], | ||||||
|  |     pos=example_pos, | ||||||
|  |     ents=example_ents, | ||||||
|  |     spans=example_spans, | ||||||
|  |     cats={"CAT": 1.0, "DOG": 0.0}, | ||||||
|  | ) | ||||||
|  | TRAIN_EXAMPLE_2 = dict( | ||||||
|  |     words=example_words_2, | ||||||
|  |     lemmas=example_lemmas_2, | ||||||
|  |     tags=example_tags, | ||||||
|  |     morphs=example_morphs, | ||||||
|  |     deps=example_deps, | ||||||
|  |     heads=[1, 1, 1], | ||||||
|  |     pos=example_pos, | ||||||
|  |     ents=example_ents, | ||||||
|  |     spans=example_spans, | ||||||
|  |     cats={"CAT": 0.0, "DOG": 1.0}, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.mark.slow | ||||||
|  | @pytest.mark.parametrize( | ||||||
|  |     "component,examples", | ||||||
|  |     [ | ||||||
|  |         ("tagger", [TRAIN_EXAMPLE_1, TRAIN_EXAMPLE_2]), | ||||||
|  |         ("morphologizer", [TRAIN_EXAMPLE_1, TRAIN_EXAMPLE_2]), | ||||||
|  |         ("trainable_lemmatizer", [TRAIN_EXAMPLE_1, TRAIN_EXAMPLE_2]), | ||||||
|  |         ("parser", [TRAIN_EXAMPLE_1] * 30), | ||||||
|  |         ("ner", [TRAIN_EXAMPLE_1, TRAIN_EXAMPLE_2]), | ||||||
|  |         ("spancat", [TRAIN_EXAMPLE_1, TRAIN_EXAMPLE_2]), | ||||||
|  |         ("textcat", [TRAIN_EXAMPLE_1, TRAIN_EXAMPLE_2]), | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  | def test_init_config_trainable(component, examples, en_vocab): | ||||||
|  |     if component == "textcat": | ||||||
|  |         train_docs = [] | ||||||
|  |         for example in examples: | ||||||
|  |             doc = Doc(en_vocab, words=example["words"]) | ||||||
|  |             doc.cats = example["cats"] | ||||||
|  |             train_docs.append(doc) | ||||||
|  |     elif component == "spancat": | ||||||
|  |         train_docs = [] | ||||||
|  |         for example in examples: | ||||||
|  |             doc = Doc(en_vocab, words=example["words"]) | ||||||
|  |             doc.spans["sc"] = [ | ||||||
|  |                 Span(doc, start, end, label) for start, end, label in example["spans"] | ||||||
|  |             ] | ||||||
|  |             train_docs.append(doc) | ||||||
|  |     else: | ||||||
|  |         train_docs = [] | ||||||
|  |         for example in examples: | ||||||
|  |             # cats, spans are not valid kwargs for instantiating a Doc | ||||||
|  |             example = {k: v for k, v in example.items() if k not in ("cats", "spans")} | ||||||
|  |             doc = Doc(en_vocab, **example) | ||||||
|  |             train_docs.append(doc) | ||||||
|  | 
 | ||||||
|  |     with make_tempdir() as d_in: | ||||||
|  |         train_bin = DocBin(docs=train_docs) | ||||||
|  |         train_bin.to_disk(d_in / "train.spacy") | ||||||
|  |         dev_bin = DocBin(docs=train_docs) | ||||||
|  |         dev_bin.to_disk(d_in / "dev.spacy") | ||||||
|  |         init_config_result = CliRunner().invoke( | ||||||
|  |             app, | ||||||
|  |             [ | ||||||
|  |                 "init", | ||||||
|  |                 "config", | ||||||
|  |                 f"{d_in}/config.cfg", | ||||||
|  |                 "--lang", | ||||||
|  |                 "en", | ||||||
|  |                 "--pipeline", | ||||||
|  |                 component, | ||||||
|  |             ], | ||||||
|  |         ) | ||||||
|  |         assert init_config_result.exit_code == 0 | ||||||
|  |         train_result = CliRunner().invoke( | ||||||
|  |             app, | ||||||
|  |             [ | ||||||
|  |                 "train", | ||||||
|  |                 f"{d_in}/config.cfg", | ||||||
|  |                 "--paths.train", | ||||||
|  |                 f"{d_in}/train.spacy", | ||||||
|  |                 "--paths.dev", | ||||||
|  |                 f"{d_in}/dev.spacy", | ||||||
|  |                 "--output", | ||||||
|  |                 f"{d_in}/model", | ||||||
|  |             ], | ||||||
|  |         ) | ||||||
|  |         assert train_result.exit_code == 0 | ||||||
|  |         assert Path(d_in / "model" / "model-last").exists() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.mark.slow | ||||||
|  | @pytest.mark.parametrize( | ||||||
|  |     "component,examples", | ||||||
|  |     [("tagger,parser,morphologizer", [TRAIN_EXAMPLE_1, TRAIN_EXAMPLE_2] * 15)], | ||||||
|  | ) | ||||||
|  | def test_init_config_trainable_multiple(component, examples, en_vocab): | ||||||
|  |     train_docs = [] | ||||||
|  |     for example in examples: | ||||||
|  |         example = {k: v for k, v in example.items() if k not in ("cats", "spans")} | ||||||
|  |         doc = Doc(en_vocab, **example) | ||||||
|  |         train_docs.append(doc) | ||||||
|  | 
 | ||||||
|  |     with make_tempdir() as d_in: | ||||||
|  |         train_bin = DocBin(docs=train_docs) | ||||||
|  |         train_bin.to_disk(d_in / "train.spacy") | ||||||
|  |         dev_bin = DocBin(docs=train_docs) | ||||||
|  |         dev_bin.to_disk(d_in / "dev.spacy") | ||||||
|  |         init_config_result = CliRunner().invoke( | ||||||
|  |             app, | ||||||
|  |             [ | ||||||
|  |                 "init", | ||||||
|  |                 "config", | ||||||
|  |                 f"{d_in}/config.cfg", | ||||||
|  |                 "--lang", | ||||||
|  |                 "en", | ||||||
|  |                 "--pipeline", | ||||||
|  |                 component, | ||||||
|  |             ], | ||||||
|  |         ) | ||||||
|  |         assert init_config_result.exit_code == 0 | ||||||
|  |         train_result = CliRunner().invoke( | ||||||
|  |             app, | ||||||
|  |             [ | ||||||
|  |                 "train", | ||||||
|  |                 f"{d_in}/config.cfg", | ||||||
|  |                 "--paths.train", | ||||||
|  |                 f"{d_in}/train.spacy", | ||||||
|  |                 "--paths.dev", | ||||||
|  |                 f"{d_in}/dev.spacy", | ||||||
|  |                 "--output", | ||||||
|  |                 f"{d_in}/model", | ||||||
|  |             ], | ||||||
|  |         ) | ||||||
|  |         assert train_result.exit_code == 0 | ||||||
|  |         assert Path(d_in / "model" / "model-last").exists() | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user