mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Move core training logic in CLI into standalone function (#9398)
This commit is contained in:
		
							parent
							
								
									2a7e327310
								
							
						
					
					
						commit
						5003a9c3c7
					
				| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional, Dict, Any
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
from wasabi import msg
 | 
					from wasabi import msg
 | 
				
			||||||
import typer
 | 
					import typer
 | 
				
			||||||
| 
						 | 
					@ -7,7 +7,7 @@ import sys
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
 | 
					from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
 | 
				
			||||||
from ._util import import_code, setup_gpu
 | 
					from ._util import import_code, setup_gpu
 | 
				
			||||||
from ..training.loop import train
 | 
					from ..training.loop import train as train_nlp
 | 
				
			||||||
from ..training.initialize import init_nlp
 | 
					from ..training.initialize import init_nlp
 | 
				
			||||||
from .. import util
 | 
					from .. import util
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -40,6 +40,18 @@ def train_cli(
 | 
				
			||||||
    DOCS: https://spacy.io/api/cli#train
 | 
					    DOCS: https://spacy.io/api/cli#train
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
 | 
					    util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
 | 
				
			||||||
 | 
					    overrides = parse_config_overrides(ctx.args)
 | 
				
			||||||
 | 
					    import_code(code_path)
 | 
				
			||||||
 | 
					    train(config_path, output_path, use_gpu=use_gpu, overrides=overrides)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def train(
 | 
				
			||||||
 | 
					    config_path: Path,
 | 
				
			||||||
 | 
					    output_path: Optional[Path] = None,
 | 
				
			||||||
 | 
					    *,
 | 
				
			||||||
 | 
					    use_gpu: int = -1,
 | 
				
			||||||
 | 
					    overrides: Dict[str, Any] = util.SimpleFrozenDict(),
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
    # Make sure all files and paths exists if they are needed
 | 
					    # Make sure all files and paths exists if they are needed
 | 
				
			||||||
    if not config_path or (str(config_path) != "-" and not config_path.exists()):
 | 
					    if not config_path or (str(config_path) != "-" and not config_path.exists()):
 | 
				
			||||||
        msg.fail("Config file not found", config_path, exits=1)
 | 
					        msg.fail("Config file not found", config_path, exits=1)
 | 
				
			||||||
| 
						 | 
					@ -50,8 +62,6 @@ def train_cli(
 | 
				
			||||||
            output_path.mkdir(parents=True)
 | 
					            output_path.mkdir(parents=True)
 | 
				
			||||||
            msg.good(f"Created output directory: {output_path}")
 | 
					            msg.good(f"Created output directory: {output_path}")
 | 
				
			||||||
        msg.info(f"Saving to output directory: {output_path}")
 | 
					        msg.info(f"Saving to output directory: {output_path}")
 | 
				
			||||||
    overrides = parse_config_overrides(ctx.args)
 | 
					 | 
				
			||||||
    import_code(code_path)
 | 
					 | 
				
			||||||
    setup_gpu(use_gpu)
 | 
					    setup_gpu(use_gpu)
 | 
				
			||||||
    with show_validation_error(config_path):
 | 
					    with show_validation_error(config_path):
 | 
				
			||||||
        config = util.load_config(config_path, overrides=overrides, interpolate=False)
 | 
					        config = util.load_config(config_path, overrides=overrides, interpolate=False)
 | 
				
			||||||
| 
						 | 
					@ -60,4 +70,4 @@ def train_cli(
 | 
				
			||||||
        nlp = init_nlp(config, use_gpu=use_gpu)
 | 
					        nlp = init_nlp(config, use_gpu=use_gpu)
 | 
				
			||||||
    msg.good("Initialized pipeline")
 | 
					    msg.good("Initialized pipeline")
 | 
				
			||||||
    msg.divider("Training pipeline")
 | 
					    msg.divider("Training pipeline")
 | 
				
			||||||
    train(nlp, output_path, use_gpu=use_gpu, stdout=sys.stdout, stderr=sys.stderr)
 | 
					    train_nlp(nlp, output_path, use_gpu=use_gpu, stdout=sys.stdout, stderr=sys.stderr)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user