mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Update config and commands
This commit is contained in:
parent
9e48ea48a1
commit
b7111da1d7
|
@ -48,7 +48,7 @@ use_averages = false
|
|||
eps = 1e-8
|
||||
#learn_rate = 0.001
|
||||
|
||||
[optimizer.learn_rate]
|
||||
[training.optimizer.learn_rate]
|
||||
@schedules = "warmup_linear.v1"
|
||||
warmup_steps = 250
|
||||
total_steps = 20000
|
||||
|
@ -56,8 +56,11 @@ initial_rate = 0.001
|
|||
|
||||
[nlp]
|
||||
lang = "en"
|
||||
base_model = null
|
||||
vectors = null
|
||||
|
||||
[nlp.pipeline]
|
||||
|
||||
[nlp.pipeline.tok2vec]
|
||||
factory = "tok2vec"
|
||||
|
||||
|
|
|
@ -25,6 +25,11 @@ score_weights = {"las": 0.4, "ents_f": 0.4, "tags_acc": 0.2}
|
|||
# These settings are invalid for the transformer models.
|
||||
init_tok2vec = null
|
||||
discard_oversize = false
|
||||
omit_extra_lookups = false
|
||||
batch_by = "words"
|
||||
use_gpu = -1
|
||||
raw_text = null
|
||||
tag_map = null
|
||||
|
||||
[training.batch_size]
|
||||
@schedules = "compounding.v1"
|
||||
|
@ -72,6 +77,9 @@ learn_rate = 0.001
|
|||
[nlp]
|
||||
lang = "en"
|
||||
vectors = null
|
||||
base_model = null
|
||||
|
||||
[nlp.pipeline]
|
||||
|
||||
[nlp.pipeline.tok2vec]
|
||||
factory = "tok2vec"
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Dict, Any, Union, List
|
||||
from typing import Dict, Any, Union, List, Optional
|
||||
from pathlib import Path
|
||||
from wasabi import msg
|
||||
import srsly
|
||||
|
@ -11,6 +11,7 @@ from configparser import InterpolationError
|
|||
import sys
|
||||
|
||||
from ..schemas import ProjectConfigSchema, validate
|
||||
from ..util import import_file
|
||||
|
||||
|
||||
PROJECT_FILE = "project.yml"
|
||||
|
@ -172,3 +173,16 @@ def show_validation_error(title: str = "Config validation error"):
|
|||
msg.fail(title, spaced=True)
|
||||
print(str(e).replace("Config validation error", "").strip())
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def import_code(code_path: Optional[Union[Path, str]]) -> None:
|
||||
"""Helper to import Python file provided in training commands / commands
|
||||
using the config. This makes custom registered functions available.
|
||||
"""
|
||||
if code_path is not None:
|
||||
if not Path(code_path).exists():
|
||||
msg.fail("Path to Python code not found", code_path, exits=1)
|
||||
try:
|
||||
import_file("python_code", code_path)
|
||||
except Exception as e:
|
||||
msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1)
|
||||
|
|
|
@ -1,15 +1,18 @@
|
|||
from typing import Optional, List, Sequence, Dict, Any, Tuple
|
||||
from typing import List, Sequence, Dict, Any, Tuple, Optional
|
||||
from pathlib import Path
|
||||
from collections import Counter
|
||||
import sys
|
||||
import srsly
|
||||
from wasabi import Printer, MESSAGES
|
||||
import typer
|
||||
|
||||
from ._util import app, Arg, Opt
|
||||
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
|
||||
from ._util import import_code
|
||||
from ..schemas import ConfigSchema
|
||||
from ..gold import Corpus, Example
|
||||
from ..syntax import nonproj
|
||||
from ..language import Language
|
||||
from ..util import load_model, get_lang_class
|
||||
from .. import util
|
||||
|
||||
|
||||
# Minimum number of expected occurrences of NER label in data to train new label
|
||||
|
@ -24,12 +27,11 @@ BLANK_MODEL_THRESHOLD = 2000
|
|||
@app.command("debug-data")
|
||||
def debug_data_cli(
|
||||
# fmt: off
|
||||
lang: str = Arg(..., help="Model language"),
|
||||
ctx: typer.Context, # This is only used to read additional arguments
|
||||
train_path: Path = Arg(..., help="Location of JSON-formatted training data", exists=True),
|
||||
dev_path: Path = Arg(..., help="Location of JSON-formatted development data", exists=True),
|
||||
tag_map_path: Optional[Path] = Opt(None, "--tag-map-path", "-tm", help="Location of JSON-formatted tag map", exists=True, dir_okay=False),
|
||||
base_model: Optional[str] = Opt(None, "--base-model", "-b", help="Name of model to update (optional)"),
|
||||
pipeline: str = Opt("tagger,parser,ner", "--pipeline", "-p", help="Comma-separated names of pipeline components to train"),
|
||||
config_path: Path = Arg(..., help="Path to config file", exists=True),
|
||||
code_path: Optional[Path] = Opt(None, "--code-path", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
||||
ignore_warnings: bool = Opt(False, "--ignore-warnings", "-IW", help="Ignore warnings, only show stats and errors"),
|
||||
verbose: bool = Opt(False, "--verbose", "-V", help="Print additional information and explanations"),
|
||||
no_format: bool = Opt(False, "--no-format", "-NF", help="Don't pretty-print the results"),
|
||||
|
@ -40,13 +42,13 @@ def debug_data_cli(
|
|||
stats, and find problems like invalid entity annotations, cyclic
|
||||
dependencies, low data labels and more.
|
||||
"""
|
||||
overrides = parse_config_overrides(ctx.args)
|
||||
import_code(code_path)
|
||||
debug_data(
|
||||
lang,
|
||||
train_path,
|
||||
dev_path,
|
||||
tag_map_path=tag_map_path,
|
||||
base_model=base_model,
|
||||
pipeline=[p.strip() for p in pipeline.split(",")],
|
||||
config_path,
|
||||
config_overrides=overrides,
|
||||
ignore_warnings=ignore_warnings,
|
||||
verbose=verbose,
|
||||
no_format=no_format,
|
||||
|
@ -55,13 +57,11 @@ def debug_data_cli(
|
|||
|
||||
|
||||
def debug_data(
|
||||
lang: str,
|
||||
train_path: Path,
|
||||
dev_path: Path,
|
||||
config_path: Path,
|
||||
*,
|
||||
tag_map_path: Optional[Path] = None,
|
||||
base_model: Optional[str] = None,
|
||||
pipeline: List[str] = ["tagger", "parser", "ner"],
|
||||
config_overrides: Dict[str, Any] = {},
|
||||
ignore_warnings: bool = False,
|
||||
verbose: bool = False,
|
||||
no_format: bool = True,
|
||||
|
@ -75,25 +75,27 @@ def debug_data(
|
|||
msg.fail("Training data not found", train_path, exits=1)
|
||||
if not dev_path.exists():
|
||||
msg.fail("Development data not found", dev_path, exits=1)
|
||||
|
||||
if not config_path.exists():
|
||||
msg.fail("Config file not found", config_path, exists=1)
|
||||
with show_validation_error():
|
||||
config = util.load_config(
|
||||
config_path,
|
||||
create_objects=False,
|
||||
schema=ConfigSchema,
|
||||
overrides=config_overrides,
|
||||
)
|
||||
nlp = util.load_model_from_config(config["nlp"])
|
||||
lang = config["nlp"]["lang"]
|
||||
base_model = config["nlp"]["base_model"]
|
||||
pipeline = list(config["nlp"]["pipeline"].keys())
|
||||
tag_map_path = util.ensure_path(config["training"]["tag_map"])
|
||||
tag_map = {}
|
||||
if tag_map_path is not None:
|
||||
tag_map = srsly.read_json(tag_map_path)
|
||||
|
||||
# Initialize the model and pipeline
|
||||
if base_model:
|
||||
nlp = load_model(base_model)
|
||||
else:
|
||||
lang_cls = get_lang_class(lang)
|
||||
nlp = lang_cls()
|
||||
# Update tag map with provided mapping
|
||||
nlp.vocab.morphology.tag_map.update(tag_map)
|
||||
|
||||
msg.divider("Data format validation")
|
||||
|
||||
# TODO: Validate data format using the JSON schema
|
||||
# TODO: update once the new format is ready
|
||||
# TODO: move validation to GoldCorpus in order to be able to load from dir
|
||||
msg.divider("Data file validation")
|
||||
|
||||
# Create the gold corpus to be able to better analyze data
|
||||
loading_train_error_message = ""
|
||||
|
@ -380,7 +382,7 @@ def debug_data(
|
|||
if gold_dev_data["n_nonproj"] > 0:
|
||||
n_nonproj = gold_dev_data["n_nonproj"]
|
||||
msg.info(f"Found {n_nonproj} nonprojective dev sentence(s)")
|
||||
msg.info(f"{labels_train_unpreprocessed} label(s) in train data")
|
||||
msg.info(f"{len(labels_train_unpreprocessed)} label(s) in train data")
|
||||
msg.info(f"{len(labels_train)} label(s) in projectivized train data")
|
||||
labels_with_counts = _format_labels(
|
||||
gold_train_unpreprocessed_data["deps"].most_common(), counts=True
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional
|
||||
from typing import Optional, Dict, Any
|
||||
import random
|
||||
import numpy
|
||||
import time
|
||||
|
@ -11,8 +11,11 @@ from thinc.api import CosineDistance, L2Distance
|
|||
from wasabi import msg
|
||||
import srsly
|
||||
from functools import partial
|
||||
import typer
|
||||
|
||||
from ._util import app, Arg, Opt
|
||||
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
|
||||
from ._util import import_code
|
||||
from ..schemas import ConfigSchema
|
||||
from ..errors import Errors
|
||||
from ..ml.models.multi_task import build_cloze_multi_task_model
|
||||
from ..ml.models.multi_task import build_cloze_characters_multi_task_model
|
||||
|
@ -24,10 +27,11 @@ from .. import util
|
|||
@app.command("pretrain")
|
||||
def pretrain_cli(
|
||||
# fmt: off
|
||||
ctx: typer.Context, # This is only used to read additional arguments
|
||||
texts_loc: Path = Arg(..., help="Path to JSONL file with raw texts to learn from, with text provided as the key 'text' or tokens as the key 'tokens'", exists=True),
|
||||
output_dir: Path = Arg(..., help="Directory to write models to on each epoch"),
|
||||
config_path: Path = Arg(..., help="Path to config file", exists=True, dir_okay=False),
|
||||
use_gpu: int = Opt(-1, "--use-gpu", "-g", help="Use GPU"),
|
||||
code_path: Optional[Path] = Opt(None, "--code-path", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
||||
resume_path: Optional[Path] = Opt(None, "--resume-path", "-r", help="Path to pretrained weights from which to resume pretraining"),
|
||||
epoch_resume: Optional[int] = Opt(None, "--epoch-resume", "-er", help="The epoch to resume counting from when using '--resume_path'. Prevents unintended overwriting of existing weight files."),
|
||||
# fmt: on
|
||||
|
@ -51,11 +55,13 @@ def pretrain_cli(
|
|||
all settings are the same between pretraining and training. Ideally,
|
||||
this is done by using the same config file for both commands.
|
||||
"""
|
||||
overrides = parse_config_overrides(ctx.args)
|
||||
import_code(code_path)
|
||||
pretrain(
|
||||
texts_loc,
|
||||
output_dir,
|
||||
config_path,
|
||||
use_gpu=use_gpu,
|
||||
config_overrides=overrides,
|
||||
resume_path=resume_path,
|
||||
epoch_resume=epoch_resume,
|
||||
)
|
||||
|
@ -65,24 +71,34 @@ def pretrain(
|
|||
texts_loc: Path,
|
||||
output_dir: Path,
|
||||
config_path: Path,
|
||||
use_gpu: int = -1,
|
||||
config_overrides: Dict[str, Any] = {},
|
||||
resume_path: Optional[Path] = None,
|
||||
epoch_resume: Optional[int] = None,
|
||||
):
|
||||
verify_cli_args(**locals())
|
||||
verify_cli_args(texts_loc, output_dir, config_path, resume_path, epoch_resume)
|
||||
msg.info(f"Loading config from: {config_path}")
|
||||
with show_validation_error():
|
||||
config = util.load_config(
|
||||
config_path,
|
||||
create_objects=False,
|
||||
validate=True,
|
||||
schema=ConfigSchema,
|
||||
overrides=config_overrides,
|
||||
)
|
||||
if not output_dir.exists():
|
||||
output_dir.mkdir()
|
||||
msg.good(f"Created output directory: {output_dir}")
|
||||
|
||||
use_gpu = config["training"]["use_gpu"]
|
||||
if use_gpu >= 0:
|
||||
msg.info("Using GPU")
|
||||
require_gpu(use_gpu)
|
||||
else:
|
||||
msg.info("Using CPU")
|
||||
|
||||
msg.info(f"Loading config from: {config_path}")
|
||||
config = util.load_config(config_path, create_objects=False)
|
||||
fix_random_seed(config["pretraining"]["seed"])
|
||||
seed = config["pretraining"]["seed"]
|
||||
if seed is not None:
|
||||
fix_random_seed(seed)
|
||||
if use_gpu >= 0 and config["pretraining"]["use_pytorch_for_gpu_memory"]:
|
||||
use_pytorch_for_gpu_memory()
|
||||
|
||||
|
@ -360,9 +376,7 @@ def _smart_round(figure, width=10, max_decimal=4):
|
|||
return format_str % figure
|
||||
|
||||
|
||||
def verify_cli_args(
|
||||
texts_loc, output_dir, config_path, use_gpu, resume_path, epoch_resume
|
||||
):
|
||||
def verify_cli_args(texts_loc, output_dir, config_path, resume_path, epoch_resume):
|
||||
if not config_path or not config_path.exists():
|
||||
msg.fail("Config file not found", config_path, exits=1)
|
||||
if output_dir.exists() and [p for p in output_dir.iterdir()]:
|
||||
|
@ -401,10 +415,3 @@ def verify_cli_args(
|
|||
f"The argument --epoch-resume has to be greater or equal to 0. {epoch_resume} is invalid",
|
||||
exits=True,
|
||||
)
|
||||
config = util.load_config(config_path, create_objects=False)
|
||||
if config["pretraining"]["objective"]["type"] == "vectors":
|
||||
if not config["nlp"]["vectors"]:
|
||||
msg.fail(
|
||||
"Must specify nlp.vectors if pretraining.objective.type is vectors",
|
||||
exits=True,
|
||||
)
|
||||
|
|
|
@ -11,6 +11,7 @@ import random
|
|||
import typer
|
||||
|
||||
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
|
||||
from ._util import import_code
|
||||
from ..gold import Corpus, Example
|
||||
from ..lookups import Lookups
|
||||
from .. import util
|
||||
|
@ -53,17 +54,10 @@ def train_cli(
|
|||
"""
|
||||
util.set_env_log(verbose)
|
||||
verify_cli_args(
|
||||
train_path=train_path,
|
||||
dev_path=dev_path,
|
||||
config_path=config_path,
|
||||
code_path=code_path,
|
||||
train_path=train_path, dev_path=dev_path, config_path=config_path,
|
||||
)
|
||||
overrides = parse_config_overrides(ctx.args)
|
||||
if code_path is not None:
|
||||
try:
|
||||
util.import_file("python_code", code_path)
|
||||
except Exception as e:
|
||||
msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1)
|
||||
import_code(code_path)
|
||||
train(
|
||||
config_path,
|
||||
{"train": train_path, "dev": dev_path},
|
||||
|
@ -503,7 +497,6 @@ def verify_cli_args(
|
|||
dev_path: Path,
|
||||
config_path: Path,
|
||||
output_path: Optional[Path] = None,
|
||||
code_path: Optional[Path] = None,
|
||||
):
|
||||
# Make sure all files and paths exists if they are needed
|
||||
if not config_path or not config_path.exists():
|
||||
|
@ -524,9 +517,6 @@ def verify_cli_args(
|
|||
"the specified output path doesn't exist, the directory will be "
|
||||
"created for you.",
|
||||
)
|
||||
if code_path is not None:
|
||||
if not code_path.exists():
|
||||
msg.fail("Path to Python code not found", code_path, exits=1)
|
||||
|
||||
|
||||
def verify_textcat_config(nlp, nlp_config):
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import Dict, List, Union, Optional, Sequence, Any
|
|||
from enum import Enum
|
||||
from pydantic import BaseModel, Field, ValidationError, validator
|
||||
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
|
||||
from pydantic import FilePath, DirectoryPath
|
||||
from pydantic import FilePath, DirectoryPath, root_validator
|
||||
from collections import defaultdict
|
||||
from thinc.api import Model, Optimizer
|
||||
|
||||
|
@ -242,6 +242,7 @@ class ConfigSchemaPipeline(BaseModel):
|
|||
|
||||
class ConfigSchemaNlp(BaseModel):
|
||||
lang: StrictStr = Field(..., title="The base language to use")
|
||||
base_model: Optional[StrictStr] = Field(..., title="The base model to use")
|
||||
vectors: Optional[DirectoryPath] = Field(..., title="Path to vectors")
|
||||
pipeline: Optional[ConfigSchemaPipeline]
|
||||
|
||||
|
@ -250,9 +251,40 @@ class ConfigSchemaNlp(BaseModel):
|
|||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class ConfigSchemaPretrain(BaseModel):
|
||||
# fmt: off
|
||||
max_epochs: StrictInt = Field(..., title="Maximum number of epochs to train for")
|
||||
min_length: StrictInt = Field(..., title="Minimum length of examples")
|
||||
max_length: StrictInt = Field(..., title="Maximum length of examples")
|
||||
dropout: StrictFloat = Field(..., title="Dropout rate")
|
||||
n_save_every: Optional[StrictInt] = Field(..., title="Saving frequency")
|
||||
batch_size: Union[Sequence[int], int] = Field(..., title="The batch size or batch size schedule")
|
||||
seed: Optional[StrictInt] = Field(..., title="Random seed")
|
||||
use_pytorch_for_gpu_memory: StrictBool = Field(..., title="Allocate memory via PyTorch")
|
||||
tok2vec_model: StrictStr = Field(..., title="tok2vec model in config, e.g. nlp.pipeline.tok2vec.model")
|
||||
optimizer: Optimizer = Field(..., title="The optimizer to use")
|
||||
# TODO: use a more detailed schema for this?
|
||||
objective: Dict[str, Any] = Field(..., title="Pretraining objective")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class ConfigSchema(BaseModel):
|
||||
training: ConfigSchemaTraining
|
||||
nlp: ConfigSchemaNlp
|
||||
pretraining: Optional[ConfigSchemaPretrain]
|
||||
|
||||
@root_validator
|
||||
def validate_config(cls, values):
|
||||
"""Perform additional validation for settings with dependencies."""
|
||||
pt = values.get("pretraining")
|
||||
if pt and pt.objective.get("type") == "vectors" and not values["nlp"].vectors:
|
||||
err = "Need nlp.vectors if pretraining.objective.type is vectors"
|
||||
raise ValueError(err)
|
||||
return values
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
|
|
@ -12,6 +12,8 @@ nlp_config_string = """
|
|||
[nlp]
|
||||
lang = "en"
|
||||
|
||||
[nlp.pipeline]
|
||||
|
||||
[nlp.pipeline.tok2vec]
|
||||
factory = "tok2vec"
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user