mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Merge pull request #5747 from explosion/feature/refactor-config-args
This commit is contained in:
commit
872938ec76
|
@ -26,6 +26,10 @@ score_weights = {"las": 0.4, "ents_f": 0.4, "tags_acc": 0.2}
|
|||
init_tok2vec = null
|
||||
discard_oversize = false
|
||||
omit_extra_lookups = false
|
||||
batch_by = "words"
|
||||
use_gpu = -1
|
||||
raw_text = null
|
||||
tag_map = null
|
||||
|
||||
[training.batch_size]
|
||||
@schedules = "compounding.v1"
|
||||
|
@ -44,7 +48,7 @@ use_averages = false
|
|||
eps = 1e-8
|
||||
#learn_rate = 0.001
|
||||
|
||||
[optimizer.learn_rate]
|
||||
[training.optimizer.learn_rate]
|
||||
@schedules = "warmup_linear.v1"
|
||||
warmup_steps = 250
|
||||
total_steps = 20000
|
||||
|
@ -52,8 +56,11 @@ initial_rate = 0.001
|
|||
|
||||
[nlp]
|
||||
lang = "en"
|
||||
base_model = null
|
||||
vectors = null
|
||||
|
||||
[nlp.pipeline]
|
||||
|
||||
[nlp.pipeline.tok2vec]
|
||||
factory = "tok2vec"
|
||||
|
||||
|
|
|
@ -25,6 +25,11 @@ score_weights = {"las": 0.4, "ents_f": 0.4, "tags_acc": 0.2}
|
|||
# These settings are invalid for the transformer models.
|
||||
init_tok2vec = null
|
||||
discard_oversize = false
|
||||
omit_extra_lookups = false
|
||||
batch_by = "words"
|
||||
use_gpu = -1
|
||||
raw_text = null
|
||||
tag_map = null
|
||||
|
||||
[training.batch_size]
|
||||
@schedules = "compounding.v1"
|
||||
|
@ -72,6 +77,9 @@ learn_rate = 0.001
|
|||
[nlp]
|
||||
lang = "en"
|
||||
vectors = null
|
||||
base_model = null
|
||||
|
||||
[nlp.pipeline]
|
||||
|
||||
[nlp.pipeline.tok2vec]
|
||||
factory = "tok2vec"
|
||||
|
|
|
@ -6,7 +6,7 @@ requires = [
|
|||
"cymem>=2.0.2,<2.1.0",
|
||||
"preshed>=3.0.2,<3.1.0",
|
||||
"murmurhash>=0.28.0,<1.1.0",
|
||||
"thinc>=8.0.0a12,<8.0.0a20",
|
||||
"thinc>=8.0.0a17,<8.0.0a20",
|
||||
"blis>=0.4.0,<0.5.0",
|
||||
"pytokenizations"
|
||||
]
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Our libraries
|
||||
cymem>=2.0.2,<2.1.0
|
||||
preshed>=3.0.2,<3.1.0
|
||||
thinc>=8.0.0a12,<8.0.0a20
|
||||
thinc>=8.0.0a17,<8.0.0a20
|
||||
blis>=0.4.0,<0.5.0
|
||||
ml_datasets>=0.1.1
|
||||
murmurhash>=0.28.0,<1.1.0
|
||||
|
|
|
@ -34,13 +34,13 @@ setup_requires =
|
|||
cymem>=2.0.2,<2.1.0
|
||||
preshed>=3.0.2,<3.1.0
|
||||
murmurhash>=0.28.0,<1.1.0
|
||||
thinc>=8.0.0a12,<8.0.0a20
|
||||
thinc>=8.0.0a17,<8.0.0a20
|
||||
install_requires =
|
||||
# Our libraries
|
||||
murmurhash>=0.28.0,<1.1.0
|
||||
cymem>=2.0.2,<2.1.0
|
||||
preshed>=3.0.2,<3.1.0
|
||||
thinc>=8.0.0a12,<8.0.0a20
|
||||
thinc>=8.0.0a17,<8.0.0a20
|
||||
blis>=0.4.0,<0.5.0
|
||||
wasabi>=0.7.0,<1.1.0
|
||||
srsly>=2.1.0,<3.0.0
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from wasabi import msg
|
||||
|
||||
from ._app import app, setup_cli # noqa: F401
|
||||
from ._util import app, setup_cli # noqa: F401
|
||||
|
||||
# These are the actual functions, NOT the wrapped CLI commands. The CLI commands
|
||||
# are registered automatically and won't have to be imported here.
|
||||
|
|
|
@ -1,31 +0,0 @@
|
|||
import typer
|
||||
from typer.main import get_command
|
||||
|
||||
|
||||
COMMAND = "python -m spacy"
|
||||
NAME = "spacy"
|
||||
HELP = """spaCy Command-line Interface
|
||||
|
||||
DOCS: https://spacy.io/api/cli
|
||||
"""
|
||||
PROJECT_HELP = f"""Command-line interface for spaCy projects and working with
|
||||
project templates. You'd typically start by cloning a project template to a local
|
||||
directory and fetching its assets like datasets etc. See the project's
|
||||
project.yml for the available commands.
|
||||
"""
|
||||
|
||||
|
||||
app = typer.Typer(name=NAME, help=HELP)
|
||||
project_cli = typer.Typer(name="project", help=PROJECT_HELP, no_args_is_help=True)
|
||||
app.add_typer(project_cli)
|
||||
|
||||
# Wrappers for Typer's annotations. Initially created to set defaults and to
|
||||
# keep the names short, but not needed at the moment.
|
||||
Arg = typer.Argument
|
||||
Opt = typer.Option
|
||||
|
||||
|
||||
def setup_cli() -> None:
|
||||
# Ensure that the help messages always display the correct prompt
|
||||
command = get_command(app)
|
||||
command(prog_name=COMMAND)
|
195
spacy/cli/_util.py
Normal file
195
spacy/cli/_util.py
Normal file
|
@ -0,0 +1,195 @@
|
|||
from typing import Dict, Any, Union, List, Optional
|
||||
from pathlib import Path
|
||||
from wasabi import msg
|
||||
import srsly
|
||||
import hashlib
|
||||
import typer
|
||||
from typer.main import get_command
|
||||
from contextlib import contextmanager
|
||||
from thinc.config import ConfigValidationError
|
||||
from configparser import InterpolationError
|
||||
import sys
|
||||
|
||||
from ..schemas import ProjectConfigSchema, validate
|
||||
from ..util import import_file
|
||||
|
||||
|
||||
PROJECT_FILE = "project.yml"
|
||||
PROJECT_LOCK = "project.lock"
|
||||
COMMAND = "python -m spacy"
|
||||
NAME = "spacy"
|
||||
HELP = """spaCy Command-line Interface
|
||||
|
||||
DOCS: https://spacy.io/api/cli
|
||||
"""
|
||||
PROJECT_HELP = f"""Command-line interface for spaCy projects and templates.
|
||||
You'd typically start by cloning a project template to a local directory and
|
||||
fetching its assets like datasets etc. See the project's {PROJECT_FILE} for the
|
||||
available commands.
|
||||
"""
|
||||
DEBUG_HELP = """Suite of helpful commands for debugging and profiling. Includes
|
||||
commands to check and validate your config files, training and evaluation data,
|
||||
and custom model implementations.
|
||||
"""
|
||||
|
||||
# Wrappers for Typer's annotations. Initially created to set defaults and to
|
||||
# keep the names short, but not needed at the moment.
|
||||
Arg = typer.Argument
|
||||
Opt = typer.Option
|
||||
|
||||
app = typer.Typer(name=NAME, help=HELP)
|
||||
project_cli = typer.Typer(name="project", help=PROJECT_HELP, no_args_is_help=True)
|
||||
debug_cli = typer.Typer(name="debug", help=DEBUG_HELP, no_args_is_help=True)
|
||||
|
||||
app.add_typer(project_cli)
|
||||
app.add_typer(debug_cli)
|
||||
|
||||
|
||||
def setup_cli() -> None:
|
||||
# Ensure that the help messages always display the correct prompt
|
||||
command = get_command(app)
|
||||
command(prog_name=COMMAND)
|
||||
|
||||
|
||||
def parse_config_overrides(args: List[str]) -> Dict[str, Any]:
|
||||
"""Generate a dictionary of config overrides based on the extra arguments
|
||||
provided on the CLI, e.g. --training.batch_size to override
|
||||
"training.batch_size". Arguments without a "." are considered invalid,
|
||||
since the config only allows top-level sections to exist.
|
||||
|
||||
args (List[str]): The extra arguments from the command line.
|
||||
RETURNS (Dict[str, Any]): The parsed dict, keyed by nested config setting.
|
||||
"""
|
||||
result = {}
|
||||
while args:
|
||||
opt = args.pop(0)
|
||||
err = f"Invalid config override '{opt}'"
|
||||
if opt.startswith("--"): # new argument
|
||||
opt = opt.replace("--", "").replace("-", "_")
|
||||
if "." not in opt:
|
||||
msg.fail(f"{err}: can't override top-level section", exits=1)
|
||||
if not args or args[0].startswith("--"): # flag with no value
|
||||
value = "true"
|
||||
else:
|
||||
value = args.pop(0)
|
||||
# Just like we do in the config, we're calling json.loads on the
|
||||
# values. But since they come from the CLI, it'd b unintuitive to
|
||||
# explicitly mark strings with escaped quotes. So we're working
|
||||
# around that here by falling back to a string if parsing fails.
|
||||
# TODO: improve logic to handle simple types like list of strings?
|
||||
try:
|
||||
result[opt] = srsly.json_loads(value)
|
||||
except ValueError:
|
||||
result[opt] = str(value)
|
||||
else:
|
||||
msg.fail(f"{err}: options need to start with --", exits=1)
|
||||
return result
|
||||
|
||||
|
||||
def load_project_config(path: Path) -> Dict[str, Any]:
|
||||
"""Load the project.yml file from a directory and validate it. Also make
|
||||
sure that all directories defined in the config exist.
|
||||
|
||||
path (Path): The path to the project directory.
|
||||
RETURNS (Dict[str, Any]): The loaded project.yml.
|
||||
"""
|
||||
config_path = path / PROJECT_FILE
|
||||
if not config_path.exists():
|
||||
msg.fail(f"Can't find {PROJECT_FILE}", config_path, exits=1)
|
||||
invalid_err = f"Invalid {PROJECT_FILE}. Double-check that the YAML is correct."
|
||||
try:
|
||||
config = srsly.read_yaml(config_path)
|
||||
except ValueError as e:
|
||||
msg.fail(invalid_err, e, exits=1)
|
||||
errors = validate(ProjectConfigSchema, config)
|
||||
if errors:
|
||||
msg.fail(invalid_err, "\n".join(errors), exits=1)
|
||||
validate_project_commands(config)
|
||||
# Make sure directories defined in config exist
|
||||
for subdir in config.get("directories", []):
|
||||
dir_path = path / subdir
|
||||
if not dir_path.exists():
|
||||
dir_path.mkdir(parents=True)
|
||||
return config
|
||||
|
||||
|
||||
def validate_project_commands(config: Dict[str, Any]) -> None:
|
||||
"""Check that project commands and workflows are valid, don't contain
|
||||
duplicates, don't clash and only refer to commands that exist.
|
||||
|
||||
config (Dict[str, Any]): The loaded config.
|
||||
"""
|
||||
command_names = [cmd["name"] for cmd in config.get("commands", [])]
|
||||
workflows = config.get("workflows", {})
|
||||
duplicates = set([cmd for cmd in command_names if command_names.count(cmd) > 1])
|
||||
if duplicates:
|
||||
err = f"Duplicate commands defined in {PROJECT_FILE}: {', '.join(duplicates)}"
|
||||
msg.fail(err, exits=1)
|
||||
for workflow_name, workflow_steps in workflows.items():
|
||||
if workflow_name in command_names:
|
||||
err = f"Can't use workflow name '{workflow_name}': name already exists as a command"
|
||||
msg.fail(err, exits=1)
|
||||
for step in workflow_steps:
|
||||
if step not in command_names:
|
||||
msg.fail(
|
||||
f"Unknown command specified in workflow '{workflow_name}': {step}",
|
||||
f"Workflows can only refer to commands defined in the 'commands' "
|
||||
f"section of the {PROJECT_FILE}.",
|
||||
exits=1,
|
||||
)
|
||||
|
||||
|
||||
def get_hash(data) -> str:
|
||||
"""Get the hash for a JSON-serializable object.
|
||||
|
||||
data: The data to hash.
|
||||
RETURNS (str): The hash.
|
||||
"""
|
||||
data_str = srsly.json_dumps(data, sort_keys=True).encode("utf8")
|
||||
return hashlib.md5(data_str).hexdigest()
|
||||
|
||||
|
||||
def get_checksum(path: Union[Path, str]) -> str:
|
||||
"""Get the checksum for a file or directory given its file path. If a
|
||||
directory path is provided, this uses all files in that directory.
|
||||
|
||||
path (Union[Path, str]): The file or directory path.
|
||||
RETURNS (str): The checksum.
|
||||
"""
|
||||
path = Path(path)
|
||||
if path.is_file():
|
||||
return hashlib.md5(Path(path).read_bytes()).hexdigest()
|
||||
if path.is_dir():
|
||||
# TODO: this is currently pretty slow
|
||||
dir_checksum = hashlib.md5()
|
||||
for sub_file in sorted(fp for fp in path.rglob("*") if fp.is_file()):
|
||||
dir_checksum.update(sub_file.read_bytes())
|
||||
return dir_checksum.hexdigest()
|
||||
raise ValueError(f"Can't get checksum for {path}: not a file or directory")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def show_validation_error(title: str = "Config validation error"):
|
||||
"""Helper to show custom config validation errors on the CLI.
|
||||
|
||||
title (str): Title of the custom formatted error.
|
||||
"""
|
||||
try:
|
||||
yield
|
||||
except (ConfigValidationError, InterpolationError) as e:
|
||||
msg.fail(title, spaced=True)
|
||||
print(str(e).replace("Config validation error", "").strip())
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def import_code(code_path: Optional[Union[Path, str]]) -> None:
|
||||
"""Helper to import Python file provided in training commands / commands
|
||||
using the config. This makes custom registered functions available.
|
||||
"""
|
||||
if code_path is not None:
|
||||
if not Path(code_path).exists():
|
||||
msg.fail("Path to Python code not found", code_path, exits=1)
|
||||
try:
|
||||
import_file("python_code", code_path)
|
||||
except Exception as e:
|
||||
msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1)
|
|
@ -6,7 +6,7 @@ import srsly
|
|||
import re
|
||||
import sys
|
||||
|
||||
from ._app import app, Arg, Opt
|
||||
from ._util import app, Arg, Opt
|
||||
from ..gold import docs_to_json
|
||||
from ..tokens import DocBin
|
||||
from ..gold.converters import iob2docs, conll_ner2docs, json2docs, conllu2docs
|
||||
|
@ -53,10 +53,13 @@ def convert_cli(
|
|||
# fmt: on
|
||||
):
|
||||
"""
|
||||
Convert files into json or DocBin format for use with train command and other
|
||||
experiment management functions. If no output_dir is specified, the data
|
||||
Convert files into json or DocBin format for training. The resulting .spacy
|
||||
file can be used with the train command and other experiment management
|
||||
functions.
|
||||
|
||||
If no output_dir is specified and the output format is JSON, the data
|
||||
is written to stdout, so you can pipe them forward to a JSON file:
|
||||
$ spacy convert some_file.conllu > some_file.json
|
||||
$ spacy convert some_file.conllu --file-type json > some_file.json
|
||||
"""
|
||||
if isinstance(file_type, FileTypes):
|
||||
# We get an instance of the FileTypes from the CLI so we need its string value
|
||||
|
|
|
@ -1,15 +1,18 @@
|
|||
from typing import Optional, List, Sequence, Dict, Any, Tuple
|
||||
from typing import List, Sequence, Dict, Any, Tuple, Optional
|
||||
from pathlib import Path
|
||||
from collections import Counter
|
||||
import sys
|
||||
import srsly
|
||||
from wasabi import Printer, MESSAGES
|
||||
from wasabi import Printer, MESSAGES, msg
|
||||
import typer
|
||||
|
||||
from ._app import app, Arg, Opt
|
||||
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
|
||||
from ._util import import_code, debug_cli
|
||||
from ..schemas import ConfigSchema
|
||||
from ..gold import Corpus, Example
|
||||
from ..syntax import nonproj
|
||||
from ..language import Language
|
||||
from ..util import load_model, get_lang_class
|
||||
from .. import util
|
||||
|
||||
|
||||
# Minimum number of expected occurrences of NER label in data to train new label
|
||||
|
@ -21,32 +24,70 @@ BLANK_MODEL_MIN_THRESHOLD = 100
|
|||
BLANK_MODEL_THRESHOLD = 2000
|
||||
|
||||
|
||||
@app.command("debug-data")
|
||||
@debug_cli.command(
|
||||
"config",
|
||||
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
||||
)
|
||||
def debug_config_cli(
|
||||
# fmt: off
|
||||
ctx: typer.Context, # This is only used to read additional arguments
|
||||
config_path: Path = Arg(..., help="Path to config file", exists=True),
|
||||
code_path: Optional[Path] = Opt(None, "--code-path", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
||||
# fmt: on
|
||||
):
|
||||
"""Debug a config.cfg file and show validation errors. The command will
|
||||
create all objects in the tree and validate them. Note that some config
|
||||
validation errors are blocking and will prevent the rest of the config from
|
||||
being resolved. This means that you may not see all validation errors at
|
||||
once and some issues are only shown once previous errors have been fixed.
|
||||
"""
|
||||
overrides = parse_config_overrides(ctx.args)
|
||||
import_code(code_path)
|
||||
with show_validation_error():
|
||||
util.load_config(
|
||||
config_path, create_objects=False, schema=ConfigSchema, overrides=overrides,
|
||||
)
|
||||
msg.good("Config is valid")
|
||||
|
||||
|
||||
@debug_cli.command(
|
||||
"data", context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
||||
)
|
||||
@app.command(
|
||||
"debug-data",
|
||||
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
||||
hidden=True, # hide this from main CLI help but still allow it to work with warning
|
||||
)
|
||||
def debug_data_cli(
|
||||
# fmt: off
|
||||
lang: str = Arg(..., help="Model language"),
|
||||
ctx: typer.Context, # This is only used to read additional arguments
|
||||
train_path: Path = Arg(..., help="Location of JSON-formatted training data", exists=True),
|
||||
dev_path: Path = Arg(..., help="Location of JSON-formatted development data", exists=True),
|
||||
tag_map_path: Optional[Path] = Opt(None, "--tag-map-path", "-tm", help="Location of JSON-formatted tag map", exists=True, dir_okay=False),
|
||||
base_model: Optional[str] = Opt(None, "--base-model", "-b", help="Name of model to update (optional)"),
|
||||
pipeline: str = Opt("tagger,parser,ner", "--pipeline", "-p", help="Comma-separated names of pipeline components to train"),
|
||||
config_path: Path = Arg(..., help="Path to config file", exists=True),
|
||||
code_path: Optional[Path] = Opt(None, "--code-path", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
||||
ignore_warnings: bool = Opt(False, "--ignore-warnings", "-IW", help="Ignore warnings, only show stats and errors"),
|
||||
verbose: bool = Opt(False, "--verbose", "-V", help="Print additional information and explanations"),
|
||||
no_format: bool = Opt(False, "--no-format", "-NF", help="Don't pretty-print the results"),
|
||||
# fmt: on
|
||||
):
|
||||
"""
|
||||
Analyze, debug and validate your training and development data, get useful
|
||||
stats, and find problems like invalid entity annotations, cyclic
|
||||
dependencies, low data labels and more.
|
||||
Analyze, debug and validate your training and development data. Outputs
|
||||
useful stats, and can help you find problems like invalid entity annotations,
|
||||
cyclic dependencies, low data labels and more.
|
||||
"""
|
||||
if ctx.command.name == "debug-data":
|
||||
msg.warn(
|
||||
"The debug-data command is now available via the 'debug data' "
|
||||
"subcommand (without the hyphen). You can run python -m spacy debug "
|
||||
"--help for an overview of the other available debugging commands."
|
||||
)
|
||||
overrides = parse_config_overrides(ctx.args)
|
||||
import_code(code_path)
|
||||
debug_data(
|
||||
lang,
|
||||
train_path,
|
||||
dev_path,
|
||||
tag_map_path=tag_map_path,
|
||||
base_model=base_model,
|
||||
pipeline=[p.strip() for p in pipeline.split(",")],
|
||||
config_path,
|
||||
config_overrides=overrides,
|
||||
ignore_warnings=ignore_warnings,
|
||||
verbose=verbose,
|
||||
no_format=no_format,
|
||||
|
@ -55,13 +96,11 @@ def debug_data_cli(
|
|||
|
||||
|
||||
def debug_data(
|
||||
lang: str,
|
||||
train_path: Path,
|
||||
dev_path: Path,
|
||||
config_path: Path,
|
||||
*,
|
||||
tag_map_path: Optional[Path] = None,
|
||||
base_model: Optional[str] = None,
|
||||
pipeline: List[str] = ["tagger", "parser", "ner"],
|
||||
config_overrides: Dict[str, Any] = {},
|
||||
ignore_warnings: bool = False,
|
||||
verbose: bool = False,
|
||||
no_format: bool = True,
|
||||
|
@ -75,25 +114,27 @@ def debug_data(
|
|||
msg.fail("Training data not found", train_path, exits=1)
|
||||
if not dev_path.exists():
|
||||
msg.fail("Development data not found", dev_path, exits=1)
|
||||
|
||||
if not config_path.exists():
|
||||
msg.fail("Config file not found", config_path, exists=1)
|
||||
with show_validation_error():
|
||||
config = util.load_config(
|
||||
config_path,
|
||||
create_objects=False,
|
||||
schema=ConfigSchema,
|
||||
overrides=config_overrides,
|
||||
)
|
||||
nlp = util.load_model_from_config(config["nlp"])
|
||||
lang = config["nlp"]["lang"]
|
||||
base_model = config["nlp"]["base_model"]
|
||||
pipeline = list(config["nlp"]["pipeline"].keys())
|
||||
tag_map_path = util.ensure_path(config["training"]["tag_map"])
|
||||
tag_map = {}
|
||||
if tag_map_path is not None:
|
||||
tag_map = srsly.read_json(tag_map_path)
|
||||
|
||||
# Initialize the model and pipeline
|
||||
if base_model:
|
||||
nlp = load_model(base_model)
|
||||
else:
|
||||
lang_cls = get_lang_class(lang)
|
||||
nlp = lang_cls()
|
||||
# Update tag map with provided mapping
|
||||
nlp.vocab.morphology.tag_map.update(tag_map)
|
||||
|
||||
msg.divider("Data format validation")
|
||||
|
||||
# TODO: Validate data format using the JSON schema
|
||||
# TODO: update once the new format is ready
|
||||
# TODO: move validation to GoldCorpus in order to be able to load from dir
|
||||
msg.divider("Data file validation")
|
||||
|
||||
# Create the gold corpus to be able to better analyze data
|
||||
loading_train_error_message = ""
|
||||
|
@ -380,7 +421,7 @@ def debug_data(
|
|||
if gold_dev_data["n_nonproj"] > 0:
|
||||
n_nonproj = gold_dev_data["n_nonproj"]
|
||||
msg.info(f"Found {n_nonproj} nonprojective dev sentence(s)")
|
||||
msg.info(f"{labels_train_unpreprocessed} label(s) in train data")
|
||||
msg.info(f"{len(labels_train_unpreprocessed)} label(s) in train data")
|
||||
msg.info(f"{len(labels_train)} label(s) in projectivized train data")
|
||||
labels_with_counts = _format_labels(
|
||||
gold_train_unpreprocessed_data["deps"].most_common(), counts=True
|
||||
|
|
|
@ -1,14 +1,13 @@
|
|||
from typing import List
|
||||
from pathlib import Path
|
||||
from wasabi import msg
|
||||
|
||||
from ._app import app, Arg, Opt
|
||||
from .. import util
|
||||
from thinc.api import require_gpu, fix_random_seed, set_dropout_rate, Adam
|
||||
|
||||
from ._util import Arg, Opt, debug_cli
|
||||
from .. import util
|
||||
from ..lang.en import English
|
||||
|
||||
|
||||
@app.command("debug-model")
|
||||
@debug_cli.command("model")
|
||||
def debug_model_cli(
|
||||
# fmt: off
|
||||
config_path: Path = Arg(..., help="Path to config file", exists=True),
|
||||
|
@ -26,7 +25,8 @@ def debug_model_cli(
|
|||
# fmt: on
|
||||
):
|
||||
"""
|
||||
Analyze a Thinc ML model - internal structure and activations during training
|
||||
Analyze a Thinc model implementation. Includes checks for internal structure
|
||||
and activations during training.
|
||||
"""
|
||||
print_settings = {
|
||||
"dimensions": dimensions,
|
||||
|
@ -50,16 +50,11 @@ def debug_model_cli(
|
|||
msg.info(f"Using CPU")
|
||||
|
||||
debug_model(
|
||||
config_path,
|
||||
print_settings=print_settings,
|
||||
config_path, print_settings=print_settings,
|
||||
)
|
||||
|
||||
|
||||
def debug_model(
|
||||
config_path: Path,
|
||||
*,
|
||||
print_settings=None
|
||||
):
|
||||
def debug_model(config_path: Path, *, print_settings=None):
|
||||
if print_settings is None:
|
||||
print_settings = {}
|
||||
|
||||
|
@ -83,7 +78,7 @@ def debug_model(
|
|||
for e in range(3):
|
||||
Y, get_dX = model.begin_update(_get_docs())
|
||||
dY = get_gradient(model, Y)
|
||||
_ = get_dX(dY)
|
||||
get_dX(dY)
|
||||
model.finish_update(optimizer)
|
||||
if print_settings.get("print_after_training"):
|
||||
msg.info(f"After training:")
|
||||
|
@ -115,7 +110,12 @@ def _get_docs():
|
|||
|
||||
|
||||
def _get_output(xp):
|
||||
return xp.asarray([xp.asarray([i+10, i+20, i+30], dtype="float32") for i, _ in enumerate(_get_docs())])
|
||||
return xp.asarray(
|
||||
[
|
||||
xp.asarray([i + 10, i + 20, i + 30], dtype="float32")
|
||||
for i, _ in enumerate(_get_docs())
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _print_model(model, print_settings):
|
||||
|
@ -161,7 +161,7 @@ def _print_matrix(value):
|
|||
return value
|
||||
result = str(value.shape) + " - sample: "
|
||||
sample_matrix = value
|
||||
for d in range(value.ndim-1):
|
||||
for d in range(value.ndim - 1):
|
||||
sample_matrix = sample_matrix[0]
|
||||
sample_matrix = sample_matrix[0:5]
|
||||
result = result + str(sample_matrix)
|
||||
|
|
|
@ -4,7 +4,7 @@ import sys
|
|||
from wasabi import msg
|
||||
import typer
|
||||
|
||||
from ._app import app, Arg, Opt
|
||||
from ._util import app, Arg, Opt
|
||||
from .. import about
|
||||
from ..util import is_package, get_base_version, run_command
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from thinc.api import require_gpu, fix_random_seed
|
|||
|
||||
from ..gold import Corpus
|
||||
from ..tokens import Doc
|
||||
from ._app import app, Arg, Opt
|
||||
from ._util import app, Arg, Opt
|
||||
from ..scorer import Scorer
|
||||
from .. import util
|
||||
from .. import displacy
|
||||
|
|
|
@ -4,7 +4,7 @@ from pathlib import Path
|
|||
from wasabi import Printer
|
||||
import srsly
|
||||
|
||||
from ._app import app, Arg, Opt
|
||||
from ._util import app, Arg, Opt
|
||||
from .. import util
|
||||
from .. import about
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ import srsly
|
|||
import warnings
|
||||
from wasabi import Printer
|
||||
|
||||
from ._app import app, Arg, Opt
|
||||
from ._util import app, Arg, Opt
|
||||
from ..vectors import Vectors
|
||||
from ..errors import Errors, Warnings
|
||||
from ..language import Language
|
||||
|
@ -46,9 +46,8 @@ def init_model_cli(
|
|||
# fmt: on
|
||||
):
|
||||
"""
|
||||
Create a new model from raw data, like word frequencies, Brown clusters
|
||||
and word vectors. If vectors are provided in Word2Vec format, they can
|
||||
be either a .txt or zipped as a .zip or .tar.gz.
|
||||
Create a new model from raw data. If vectors are provided in Word2Vec format,
|
||||
they can be either a .txt or zipped as a .zip or .tar.gz.
|
||||
"""
|
||||
init_model(
|
||||
lang,
|
||||
|
|
|
@ -5,7 +5,7 @@ from wasabi import Printer, get_raw_input
|
|||
import srsly
|
||||
import sys
|
||||
|
||||
from ._app import app, Arg, Opt
|
||||
from ._util import app, Arg, Opt
|
||||
from ..schemas import validate, ModelMetaSchema
|
||||
from .. import util
|
||||
from .. import about
|
||||
|
@ -23,11 +23,13 @@ def package_cli(
|
|||
# fmt: on
|
||||
):
|
||||
"""
|
||||
Generate Python package for model data, including meta and required
|
||||
installation files. A new directory will be created in the specified
|
||||
output directory, and model data will be copied over. If --create-meta is
|
||||
set and a meta.json already exists in the output directory, the existing
|
||||
values will be used as the defaults in the command-line prompt.
|
||||
Generate an installable Python package for a model. Includes model data,
|
||||
meta and required installation files. A new directory will be created in the
|
||||
specified output directory, and model data will be copied over. If
|
||||
--create-meta is set and a meta.json already exists in the output directory,
|
||||
the existing values will be used as the defaults in the command-line prompt.
|
||||
After packaging, "python setup.py sdist" is run in the package directory,
|
||||
which will create a .tar.gz archive that can be installed via "pip install".
|
||||
"""
|
||||
package(
|
||||
input_dir,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional
|
||||
from typing import Optional, Dict, Any
|
||||
import random
|
||||
import numpy
|
||||
import time
|
||||
|
@ -11,8 +11,11 @@ from thinc.api import CosineDistance, L2Distance
|
|||
from wasabi import msg
|
||||
import srsly
|
||||
from functools import partial
|
||||
import typer
|
||||
|
||||
from ._app import app, Arg, Opt
|
||||
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
|
||||
from ._util import import_code
|
||||
from ..schemas import ConfigSchema
|
||||
from ..errors import Errors
|
||||
from ..ml.models.multi_task import build_cloze_multi_task_model
|
||||
from ..ml.models.multi_task import build_cloze_characters_multi_task_model
|
||||
|
@ -21,13 +24,17 @@ from ..attrs import ID, HEAD
|
|||
from .. import util
|
||||
|
||||
|
||||
@app.command("pretrain")
|
||||
@app.command(
|
||||
"pretrain",
|
||||
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
||||
)
|
||||
def pretrain_cli(
|
||||
# fmt: off
|
||||
ctx: typer.Context, # This is only used to read additional arguments
|
||||
texts_loc: Path = Arg(..., help="Path to JSONL file with raw texts to learn from, with text provided as the key 'text' or tokens as the key 'tokens'", exists=True),
|
||||
output_dir: Path = Arg(..., help="Directory to write models to on each epoch"),
|
||||
config_path: Path = Arg(..., help="Path to config file", exists=True, dir_okay=False),
|
||||
use_gpu: int = Opt(-1, "--use-gpu", "-g", help="Use GPU"),
|
||||
code_path: Optional[Path] = Opt(None, "--code-path", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
||||
resume_path: Optional[Path] = Opt(None, "--resume-path", "-r", help="Path to pretrained weights from which to resume pretraining"),
|
||||
epoch_resume: Optional[int] = Opt(None, "--epoch-resume", "-er", help="The epoch to resume counting from when using '--resume_path'. Prevents unintended overwriting of existing weight files."),
|
||||
# fmt: on
|
||||
|
@ -51,11 +58,13 @@ def pretrain_cli(
|
|||
all settings are the same between pretraining and training. Ideally,
|
||||
this is done by using the same config file for both commands.
|
||||
"""
|
||||
overrides = parse_config_overrides(ctx.args)
|
||||
import_code(code_path)
|
||||
pretrain(
|
||||
texts_loc,
|
||||
output_dir,
|
||||
config_path,
|
||||
use_gpu=use_gpu,
|
||||
config_overrides=overrides,
|
||||
resume_path=resume_path,
|
||||
epoch_resume=epoch_resume,
|
||||
)
|
||||
|
@ -65,24 +74,34 @@ def pretrain(
|
|||
texts_loc: Path,
|
||||
output_dir: Path,
|
||||
config_path: Path,
|
||||
use_gpu: int = -1,
|
||||
config_overrides: Dict[str, Any] = {},
|
||||
resume_path: Optional[Path] = None,
|
||||
epoch_resume: Optional[int] = None,
|
||||
):
|
||||
verify_cli_args(**locals())
|
||||
verify_cli_args(texts_loc, output_dir, config_path, resume_path, epoch_resume)
|
||||
msg.info(f"Loading config from: {config_path}")
|
||||
with show_validation_error():
|
||||
config = util.load_config(
|
||||
config_path,
|
||||
create_objects=False,
|
||||
validate=True,
|
||||
schema=ConfigSchema,
|
||||
overrides=config_overrides,
|
||||
)
|
||||
if not output_dir.exists():
|
||||
output_dir.mkdir()
|
||||
msg.good(f"Created output directory: {output_dir}")
|
||||
|
||||
use_gpu = config["training"]["use_gpu"]
|
||||
if use_gpu >= 0:
|
||||
msg.info("Using GPU")
|
||||
require_gpu(use_gpu)
|
||||
else:
|
||||
msg.info("Using CPU")
|
||||
|
||||
msg.info(f"Loading config from: {config_path}")
|
||||
config = util.load_config(config_path, create_objects=False)
|
||||
fix_random_seed(config["pretraining"]["seed"])
|
||||
seed = config["pretraining"]["seed"]
|
||||
if seed is not None:
|
||||
fix_random_seed(seed)
|
||||
if use_gpu >= 0 and config["pretraining"]["use_pytorch_for_gpu_memory"]:
|
||||
use_pytorch_for_gpu_memory()
|
||||
|
||||
|
@ -360,9 +379,7 @@ def _smart_round(figure, width=10, max_decimal=4):
|
|||
return format_str % figure
|
||||
|
||||
|
||||
def verify_cli_args(
|
||||
texts_loc, output_dir, config_path, use_gpu, resume_path, epoch_resume
|
||||
):
|
||||
def verify_cli_args(texts_loc, output_dir, config_path, resume_path, epoch_resume):
|
||||
if not config_path or not config_path.exists():
|
||||
msg.fail("Config file not found", config_path, exits=1)
|
||||
if output_dir.exists() and [p for p in output_dir.iterdir()]:
|
||||
|
@ -401,10 +418,3 @@ def verify_cli_args(
|
|||
f"The argument --epoch-resume has to be greater or equal to 0. {epoch_resume} is invalid",
|
||||
exits=True,
|
||||
)
|
||||
config = util.load_config(config_path, create_objects=False)
|
||||
if config["pretraining"]["objective"]["type"] == "vectors":
|
||||
if not config["nlp"]["vectors"]:
|
||||
msg.fail(
|
||||
"Must specify nlp.vectors if pretraining.objective.type is vectors",
|
||||
exits=True,
|
||||
)
|
||||
|
|
|
@ -7,15 +7,18 @@ import pstats
|
|||
import sys
|
||||
import itertools
|
||||
from wasabi import msg, Printer
|
||||
import typer
|
||||
|
||||
from ._app import app, Arg, Opt
|
||||
from ._util import app, debug_cli, Arg, Opt, NAME
|
||||
from ..language import Language
|
||||
from ..util import load_model
|
||||
|
||||
|
||||
@app.command("profile")
|
||||
@debug_cli.command("profile")
|
||||
@app.command("profile", hidden=True)
|
||||
def profile_cli(
|
||||
# fmt: off
|
||||
ctx: typer.Context, # This is only used to read current calling context
|
||||
model: str = Arg(..., help="Model to load"),
|
||||
inputs: Optional[Path] = Arg(None, help="Location of input file. '-' for stdin.", exists=True, allow_dash=True),
|
||||
n_texts: int = Opt(10000, "--n-texts", "-n", help="Maximum number of texts to use if available"),
|
||||
|
@ -27,6 +30,12 @@ def profile_cli(
|
|||
It can either be provided as a JSONL file, or be read from sys.sytdin.
|
||||
If no input file is specified, the IMDB dataset is loaded via Thinc.
|
||||
"""
|
||||
if ctx.parent.command.name == NAME: # called as top-level command
|
||||
msg.warn(
|
||||
"The profile command is now available via the 'debug profile' "
|
||||
"subcommand. You can run python -m spacy debug --help for an "
|
||||
"overview of the other available debugging commands."
|
||||
)
|
||||
profile(model, inputs=inputs, n_texts=n_texts)
|
||||
|
||||
|
||||
|
|
|
@ -7,8 +7,7 @@ import re
|
|||
import shutil
|
||||
|
||||
from ...util import ensure_path, working_dir
|
||||
from .._app import project_cli, Arg
|
||||
from .util import PROJECT_FILE, load_project_config, get_checksum
|
||||
from .._util import project_cli, Arg, PROJECT_FILE, load_project_config, get_checksum
|
||||
|
||||
|
||||
# TODO: find a solution for caches
|
||||
|
|
|
@ -7,8 +7,7 @@ import re
|
|||
|
||||
from ... import about
|
||||
from ...util import ensure_path, run_command, make_tempdir
|
||||
from .._app import project_cli, Arg, Opt, COMMAND
|
||||
from .util import PROJECT_FILE
|
||||
from .._util import project_cli, Arg, Opt, COMMAND, PROJECT_FILE
|
||||
|
||||
|
||||
@project_cli.command("clone")
|
||||
|
|
|
@ -5,8 +5,8 @@ import subprocess
|
|||
from pathlib import Path
|
||||
from wasabi import msg
|
||||
|
||||
from .util import PROJECT_FILE, load_project_config, get_hash
|
||||
from .._app import project_cli, Arg, Opt, NAME, COMMAND
|
||||
from .._util import PROJECT_FILE, load_project_config, get_hash, project_cli
|
||||
from .._util import Arg, Opt, NAME, COMMAND
|
||||
from ...util import working_dir, split_command, join_command, run_command
|
||||
|
||||
|
||||
|
|
|
@ -5,9 +5,8 @@ import sys
|
|||
import srsly
|
||||
|
||||
from ...util import working_dir, run_command, split_command, is_cwd, join_command
|
||||
from .._app import project_cli, Arg, Opt, COMMAND
|
||||
from .util import PROJECT_FILE, PROJECT_LOCK, load_project_config, get_hash
|
||||
from .util import get_checksum
|
||||
from .._util import PROJECT_FILE, PROJECT_LOCK, load_project_config, get_hash
|
||||
from .._util import get_checksum, project_cli, Arg, Opt, COMMAND
|
||||
|
||||
|
||||
@project_cli.command("run")
|
||||
|
|
|
@ -1,93 +0,0 @@
|
|||
from typing import Dict, Any, Union
|
||||
from pathlib import Path
|
||||
from wasabi import msg
|
||||
import srsly
|
||||
import hashlib
|
||||
|
||||
from ...schemas import ProjectConfigSchema, validate
|
||||
|
||||
|
||||
PROJECT_FILE = "project.yml"
|
||||
PROJECT_LOCK = "project.lock"
|
||||
|
||||
|
||||
def load_project_config(path: Path) -> Dict[str, Any]:
|
||||
"""Load the project.yml file from a directory and validate it. Also make
|
||||
sure that all directories defined in the config exist.
|
||||
|
||||
path (Path): The path to the project directory.
|
||||
RETURNS (Dict[str, Any]): The loaded project.yml.
|
||||
"""
|
||||
config_path = path / PROJECT_FILE
|
||||
if not config_path.exists():
|
||||
msg.fail(f"Can't find {PROJECT_FILE}", config_path, exits=1)
|
||||
invalid_err = f"Invalid {PROJECT_FILE}. Double-check that the YAML is correct."
|
||||
try:
|
||||
config = srsly.read_yaml(config_path)
|
||||
except ValueError as e:
|
||||
msg.fail(invalid_err, e, exits=1)
|
||||
errors = validate(ProjectConfigSchema, config)
|
||||
if errors:
|
||||
msg.fail(invalid_err, "\n".join(errors), exits=1)
|
||||
validate_project_commands(config)
|
||||
# Make sure directories defined in config exist
|
||||
for subdir in config.get("directories", []):
|
||||
dir_path = path / subdir
|
||||
if not dir_path.exists():
|
||||
dir_path.mkdir(parents=True)
|
||||
return config
|
||||
|
||||
|
||||
def validate_project_commands(config: Dict[str, Any]) -> None:
|
||||
"""Check that project commands and workflows are valid, don't contain
|
||||
duplicates, don't clash and only refer to commands that exist.
|
||||
|
||||
config (Dict[str, Any]): The loaded config.
|
||||
"""
|
||||
command_names = [cmd["name"] for cmd in config.get("commands", [])]
|
||||
workflows = config.get("workflows", {})
|
||||
duplicates = set([cmd for cmd in command_names if command_names.count(cmd) > 1])
|
||||
if duplicates:
|
||||
err = f"Duplicate commands defined in {PROJECT_FILE}: {', '.join(duplicates)}"
|
||||
msg.fail(err, exits=1)
|
||||
for workflow_name, workflow_steps in workflows.items():
|
||||
if workflow_name in command_names:
|
||||
err = f"Can't use workflow name '{workflow_name}': name already exists as a command"
|
||||
msg.fail(err, exits=1)
|
||||
for step in workflow_steps:
|
||||
if step not in command_names:
|
||||
msg.fail(
|
||||
f"Unknown command specified in workflow '{workflow_name}': {step}",
|
||||
f"Workflows can only refer to commands defined in the 'commands' "
|
||||
f"section of the {PROJECT_FILE}.",
|
||||
exits=1,
|
||||
)
|
||||
|
||||
|
||||
def get_hash(data) -> str:
|
||||
"""Get the hash for a JSON-serializable object.
|
||||
|
||||
data: The data to hash.
|
||||
RETURNS (str): The hash.
|
||||
"""
|
||||
data_str = srsly.json_dumps(data, sort_keys=True).encode("utf8")
|
||||
return hashlib.md5(data_str).hexdigest()
|
||||
|
||||
|
||||
def get_checksum(path: Union[Path, str]) -> str:
|
||||
"""Get the checksum for a file or directory given its file path. If a
|
||||
directory path is provided, this uses all files in that directory.
|
||||
|
||||
path (Union[Path, str]): The file or directory path.
|
||||
RETURNS (str): The checksum.
|
||||
"""
|
||||
path = Path(path)
|
||||
if path.is_file():
|
||||
return hashlib.md5(Path(path).read_bytes()).hexdigest()
|
||||
if path.is_dir():
|
||||
# TODO: this is currently pretty slow
|
||||
dir_checksum = hashlib.md5()
|
||||
for sub_file in sorted(fp for fp in path.rglob("*") if fp.is_file()):
|
||||
dir_checksum.update(sub_file.read_bytes())
|
||||
return dir_checksum.hexdigest()
|
||||
raise ValueError(f"Can't get checksum for {path}: not a file or directory")
|
|
@ -1,172 +1,68 @@
|
|||
from typing import Optional, Dict, List, Union, Sequence
|
||||
from typing import Optional, Dict, Any
|
||||
from timeit import default_timer as timer
|
||||
import srsly
|
||||
import tqdm
|
||||
from pydantic import BaseModel, FilePath
|
||||
from pathlib import Path
|
||||
from wasabi import msg
|
||||
import thinc
|
||||
import thinc.schedules
|
||||
from thinc.api import Model, use_pytorch_for_gpu_memory, require_gpu, fix_random_seed
|
||||
from thinc.api import use_pytorch_for_gpu_memory, require_gpu, fix_random_seed
|
||||
import random
|
||||
import typer
|
||||
|
||||
from ._app import app, Arg, Opt
|
||||
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
|
||||
from ._util import import_code
|
||||
from ..gold import Corpus, Example
|
||||
from ..lookups import Lookups
|
||||
from .. import util
|
||||
from ..errors import Errors
|
||||
from ..schemas import ConfigSchema
|
||||
|
||||
|
||||
# Don't remove - required to load the built-in architectures
|
||||
from ..ml import models # noqa: F401
|
||||
|
||||
# from ..schemas import ConfigSchema # TODO: include?
|
||||
|
||||
|
||||
registry = util.registry
|
||||
|
||||
CONFIG_STR = """
|
||||
[training]
|
||||
patience = 10
|
||||
eval_frequency = 10
|
||||
dropout = 0.2
|
||||
init_tok2vec = null
|
||||
max_epochs = 100
|
||||
orth_variant_level = 0.0
|
||||
gold_preproc = false
|
||||
max_length = 0
|
||||
use_gpu = 0
|
||||
scores = ["ents_p", "ents_r", "ents_f"]
|
||||
score_weights = {"ents_f": 1.0}
|
||||
limit = 0
|
||||
|
||||
[training.batch_size]
|
||||
@schedules = "compounding.v1"
|
||||
start = 100
|
||||
stop = 1000
|
||||
compound = 1.001
|
||||
|
||||
[optimizer]
|
||||
@optimizers = "Adam.v1"
|
||||
learn_rate = 0.001
|
||||
beta1 = 0.9
|
||||
beta2 = 0.999
|
||||
|
||||
[nlp]
|
||||
lang = "en"
|
||||
vectors = null
|
||||
|
||||
[nlp.pipeline.tok2vec]
|
||||
factory = "tok2vec"
|
||||
|
||||
[nlp.pipeline.ner]
|
||||
factory = "ner"
|
||||
|
||||
[nlp.pipeline.ner.model]
|
||||
@architectures = "spacy.TransitionBasedParser.v1"
|
||||
nr_feature_tokens = 3
|
||||
hidden_width = 64
|
||||
maxout_pieces = 3
|
||||
|
||||
[nlp.pipeline.ner.model.tok2vec]
|
||||
@architectures = "spacy.Tok2VecTensors.v1"
|
||||
width = ${nlp.pipeline.tok2vec.model:width}
|
||||
|
||||
[nlp.pipeline.tok2vec.model]
|
||||
@architectures = "spacy.HashEmbedCNN.v1"
|
||||
pretrained_vectors = ${nlp:vectors}
|
||||
width = 128
|
||||
depth = 4
|
||||
window_size = 1
|
||||
embed_size = 10000
|
||||
maxout_pieces = 3
|
||||
subword_features = true
|
||||
"""
|
||||
|
||||
|
||||
class PipelineComponent(BaseModel):
|
||||
factory: str
|
||||
model: Model
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class ConfigSchema(BaseModel):
|
||||
optimizer: Optional["Optimizer"]
|
||||
|
||||
class training(BaseModel):
|
||||
patience: int = 10
|
||||
eval_frequency: int = 100
|
||||
dropout: float = 0.2
|
||||
init_tok2vec: Optional[FilePath] = None
|
||||
max_epochs: int = 100
|
||||
orth_variant_level: float = 0.0
|
||||
gold_preproc: bool = False
|
||||
max_length: int = 0
|
||||
use_gpu: int = 0
|
||||
scores: List[str] = ["ents_p", "ents_r", "ents_f"]
|
||||
score_weights: Dict[str, Union[int, float]] = {"ents_f": 1.0}
|
||||
limit: int = 0
|
||||
batch_size: Union[Sequence[int], int]
|
||||
|
||||
class nlp(BaseModel):
|
||||
lang: str
|
||||
vectors: Optional[str]
|
||||
pipeline: Optional[Dict[str, PipelineComponent]]
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
@app.command("train")
|
||||
@app.command(
|
||||
"train", context_settings={"allow_extra_args": True, "ignore_unknown_options": True}
|
||||
)
|
||||
def train_cli(
|
||||
# fmt: off
|
||||
ctx: typer.Context, # This is only used to read additional arguments
|
||||
train_path: Path = Arg(..., help="Location of training data", exists=True),
|
||||
dev_path: Path = Arg(..., help="Location of development data", exists=True),
|
||||
config_path: Path = Arg(..., help="Path to config file", exists=True),
|
||||
output_path: Optional[Path] = Opt(None, "--output", "--output-path", "-o", help="Output directory to store model in"),
|
||||
code_path: Optional[Path] = Opt(None, "--code-path", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
||||
init_tok2vec: Optional[Path] = Opt(None, "--init-tok2vec", "-t2v", help="Path to pretrained weights for the tok2vec components. See 'spacy pretrain'. Experimental."),
|
||||
raw_text: Optional[Path] = Opt(None, "--raw-text", "-rt", help="Path to jsonl file with unlabelled text documents."),
|
||||
verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"),
|
||||
use_gpu: int = Opt(-1, "--use-gpu", "-g", help="Use GPU"),
|
||||
tag_map_path: Optional[Path] = Opt(None, "--tag-map-path", "-tm", help="Location of JSON-formatted tag map"),
|
||||
omit_extra_lookups: bool = Opt(False, "--omit-extra-lookups", "-OEL", help="Don't include extra lookups in model"),
|
||||
# fmt: on
|
||||
):
|
||||
"""
|
||||
Train or update a spaCy model. Requires data to be formatted in spaCy's
|
||||
JSON format. To convert data from other formats, use the `spacy convert`
|
||||
command.
|
||||
Train or update a spaCy model. Requires data in spaCy's binary format. To
|
||||
convert data from other formats, use the `spacy convert` command. The
|
||||
config file includes all settings and hyperparameters used during traing.
|
||||
To override settings in the config, e.g. settings that point to local
|
||||
paths or that you want to experiment with, you can override them as
|
||||
command line options. For instance, --training.batch_size 128 overrides
|
||||
the value of "batch_size" in the block "[training]". The --code argument
|
||||
lets you pass in a Python file that's imported before training. It can be
|
||||
used to register custom functions and architectures that can then be
|
||||
referenced in the config.
|
||||
"""
|
||||
util.set_env_log(verbose)
|
||||
verify_cli_args(**locals())
|
||||
|
||||
if raw_text is not None:
|
||||
raw_text = list(srsly.read_jsonl(raw_text))
|
||||
tag_map = {}
|
||||
if tag_map_path is not None:
|
||||
tag_map = srsly.read_json(tag_map_path)
|
||||
|
||||
weights_data = None
|
||||
if init_tok2vec is not None:
|
||||
with init_tok2vec.open("rb") as file_:
|
||||
weights_data = file_.read()
|
||||
|
||||
if use_gpu >= 0:
|
||||
msg.info(f"Using GPU: {use_gpu}")
|
||||
require_gpu(use_gpu)
|
||||
else:
|
||||
msg.info("Using CPU")
|
||||
|
||||
verify_cli_args(
|
||||
train_path=train_path, dev_path=dev_path, config_path=config_path,
|
||||
)
|
||||
overrides = parse_config_overrides(ctx.args)
|
||||
import_code(code_path)
|
||||
train(
|
||||
config_path,
|
||||
{"train": train_path, "dev": dev_path},
|
||||
output_path=output_path,
|
||||
raw_text=raw_text,
|
||||
tag_map=tag_map,
|
||||
weights_data=weights_data,
|
||||
omit_extra_lookups=omit_extra_lookups,
|
||||
config_overrides=overrides,
|
||||
)
|
||||
|
||||
|
||||
|
@ -175,20 +71,36 @@ def train(
|
|||
data_paths: Dict[str, Path],
|
||||
raw_text: Optional[Path] = None,
|
||||
output_path: Optional[Path] = None,
|
||||
tag_map: Optional[Path] = None,
|
||||
weights_data: Optional[bytes] = None,
|
||||
omit_extra_lookups: bool = False,
|
||||
config_overrides: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
msg.info(f"Loading config from: {config_path}")
|
||||
# Read the config first without creating objects, to get to the original nlp_config
|
||||
config = util.load_config(config_path, create_objects=False)
|
||||
if config["training"].get("seed"):
|
||||
with show_validation_error():
|
||||
config = util.load_config(
|
||||
config_path,
|
||||
create_objects=False,
|
||||
schema=ConfigSchema,
|
||||
overrides=config_overrides,
|
||||
)
|
||||
use_gpu = config["training"]["use_gpu"]
|
||||
if use_gpu >= 0:
|
||||
msg.info(f"Using GPU: {use_gpu}")
|
||||
require_gpu(use_gpu)
|
||||
else:
|
||||
msg.info("Using CPU")
|
||||
raw_text, tag_map, weights_data = load_from_paths(config)
|
||||
if config["training"]["seed"] is not None:
|
||||
fix_random_seed(config["training"]["seed"])
|
||||
if config["training"].get("use_pytorch_for_gpu_memory"):
|
||||
# It feels kind of weird to not have a default for this.
|
||||
use_pytorch_for_gpu_memory()
|
||||
nlp_config = config["nlp"]
|
||||
config = util.load_config(config_path, create_objects=True)
|
||||
config = util.load_config(
|
||||
config_path,
|
||||
create_objects=True,
|
||||
schema=ConfigSchema,
|
||||
overrides=config_overrides,
|
||||
)
|
||||
training = config["training"]
|
||||
msg.info("Creating nlp from config")
|
||||
nlp = util.load_model_from_config(nlp_config)
|
||||
|
@ -217,7 +129,7 @@ def train(
|
|||
|
||||
# Create empty extra lexeme tables so the data from spacy-lookups-data
|
||||
# isn't loaded if these features are accessed
|
||||
if omit_extra_lookups:
|
||||
if config["training"]["omit_extra_lookups"]:
|
||||
nlp.vocab.lookups_extra = Lookups()
|
||||
nlp.vocab.lookups_extra.add_table("lexeme_cluster")
|
||||
nlp.vocab.lookups_extra.add_table("lexeme_prob")
|
||||
|
@ -557,18 +469,34 @@ def update_meta(training, nlp, info):
|
|||
nlp.meta["performance"][f"{pipe_name}_loss"] = info["losses"][pipe_name]
|
||||
|
||||
|
||||
def load_from_paths(config):
|
||||
# TODO: separate checks from loading
|
||||
raw_text = util.ensure_path(config["training"]["raw_text"])
|
||||
if raw_text is not None:
|
||||
if not raw_text.exists():
|
||||
msg.fail("Can't find raw text", raw_text, exits=1)
|
||||
raw_text = list(srsly.read_jsonl(config["training"]["raw_text"]))
|
||||
tag_map = {}
|
||||
tag_map_path = util.ensure_path(config["training"]["tag_map"])
|
||||
if tag_map_path is not None:
|
||||
if not tag_map_path.exists():
|
||||
msg.fail("Can't find tag map path", tag_map_path, exits=1)
|
||||
tag_map = srsly.read_json(config["training"]["tag_map"])
|
||||
weights_data = None
|
||||
init_tok2vec = util.ensure_path(config["training"]["init_tok2vec"])
|
||||
if init_tok2vec is not None:
|
||||
if not init_tok2vec.exists():
|
||||
msg.fail("Can't find pretrained tok2vec", init_tok2vec, exits=1)
|
||||
with init_tok2vec.open("rb") as file_:
|
||||
weights_data = file_.read()
|
||||
return raw_text, tag_map, weights_data
|
||||
|
||||
|
||||
def verify_cli_args(
|
||||
train_path,
|
||||
dev_path,
|
||||
config_path,
|
||||
output_path=None,
|
||||
code_path=None,
|
||||
init_tok2vec=None,
|
||||
raw_text=None,
|
||||
verbose=False,
|
||||
use_gpu=-1,
|
||||
tag_map_path=None,
|
||||
omit_extra_lookups=False,
|
||||
train_path: Path,
|
||||
dev_path: Path,
|
||||
config_path: Path,
|
||||
output_path: Optional[Path] = None,
|
||||
):
|
||||
# Make sure all files and paths exists if they are needed
|
||||
if not config_path or not config_path.exists():
|
||||
|
@ -589,15 +517,6 @@ def verify_cli_args(
|
|||
"the specified output path doesn't exist, the directory will be "
|
||||
"created for you.",
|
||||
)
|
||||
if code_path is not None:
|
||||
if not code_path.exists():
|
||||
msg.fail("Path to Python code not found", code_path, exits=1)
|
||||
try:
|
||||
util.import_file("python_code", code_path)
|
||||
except Exception as e:
|
||||
msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1)
|
||||
if init_tok2vec is not None and not init_tok2vec.exists():
|
||||
msg.fail("Can't find pretrained tok2vec", init_tok2vec, exits=1)
|
||||
|
||||
|
||||
def verify_textcat_config(nlp, nlp_config):
|
||||
|
|
|
@ -4,7 +4,7 @@ import sys
|
|||
import requests
|
||||
from wasabi import msg, Printer
|
||||
|
||||
from ._app import app
|
||||
from ._util import app
|
||||
from .. import about
|
||||
from ..util import get_package_version, get_installed_models, get_base_version
|
||||
from ..util import get_package_path, get_model_meta, is_compatible_version
|
||||
|
|
149
spacy/schemas.py
149
spacy/schemas.py
|
@ -1,9 +1,10 @@
|
|||
from typing import Dict, List, Union, Optional, Sequence, Any
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field, ValidationError, validator
|
||||
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool, FilePath
|
||||
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
|
||||
from pydantic import root_validator
|
||||
from collections import defaultdict
|
||||
from thinc.api import Model
|
||||
from thinc.api import Model, Optimizer
|
||||
|
||||
from .attrs import NAMES
|
||||
|
||||
|
@ -173,41 +174,6 @@ class ModelMetaSchema(BaseModel):
|
|||
# JSON training format
|
||||
|
||||
|
||||
class PipelineComponent(BaseModel):
|
||||
factory: str
|
||||
model: Model
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class ConfigSchema(BaseModel):
|
||||
optimizer: Optional["Optimizer"]
|
||||
|
||||
class training(BaseModel):
|
||||
patience: int = 10
|
||||
eval_frequency: int = 100
|
||||
dropout: float = 0.2
|
||||
init_tok2vec: Optional[FilePath] = None
|
||||
max_epochs: int = 100
|
||||
orth_variant_level: float = 0.0
|
||||
gold_preproc: bool = False
|
||||
max_length: int = 0
|
||||
use_gpu: int = 0
|
||||
scores: List[str] = ["ents_p", "ents_r", "ents_f"]
|
||||
score_weights: Dict[str, Union[int, float]] = {"ents_f": 1.0}
|
||||
limit: int = 0
|
||||
batch_size: Union[Sequence[int], int]
|
||||
|
||||
class nlp(BaseModel):
|
||||
lang: str
|
||||
vectors: Optional[str]
|
||||
pipeline: Optional[Dict[str, PipelineComponent]]
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class TrainingSchema(BaseModel):
|
||||
# TODO: write
|
||||
|
||||
|
@ -216,6 +182,115 @@ class TrainingSchema(BaseModel):
|
|||
extra = "forbid"
|
||||
|
||||
|
||||
# Config schema
|
||||
# We're not setting any defaults here (which is too messy) and are making all
|
||||
# fields required, so we can raise validation errors for missing values. To
|
||||
# provide a default, we include a separate .cfg file with all values and
|
||||
# check that against this schema in the test suite to make sure it's always
|
||||
# up to date.
|
||||
|
||||
|
||||
class ConfigSchemaTraining(BaseModel):
|
||||
# fmt: off
|
||||
gold_preproc: StrictBool = Field(..., title="Whether to train on gold-standard sentences and tokens")
|
||||
max_length: StrictInt = Field(..., title="Maximum length of examples (longer examples are divided into sentences if possible)")
|
||||
limit: StrictInt = Field(..., title="Number of examples to use (0 for all)")
|
||||
orth_variant_level: StrictFloat = Field(..., title="Orth variants for data augmentation")
|
||||
dropout: StrictFloat = Field(..., title="Dropout rate")
|
||||
patience: StrictInt = Field(..., title="How many steps to continue without improvement in evaluation score")
|
||||
max_epochs: StrictInt = Field(..., title="Maximum number of epochs to train for")
|
||||
max_steps: StrictInt = Field(..., title="Maximum number of update steps to train for")
|
||||
eval_frequency: StrictInt = Field(..., title="How often to evaluate during training (steps)")
|
||||
seed: Optional[StrictInt] = Field(..., title="Random seed")
|
||||
accumulate_gradient: StrictInt = Field(..., title="Whether to divide the batch up into substeps")
|
||||
use_pytorch_for_gpu_memory: StrictBool = Field(..., title="Allocate memory via PyTorch")
|
||||
use_gpu: StrictInt = Field(..., title="GPU ID or -1 for CPU")
|
||||
scores: List[StrictStr] = Field(..., title="Score types to be printed in overview")
|
||||
score_weights: Dict[StrictStr, Union[StrictFloat, StrictInt]] = Field(..., title="Weights of each score type for selecting final model")
|
||||
init_tok2vec: Optional[StrictStr] = Field(..., title="Path to pretrained tok2vec weights")
|
||||
discard_oversize: StrictBool = Field(..., title="Whether to skip examples longer than batch size")
|
||||
omit_extra_lookups: StrictBool = Field(..., title="Don't include extra lookups in model")
|
||||
batch_by: StrictStr = Field(..., title="Batch examples by type")
|
||||
raw_text: Optional[StrictStr] = Field(..., title="Raw text")
|
||||
tag_map: Optional[StrictStr] = Field(..., title="Path to JSON-formatted tag map")
|
||||
batch_size: Union[Sequence[int], int] = Field(..., title="The batch size or batch size schedule")
|
||||
optimizer: Optimizer = Field(..., title="The optimizer to use")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class ConfigSchemaNlpComponent(BaseModel):
|
||||
factory: StrictStr = Field(..., title="Component factory name")
|
||||
model: Model = Field(..., title="Component model")
|
||||
# TODO: add config schema / types for components so we can fill and validate
|
||||
# component options like learn_tokens, min_action_freq etc.
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class ConfigSchemaPipeline(BaseModel):
|
||||
__root__: Dict[str, ConfigSchemaNlpComponent]
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class ConfigSchemaNlp(BaseModel):
|
||||
lang: StrictStr = Field(..., title="The base language to use")
|
||||
base_model: Optional[StrictStr] = Field(..., title="The base model to use")
|
||||
vectors: Optional[StrictStr] = Field(..., title="Path to vectors")
|
||||
pipeline: Optional[ConfigSchemaPipeline]
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class ConfigSchemaPretrain(BaseModel):
|
||||
# fmt: off
|
||||
max_epochs: StrictInt = Field(..., title="Maximum number of epochs to train for")
|
||||
min_length: StrictInt = Field(..., title="Minimum length of examples")
|
||||
max_length: StrictInt = Field(..., title="Maximum length of examples")
|
||||
dropout: StrictFloat = Field(..., title="Dropout rate")
|
||||
n_save_every: Optional[StrictInt] = Field(..., title="Saving frequency")
|
||||
batch_size: Union[Sequence[int], int] = Field(..., title="The batch size or batch size schedule")
|
||||
seed: Optional[StrictInt] = Field(..., title="Random seed")
|
||||
use_pytorch_for_gpu_memory: StrictBool = Field(..., title="Allocate memory via PyTorch")
|
||||
tok2vec_model: StrictStr = Field(..., title="tok2vec model in config, e.g. nlp.pipeline.tok2vec.model")
|
||||
optimizer: Optimizer = Field(..., title="The optimizer to use")
|
||||
# TODO: use a more detailed schema for this?
|
||||
objective: Dict[str, Any] = Field(..., title="Pretraining objective")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class ConfigSchema(BaseModel):
|
||||
training: ConfigSchemaTraining
|
||||
nlp: ConfigSchemaNlp
|
||||
pretraining: Optional[ConfigSchemaPretrain]
|
||||
|
||||
@root_validator
|
||||
def validate_config(cls, values):
|
||||
"""Perform additional validation for settings with dependencies."""
|
||||
pt = values.get("pretraining")
|
||||
if pt and pt.objective.get("type") == "vectors" and not values["nlp"].vectors:
|
||||
err = "Need nlp.vectors if pretraining.objective.type is vectors"
|
||||
raise ValueError(err)
|
||||
return values
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
# Project config Schema
|
||||
|
||||
|
||||
|
|
|
@ -3,7 +3,9 @@ import pytest
|
|||
from spacy.gold import docs_to_json, biluo_tags_from_offsets
|
||||
from spacy.gold.converters import iob2docs, conll_ner2docs, conllu2docs
|
||||
from spacy.lang.en import English
|
||||
from spacy.schemas import ProjectConfigSchema, validate
|
||||
from spacy.cli.pretrain import make_docs
|
||||
from spacy.cli._util import validate_project_commands, parse_config_overrides
|
||||
|
||||
|
||||
def test_cli_converters_conllu2json():
|
||||
|
@ -261,3 +263,55 @@ def test_pretrain_make_docs():
|
|||
docs, skip_count = make_docs(nlp, [too_long_jsonl], 1, 5)
|
||||
assert len(docs) == 0
|
||||
assert skip_count == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config",
|
||||
[
|
||||
{"commands": [{"name": "a"}, {"name": "a"}]},
|
||||
{"commands": [{"name": "a"}], "workflows": {"a": []}},
|
||||
{"commands": [{"name": "a"}], "workflows": {"b": ["c"]}},
|
||||
],
|
||||
)
|
||||
def test_project_config_validation1(config):
|
||||
with pytest.raises(SystemExit):
|
||||
validate_project_commands(config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config,n_errors",
|
||||
[
|
||||
({"commands": {"a": []}}, 1),
|
||||
({"commands": [{"help": "..."}]}, 1),
|
||||
({"commands": [{"name": "a", "extra": "b"}]}, 1),
|
||||
({"commands": [{"extra": "b"}]}, 2),
|
||||
({"commands": [{"name": "a", "deps": [123]}]}, 1),
|
||||
],
|
||||
)
|
||||
def test_project_config_validation2(config, n_errors):
|
||||
errors = validate(ProjectConfigSchema, config)
|
||||
assert len(errors) == n_errors
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"args,expected",
|
||||
[
|
||||
# fmt: off
|
||||
(["--x.foo", "10"], {"x.foo": 10}),
|
||||
(["--x.foo", "bar"], {"x.foo": "bar"}),
|
||||
(["--x.foo", "--x.bar", "baz"], {"x.foo": True, "x.bar": "baz"}),
|
||||
(["--x.foo", "10.1", "--x.bar", "--x.baz", "false"], {"x.foo": 10.1, "x.bar": True, "x.baz": False})
|
||||
# fmt: on
|
||||
],
|
||||
)
|
||||
def test_parse_config_overrides(args, expected):
|
||||
assert parse_config_overrides(args) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"args",
|
||||
[["--foo"], ["--x.foo", "bar", "--baz"], ["--x.foo", "bar", "baz"], ["x.foo"]],
|
||||
)
|
||||
def test_parse_config_overrides_invalid(args):
|
||||
with pytest.raises(SystemExit):
|
||||
parse_config_overrides(args)
|
||||
|
|
|
@ -1,31 +0,0 @@
|
|||
import pytest
|
||||
from spacy.cli.project.util import validate_project_commands
|
||||
from spacy.schemas import ProjectConfigSchema, validate
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config",
|
||||
[
|
||||
{"commands": [{"name": "a"}, {"name": "a"}]},
|
||||
{"commands": [{"name": "a"}], "workflows": {"a": []}},
|
||||
{"commands": [{"name": "a"}], "workflows": {"b": ["c"]}},
|
||||
],
|
||||
)
|
||||
def test_project_config_validation1(config):
|
||||
with pytest.raises(SystemExit):
|
||||
validate_project_commands(config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config,n_errors",
|
||||
[
|
||||
({"commands": {"a": []}}, 1),
|
||||
({"commands": [{"help": "..."}]}, 1),
|
||||
({"commands": [{"name": "a", "extra": "b"}]}, 1),
|
||||
({"commands": [{"extra": "b"}]}, 2),
|
||||
({"commands": [{"name": "a", "deps": [123]}]}, 1),
|
||||
],
|
||||
)
|
||||
def test_project_config_validation2(config, n_errors):
|
||||
errors = validate(ProjectConfigSchema, config)
|
||||
assert len(errors) == n_errors
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Union
|
||||
from typing import List, Union, Type, Dict, Any
|
||||
import os
|
||||
import importlib
|
||||
import importlib.util
|
||||
|
@ -6,6 +6,8 @@ import re
|
|||
from pathlib import Path
|
||||
import thinc
|
||||
from thinc.api import NumpyOps, get_current_ops, Adam, Config
|
||||
from thinc.config import EmptySchema
|
||||
from pydantic import BaseModel
|
||||
import functools
|
||||
import itertools
|
||||
import numpy.random
|
||||
|
@ -326,20 +328,33 @@ def get_base_version(version):
|
|||
return Version(version).base_version
|
||||
|
||||
|
||||
def load_config(path, create_objects=False):
|
||||
def load_config(
|
||||
path: Union[Path, str],
|
||||
*,
|
||||
create_objects: bool = False,
|
||||
schema: Type[BaseModel] = EmptySchema,
|
||||
overrides: Dict[str, Any] = {},
|
||||
validate: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""Load a Thinc-formatted config file, optionally filling in objects where
|
||||
the config references registry entries. See "Thinc config files" for details.
|
||||
|
||||
path (str / Path): Path to the config file
|
||||
create_objects (bool): Whether to automatically create objects when the config
|
||||
references registry entries. Defaults to False.
|
||||
|
||||
schema (BaseModel): Optional pydantic base schema to use for validation.
|
||||
overrides (Dict[str, Any]): Optional overrides to substitute in config.
|
||||
validate (bool): Whether to validate against schema.
|
||||
RETURNS (dict): The objects from the config file.
|
||||
"""
|
||||
config = thinc.config.Config().from_disk(path)
|
||||
kwargs = {"validate": validate, "schema": schema, "overrides": overrides}
|
||||
if create_objects:
|
||||
return registry.make_from_config(config, validate=True)
|
||||
return registry.make_from_config(config, **kwargs)
|
||||
else:
|
||||
# Just fill config here so we can validate and fail early
|
||||
if validate and schema:
|
||||
registry.fill_config(config, **kwargs)
|
||||
return config
|
||||
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ menu:
|
|||
- ['Info', 'info']
|
||||
- ['Validate', 'validate']
|
||||
- ['Convert', 'convert']
|
||||
- ['Debug data', 'debug-data']
|
||||
- ['Debug', 'debug']
|
||||
- ['Train', 'train']
|
||||
- ['Pretrain', 'pretrain']
|
||||
- ['Init Model', 'init-model']
|
||||
|
@ -133,30 +133,82 @@ $ python -m spacy convert [input_file] [output_dir] [--converter]
|
|||
| `ner` | NER with IOB/IOB2 tags, one token per line with columns separated by whitespace. The first column is the token and the final column is the IOB tag. Sentences are separated by blank lines and documents are separated by the line `-DOCSTART- -X- O O`. Supports CoNLL 2003 NER format. See [sample data](https://github.com/explosion/spaCy/tree/master/examples/training/ner_example_data). |
|
||||
| `iob` | NER with IOB/IOB2 tags, one sentence per line with tokens separated by whitespace and annotation separated by `|`, either `word|B-ENT` or `word|POS|B-ENT`. See [sample data](https://github.com/explosion/spaCy/tree/master/examples/training/ner_example_data). |
|
||||
|
||||
## Debug data {#debug-data new="2.2"}
|
||||
## Debug {#debug new="3"}
|
||||
|
||||
The `spacy debug` CLI includes helpful commands for debugging and profiling your
|
||||
configs, data and implementations.
|
||||
|
||||
### debug config {#debug-config}
|
||||
|
||||
Debug a [`config.cfg` file](/usage/training#config) and show validation errors.
|
||||
The command will create all objects in the tree and validate them. Note that
|
||||
some config validation errors are blocking and will prevent the rest of the
|
||||
config from being resolved. This means that you may not see all validation
|
||||
errors at once and some issues are only shown once previous errors have been
|
||||
fixed.
|
||||
|
||||
```bash
|
||||
$ python -m spacy debug config [config_path] [--code] [overrides]
|
||||
```
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```bash
|
||||
> $ python -m spacy debug config ./config.cfg
|
||||
> ```
|
||||
|
||||
<Accordion title="Example output" spaced>
|
||||
|
||||
```
|
||||
✘ Config validation error
|
||||
|
||||
training -> use_gpu field required
|
||||
training -> omit_extra_lookups field required
|
||||
training -> batch_by field required
|
||||
training -> raw_text field required
|
||||
training -> tag_map field required
|
||||
training -> evaluation_batch_size extra fields not permitted
|
||||
training -> vectors extra fields not permitted
|
||||
training -> width extra fields not permitted
|
||||
|
||||
{'gold_preproc': False, 'max_length': 3000, 'limit': 0, 'orth_variant_level': 0.0, 'dropout': 0.1, 'patience': 6000, 'max_epochs': 0, 'max_steps': 100000, 'eval_frequency': 400, 'seed': 0, 'accumulate_gradient': 4, 'width': 768, 'use_pytorch_for_gpu_memory': True, 'scores': ['speed', 'tags_acc', 'uas', 'las', 'ents_f'], 'score_weights': {'las': 0.4, 'ents_f': 0.4, 'tags_acc': 0.2}, 'init_tok2vec': None, 'vectors': None, 'discard_oversize': True, 'evaluation_batch_size': 16, 'batch_size': {'@schedules': 'compounding.v1', 'start': 800, 'stop': 800, 'compound': 1.001}, 'optimizer': {'@optimizers': 'Adam.v1', 'beta1': 0.9, 'beta2': 0.999, 'L2_is_weight_decay': True, 'L2': 0.01, 'grad_clip': 1.0, 'use_averages': False, 'eps': 1e-08, 'learn_rate': {'@schedules': 'warmup_linear.v1', 'warmup_steps': 250, 'total_steps': 20000, 'initial_rate': 5e-05}}}
|
||||
```
|
||||
|
||||
</Accordion>
|
||||
|
||||
| Argument | Type | Description |
|
||||
| -------------- | ---------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `config_path` | positional | Path to [training config](/api/data-formats#config) file containing all settings and hyperparameters. |
|
||||
| `--code`, `-c` | option | Path to Python file with additional code to be imported. Allows [registering custom functions](/usage/training#custom-models) for new architectures. |
|
||||
| `--help`, `-h` | flag | Show help message and available arguments. |
|
||||
| overrides | | Config parameters to override. Should be options starting with `--` that correspond to the config section and value to override, e.g. `--training.use_gpu 1`. |
|
||||
|
||||
### debug data {#debug-data}
|
||||
|
||||
Analyze, debug, and validate your training and development data. Get useful
|
||||
stats, and find problems like invalid entity annotations, cyclic dependencies,
|
||||
low data labels and more.
|
||||
|
||||
<Infobox title="New in v3.0" variant="warning">
|
||||
|
||||
The `debug-data` command is now available as a subcommand of `spacy debug`. It
|
||||
takes the same arguments as `train` and reads settings off the
|
||||
[`config.cfg` file](/usage/training#config).
|
||||
|
||||
</Infobox>
|
||||
|
||||
```bash
|
||||
$ python -m spacy debug-data [lang] [train_path] [dev_path] [--base-model]
|
||||
[--pipeline] [--tag-map-path] [--ignore-warnings] [--verbose] [--no-format]
|
||||
$ python -m spacy debug data [train_path] [dev_path] [config_path] [--code]
|
||||
[--ignore-warnings] [--verbose] [--no-format] [overrides]
|
||||
```
|
||||
|
||||
| Argument | Type | Description |
|
||||
| ------------------------------------------------------ | ---------- | ------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `lang` | positional | Model language. |
|
||||
| `train_path` | positional | Location of [binary training data](/usage/training#data-format). Can be a file or a directory of files. |
|
||||
| `dev_path` | positional | Location of [binary development data](/usage/training#data-format) for evaluation. Can be a file or a directory of files. |
|
||||
| `--tag-map-path`, `-tm` <Tag variant="new">2.2.4</Tag> | option | Location of JSON-formatted tag map. |
|
||||
| `--base-model`, `-b` | option | Optional name of base model to update. Can be any loadable spaCy model. |
|
||||
| `--pipeline`, `-p` | option | Comma-separated names of pipeline components to train. Defaults to `'tagger,parser,ner'`. |
|
||||
| `--ignore-warnings`, `-IW` | flag | Ignore warnings, only show stats and errors. |
|
||||
| `--verbose`, `-V` | flag | Print additional information and explanations. |
|
||||
| `--no-format`, `-NF` | flag | Don't pretty-print the results. Use this if you want to write to a file. |
|
||||
> #### Example
|
||||
>
|
||||
> ```bash
|
||||
> $ python -m spacy debug data ./train.spacy ./dev.spacy ./config.cfg
|
||||
> ```
|
||||
|
||||
<Accordion title="Example output">
|
||||
<Accordion title="Example output" spaced>
|
||||
|
||||
```
|
||||
=========================== Data format validation ===========================
|
||||
|
@ -295,6 +347,20 @@ will not be available.
|
|||
|
||||
</Accordion>
|
||||
|
||||
| Argument | Type | Description |
|
||||
| -------------------------- | ---------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `train_path` | positional | Location of [binary training data](/usage/training#data-format). Can be a file or a directory of files. |
|
||||
| `dev_path` | positional | Location of [binary development data](/usage/training#data-format) for evaluation. Can be a file or a directory of files. |
|
||||
| `config_path` | positional | Path to [training config](/api/data-formats#config) file containing all settings and hyperparameters. |
|
||||
| `--code`, `-c` | option | Path to Python file with additional code to be imported. Allows [registering custom functions](/usage/training#custom-models) for new architectures. |
|
||||
| `--ignore-warnings`, `-IW` | flag | Ignore warnings, only show stats and errors. |
|
||||
| `--verbose`, `-V` | flag | Print additional information and explanations. |
|
||||
| `--no-format`, `-NF` | flag | Don't pretty-print the results. Use this if you want to write to a file. |
|
||||
| `--help`, `-h` | flag | Show help message and available arguments. |
|
||||
| overrides | | Config parameters to override. Should be options starting with `--` that correspond to the config section and value to override, e.g. `--training.use_gpu 1`. |
|
||||
|
||||
<!-- TODO: document debug profile and debug model? -->
|
||||
|
||||
## Train {#train}
|
||||
|
||||
Train a model. Expects data in spaCy's
|
||||
|
@ -310,28 +376,28 @@ you need to manage complex multi-step training workflows, check out the new
|
|||
|
||||
<Infobox title="New in v3.0" variant="warning">
|
||||
|
||||
As of spaCy v3.0, the `train` command doesn't take a long list of command-line
|
||||
arguments anymore and instead expects a single
|
||||
[`config.cfg` file](/usage/training#config) containing all settings for the
|
||||
pipeline, training process and hyperparameters.
|
||||
The `train` command doesn't take a long list of command-line arguments anymore
|
||||
and instead expects a single [`config.cfg` file](/usage/training#config)
|
||||
containing all settings for the pipeline, training process and hyperparameters.
|
||||
|
||||
</Infobox>
|
||||
|
||||
```bash
|
||||
$ python -m spacy train [train_path] [dev_path] [config_path] [--output]
|
||||
[--code] [--verbose]
|
||||
[--code] [--verbose] [overrides]
|
||||
```
|
||||
|
||||
| Argument | Type | Description |
|
||||
| ----------------- | ---------- | ---------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `train_path` | positional | Location of training data in spaCy's [binary format](/api/data-formats#training). Can be a file or a directory of files. |
|
||||
| `dev_path` | positional | Location of development data for evaluation in spaCy's [binary format](/api/data-formats#training). Can be a file or a directory of files. |
|
||||
| `config_path` | positional | Path to [training config](/api/data-formats#config) file containing all settings and hyperparameters. |
|
||||
| `--output`, `-o` | positional | Directory to store model in. Will be created if it doesn't exist. |
|
||||
| `--code`, `-c` | option | Path to Python file with additional code to be imported. Allows [registering custom functions](/usage/training#custom-models) for new architectures. |
|
||||
| `--verbose`, `-V` | flag | Show more detailed messages during training. |
|
||||
| `--help`, `-h` | flag | Show help message and available arguments. |
|
||||
| **CREATES** | model | The final model and the best model. |
|
||||
| Argument | Type | Description |
|
||||
| ----------------- | ---------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `train_path` | positional | Location of training data in spaCy's [binary format](/api/data-formats#training). Can be a file or a directory of files. |
|
||||
| `dev_path` | positional | Location of development data for evaluation in spaCy's [binary format](/api/data-formats#training). Can be a file or a directory of files. |
|
||||
| `config_path` | positional | Path to [training config](/api/data-formats#config) file containing all settings and hyperparameters. |
|
||||
| `--output`, `-o` | positional | Directory to store model in. Will be created if it doesn't exist. |
|
||||
| `--code`, `-c` | option | Path to Python file with additional code to be imported. Allows [registering custom functions](/usage/training#custom-models) for new architectures. |
|
||||
| `--verbose`, `-V` | flag | Show more detailed messages during training. |
|
||||
| `--help`, `-h` | flag | Show help message and available arguments. |
|
||||
| overrides | | Config parameters to override. Should be options starting with `--` that correspond to the config section and value to override, e.g. `--training.use_gpu 1`. |
|
||||
| **CREATES** | model | The final model and the best model. |
|
||||
|
||||
## Pretrain {#pretrain new="2.1" tag="experimental"}
|
||||
|
||||
|
@ -342,46 +408,35 @@ an approximate language-modeling objective. Specifically, we load pretrained
|
|||
vectors, and train a component like a CNN, BiLSTM, etc to predict vectors which
|
||||
match the pretrained ones. The weights are saved to a directory after each
|
||||
epoch. You can then pass a path to one of these pretrained weights files to the
|
||||
`spacy train` command.
|
||||
`spacy train` command. This technique may be especially helpful if you have
|
||||
little labelled data.
|
||||
|
||||
This technique may be especially helpful if you have little labelled data.
|
||||
However, it's still quite experimental, so your mileage may vary. To load the
|
||||
weights back in during `spacy train`, you need to ensure all settings are the
|
||||
same between pretraining and training. The API and errors around this need some
|
||||
improvement.
|
||||
<Infobox title="Changed in v3.0" variant="warning">
|
||||
|
||||
As of spaCy v3.0, the `pretrain` command takes the same
|
||||
[config file](/usage/training#config) as the `train` command. This ensures that
|
||||
settings are consistent between pretraining and training. Settings for
|
||||
pretraining can be defined in the `[pretraining]` block of the config file. See
|
||||
the [data format](/api/data-formats#config) for details.
|
||||
|
||||
</Infobox>
|
||||
|
||||
```bash
|
||||
$ python -m spacy pretrain [texts_loc] [vectors_model] [output_dir]
|
||||
[--width] [--conv-depth] [--cnn-window] [--cnn-pieces] [--use-chars] [--sa-depth]
|
||||
[--embed-rows] [--loss_func] [--dropout] [--batch-size] [--max-length]
|
||||
[--min-length] [--seed] [--n-iter] [--use-vectors] [--n-save-every]
|
||||
[--init-tok2vec] [--epoch-start]
|
||||
$ python -m spacy pretrain [texts_loc] [output_dir] [config_path]
|
||||
[--code] [--resume-path] [--epoch-resume] [overrides]
|
||||
```
|
||||
|
||||
| Argument | Type | Description |
|
||||
| ----------------------------------------------------- | ---------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `texts_loc` | positional | Path to JSONL file with raw texts to learn from, with text provided as the key `"text"` or tokens as the key `"tokens"`. [See here](#pretrain-jsonl) for details. |
|
||||
| `vectors_model` | positional | Name or path to spaCy model with vectors to learn from. |
|
||||
| `output_dir` | positional | Directory to write models to on each epoch. |
|
||||
| `--width`, `-cw` | option | Width of CNN layers. |
|
||||
| `--conv-depth`, `-cd` | option | Depth of CNN layers. |
|
||||
| `--cnn-window`, `-cW` <Tag variant="new">2.2.2</Tag> | option | Window size for CNN layers. |
|
||||
| `--cnn-pieces`, `-cP` <Tag variant="new">2.2.2</Tag> | option | Maxout size for CNN layers. `1` for [Mish](https://github.com/digantamisra98/Mish). |
|
||||
| `--use-chars`, `-chr` <Tag variant="new">2.2.2</Tag> | flag | Whether to use character-based embedding. |
|
||||
| `--sa-depth`, `-sa` <Tag variant="new">2.2.2</Tag> | option | Depth of self-attention layers. |
|
||||
| `--embed-rows`, `-er` | option | Number of embedding rows. |
|
||||
| `--loss-func`, `-L` | option | Loss function to use for the objective. Either `"L2"` or `"cosine"`. |
|
||||
| `--dropout`, `-d` | option | Dropout rate. |
|
||||
| `--batch-size`, `-bs` | option | Number of words per training batch. |
|
||||
| `--max-length`, `-xw` | option | Maximum words per example. Longer examples are discarded. |
|
||||
| `--min-length`, `-nw` | option | Minimum words per example. Shorter examples are discarded. |
|
||||
| `--seed`, `-s` | option | Seed for random number generators. |
|
||||
| `--n-iter`, `-i` | option | Number of iterations to pretrain. |
|
||||
| `--use-vectors`, `-uv` | flag | Whether to use the static vectors as input features. |
|
||||
| `--n-save-every`, `-se` | option | Save model every X batches. |
|
||||
| `--init-tok2vec`, `-t2v` <Tag variant="new">2.1</Tag> | option | Path to pretrained weights for the token-to-vector parts of the models. See `spacy pretrain`. Experimental. |
|
||||
| `--epoch-start`, `-es` <Tag variant="new">2.1.5</Tag> | option | The epoch to start counting at. Only relevant when using `--init-tok2vec` and the given weight file has been renamed. Prevents unintended overwriting of existing weight files. |
|
||||
| **CREATES** | weights | The pretrained weights that can be used to initialize `spacy train`. |
|
||||
| Argument | Type | Description |
|
||||
| ----------------------- | ---------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `texts_loc` | positional | Path to JSONL file with raw texts to learn from, with text provided as the key `"text"` or tokens as the key `"tokens"`. [See here](#pretrain-jsonl) for details. |
|
||||
| `output_dir` | positional | Directory to write models to on each epoch. |
|
||||
| `config_path` | positional | Path to [training config](/api/data-formats#config) file containing all settings and hyperparameters. |
|
||||
| `--code`, `-c` | option | Path to Python file with additional code to be imported. Allows [registering custom functions](/usage/training#custom-models) for new architectures. |
|
||||
| `--resume-path`, `-r` | option | TODO: |
|
||||
| `--epoch-resume`, `-er` | option | TODO: |
|
||||
| `--help`, `-h` | flag | Show help message and available arguments. |
|
||||
| overrides | | Config parameters to override. Should be options starting with `--` that correspond to the config section and value to override, e.g. `--training.use_gpu 1`. |
|
||||
| **CREATES** | weights | The pretrained weights that can be used to initialize `spacy train`. |
|
||||
|
||||
### JSONL format for raw text {#pretrain-jsonl}
|
||||
|
||||
|
|
|
@ -136,69 +136,32 @@ Some of the main advantages and features of spaCy's training config are:
|
|||
Python [type hints](https://docs.python.org/3/library/typing.html) to tell the
|
||||
config which types of data to expect.
|
||||
|
||||
<!-- TODO: instead of hard-coding a full config here, we probably want to embed it from GitHub, e.g. from one of the project templates. This also makes it easier to keep it up to date, and the embed widgets take up less space-->
|
||||
<!-- TODO: update this config? -->
|
||||
|
||||
```ini
|
||||
[training]
|
||||
use_gpu = -1
|
||||
limit = 0
|
||||
dropout = 0.2
|
||||
patience = 1000
|
||||
eval_frequency = 20
|
||||
scores = ["ents_p", "ents_r", "ents_f"]
|
||||
score_weights = {"ents_f": 1}
|
||||
orth_variant_level = 0.0
|
||||
gold_preproc = false
|
||||
max_length = 0
|
||||
seed = 0
|
||||
accumulate_gradient = 1
|
||||
discard_oversize = false
|
||||
|
||||
[training.batch_size]
|
||||
@schedules = "compounding.v1"
|
||||
start = 100
|
||||
stop = 1000
|
||||
compound = 1.001
|
||||
|
||||
[training.optimizer]
|
||||
@optimizers = "Adam.v1"
|
||||
learn_rate = 0.001
|
||||
beta1 = 0.9
|
||||
beta2 = 0.999
|
||||
use_averages = false
|
||||
|
||||
[nlp]
|
||||
lang = "en"
|
||||
vectors = null
|
||||
|
||||
[nlp.pipeline.ner]
|
||||
factory = "ner"
|
||||
|
||||
[nlp.pipeline.ner.model]
|
||||
@architectures = "spacy.TransitionBasedParser.v1"
|
||||
nr_feature_tokens = 3
|
||||
hidden_width = 128
|
||||
maxout_pieces = 3
|
||||
use_upper = true
|
||||
|
||||
[nlp.pipeline.ner.model.tok2vec]
|
||||
@architectures = "spacy.HashEmbedCNN.v1"
|
||||
width = 128
|
||||
depth = 4
|
||||
embed_size = 7000
|
||||
maxout_pieces = 3
|
||||
window_size = 1
|
||||
subword_features = true
|
||||
pretrained_vectors = null
|
||||
dropout = null
|
||||
https://github.com/explosion/spaCy/blob/develop/examples/experiments/onto-joint/defaults.cfg
|
||||
```
|
||||
|
||||
<!-- TODO: explain settings and @ notation, refer to function registry docs -->
|
||||
Under the hood, the config is parsed into a dictionary. It's divided into
|
||||
sections and subsections, indicated by the square brackets and dot notation. For
|
||||
example, `[training]` is a section and `[training.batch_size]` a subsections.
|
||||
Subsections can define values, just like a dictionary, or use the `@` syntax to
|
||||
refer to [registered functions](#config-functions). This allows the config to
|
||||
not just define static settings, but also construct objects like architectures,
|
||||
schedules, optimizers or any other custom components. The main top-level
|
||||
sections of a config file are:
|
||||
|
||||
| Section | Description |
|
||||
| ------------- | ----------------------------------------------------------------------------------------------------- |
|
||||
| `training` | Settings and controls for the training and evaluation process. |
|
||||
| `pretraining` | Optional settings and controls for the [language model pretraining](#pretraining). |
|
||||
| `nlp` | Definition of the [processing pipeline](/docs/processing-pipelines), its components and their models. |
|
||||
|
||||
<Infobox title="Config format and settings" emoji="📖">
|
||||
|
||||
For a full overview of spaCy's config format and settings, see the
|
||||
[training format documentation](/api/data-formats#config). The settings
|
||||
[training format documentation](/api/data-formats#config) and
|
||||
[Thinc's config system docs](https://thinc.ai/usage/config). The settings
|
||||
available for the different architectures are documented with the
|
||||
[model architectures API](/api/architectures). See the Thinc documentation for
|
||||
[optimizers](https://thinc.ai/docs/api-optimizers) and
|
||||
|
@ -206,6 +169,30 @@ available for the different architectures are documented with the
|
|||
|
||||
</Infobox>
|
||||
|
||||
#### Overwriting config settings on the command line {#config-overrides}
|
||||
|
||||
The config system means that you can define all settings **in one place** and in
|
||||
a consistent format. There are no command-line arguments that need to be set,
|
||||
and no hidden defaults. However, there can still be scenarios where you may want
|
||||
to override config settings when you run [`spacy train`](/api/cli#train). This
|
||||
includes **file paths** to vectors or other resources that shouldn't be
|
||||
hard-code in a config file, or **system-dependent settings** like the GPU ID.
|
||||
|
||||
For cases like this, you can set additional command-line options starting with
|
||||
`--` that correspond to the config section and value to override. For example,
|
||||
`--training.use_gpu 1` sets the `use_gpu` value in the `[training]` block to
|
||||
`1`.
|
||||
|
||||
```bash
|
||||
$ python -m spacy train train.spacy dev.spacy config.cfg
|
||||
--training.use_gpu 1 --nlp.vectors /path/to/vectors
|
||||
```
|
||||
|
||||
Only existing sections and values in the config can be overwritten. At the end
|
||||
of the training, the final filled `config.cfg` is exported with your model, so
|
||||
you'll always have a record of the settings that were used, including your
|
||||
overrides.
|
||||
|
||||
#### Using registered functions {#config-functions}
|
||||
|
||||
The training configuration defined in the config file doesn't have to only
|
||||
|
@ -229,9 +216,14 @@ You can also use this mechanism to register
|
|||
[custom implementations and architectures](#custom-models) and reference them
|
||||
from your configs.
|
||||
|
||||
> #### TODO
|
||||
> #### How the config is resolved
|
||||
>
|
||||
> TODO: something about how the tree is built bottom-up?
|
||||
> The config file is parsed into a regular dictionary and is resolved and
|
||||
> validated **bottom-up**. Arguments provided for registered functions are
|
||||
> checked against the function's signature and type annotations. The return
|
||||
> value of a registered function can also be passed into another function – for
|
||||
> instance, a learning rate schedule can be provided as the an argument of an
|
||||
> optimizer.
|
||||
|
||||
```ini
|
||||
### With registered function
|
||||
|
@ -382,6 +374,9 @@ cases, it's recommended to train your models via the
|
|||
[`spacy train`](/api/cli#train) command with a [`config.cfg`](#config) to keep
|
||||
track of your settings and hyperparameters, instead of writing your own training
|
||||
scripts from scratch.
|
||||
[Custom registered functions](/usage/training/#custom-code) should typically
|
||||
give you everything you need to train fully custom models with
|
||||
[`spacy train`](/api/cli#train).
|
||||
|
||||
</Infobox>
|
||||
|
||||
|
|
|
@ -5,8 +5,11 @@ import classNames from 'classnames'
|
|||
import Link from './link'
|
||||
import classes from '../styles/accordion.module.sass'
|
||||
|
||||
const Accordion = ({ title, id, expanded, children }) => {
|
||||
const Accordion = ({ title, id, expanded, spaced, children }) => {
|
||||
const [isExpanded, setIsExpanded] = useState(true)
|
||||
const rootClassNames = classNames(classes.root, {
|
||||
[classes.spaced]: !!spaced,
|
||||
})
|
||||
const contentClassNames = classNames(classes.content, {
|
||||
[classes.hidden]: !isExpanded,
|
||||
})
|
||||
|
@ -17,7 +20,7 @@ const Accordion = ({ title, id, expanded, children }) => {
|
|||
useEffect(() => setIsExpanded(expanded), [])
|
||||
return (
|
||||
<section className="accordion" id={id}>
|
||||
<div className={classes.root}>
|
||||
<div className={rootClassNames}>
|
||||
<h4>
|
||||
<button
|
||||
className={classes.button}
|
||||
|
|
|
@ -4,6 +4,9 @@
|
|||
margin-bottom: var(--spacing-xs)
|
||||
border-radius: var(--border-radius)
|
||||
|
||||
.spaced
|
||||
margin-bottom: var(--spacing-md)
|
||||
|
||||
.button
|
||||
font: bold var(--font-size-lg)/var(--line-height-md) var(--font-primary)
|
||||
color: var(--color-theme-dark)
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
.juniper-input pre,
|
||||
.juniper-output
|
||||
font: var(--font-size-code)/var(--line-height-code) var(--font-code) !important
|
||||
font-variant-ligatures: none !important
|
||||
-webkit-font-smoothing: subpixel-antialiased
|
||||
-moz-osx-font-smoothing: auto
|
||||
|
||||
|
@ -44,6 +45,7 @@
|
|||
box-decoration-break: clone
|
||||
white-space: nowrap
|
||||
text-shadow: none
|
||||
font-variant-ligatures: none
|
||||
-webkit-font-smoothing: subpixel-antialiased
|
||||
-moz-osx-font-smoothing: auto
|
||||
|
||||
|
|
|
@ -358,6 +358,14 @@ body [id]:target
|
|||
&.italic
|
||||
font-style: italic
|
||||
|
||||
// Settings for ini syntax (config files)
|
||||
[class*="language-ini"]
|
||||
color: var(--syntax-comment)
|
||||
|
||||
.token
|
||||
color: var(--color-subtle)
|
||||
|
||||
|
||||
.gatsby-highlight-code-line
|
||||
background-color: var(--color-dark-secondary)
|
||||
border-left: 0.35em solid var(--color-theme)
|
||||
|
@ -371,7 +379,6 @@ body [id]:target
|
|||
// Fix issue where empty lines would disappear
|
||||
content: " "
|
||||
|
||||
|
||||
// CodeMirror
|
||||
|
||||
.CodeMirror.cm-s-default
|
||||
|
|
Loading…
Reference in New Issue
Block a user