mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 10:46:29 +03:00
Start updating train script
This commit is contained in:
parent
39b178999c
commit
b5556093e2
|
@ -16,6 +16,7 @@ from ._util import import_code, get_sourced_components
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..training.example import Example
|
from ..training.example import Example
|
||||||
|
from ..training.initialize import must_initialize, init_pipeline
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
from ..util import dot_to_object
|
from ..util import dot_to_object
|
||||||
|
|
||||||
|
@ -31,8 +32,6 @@ def train_cli(
|
||||||
code_path: Optional[Path] = Opt(None, "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
code_path: Optional[Path] = Opt(None, "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
||||||
verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"),
|
verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"),
|
||||||
use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU"),
|
use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU"),
|
||||||
resume: bool = Opt(False, "--resume", "-R", help="Resume training"),
|
|
||||||
dave_path: Optional[Path] = Opt(None, "--dave", "-D", help="etc etc"),
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -53,38 +52,37 @@ def train_cli(
|
||||||
verify_cli_args(config_path, output_path)
|
verify_cli_args(config_path, output_path)
|
||||||
overrides = parse_config_overrides(ctx.args)
|
overrides = parse_config_overrides(ctx.args)
|
||||||
import_code(code_path)
|
import_code(code_path)
|
||||||
if prepared is None:
|
|
||||||
prepare(config_path, output_path / "prepared", config_overrides=overrides)
|
|
||||||
train(
|
|
||||||
config_path,
|
|
||||||
output_path=output_path,
|
|
||||||
dave_path=dave_path,
|
|
||||||
config_overrides=overrides,
|
|
||||||
use_gpu=use_gpu,
|
|
||||||
resume_training=resume,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def train(
|
|
||||||
output_path: Path,
|
|
||||||
config_overrides: Dict[str, Any] = {},
|
|
||||||
use_gpu: int = -1,
|
|
||||||
resume_training: bool = False,
|
|
||||||
) -> None:
|
|
||||||
if use_gpu >= 0:
|
if use_gpu >= 0:
|
||||||
msg.info(f"Using GPU: {use_gpu}")
|
msg.info(f"Using GPU: {use_gpu}")
|
||||||
require_gpu(use_gpu)
|
require_gpu(use_gpu)
|
||||||
else:
|
else:
|
||||||
msg.info("Using CPU")
|
msg.info("Using CPU")
|
||||||
msg.info(f"Loading config and nlp from: {config_path}")
|
config = util.load_config(
|
||||||
# TODO: The details of this will change
|
config_path, overrides=config_overrides, interpolate=True
|
||||||
dave_path = output_path / "dave"
|
)
|
||||||
config_path = dave_path / "config.cfg"
|
if output_path is None:
|
||||||
with show_validation_error(config_path):
|
nlp = init_pipeline(config)
|
||||||
config = fill_config_etc_etc(config_path)
|
else:
|
||||||
nlp = make_and_load_nlp_etc_etc(config, dave_path)
|
init_path = output_path / "model-initial"
|
||||||
optimizer, train_corpus, dev_corpus, score_weights, T_cfg = resolve_more_things_etc_etc(config)
|
if must_reinitialize(config, init_path):
|
||||||
|
nlp = init_pipeline(config)
|
||||||
|
nlp.to_disk(init_path)
|
||||||
|
else:
|
||||||
|
nlp = spacy.load(output_path / "model-initial")
|
||||||
|
msg.info("Start training")
|
||||||
|
train(nlp, config, output_path)
|
||||||
|
|
||||||
|
|
||||||
|
def train(nlp: Language, output_path: Optional[Path]=None) -> None:
|
||||||
|
# Create iterator, which yields out info after each optimization step.
|
||||||
|
config = nlp.config
|
||||||
|
T_cfg = config["training"]
|
||||||
|
score_weights = T_cfg["score_weights"]
|
||||||
|
optimizer = T_cfg["optimizer"]
|
||||||
|
train_corpus = dot_to_object(config, T_cfg["train_corpus"])
|
||||||
|
dev_corpus = dot_to_object(config, T_cfg["dev_corpus"])
|
||||||
|
batcher = T_cfg["batcher"]
|
||||||
|
|
||||||
training_step_iterator = train_while_improving(
|
training_step_iterator = train_while_improving(
|
||||||
nlp,
|
nlp,
|
||||||
optimizer,
|
optimizer,
|
||||||
|
@ -142,6 +140,7 @@ def train(
|
||||||
msg.good(f"Saved pipeline to output directory {final_model_path}")
|
msg.good(f"Saved pipeline to output directory {final_model_path}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def add_vectors(nlp: Language, vectors: str) -> None:
|
def add_vectors(nlp: Language, vectors: str) -> None:
|
||||||
title = f"Config validation error for vectors {vectors}"
|
title = f"Config validation error for vectors {vectors}"
|
||||||
desc = (
|
desc = (
|
||||||
|
|
Loading…
Reference in New Issue
Block a user