Update config and commands

This commit is contained in:
Ines Montani 2020-07-11 13:03:53 +02:00
parent 9e48ea48a1
commit b7111da1d7
8 changed files with 122 additions and 64 deletions

View File

@ -48,7 +48,7 @@ use_averages = false
eps = 1e-8 eps = 1e-8
#learn_rate = 0.001 #learn_rate = 0.001
[optimizer.learn_rate] [training.optimizer.learn_rate]
@schedules = "warmup_linear.v1" @schedules = "warmup_linear.v1"
warmup_steps = 250 warmup_steps = 250
total_steps = 20000 total_steps = 20000
@ -56,8 +56,11 @@ initial_rate = 0.001
[nlp] [nlp]
lang = "en" lang = "en"
base_model = null
vectors = null vectors = null
[nlp.pipeline]
[nlp.pipeline.tok2vec] [nlp.pipeline.tok2vec]
factory = "tok2vec" factory = "tok2vec"

View File

@ -25,6 +25,11 @@ score_weights = {"las": 0.4, "ents_f": 0.4, "tags_acc": 0.2}
# These settings are invalid for the transformer models. # These settings are invalid for the transformer models.
init_tok2vec = null init_tok2vec = null
discard_oversize = false discard_oversize = false
omit_extra_lookups = false
batch_by = "words"
use_gpu = -1
raw_text = null
tag_map = null
[training.batch_size] [training.batch_size]
@schedules = "compounding.v1" @schedules = "compounding.v1"
@ -72,6 +77,9 @@ learn_rate = 0.001
[nlp] [nlp]
lang = "en" lang = "en"
vectors = null vectors = null
base_model = null
[nlp.pipeline]
[nlp.pipeline.tok2vec] [nlp.pipeline.tok2vec]
factory = "tok2vec" factory = "tok2vec"

View File

@ -1,4 +1,4 @@
from typing import Dict, Any, Union, List from typing import Dict, Any, Union, List, Optional
from pathlib import Path from pathlib import Path
from wasabi import msg from wasabi import msg
import srsly import srsly
@ -11,6 +11,7 @@ from configparser import InterpolationError
import sys import sys
from ..schemas import ProjectConfigSchema, validate from ..schemas import ProjectConfigSchema, validate
from ..util import import_file
PROJECT_FILE = "project.yml" PROJECT_FILE = "project.yml"
@ -172,3 +173,16 @@ def show_validation_error(title: str = "Config validation error"):
msg.fail(title, spaced=True) msg.fail(title, spaced=True)
print(str(e).replace("Config validation error", "").strip()) print(str(e).replace("Config validation error", "").strip())
sys.exit(1) sys.exit(1)
def import_code(code_path: Optional[Union[Path, str]]) -> None:
"""Helper to import Python file provided in training commands / commands
using the config. This makes custom registered functions available.
"""
if code_path is not None:
if not Path(code_path).exists():
msg.fail("Path to Python code not found", code_path, exits=1)
try:
import_file("python_code", code_path)
except Exception as e:
msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1)

View File

