mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Add tag_map argument to CLI debug-data and train (#4750)
Add an argument for a path to a JSON-formatted tag map, which is used to update and extend the default language tag map.
This commit is contained in:
		
							parent
							
								
									eb9b1858c4
								
							
						
					
					
						commit
						a4cacd3402
					
				| 
						 | 
				
			
			@ -26,6 +26,7 @@ BLANK_MODEL_THRESHOLD = 2000
 | 
			
		|||
    lang=("model language", "positional", None, str),
 | 
			
		||||
    train_path=("location of JSON-formatted training data", "positional", None, Path),
 | 
			
		||||
    dev_path=("location of JSON-formatted development data", "positional", None, Path),
 | 
			
		||||
    tag_map_path=("Location of JSON-formatted tag map", "option", "tm", Path),
 | 
			
		||||
    base_model=("name of model to update (optional)", "option", "b", str),
 | 
			
		||||
    pipeline=(
 | 
			
		||||
        "Comma-separated names of pipeline components to train",
 | 
			
		||||
| 
						 | 
				
			
			@ -41,6 +42,7 @@ def debug_data(
 | 
			
		|||
    lang,
 | 
			
		||||
    train_path,
 | 
			
		||||
    dev_path,
 | 
			
		||||
    tag_map_path=None,
 | 
			
		||||
    base_model=None,
 | 
			
		||||
    pipeline="tagger,parser,ner",
 | 
			
		||||
    ignore_warnings=False,
 | 
			
		||||
| 
						 | 
				
			
			@ -60,6 +62,10 @@ def debug_data(
 | 
			
		|||
    if not dev_path.exists():
 | 
			
		||||
        msg.fail("Development data not found", dev_path, exits=1)
 | 
			
		||||
 | 
			
		||||
    tag_map = {}
 | 
			
		||||
    if tag_map_path is not None:
 | 
			
		||||
        tag_map = srsly.read_json(tag_map_path)
 | 
			
		||||
 | 
			
		||||
    # Initialize the model and pipeline
 | 
			
		||||
    pipeline = [p.strip() for p in pipeline.split(",")]
 | 
			
		||||
    if base_model:
 | 
			
		||||
| 
						 | 
				
			
			@ -67,6 +73,8 @@ def debug_data(
 | 
			
		|||
    else:
 | 
			
		||||
        lang_cls = get_lang_class(lang)
 | 
			
		||||
        nlp = lang_cls()
 | 
			
		||||
    # Update tag map with provided mapping
 | 
			
		||||
    nlp.vocab.morphology.tag_map.update(tag_map)
 | 
			
		||||
 | 
			
		||||
    msg.divider("Data format validation")
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -329,7 +337,7 @@ def debug_data(
 | 
			
		|||
    if "tagger" in pipeline:
 | 
			
		||||
        msg.divider("Part-of-speech Tagging")
 | 
			
		||||
        labels = [label for label in gold_train_data["tags"]]
 | 
			
		||||
        tag_map = nlp.Defaults.tag_map
 | 
			
		||||
        tag_map = nlp.vocab.morphology.tag_map
 | 
			
		||||
        msg.info(
 | 
			
		||||
            "{} {} in data ({} {} in tag map)".format(
 | 
			
		||||
                len(labels),
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -48,6 +48,7 @@ from .. import about
 | 
			
		|||
    textcat_multilabel=("Textcat classes aren't mutually exclusive (multilabel)", "flag", "TML", bool),
 | 
			
		||||
    textcat_arch=("Textcat model architecture", "option", "ta", str),
 | 
			
		||||
    textcat_positive_label=("Textcat positive label for binary classes with two labels", "option", "tpl", str),
 | 
			
		||||
    tag_map_path=("Location of JSON-formatted tag map", "option", "tm", Path),
 | 
			
		||||
    verbose=("Display more information for debug", "flag", "VV", bool),
 | 
			
		||||
    debug=("Run data diagnostics before training", "flag", "D", bool),
 | 
			
		||||
    # fmt: on
 | 
			
		||||
| 
						 | 
				
			
			@ -78,6 +79,7 @@ def train(
 | 
			
		|||
    textcat_multilabel=False,
 | 
			
		||||
    textcat_arch="bow",
 | 
			
		||||
    textcat_positive_label=None,
 | 
			
		||||
    tag_map_path=None,
 | 
			
		||||
    verbose=False,
 | 
			
		||||
    debug=False,
 | 
			
		||||
):
 | 
			
		||||
| 
						 | 
				
			
			@ -118,6 +120,9 @@ def train(
 | 
			
		|||
    if not output_path.exists():
 | 
			
		||||
        output_path.mkdir()
 | 
			
		||||
 | 
			
		||||
    tag_map = {}
 | 
			
		||||
    if tag_map_path is not None:
 | 
			
		||||
        tag_map = srsly.read_json(tag_map_path)
 | 
			
		||||
    # Take dropout and batch size as generators of values -- dropout
 | 
			
		||||
    # starts high and decays sharply, to force the optimizer to explore.
 | 
			
		||||
    # Batch size starts at 1 and grows, so that we make updates quickly
 | 
			
		||||
| 
						 | 
				
			
			@ -209,6 +214,9 @@ def train(
 | 
			
		|||
                pipe_cfg = {}
 | 
			
		||||
            nlp.add_pipe(nlp.create_pipe(pipe, config=pipe_cfg))
 | 
			
		||||
 | 
			
		||||
    # Update tag map with provided mapping
 | 
			
		||||
    nlp.vocab.morphology.tag_map.update(tag_map)
 | 
			
		||||
 | 
			
		||||
    if vectors:
 | 
			
		||||
        msg.text("Loading vector from model '{}'".format(vectors))
 | 
			
		||||
        _load_vectors(nlp, vectors)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user