mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-29 01:13:17 +03:00
Update config and commands
This commit is contained in:
parent
9e48ea48a1
commit
b7111da1d7
|
@ -48,7 +48,7 @@ use_averages = false
|
||||||
eps = 1e-8
|
eps = 1e-8
|
||||||
#learn_rate = 0.001
|
#learn_rate = 0.001
|
||||||
|
|
||||||
[optimizer.learn_rate]
|
[training.optimizer.learn_rate]
|
||||||
@schedules = "warmup_linear.v1"
|
@schedules = "warmup_linear.v1"
|
||||||
warmup_steps = 250
|
warmup_steps = 250
|
||||||
total_steps = 20000
|
total_steps = 20000
|
||||||
|
@ -56,8 +56,11 @@ initial_rate = 0.001
|
||||||
|
|
||||||
[nlp]
|
[nlp]
|
||||||
lang = "en"
|
lang = "en"
|
||||||
|
base_model = null
|
||||||
vectors = null
|
vectors = null
|
||||||
|
|
||||||
|
[nlp.pipeline]
|
||||||
|
|
||||||
[nlp.pipeline.tok2vec]
|
[nlp.pipeline.tok2vec]
|
||||||
factory = "tok2vec"
|
factory = "tok2vec"
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,11 @@ score_weights = {"las": 0.4, "ents_f": 0.4, "tags_acc": 0.2}
|
||||||
# These settings are invalid for the transformer models.
|
# These settings are invalid for the transformer models.
|
||||||
init_tok2vec = null
|
init_tok2vec = null
|
||||||
discard_oversize = false
|
discard_oversize = false
|
||||||
|
omit_extra_lookups = false
|
||||||
|
batch_by = "words"
|
||||||
|
use_gpu = -1
|
||||||
|
raw_text = null
|
||||||
|
tag_map = null
|
||||||
|
|
||||||
[training.batch_size]
|
[training.batch_size]
|
||||||
@schedules = "compounding.v1"
|
@schedules = "compounding.v1"
|
||||||
|
@ -72,6 +77,9 @@ learn_rate = 0.001
|
||||||
[nlp]
|
[nlp]
|
||||||
lang = "en"
|
lang = "en"
|
||||||
vectors = null
|
vectors = null
|
||||||
|
base_model = null
|
||||||
|
|
||||||
|
[nlp.pipeline]
|
||||||
|
|
||||||
[nlp.pipeline.tok2vec]
|
[nlp.pipeline.tok2vec]
|
||||||
factory = "tok2vec"
|
factory = "tok2vec"
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Dict, Any, Union, List
|
from typing import Dict, Any, Union, List, Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from wasabi import msg
|
from wasabi import msg
|
||||||
import srsly
|
import srsly
|
||||||
|
@ -11,6 +11,7 @@ from configparser import InterpolationError
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from ..schemas import ProjectConfigSchema, validate
|
from ..schemas import ProjectConfigSchema, validate
|
||||||
|
from ..util import import_file
|
||||||
|
|
||||||
|
|
||||||
PROJECT_FILE = "project.yml"
|
PROJECT_FILE = "project.yml"
|
||||||
|
@ -172,3 +173,16 @@ def show_validation_error(title: str = "Config validation error"):
|
||||||
msg.fail(title, spaced=True)
|
msg.fail(title, spaced=True)
|
||||||
print(str(e).replace("Config validation error", "").strip())
|
print(str(e).replace("Config validation error", "").strip())
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def import_code(code_path: Optional[Union[Path, str]]) -> None:
|
||||||
|
"""Helper to import Python file provided in training commands / commands
|
||||||
|
using the config. This makes custom registered functions available.
|
||||||
|
"""
|
||||||
|
if code_path is not None:
|
||||||
|
if not Path(code_path).exists():
|
||||||
|
msg.fail("Path to Python code not found", code_path, exits=1)
|
||||||
|
try:
|
||||||
|
import_file("python_code", code_path)
|
||||||
|
except Exception as e:
|
||||||
|
msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1)
|
||||||
|
|
|
@ -1,15 +1,18 @@
|
||||||
from typing import Optional, List, Sequence, Dict, Any, Tuple
|
from typing import List, Sequence, Dict, Any, Tuple, Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
import sys
|
import sys
|
||||||
import srsly
|
import srsly
|
||||||
from wasabi import Printer, MESSAGES
|
from wasabi import Printer, MESSAGES
|
||||||
|
import typer
|
||||||
|
|
||||||
from ._util import app, Arg, Opt
|
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
|
||||||
|
from ._util import import_code
|
||||||
|
from ..schemas import ConfigSchema
|
||||||
from ..gold import Corpus, Example
|
from ..gold import Corpus, Example
|
||||||
from ..syntax import nonproj
|
from ..syntax import nonproj
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
from ..util import load_model, get_lang_class
|
from .. import util
|
||||||
|
|
||||||
|
|
||||||
# Minimum number of expected occurrences of NER label in data to train new label
|
# Minimum number of expected occurrences of NER label in data to train new label
|
||||||
|
@ -24,12 +27,11 @@ BLANK_MODEL_THRESHOLD = 2000
|
||||||
@app.command("debug-data")
|
@app.command("debug-data")
|
||||||
def debug_data_cli(
|
def debug_data_cli(
|
||||||
# fmt: off
|
# fmt: off
|
||||||
lang: str = Arg(..., help="Model language"),
|
ctx: typer.Context, # This is only used to read additional arguments
|
||||||
train_path: Path = Arg(..., help="Location of JSON-formatted training data", exists=True),
|
train_path: Path = Arg(..., help="Location of JSON-formatted training data", exists=True),
|
||||||
dev_path: Path = Arg(..., help="Location of JSON-formatted development data", exists=True),
|
dev_path: Path = Arg(..., help="Location of JSON-formatted development data", exists=True),
|
||||||
tag_map_path: Optional[Path] = Opt(None, "--tag-map-path", "-tm", help="Location of JSON-formatted tag map", exists=True, dir_okay=False),
|
config_path: Path = Arg(..., help="Path to config file", exists=True),
|
||||||
base_model: Optional[str] = Opt(None, "--base-model", "-b", help="Name of model to update (optional)"),
|
code_path: Optional[Path] = Opt(None, "--code-path", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
||||||
pipeline: str = Opt("tagger,parser,ner", "--pipeline", "-p", help="Comma-separated names of pipeline components to train"),
|
|
||||||
ignore_warnings: bool = Opt(False, "--ignore-warnings", "-IW", help="Ignore warnings, only show stats and errors"),
|
ignore_warnings: bool = Opt(False, "--ignore-warnings", "-IW", help="Ignore warnings, only show stats and errors"),
|
||||||
verbose: bool = Opt(False, "--verbose", "-V", help="Print additional information and explanations"),
|
verbose: bool = Opt(False, "--verbose", "-V", help="Print additional information and explanations"),
|
||||||
no_format: bool = Opt(False, "--no-format", "-NF", help="Don't pretty-print the results"),
|
no_format: bool = Opt(False, "--no-format", "-NF", help="Don't pretty-print the results"),
|
||||||
|
@ -40,13 +42,13 @@ def debug_data_cli(
|
||||||
stats, and find problems like invalid entity annotations, cyclic
|
stats, and find problems like invalid entity annotations, cyclic
|
||||||
dependencies, low data labels and more.
|
dependencies, low data labels and more.
|
||||||
"""
|
"""
|
||||||
|
overrides = parse_config_overrides(ctx.args)
|
||||||
|
import_code(code_path)
|
||||||
debug_data(
|
debug_data(
|
||||||
lang,
|
|
||||||
train_path,
|
train_path,
|
||||||
dev_path,
|
dev_path,
|
||||||
tag_map_path=tag_map_path,
|
config_path,
|
||||||
base_model=base_model,
|
config_overrides=overrides,
|
||||||
pipeline=[p.strip() for p in pipeline.split(",")],
|
|
||||||
ignore_warnings=ignore_warnings,
|
ignore_warnings=ignore_warnings,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
no_format=no_format,
|
no_format=no_format,
|
||||||
|
@ -55,13 +57,11 @@ def debug_data_cli(
|
||||||
|
|
||||||
|
|
||||||
def debug_data(
|
def debug_data(
|
||||||
lang: str,
|
|
||||||
train_path: Path,
|
train_path: Path,
|
||||||
dev_path: Path,
|
dev_path: Path,
|
||||||
|
config_path: Path,
|
||||||
*,
|
*,
|
||||||
tag_map_path: Optional[Path] = None,
|
config_overrides: Dict[str, Any] = {},
|
||||||
base_model: Optional[str] = None,
|
|
||||||
pipeline: List[str] = ["tagger", "parser", "ner"],
|
|
||||||
ignore_warnings: bool = False,
|
ignore_warnings: bool = False,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
no_format: bool = True,
|
no_format: bool = True,
|
||||||
|
@ -75,25 +75,27 @@ def debug_data(
|
||||||
msg.fail("Training data not found", train_path, exits=1)
|
msg.fail("Training data not found", train_path, exits=1)
|
||||||
if not dev_path.exists():
|
if not dev_path.exists():
|
||||||
msg.fail("Development data not found", dev_path, exits=1)
|
msg.fail("Development data not found", dev_path, exits=1)
|
||||||
|
if not config_path.exists():
|
||||||
|
msg.fail("Config file not found", config_path, exists=1)
|
||||||
|
with show_validation_error():
|
||||||
|
config = util.load_config(
|
||||||
|
config_path,
|
||||||
|
create_objects=False,
|
||||||
|
schema=ConfigSchema,
|
||||||
|
overrides=config_overrides,
|
||||||
|
)
|
||||||
|
nlp = util.load_model_from_config(config["nlp"])
|
||||||
|
lang = config["nlp"]["lang"]
|
||||||
|
base_model = config["nlp"]["base_model"]
|
||||||
|
pipeline = list(config["nlp"]["pipeline"].keys())
|
||||||
|
tag_map_path = util.ensure_path(config["training"]["tag_map"])
|
||||||
tag_map = {}
|
tag_map = {}
|
||||||
if tag_map_path is not None:
|
if tag_map_path is not None:
|
||||||
tag_map = srsly.read_json(tag_map_path)
|
tag_map = srsly.read_json(tag_map_path)
|
||||||
|
|
||||||
# Initialize the model and pipeline
|
|
||||||
if base_model:
|
|
||||||
nlp = load_model(base_model)
|
|
||||||
else:
|
|
||||||
lang_cls = get_lang_class(lang)
|
|
||||||
nlp = lang_cls()
|
|
||||||
# Update tag map with provided mapping
|
# Update tag map with provided mapping
|
||||||
nlp.vocab.morphology.tag_map.update(tag_map)
|
nlp.vocab.morphology.tag_map.update(tag_map)
|
||||||
|
|
||||||
msg.divider("Data format validation")
|
msg.divider("Data file validation")
|
||||||
|
|
||||||
# TODO: Validate data format using the JSON schema
|
|
||||||
# TODO: update once the new format is ready
|
|
||||||
# TODO: move validation to GoldCorpus in order to be able to load from dir
|
|
||||||
|
|
||||||
# Create the gold corpus to be able to better analyze data
|
# Create the gold corpus to be able to better analyze data
|
||||||
loading_train_error_message = ""
|
loading_train_error_message = ""
|
||||||
|
@ -380,7 +382,7 @@ def debug_data(
|
||||||
if gold_dev_data["n_nonproj"] > 0:
|
if gold_dev_data["n_nonproj"] > 0:
|
||||||
n_nonproj = gold_dev_data["n_nonproj"]
|
n_nonproj = gold_dev_data["n_nonproj"]
|
||||||
msg.info(f"Found {n_nonproj} nonprojective dev sentence(s)")
|
msg.info(f"Found {n_nonproj} nonprojective dev sentence(s)")
|
||||||
msg.info(f"{labels_train_unpreprocessed} label(s) in train data")
|
msg.info(f"{len(labels_train_unpreprocessed)} label(s) in train data")
|
||||||
msg.info(f"{len(labels_train)} label(s) in projectivized train data")
|
msg.info(f"{len(labels_train)} label(s) in projectivized train data")
|
||||||
labels_with_counts = _format_labels(
|
labels_with_counts = _format_labels(
|
||||||
gold_train_unpreprocessed_data["deps"].most_common(), counts=True
|
gold_train_unpreprocessed_data["deps"].most_common(), counts=True
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Optional
|
from typing import Optional, Dict, Any
|
||||||
import random
|
import random
|
||||||
import numpy
|
import numpy
|
||||||
import time
|
import time
|
||||||
|
@ -11,8 +11,11 @@ from thinc.api import CosineDistance, L2Distance
|
||||||
from wasabi import msg
|
from wasabi import msg
|
||||||
import srsly
|
import srsly
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
import typer
|
||||||
|
|
||||||
from ._util import app, Arg, Opt
|
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
|
||||||
|
from ._util import import_code
|
||||||
|
from ..schemas import ConfigSchema
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
from ..ml.models.multi_task import build_cloze_multi_task_model
|
from ..ml.models.multi_task import build_cloze_multi_task_model
|
||||||
from ..ml.models.multi_task import build_cloze_characters_multi_task_model
|
from ..ml.models.multi_task import build_cloze_characters_multi_task_model
|
||||||
|
@ -24,10 +27,11 @@ from .. import util
|
||||||
@app.command("pretrain")
|
@app.command("pretrain")
|
||||||
def pretrain_cli(
|
def pretrain_cli(
|
||||||
# fmt: off
|
# fmt: off
|
||||||
|
ctx: typer.Context, # This is only used to read additional arguments
|
||||||
texts_loc: Path = Arg(..., help="Path to JSONL file with raw texts to learn from, with text provided as the key 'text' or tokens as the key 'tokens'", exists=True),
|
texts_loc: Path = Arg(..., help="Path to JSONL file with raw texts to learn from, with text provided as the key 'text' or tokens as the key 'tokens'", exists=True),
|
||||||
output_dir: Path = Arg(..., help="Directory to write models to on each epoch"),
|
output_dir: Path = Arg(..., help="Directory to write models to on each epoch"),
|
||||||
config_path: Path = Arg(..., help="Path to config file", exists=True, dir_okay=False),
|
config_path: Path = Arg(..., help="Path to config file", exists=True, dir_okay=False),
|
||||||
use_gpu: int = Opt(-1, "--use-gpu", "-g", help="Use GPU"),
|
code_path: Optional[Path] = Opt(None, "--code-path", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
||||||
resume_path: Optional[Path] = Opt(None, "--resume-path", "-r", help="Path to pretrained weights from which to resume pretraining"),
|
resume_path: Optional[Path] = Opt(None, "--resume-path", "-r", help="Path to pretrained weights from which to resume pretraining"),
|
||||||
epoch_resume: Optional[int] = Opt(None, "--epoch-resume", "-er", help="The epoch to resume counting from when using '--resume_path'. Prevents unintended overwriting of existing weight files."),
|
epoch_resume: Optional[int] = Opt(None, "--epoch-resume", "-er", help="The epoch to resume counting from when using '--resume_path'. Prevents unintended overwriting of existing weight files."),
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
@ -51,11 +55,13 @@ def pretrain_cli(
|
||||||
all settings are the same between pretraining and training. Ideally,
|
all settings are the same between pretraining and training. Ideally,
|
||||||
this is done by using the same config file for both commands.
|
this is done by using the same config file for both commands.
|
||||||
"""
|
"""
|
||||||
|
overrides = parse_config_overrides(ctx.args)
|
||||||
|
import_code(code_path)
|
||||||
pretrain(
|
pretrain(
|
||||||
texts_loc,
|
texts_loc,
|
||||||
output_dir,
|
output_dir,
|
||||||
config_path,
|
config_path,
|
||||||
use_gpu=use_gpu,
|
config_overrides=overrides,
|
||||||
resume_path=resume_path,
|
resume_path=resume_path,
|
||||||
epoch_resume=epoch_resume,
|
epoch_resume=epoch_resume,
|
||||||
)
|
)
|
||||||
|
@ -65,24 +71,34 @@ def pretrain(
|
||||||
texts_loc: Path,
|
texts_loc: Path,
|
||||||
output_dir: Path,
|
output_dir: Path,
|
||||||
config_path: Path,
|
config_path: Path,
|
||||||
use_gpu: int = -1,
|
config_overrides: Dict[str, Any] = {},
|
||||||
resume_path: Optional[Path] = None,
|
resume_path: Optional[Path] = None,
|
||||||
epoch_resume: Optional[int] = None,
|
epoch_resume: Optional[int] = None,
|
||||||
):
|
):
|
||||||
verify_cli_args(**locals())
|
verify_cli_args(texts_loc, output_dir, config_path, resume_path, epoch_resume)
|
||||||
|
msg.info(f"Loading config from: {config_path}")
|
||||||
|
with show_validation_error():
|
||||||
|
config = util.load_config(
|
||||||
|
config_path,
|
||||||
|
create_objects=False,
|
||||||
|
validate=True,
|
||||||
|
schema=ConfigSchema,
|
||||||
|
overrides=config_overrides,
|
||||||
|
)
|
||||||
if not output_dir.exists():
|
if not output_dir.exists():
|
||||||
output_dir.mkdir()
|
output_dir.mkdir()
|
||||||
msg.good(f"Created output directory: {output_dir}")
|
msg.good(f"Created output directory: {output_dir}")
|
||||||
|
|
||||||
|
use_gpu = config["training"]["use_gpu"]
|
||||||
if use_gpu >= 0:
|
if use_gpu >= 0:
|
||||||
msg.info("Using GPU")
|
msg.info("Using GPU")
|
||||||
require_gpu(use_gpu)
|
require_gpu(use_gpu)
|
||||||
else:
|
else:
|
||||||
msg.info("Using CPU")
|
msg.info("Using CPU")
|
||||||
|
|
||||||
msg.info(f"Loading config from: {config_path}")
|
seed = config["pretraining"]["seed"]
|
||||||
config = util.load_config(config_path, create_objects=False)
|
if seed is not None:
|
||||||
fix_random_seed(config["pretraining"]["seed"])
|
fix_random_seed(seed)
|
||||||
if use_gpu >= 0 and config["pretraining"]["use_pytorch_for_gpu_memory"]:
|
if use_gpu >= 0 and config["pretraining"]["use_pytorch_for_gpu_memory"]:
|
||||||
use_pytorch_for_gpu_memory()
|
use_pytorch_for_gpu_memory()
|
||||||
|
|
||||||
|
@ -360,9 +376,7 @@ def _smart_round(figure, width=10, max_decimal=4):
|
||||||
return format_str % figure
|
return format_str % figure
|
||||||
|
|
||||||
|
|
||||||
def verify_cli_args(
|
def verify_cli_args(texts_loc, output_dir, config_path, resume_path, epoch_resume):
|
||||||
texts_loc, output_dir, config_path, use_gpu, resume_path, epoch_resume
|
|
||||||
):
|
|
||||||
if not config_path or not config_path.exists():
|
if not config_path or not config_path.exists():
|
||||||
msg.fail("Config file not found", config_path, exits=1)
|
msg.fail("Config file not found", config_path, exits=1)
|
||||||
if output_dir.exists() and [p for p in output_dir.iterdir()]:
|
if output_dir.exists() and [p for p in output_dir.iterdir()]:
|
||||||
|
@ -401,10 +415,3 @@ def verify_cli_args(
|
||||||
f"The argument --epoch-resume has to be greater or equal to 0. {epoch_resume} is invalid",
|
f"The argument --epoch-resume has to be greater or equal to 0. {epoch_resume} is invalid",
|
||||||
exits=True,
|
exits=True,
|
||||||
)
|
)
|
||||||
config = util.load_config(config_path, create_objects=False)
|
|
||||||
if config["pretraining"]["objective"]["type"] == "vectors":
|
|
||||||
if not config["nlp"]["vectors"]:
|
|
||||||
msg.fail(
|
|
||||||
"Must specify nlp.vectors if pretraining.objective.type is vectors",
|
|
||||||
exits=True,
|
|
||||||
)
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ import random
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
|
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
|
||||||
|
from ._util import import_code
|
||||||
from ..gold import Corpus, Example
|
from ..gold import Corpus, Example
|
||||||
from ..lookups import Lookups
|
from ..lookups import Lookups
|
||||||
from .. import util
|
from .. import util
|
||||||
|
@ -53,17 +54,10 @@ def train_cli(
|
||||||
"""
|
"""
|
||||||
util.set_env_log(verbose)
|
util.set_env_log(verbose)
|
||||||
verify_cli_args(
|
verify_cli_args(
|
||||||
train_path=train_path,
|
train_path=train_path, dev_path=dev_path, config_path=config_path,
|
||||||
dev_path=dev_path,
|
|
||||||
config_path=config_path,
|
|
||||||
code_path=code_path,
|
|
||||||
)
|
)
|
||||||
overrides = parse_config_overrides(ctx.args)
|
overrides = parse_config_overrides(ctx.args)
|
||||||
if code_path is not None:
|
import_code(code_path)
|
||||||
try:
|
|
||||||
util.import_file("python_code", code_path)
|
|
||||||
except Exception as e:
|
|
||||||
msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1)
|
|
||||||
train(
|
train(
|
||||||
config_path,
|
config_path,
|
||||||
{"train": train_path, "dev": dev_path},
|
{"train": train_path, "dev": dev_path},
|
||||||
|
@ -503,7 +497,6 @@ def verify_cli_args(
|
||||||
dev_path: Path,
|
dev_path: Path,
|
||||||
config_path: Path,
|
config_path: Path,
|
||||||
output_path: Optional[Path] = None,
|
output_path: Optional[Path] = None,
|
||||||
code_path: Optional[Path] = None,
|
|
||||||
):
|
):
|
||||||
# Make sure all files and paths exists if they are needed
|
# Make sure all files and paths exists if they are needed
|
||||||
if not config_path or not config_path.exists():
|
if not config_path or not config_path.exists():
|
||||||
|
@ -524,9 +517,6 @@ def verify_cli_args(
|
||||||
"the specified output path doesn't exist, the directory will be "
|
"the specified output path doesn't exist, the directory will be "
|
||||||
"created for you.",
|
"created for you.",
|
||||||
)
|
)
|
||||||
if code_path is not None:
|
|
||||||
if not code_path.exists():
|
|
||||||
msg.fail("Path to Python code not found", code_path, exits=1)
|
|
||||||
|
|
||||||
|
|
||||||
def verify_textcat_config(nlp, nlp_config):
|
def verify_textcat_config(nlp, nlp_config):
|
||||||
|
|
|
@ -2,7 +2,7 @@ from typing import Dict, List, Union, Optional, Sequence, Any
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pydantic import BaseModel, Field, ValidationError, validator
|
from pydantic import BaseModel, Field, ValidationError, validator
|
||||||
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
|
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
|
||||||
from pydantic import FilePath, DirectoryPath
|
from pydantic import FilePath, DirectoryPath, root_validator
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from thinc.api import Model, Optimizer
|
from thinc.api import Model, Optimizer
|
||||||
|
|
||||||
|
@ -242,6 +242,7 @@ class ConfigSchemaPipeline(BaseModel):
|
||||||
|
|
||||||
class ConfigSchemaNlp(BaseModel):
|
class ConfigSchemaNlp(BaseModel):
|
||||||
lang: StrictStr = Field(..., title="The base language to use")
|
lang: StrictStr = Field(..., title="The base language to use")
|
||||||
|
base_model: Optional[StrictStr] = Field(..., title="The base model to use")
|
||||||
vectors: Optional[DirectoryPath] = Field(..., title="Path to vectors")
|
vectors: Optional[DirectoryPath] = Field(..., title="Path to vectors")
|
||||||
pipeline: Optional[ConfigSchemaPipeline]
|
pipeline: Optional[ConfigSchemaPipeline]
|
||||||
|
|
||||||
|
@ -250,9 +251,40 @@ class ConfigSchemaNlp(BaseModel):
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigSchemaPretrain(BaseModel):
|
||||||
|
# fmt: off
|
||||||
|
max_epochs: StrictInt = Field(..., title="Maximum number of epochs to train for")
|
||||||
|
min_length: StrictInt = Field(..., title="Minimum length of examples")
|
||||||
|
max_length: StrictInt = Field(..., title="Maximum length of examples")
|
||||||
|
dropout: StrictFloat = Field(..., title="Dropout rate")
|
||||||
|
n_save_every: Optional[StrictInt] = Field(..., title="Saving frequency")
|
||||||
|
batch_size: Union[Sequence[int], int] = Field(..., title="The batch size or batch size schedule")
|
||||||
|
seed: Optional[StrictInt] = Field(..., title="Random seed")
|
||||||
|
use_pytorch_for_gpu_memory: StrictBool = Field(..., title="Allocate memory via PyTorch")
|
||||||
|
tok2vec_model: StrictStr = Field(..., title="tok2vec model in config, e.g. nlp.pipeline.tok2vec.model")
|
||||||
|
optimizer: Optimizer = Field(..., title="The optimizer to use")
|
||||||
|
# TODO: use a more detailed schema for this?
|
||||||
|
objective: Dict[str, Any] = Field(..., title="Pretraining objective")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "forbid"
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
class ConfigSchema(BaseModel):
|
class ConfigSchema(BaseModel):
|
||||||
training: ConfigSchemaTraining
|
training: ConfigSchemaTraining
|
||||||
nlp: ConfigSchemaNlp
|
nlp: ConfigSchemaNlp
|
||||||
|
pretraining: Optional[ConfigSchemaPretrain]
|
||||||
|
|
||||||
|
@root_validator
|
||||||
|
def validate_config(cls, values):
|
||||||
|
"""Perform additional validation for settings with dependencies."""
|
||||||
|
pt = values.get("pretraining")
|
||||||
|
if pt and pt.objective.get("type") == "vectors" and not values["nlp"].vectors:
|
||||||
|
err = "Need nlp.vectors if pretraining.objective.type is vectors"
|
||||||
|
raise ValueError(err)
|
||||||
|
return values
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
extra = "allow"
|
extra = "allow"
|
||||||
|
|
|
@ -12,6 +12,8 @@ nlp_config_string = """
|
||||||
[nlp]
|
[nlp]
|
||||||
lang = "en"
|
lang = "en"
|
||||||
|
|
||||||
|
[nlp.pipeline]
|
||||||
|
|
||||||
[nlp.pipeline.tok2vec]
|
[nlp.pipeline.tok2vec]
|
||||||
factory = "tok2vec"
|
factory = "tok2vec"
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user