@ -1,15 +1,18 @@
from typing import Optional, List, Sequence, Dict, Any, Tuple from typing import List, Sequence, Dict, Any, Tuple, Optional
from pathlib import Path from pathlib import Path
from collections import Counter from collections import Counter
import sys import sys
import srsly import srsly
from wasabi import Printer, MESSAGES from wasabi import Printer, MESSAGES
import typer
from ._util import app, Arg, Opt from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
from ._util import import_code
from ..schemas import ConfigSchema
from ..gold import Corpus, Example from ..gold import Corpus, Example
from ..syntax import nonproj from ..syntax import nonproj
from ..language import Language from ..language import Language
from ..util import load_model, get_lang_class from .. import util
# Minimum number of expected occurrences of NER label in data to train new label # Minimum number of expected occurrences of NER label in data to train new label
@ -24,12 +27,11 @@ BLANK_MODEL_THRESHOLD = 2000
@app.command("debug-data") @app.command("debug-data")
def debug_data_cli( def debug_data_cli(
# fmt: off # fmt: off
lang: str = Arg(..., help="Model language"), ctx: typer.Context, # This is only used to read additional arguments
train_path: Path = Arg(..., help="Location of JSON-formatted training data", exists=True), train_path: Path = Arg(..., help="Location of JSON-formatted training data", exists=True),
dev_path: Path = Arg(..., help="Location of JSON-formatted development data", exists=True), dev_path: Path = Arg(..., help="Location of JSON-formatted development data", exists=True),
tag_map_path: Optional[Path] = Opt(None, "--tag-map-path", "-tm", help="Location of JSON-formatted tag map", exists=True, dir_okay=False), config_path: Path = Arg(..., help="Path to config file", exists=True),
base_model: Optional[str] = Opt(None, "--base-model", "-b", help="Name of model to update (optional)"), code_path: Optional[Path] = Opt(None, "--code-path", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
pipeline: str = Opt("tagger,parser,ner", "--pipeline", "-p", help="Comma-separated names of pipeline components to train"),
ignore_warnings: bool = Opt(False, "--ignore-warnings", "-IW", help="Ignore warnings, only show stats and errors"), ignore_warnings: bool = Opt(False, "--ignore-warnings", "-IW", help="Ignore warnings, only show stats and errors"),
verbose: bool = Opt(False, "--verbose", "-V", help="Print additional information and explanations"), verbose: bool = Opt(False, "--verbose", "-V", help="Print additional information and explanations"),
no_format: bool = Opt(False, "--no-format", "-NF", help="Don't pretty-print the results"), no_format: bool = Opt(False, "--no-format", "-NF", help="Don't pretty-print the results"),
@ -40,13 +42,13 @@ def debug_data_cli(
stats, and find problems like invalid entity annotations, cyclic stats, and find problems like invalid entity annotations, cyclic
dependencies, low data labels and more. dependencies, low data labels and more.
""" """
overrides = parse_config_overrides(ctx.args)
import_code(code_path)
debug_data( debug_data(
lang,
train_path, train_path,
dev_path, dev_path,
tag_map_path=tag_map_path, config_path,
base_model=base_model, config_overrides=overrides,
pipeline=[p.strip() for p in pipeline.split(",")],
ignore_warnings=ignore_warnings, ignore_warnings=ignore_warnings,
verbose=verbose, verbose=verbose,
no_format=no_format, no_format=no_format,
@ -55,13 +57,11 @@ def debug_data_cli(
def debug_data( def debug_data(
lang: str,
train_path: Path, train_path: Path,
dev_path: Path, dev_path: Path,
config_path: Path,
*, *,
tag_map_path: Optional[Path] = None, config_overrides: Dict[str, Any] = {},
base_model: Optional[str] = None,
pipeline: List[str] = ["tagger", "parser", "ner"],
ignore_warnings: bool = False, ignore_warnings: bool = False,
verbose: bool = False, verbose: bool = False,
no_format: bool = True, no_format: bool = True,
@ -75,25 +75,27 @@ def debug_data(
msg.fail("Training data not found", train_path, exits=1) msg.fail("Training data not found", train_path, exits=1)
if not dev_path.exists(): if not dev_path.exists():
msg.fail("Development data not found", dev_path, exits=1) msg.fail("Development data not found", dev_path, exits=1)
if not config_path.exists():
msg.fail("Config file not found", config_path, exists=1)
with show_validation_error():
config = util.load_config(
config_path,
create_objects=False,
schema=ConfigSchema,
overrides=config_overrides,
)
nlp = util.load_model_from_config(config["nlp"])
lang = config["nlp"]["lang"]
base_model = config["nlp"]["base_model"]
pipeline = list(config["nlp"]["pipeline"].keys())
tag_map_path = util.ensure_path(config["training"]["tag_map"])
tag_map = {} tag_map = {}
if tag_map_path is not None: if tag_map_path is not None:
tag_map = srsly.read_json(tag_map_path) tag_map = srsly.read_json(tag_map_path)
# Initialize the model and pipeline
if base_model:
nlp = load_model(base_model)
else:
lang_cls = get_lang_class(lang)
nlp = lang_cls()
# Update tag map with provided mapping # Update tag map with provided mapping
nlp.vocab.morphology.tag_map.update(tag_map) nlp.vocab.morphology.tag_map.update(tag_map)
msg.divider("Data format validation") msg.divider("Data file validation")
# TODO: Validate data format using the JSON schema
# TODO: update once the new format is ready
# TODO: move validation to GoldCorpus in order to be able to load from dir
# Create the gold corpus to be able to better analyze data # Create the gold corpus to be able to better analyze data
loading_train_error_message = "" loading_train_error_message = ""
@ -380,7 +382,7 @@ def debug_data(
if gold_dev_data["n_nonproj"] > 0: if gold_dev_data["n_nonproj"] > 0:
n_nonproj = gold_dev_data["n_nonproj"] n_nonproj = gold_dev_data["n_nonproj"]
msg.info(f"Found {n_nonproj} nonprojective dev sentence(s)") msg.info(f"Found {n_nonproj} nonprojective dev sentence(s)")
msg.info(f"{labels_train_unpreprocessed} label(s) in train data") msg.info(f"{len(labels_train_unpreprocessed)} label(s) in train data")
msg.info(f"{len(labels_train)} label(s) in projectivized train data") msg.info(f"{len(labels_train)} label(s) in projectivized train data")
labels_with_counts = _format_labels( labels_with_counts = _format_labels(
gold_train_unpreprocessed_data["deps"].most_common(), counts=True gold_train_unpreprocessed_data["deps"].most_common(), counts=True

View File

@ -1,4 +1,4 @@
from typing import Optional from typing import Optional, Dict, Any
import random import random
import numpy import numpy
import time import time
@ -11,8 +11,11 @@ from thinc.api import CosineDistance, L2Distance
from wasabi import msg from wasabi import msg
import srsly import srsly
from functools import partial from functools import partial
import typer
from ._util import app, Arg, Opt from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
from ._util import import_code
from ..schemas import ConfigSchema
from ..errors import Errors from ..errors import Errors
from ..ml.models.multi_task import build_cloze_multi_task_model from ..ml.models.multi_task import build_cloze_multi_task_model
from ..ml.models.multi_task import build_cloze_characters_multi_task_model from ..ml.models.multi_task import build_cloze_characters_multi_task_model
@ -24,10 +27,11 @@ from .. import util
@app.command("pretrain") @app.command("pretrain")
def pretrain_cli( def pretrain_cli(
# fmt: off # fmt: off
ctx: typer.Context, # This is only used to read additional arguments
texts_loc: Path = Arg(..., help="Path to JSONL file with raw texts to learn from, with text provided as the key 'text' or tokens as the key 'tokens'", exists=True), texts_loc: Path = Arg(..., help="Path to JSONL file with raw texts to learn from, with text provided as the key 'text' or tokens as the key 'tokens'", exists=True),
output_dir: Path = Arg(..., help="Directory to write models to on each epoch"), output_dir: Path = Arg(..., help="Directory to write models to on each epoch"),
config_path: Path = Arg(..., help="Path to config file", exists=True, dir_okay=False), config_path: Path = Arg(..., help="Path to config file", exists=True, dir_okay=False),
use_gpu: int = Opt(-1, "--use-gpu", "-g", help="Use GPU"), code_path: Optional[Path] = Opt(None, "--code-path", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
resume_path: Optional[Path] = Opt(None, "--resume-path", "-r", help="Path to pretrained weights from which to resume pretraining"), resume_path: Optional[Path] = Opt(None, "--resume-path", "-r", help="Path to pretrained weights from which to resume pretraining"),
epoch_resume: Optional[int] = Opt(None, "--epoch-resume", "-er", help="The epoch to resume counting from when using '--resume_path'. Prevents unintended overwriting of existing weight files."), epoch_resume: Optional[int] = Opt(None, "--epoch-resume", "-er", help="The epoch to resume counting from when using '--resume_path'. Prevents unintended overwriting of existing weight files."),
# fmt: on # fmt: on
@ -51,11 +55,13 @@ def pretrain_cli(
all settings are the same between pretraining and training. Ideally, all settings are the same between pretraining and training. Ideally,
this is done by using the same config file for both commands. this is done by using the same config file for both commands.
""" """
overrides = parse_config_overrides(ctx.args)
import_code(code_path)
pretrain( pretrain(
texts_loc, texts_loc,
output_dir, output_dir,
config_path, config_path,
use_gpu=use_gpu, config_overrides=overrides,
resume_path=resume_path, resume_path=resume_path,
epoch_resume=epoch_resume, epoch_resume=epoch_resume,
) )
@ -65,24 +71,34 @@ def pretrain(
texts_loc: Path, texts_loc: Path,
output_dir: Path, output_dir: Path,
config_path: Path, config_path: Path,
use_gpu: int = -1, config_overrides: Dict[str, Any] = {},
resume_path: Optional[Path] = None, resume_path: Optional[Path] = None,
epoch_resume: Optional[int] = None, epoch_resume: Optional[int] = None,
): ):
verify_cli_args(**locals()) verify_cli_args(texts_loc, output_dir, config_path, resume_path, epoch_resume)
msg.info(f"Loading config from: {config_path}")
with show_validation_error():
config = util.load_config(
config_path,
create_objects=False,
validate=True,
schema=ConfigSchema,
overrides=config_overrides,
)
if not output_dir.exists(): if not output_dir.exists():
output_dir.mkdir() output_dir.mkdir()
msg.good(f"Created output directory: {output_dir}") msg.good(f"Created output directory: {output_dir}")
use_gpu = config["training"]["use_gpu"]
if use_gpu >= 0: if use_gpu >= 0:
msg.info("Using GPU") msg.info("Using GPU")
require_gpu(use_gpu) require_gpu(use_gpu)
else: else:
msg.info("Using CPU") msg.info("Using CPU")
msg.info(f"Loading config from: {config_path}") seed = config["pretraining"]["seed"]
config = util.load_config(config_path, create_objects=False) if seed is not None:
fix_random_seed(config["pretraining"]["seed"]) fix_random_seed(seed)
if use_gpu >= 0 and config["pretraining"]["use_pytorch_for_gpu_memory"]: if use_gpu >= 0 and config["pretraining"]["use_pytorch_for_gpu_memory"]:
use_pytorch_for_gpu_memory() use_pytorch_for_gpu_memory()
@ -360,9 +376,7 @@ def _smart_round(figure, width=10, max_decimal=4):
return format_str % figure return format_str % figure
def verify_cli_args( def verify_cli_args(texts_loc, output_dir, config_path, resume_path, epoch_resume):
texts_loc, output_dir, config_path, use_gpu, resume_path, epoch_resume
):
if not config_path or not config_path.exists(): if not config_path or not config_path.exists():
msg.fail("Config file not found", config_path, exits=1) msg.fail("Config file not found", config_path, exits=1)
if output_dir.exists() and [p for p in output_dir.iterdir()]: if output_dir.exists() and [p for p in output_dir.iterdir()]:
@ -401,10 +415,3 @@ def verify_cli_args(
f"The argument --epoch-resume has to be greater or equal to 0. {epoch_resume} is invalid", f"The argument --epoch-resume has to be greater or equal to 0. {epoch_resume} is invalid",
exits=True, exits=True,
) )
config = util.load_config(config_path, create_objects=False)
if config["pretraining"]["objective"]["type"] == "vectors":
if not config["nlp"]["vectors"]:
msg.fail(
"Must specify nlp.vectors if pretraining.objective.type is vectors",
exits=True,
)

View File

@ -11,6 +11,7 @@ import random
import typer import typer
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
from ._util import import_code
from ..gold import Corpus, Example from ..gold import Corpus, Example
from ..lookups import Lookups from ..lookups import Lookups
from .. import util from .. import util
@ -53,17 +54,10 @@ def train_cli(
""" """
util.set_env_log(verbose) util.set_env_log(verbose)
verify_cli_args( verify_cli_args(
train_path=train_path, train_path=train_path, dev_path=dev_path, config_path=config_path,
dev_path=dev_path,
config_path=config_path,
code_path=code_path,
) )
overrides = parse_config_overrides(ctx.args) overrides = parse_config_overrides(ctx.args)
if code_path is not None: import_code(code_path)
try:
util.import_file("python_code", code_path)
except Exception as e:
msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1)
train( train(
config_path, config_path,
{"train": train_path, "dev": dev_path}, {"train": train_path, "dev": dev_path},
@ -503,7 +497,6 @@ def verify_cli_args(
dev_path: Path, dev_path: Path,
config_path: Path, config_path: Path,
output_path: Optional[Path] = None, output_path: Optional[Path] = None,
code_path: Optional[Path] = None,
): ):
# Make sure all files and paths exists if they are needed # Make sure all files and paths exists if they are needed
if not config_path or not config_path.exists(): if not config_path or not config_path.exists():
@ -524,9 +517,6 @@ def verify_cli_args(
"the specified output path doesn't exist, the directory will be " "the specified output path doesn't exist, the directory will be "
"created for you.", "created for you.",
) )
if code_path is not None:
if not code_path.exists():
msg.fail("Path to Python code not found", code_path, exits=1)
def verify_textcat_config(nlp, nlp_config): def verify_textcat_config(nlp, nlp_config):

View File

@ -2,7 +2,7 @@ from typing import Dict, List, Union, Optional, Sequence, Any
from enum import Enum from enum import Enum
from pydantic import BaseModel, Field, ValidationError, validator from pydantic import BaseModel, Field, ValidationError, validator
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
from pydantic import FilePath, DirectoryPath from pydantic import FilePath, DirectoryPath, root_validator
from collections import defaultdict from collections import defaultdict
from thinc.api import Model, Optimizer from thinc.api import Model, Optimizer
@ -242,6 +242,7 @@ class ConfigSchemaPipeline(BaseModel):
class ConfigSchemaNlp(BaseModel): class ConfigSchemaNlp(BaseModel):
lang: StrictStr = Field(..., title="The base language to use") lang: StrictStr = Field(..., title="The base language to use")
base_model: Optional[StrictStr] = Field(..., title="The base model to use")
vectors: Optional[DirectoryPath] = Field(..., title="Path to vectors") vectors: Optional[DirectoryPath] = Field(..., title="Path to vectors")
pipeline: Optional[ConfigSchemaPipeline] pipeline: Optional[ConfigSchemaPipeline]
@ -250,9 +251,40 @@ class ConfigSchemaNlp(BaseModel):
arbitrary_types_allowed = True arbitrary_types_allowed = True
class ConfigSchemaPretrain(BaseModel):
# fmt: off
max_epochs: StrictInt = Field(..., title="Maximum number of epochs to train for")
min_length: StrictInt = Field(..., title="Minimum length of examples")
max_length: StrictInt = Field(..., title="Maximum length of examples")
dropout: StrictFloat = Field(..., title="Dropout rate")
n_save_every: Optional[StrictInt] = Field(..., title="Saving frequency")
batch_size: Union[Sequence[int], int] = Field(..., title="The batch size or batch size schedule")
seed: Optional[StrictInt] = Field(..., title="Random seed")
use_pytorch_for_gpu_memory: StrictBool = Field(..., title="Allocate memory via PyTorch")
tok2vec_model: StrictStr = Field(..., title="tok2vec model in config, e.g. nlp.pipeline.tok2vec.model")
optimizer: Optimizer = Field(..., title="The optimizer to use")
# TODO: use a more detailed schema for this?
objective: Dict[str, Any] = Field(..., title="Pretraining objective")
# fmt: on
class Config:
extra = "forbid"
arbitrary_types_allowed = True
class ConfigSchema(BaseModel): class ConfigSchema(BaseModel):
training: ConfigSchemaTraining training: ConfigSchemaTraining
nlp: ConfigSchemaNlp nlp: ConfigSchemaNlp
pretraining: Optional[ConfigSchemaPretrain]
@root_validator
def validate_config(cls, values):
"""Perform additional validation for settings with dependencies."""
pt = values.get("pretraining")
if pt and pt.objective.get("type") == "vectors" and not values["nlp"].vectors:
err = "Need nlp.vectors if pretraining.objective.type is vectors"
raise ValueError(err)
return values
class Config: class Config:
extra = "allow" extra = "allow"

View File

@ -12,6 +12,8 @@ nlp_config_string = """
[nlp] [nlp]
lang = "en" lang = "en"
[nlp.pipeline]
[nlp.pipeline.tok2vec] [nlp.pipeline.tok2vec]
factory = "tok2vec" factory = "tok2vec"