Port CLI to Typer and add project stubs

This commit is contained in:
Ines Montani 2020-06-21 13:44:00 +02:00
parent 988d2a4eda
commit c12713a8be
17 changed files with 327 additions and 170 deletions

View File

@ -1,31 +1,4 @@
if __name__ == "__main__": from spacy.cli import app
import plac
import sys
from wasabi import msg
from spacy.cli import download, link, info, package, pretrain, convert
from spacy.cli import init_model, profile, evaluate, validate, debug_data
from spacy.cli import train_cli
commands = { if __name__ == "__main__":
"download": download, app()
"link": link,
"info": info,
"train": train_cli,
"pretrain": pretrain,
"debug-data": debug_data,
"evaluate": evaluate,
"convert": convert,
"package": package,
"init-model": init_model,
"profile": profile,
"validate": validate,
}
if len(sys.argv) == 1:
msg.info("Available commands", ", ".join(commands), exits=1)
command = sys.argv.pop(1)
sys.argv[0] = f"spacy {command}"
if command in commands:
plac.call(commands[command], sys.argv[1:])
else:
available = f"Available: {', '.join(commands)}"
msg.fail(f"Unknown command: {command}", available, exits=1)

View File

@ -5,3 +5,4 @@ __release__ = True
__download_url__ = "https://github.com/explosion/spacy-models/releases/download" __download_url__ = "https://github.com/explosion/spacy-models/releases/download"
__compatibility__ = "https://raw.githubusercontent.com/explosion/spacy-models/master/compatibility.json" __compatibility__ = "https://raw.githubusercontent.com/explosion/spacy-models/master/compatibility.json"
__shortcuts__ = "https://raw.githubusercontent.com/explosion/spacy-models/master/shortcuts-v2.json" __shortcuts__ = "https://raw.githubusercontent.com/explosion/spacy-models/master/shortcuts-v2.json"
__projects__ = "https://github.com/explosion/spacy-boilerplates"

View File

@ -1,5 +1,4 @@
from wasabi import msg from ._app import app # noqa: F401
from .download import download # noqa: F401 from .download import download # noqa: F401
from .info import info # noqa: F401 from .info import info # noqa: F401
from .package import package # noqa: F401 from .package import package # noqa: F401
@ -11,10 +10,4 @@ from .evaluate import evaluate # noqa: F401
from .convert import convert # noqa: F401 from .convert import convert # noqa: F401
from .init_model import init_model # noqa: F401 from .init_model import init_model # noqa: F401
from .validate import validate # noqa: F401 from .validate import validate # noqa: F401
from .project import project_cli # noqa: F401
def link(*args, **kwargs):
msg.warn(
"As of spaCy v3.0, model symlinks are deprecated. You can load models "
"using their full names or from a directory path."
)

31
spacy/cli/_app.py Normal file
View File

@ -0,0 +1,31 @@
import typer
from wasabi import msg
def Arg(*args, help=None, **kwargs):
# Filter out help for now until it's officially supported
return typer.Argument(*args, **kwargs)
def Opt(*args, **kwargs):
return typer.Option(*args, show_default=True, **kwargs)
app = typer.Typer(
name="spacy",
help="""spaCy Command-line Interface
DOCS: https://spacy.io/api/cli
""",
)
@app.command("link", no_args_is_help=True, deprecated=True, hidden=True)
def link(*args, **kwargs):
"""As of spaCy v3.0, model symlinks are deprecated. You can load models
using their full names or from a directory path."""
msg.warn(
"As of spaCy v3.0, model symlinks are deprecated. You can load models "
"using their full names or from a directory path."
)

View File

@ -1,8 +1,11 @@
from typing import Optional
from enum import Enum
from pathlib import Path from pathlib import Path
from wasabi import Printer from wasabi import Printer
import srsly import srsly
import re import re
from ._app import app, Arg, Opt
from .converters import conllu2json, iob2json, conll_ner2json from .converters import conllu2json, iob2json, conll_ner2json
from .converters import ner_jsonl2json from .converters import ner_jsonl2json
@ -21,23 +24,29 @@ CONVERTERS = {
} }
# File types # File types
FILE_TYPES = ("json", "jsonl", "msg")
FILE_TYPES_STDOUT = ("json", "jsonl") FILE_TYPES_STDOUT = ("json", "jsonl")
class FileTypes(str, Enum):
json = "json"
jsonl = "jsonl"
msg = "msg"
@app.command("convert")
def convert( def convert(
# fmt: off # fmt: off
input_file: ("Input file", "positional", None, str), input_file: str = Arg(..., help="Input file"),
output_dir: ("Output directory. '-' for stdout.", "positional", None, str) = "-", output_dir: str = Arg("-", help="Output directory. '-' for stdout."),
file_type: (f"Type of data to produce: {FILE_TYPES}", "option", "t", str, FILE_TYPES) = "json", file_type: FileTypes = Opt(FileTypes.json.value, "--file-type", "-t", help="Type of data to produce"),
n_sents: ("Number of sentences per doc (0 to disable)", "option", "n", int) = 1, n_sents: int = Opt(1, "--n-sents", "-n", help="Number of sentences per doc (0 to disable)"),
seg_sents: ("Segment sentences (for -c ner)", "flag", "s") = False, seg_sents: bool = Opt(False, "--seg-sents", "-s", help="Segment sentences (for -c ner)"),
model: ("Model for sentence segmentation (for -s)", "option", "b", str) = None, model: Optional[str] = Opt(None, "--model", "-b", help="Model for sentence segmentation (for -s)"),
morphology: ("Enable appending morphology to tags", "flag", "m", bool) = False, morphology: bool = Opt(False, "--morphology", "-m", help="Enable appending morphology to tags"),
merge_subtokens: ("Merge CoNLL-U subtokens", "flag", "T", bool) = False, merge_subtokens: bool = Opt(False, "--merge-subtokens", "-T", help="Merge CoNLL-U subtokens"),
converter: (f"Converter: {tuple(CONVERTERS.keys())}", "option", "c", str) = "auto", converter: str = Opt("auto", "--converter", "-c", help=f"Converter: {tuple(CONVERTERS.keys())}"),
ner_map_path: ("NER tag mapping (as JSON-encoded dict of entity types)", "option", "N", Path) = None, ner_map_path: Optional[Path] = Opt(None, "--ner-map-path", "-N", help="NER tag mapping (as JSON-encoded dict of entity types)"),
lang: ("Language (if tokenizer required)", "option", "l", str) = None, lang: Optional[str] = Opt(None, "--lang", "-l", help="Language (if tokenizer required)"),
# fmt: on # fmt: on
): ):
""" """
@ -46,6 +55,9 @@ def convert(
is written to stdout, so you can pipe them forward to a JSON file: is written to stdout, so you can pipe them forward to a JSON file:
$ spacy convert some_file.conllu > some_file.json $ spacy convert some_file.conllu > some_file.json
""" """
if isinstance(file_type, FileTypes):
# We get an instance of the FileTypes from the CLI so we need its string value
file_type = file_type.value
no_print = output_dir == "-" no_print = output_dir == "-"
msg = Printer(no_print=no_print) msg = Printer(no_print=no_print)
input_path = Path(input_file) input_path = Path(input_file)

View File

@ -1,9 +1,11 @@
from typing import Optional
from pathlib import Path from pathlib import Path
from collections import Counter from collections import Counter
import sys import sys
import srsly import srsly
from wasabi import Printer, MESSAGES from wasabi import Printer, MESSAGES
from ._app import app, Arg, Opt
from ..gold import GoldCorpus from ..gold import GoldCorpus
from ..syntax import nonproj from ..syntax import nonproj
from ..util import load_model, get_lang_class from ..util import load_model, get_lang_class
@ -18,17 +20,18 @@ BLANK_MODEL_MIN_THRESHOLD = 100
BLANK_MODEL_THRESHOLD = 2000 BLANK_MODEL_THRESHOLD = 2000
@app.command("debug-data")
def debug_data( def debug_data(
# fmt: off # fmt: off
lang: ("Model language", "positional", None, str), lang: str = Arg(..., help="Model language"),
train_path: ("Location of JSON-formatted training data", "positional", None, Path), train_path: Path = Arg(..., help="Location of JSON-formatted training data"),
dev_path: ("Location of JSON-formatted development data", "positional", None, Path), dev_path: Path = Arg(..., help="Location of JSON-formatted development data"),
tag_map_path: ("Location of JSON-formatted tag map", "option", "tm", Path) = None, tag_map_path: Optional[Path] = Opt(None, "--tag-map-path", "-tm", help="Location of JSON-formatted tag map"),
base_model: ("Name of model to update (optional)", "option", "b", str) = None, base_model: Optional[str] = Opt(None, "--base-model", "-b", help="Name of model to update (optional)"),
pipeline: ("Comma-separated names of pipeline components to train", "option", "p", str) = "tagger,parser,ner", pipeline: str = Opt("tagger,parser,ner", "--pipeline", "-p", help="Comma-separated names of pipeline components to train"),
ignore_warnings: ("Ignore warnings, only show stats and errors", "flag", "IW", bool) = False, ignore_warnings: bool = Opt(False, "--ignore-warnings", "-IW", help="Ignore warnings, only show stats and errors"),
verbose: ("Print additional information and explanations", "flag", "V", bool) = False, verbose: bool = Opt(False, "--verbose", "-V", help="Print additional information and explanations"),
no_format: ("Don't pretty-print the results", "flag", "NF", bool) = False, no_format: bool = Opt(False, "--no-format", "-NF", help="Don't pretty-print the results"),
# fmt: on # fmt: on
): ):
""" """

View File

@ -1,17 +1,25 @@
from typing import List
import requests import requests
import os import os
import subprocess import subprocess
import sys import sys
from wasabi import msg from wasabi import msg
from ._app import app, Arg, Opt
from .. import about from .. import about
from ..util import is_package, get_base_version from ..util import is_package, get_base_version
@app.command(
"download",
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
)
def download( def download(
model: ("Model to download (shortcut or name)", "positional", None, str), # fmt: off
direct: ("Force direct download of name + version", "flag", "d", bool) = False, model: str = Arg(..., help="Model to download (shortcut or name)"),
*pip_args: ("Additional arguments to be passed to `pip install` on model install"), direct: bool = Opt(False, "--direct", "-d", help="Force direct download of name + version"),
pip_args: List[str] = Arg(..., help="Additional arguments to be passed to `pip install` on model install"),
# fmt: on
): ):
""" """
Download compatible model from default download path using pip. If --direct Download compatible model from default download path using pip. If --direct

View File

@ -1,20 +1,23 @@
from typing import Optional
from timeit import default_timer as timer from timeit import default_timer as timer
from wasabi import msg from wasabi import msg
from ._app import app, Arg, Opt
from ..gold import GoldCorpus from ..gold import GoldCorpus
from .. import util from .. import util
from .. import displacy from .. import displacy
@app.command("evaluate")
def evaluate( def evaluate(
# fmt: off # fmt: off
model: ("Model name or path", "positional", None, str), model: str = Arg(..., help="Model name or path"),
data_path: ("Location of JSON-formatted evaluation data", "positional", None, str), data_path: str = Arg(..., help="Location of JSON-formatted evaluation data"),
gpu_id: ("Use GPU", "option", "g", int) = -1, gpu_id: int = Opt(-1, "--gpu-id", "-g", help="Use GPU"),
gold_preproc: ("Use gold preprocessing", "flag", "G", bool) = False, gold_preproc: bool = Opt(False, "--gold-preproc", "-G", help="Use gold preprocessing"),
displacy_path: ("Directory to output rendered parses as HTML", "option", "dp", str) = None, displacy_path: Optional[str] = Opt(None, "--displacy-path", "-dp", help="Directory to output rendered parses as HTML"),
displacy_limit: ("Limit of parses to render as HTML", "option", "dl", int) = 25, displacy_limit: int = Opt(25, "--displacy-limit", "-dl", help="Limit of parses to render as HTML"),
return_scores: ("Return dict containing model scores", "flag", "R", bool) = False, return_scores: bool = Opt(False, "--return-scores", "-R", help="Return dict containing model scores"),
# fmt: on # fmt: on
): ):
""" """

View File

@ -1,17 +1,22 @@
from typing import Optional
import platform import platform
from pathlib import Path from pathlib import Path
from wasabi import msg from wasabi import msg
import srsly import srsly
from ._app import app, Arg, Opt
from .validate import get_model_pkgs from .validate import get_model_pkgs
from .. import util from .. import util
from .. import about from .. import about
@app.command("info")
def info( def info(
model: ("Optional model name", "positional", None, str) = None, # fmt: off
markdown: ("Generate Markdown for GitHub issues", "flag", "md", str) = False, model: Optional[str] = Arg(None, help="Optional model name"),
silent: ("Don't print anything (just return)", "flag", "s") = False, markdown: bool = Opt(False, "--markdown", "-md", help="Generate Markdown for GitHub issues"),
silent: bool = Opt(False, "--silent", "-s", help="Don't print anything (just return)"),
# fmt: on
): ):
""" """
Print info about spaCy installation. If a model is speficied as an argument, Print info about spaCy installation. If a model is speficied as an argument,

View File

@ -1,3 +1,4 @@
from typing import Optional
import math import math
from tqdm import tqdm from tqdm import tqdm
import numpy import numpy
@ -11,6 +12,7 @@ import srsly
import warnings import warnings
from wasabi import msg from wasabi import msg
from ._app import app, Arg, Opt
from ..vectors import Vectors from ..vectors import Vectors
from ..errors import Errors, Warnings from ..errors import Errors, Warnings
from ..util import ensure_path, get_lang_class, load_model, OOV_RANK from ..util import ensure_path, get_lang_class, load_model, OOV_RANK
@ -25,20 +27,21 @@ except ImportError:
DEFAULT_OOV_PROB = -20 DEFAULT_OOV_PROB = -20
@app.command("init-model")
def init_model( def init_model(
# fmt: off # fmt: off
lang: ("Model language", "positional", None, str), lang: str = Arg(..., help="Model language"),
output_dir: ("Model output directory", "positional", None, Path), output_dir: Path = Arg(..., help="Model output directory"),
freqs_loc: ("Location of words frequencies file", "option", "f", Path) = None, freqs_loc: Optional[Path] = Arg(None, help="Location of words frequencies file"),
clusters_loc: ("Optional location of brown clusters data", "option", "c", str) = None, clusters_loc: Optional[str] = Opt(None, "--clusters-loc", "-c", help="Optional location of brown clusters data"),
jsonl_loc: ("Location of JSONL-formatted attributes file", "option", "j", Path) = None, jsonl_loc: Optional[Path] = Opt(None, "--jsonl-loc", "-j", help="Location of JSONL-formatted attributes file"),
vectors_loc: ("Optional vectors file in Word2Vec format", "option", "v", str) = None, vectors_loc: Optional[str] = Opt(None, "--vectors-loc", "-v", help="Optional vectors file in Word2Vec format"),
prune_vectors: ("Optional number of vectors to prune to", "option", "V", int) = -1, prune_vectors: int = Opt(-1 , "--prune-vectors", "-V", help="Optional number of vectors to prune to"),
truncate_vectors: ("Optional number of vectors to truncate to when reading in vectors file", "option", "t", int) = 0, truncate_vectors: int = Opt(0, "--truncate-vectors", "-t", help="Optional number of vectors to truncate to when reading in vectors file"),
vectors_name: ("Optional name for the word vectors, e.g. en_core_web_lg.vectors", "option", "vn", str) = None, vectors_name: Optional[str] = Opt(None, "--vectors-name", "-vn", help="Optional name for the word vectors, e.g. en_core_web_lg.vectors"),
model_name: ("Optional name for the model meta", "option", "mn", str) = None, model_name: Optional[str] = Opt(None, "--model-name", "-mn", help="Optional name for the model meta"),
omit_extra_lookups: ("Don't include extra lookups in model", "flag", "OEL", bool) = False, omit_extra_lookups: bool = Opt(False, "--omit-extra-lookups", "-OEL", help="Don't include extra lookups in model"),
base_model: ("Base model (for languages with custom tokenizers)", "option", "b", str) = None base_model: Optional[str] = Opt(None, "--base-model", "-b", help="Base model (for languages with custom tokenizers)")
# fmt: on # fmt: on
): ):
""" """

View File

@ -1,19 +1,22 @@
from typing import Optional
import shutil import shutil
from pathlib import Path from pathlib import Path
from wasabi import msg, get_raw_input from wasabi import msg, get_raw_input
import srsly import srsly
from ._app import app, Arg, Opt
from .. import util from .. import util
from .. import about from .. import about
@app.command("package")
def package( def package(
# fmt: off # fmt: off
input_dir: ("Directory with model data", "positional", None, str), input_dir: str = Arg(..., help="Directory with model data"),
output_dir: ("Output parent directory", "positional", None, str), output_dir: str = Arg(..., help="Output parent directory"),
meta_path: ("Path to meta.json", "option", "m", str) = None, meta_path: Optional[str] = Opt(None, "--meta-path", "-m", help="Path to meta.json"),
create_meta: ("Create meta.json, even if one exists", "flag", "c", bool) = False, create_meta: bool = Opt(False, "--create-meta", "-c", help="Create meta.json, even if one exists"),
force: ("Force overwriting existing model in output directory", "flag", "f", bool) = False, force: bool = Opt(False, "--force", "-f", help="Force overwriting existing model in output directory"),
# fmt: on # fmt: on
): ):
""" """

View File

@ -1,14 +1,15 @@
from typing import Optional
import random import random
import numpy import numpy
import time import time
import re import re
from collections import Counter from collections import Counter
import plac
from pathlib import Path from pathlib import Path
from thinc.api import Linear, Maxout, chain, list2array, use_pytorch_for_gpu_memory from thinc.api import Linear, Maxout, chain, list2array, use_pytorch_for_gpu_memory
from wasabi import msg from wasabi import msg
import srsly import srsly
from ._app import app, Arg, Opt
from ..errors import Errors from ..errors import Errors
from ..ml.models.multi_task import build_masked_language_model from ..ml.models.multi_task import build_masked_language_model
from ..tokens import Doc from ..tokens import Doc
@ -17,25 +18,17 @@ from .. import util
from ..gold import Example from ..gold import Example
@plac.annotations( @app.command("pretrain")
# fmt: off
texts_loc=("Path to JSONL file with raw texts to learn from, with text provided as the key 'text' or tokens as the key 'tokens'", "positional", None, str),
vectors_model=("Name or path to spaCy model with vectors to learn from", "positional", None, str),
output_dir=("Directory to write models to on each epoch", "positional", None, Path),
config_path=("Path to config file", "positional", None, Path),
use_gpu=("Use GPU", "option", "g", int),
resume_path=("Path to pretrained weights from which to resume pretraining", "option", "r", Path),
epoch_resume=("The epoch to resume counting from when using '--resume_path'. Prevents unintended overwriting of existing weight files.", "option", "er", int),
# fmt: on
)
def pretrain( def pretrain(
texts_loc, # fmt: off
vectors_model, texts_loc: str =Arg(..., help="Path to JSONL file with raw texts to learn from, with text provided as the key 'text' or tokens as the key 'tokens'"),
config_path, vectors_model: str = Arg(..., help="Name or path to spaCy model with vectors to learn from"),
output_dir, output_dir: Path = Arg(..., help="Directory to write models to on each epoch"),
use_gpu=-1, config_path: Path = Arg(..., help="Path to config file"),
resume_path=None, use_gpu: int = Opt(-1, "--use-gpu", "-g", help="Use GPU"),
epoch_resume=None, resume_path: Optional[Path] = Opt(None, "--resume-path", "-r", help="Path to pretrained weights from which to resume pretraining"),
epoch_resume: Optional[int] = Opt(None, "--epoch-resume", "-er", help="The epoch to resume counting from when using '--resume_path'. Prevents unintended overwriting of existing weight files."),
# fmt: on
): ):
""" """
Pre-train the 'token-to-vector' (tok2vec) layer of pipeline components, Pre-train the 'token-to-vector' (tok2vec) layer of pipeline components,

View File

@ -1,3 +1,4 @@
from typing import Optional
import tqdm import tqdm
from pathlib import Path from pathlib import Path
import srsly import srsly
@ -8,14 +9,16 @@ import itertools
import ml_datasets import ml_datasets
from wasabi import msg from wasabi import msg
from ._app import app, Arg, Opt
from ..util import load_model from ..util import load_model
@app.command("profile")
def profile( def profile(
# fmt: off # fmt: off
model: ("Model to load", "positional", None, str), model: str = Arg(..., help="Model to load"),
inputs: ("Location of input file. '-' for stdin.", "positional", None, str) = None, inputs: Optional[str] = Arg(None, help="Location of input file. '-' for stdin."),
n_texts: ("Maximum number of texts to use if available", "option", "n", int) = 10000, n_texts: int = Opt(10000, "--n-texts", "-n", help="Maximum number of texts to use if available"),
# fmt: on # fmt: on
): ):
""" """

100
spacy/cli/project.py Normal file
View File

@ -0,0 +1,100 @@
from typing import List, Dict
import typer
import srsly
from pathlib import Path
import os
import subprocess
import sys
from wasabi import msg
import shlex
from ._app import app, Arg, Opt
from .. import about
from ..schemas import ProjectConfigSchema, validate
CONFIG_FILE = "project.yml"
SUBDIRS = [
"assets",
"configs",
"packages",
"metrics",
"scripts",
"notebooks",
"training",
]
project_cli = typer.Typer(help="Command-line interface for spaCy projects")
def load_project_config(path):
config_path = path / CONFIG_FILE
if not config_path.exists():
msg.fail("Can't find project config", config_path, exits=1)
config = srsly.read_yaml(config_path)
errors = validate(ProjectConfigSchema, config)
if errors:
msg.fail(f"Invalid project config in {CONFIG_FILE}", "\n".join(errors), exits=1)
return config
def create_dirs(project_dir: Path):
for subdir in SUBDIRS:
(project_dir / subdir).mkdir(parents=True)
def run_cmd(command: str):
status = subprocess.call(shlex.split(command), env=os.environ.copy())
if status != 0:
sys.exit(status)
def run_commands(commands: List[str] = tuple(), variables: Dict[str, str] = {}):
for command in commands:
# Substitute variables, e.g. "./{NAME}.json"
command = command.format(**variables)
msg.info(command)
run_cmd(command)
@project_cli.command("clone")
def project_clone(
# fmt: off
name: str = Arg(..., help="The name of the template to fetch"),
dest: Path = Arg(Path.cwd(), help="Where to download and work. Defaults to current working directory.", exists=True, file_okay=False),
repo: str = Opt(about.__projects__, "--repo", "-r", help="The repository to look in."),
# fmt: on
):
"""Clone a project template from a repository."""
print("Cloning", repo)
@project_cli.command("run")
def project_run(
# fmt: off
project_dir: Path = Arg(..., help="Location of project directory", exists=True, file_okay=False),
subcommand: str = Arg(None, help="Name of command defined in project config")
# fmt: on
):
"""Run scripts defined in the project."""
config = load_project_config(project_dir)
config_commands = config.get("commands", [])
variables = config.get("variables", {})
commands = {cmd["name"]: cmd for cmd in config_commands}
if subcommand is None:
all_commands = config.get("run", [])
if not all_commands:
msg.warn("No run commands defined in project config", exits=0)
msg.table([(cmd["name"], cmd.get("help", "")) for cmd in config_commands])
for command in all_commands:
if command not in commands:
msg.fail(f"Can't find command '{command}' in project config", exits=1)
msg.divider(command)
run_commands(commands[command]["script"], variables)
return
if subcommand not in commands:
msg.fail(f"Can't find command '{subcommand}' in project config", exits=1)
run_commands(commands[subcommand]["script"], variables)
app.add_typer(project_cli, name="project")

View File

@ -1,16 +1,15 @@
from typing import Optional, Dict, List, Union, Sequence from typing import Optional
from timeit import default_timer as timer from timeit import default_timer as timer
import srsly import srsly
from pydantic import BaseModel, FilePath
import tqdm import tqdm
from pathlib import Path from pathlib import Path
from wasabi import msg from wasabi import msg
import thinc import thinc
import thinc.schedules import thinc.schedules
from thinc.api import Model, use_pytorch_for_gpu_memory from thinc.api import use_pytorch_for_gpu_memory
import random import random
from ._app import app, Arg, Opt
from ..gold import GoldCorpus from ..gold import GoldCorpus
from ..lookups import Lookups from ..lookups import Lookups
from .. import util from .. import util
@ -19,6 +18,9 @@ from ..errors import Errors
# Don't remove - required to load the built-in architectures # Don't remove - required to load the built-in architectures
from ..ml import models # noqa: F401 from ..ml import models # noqa: F401
# from ..schemas import ConfigSchema # TODO: include?
registry = util.registry registry = util.registry
CONFIG_STR = """ CONFIG_STR = """
@ -80,54 +82,20 @@ subword_features = true
""" """
class PipelineComponent(BaseModel): @app.command("train")
factory: str
model: Model
class Config:
arbitrary_types_allowed = True
class ConfigSchema(BaseModel):
optimizer: Optional["Optimizer"]
class training(BaseModel):
patience: int = 10
eval_frequency: int = 100
dropout: float = 0.2
init_tok2vec: Optional[FilePath] = None
max_epochs: int = 100
orth_variant_level: float = 0.0
gold_preproc: bool = False
max_length: int = 0
use_gpu: int = 0
scores: List[str] = ["ents_p", "ents_r", "ents_f"]
score_weights: Dict[str, Union[int, float]] = {"ents_f": 1.0}
limit: int = 0
batch_size: Union[Sequence[int], int]
class nlp(BaseModel):
lang: str
vectors: Optional[str]
pipeline: Optional[Dict[str, PipelineComponent]]
class Config:
extra = "allow"
def train_cli( def train_cli(
# fmt: off # fmt: off
train_path: ("Location of JSON-formatted training data", "positional", None, Path), train_path: Path = Arg(..., help="Location of JSON-formatted training data"),
dev_path: ("Location of JSON-formatted development data", "positional", None, Path), dev_path: Path = Arg(..., help="Location of JSON-formatted development data"),
config_path: ("Path to config file", "positional", None, Path), config_path: Path = Arg(..., help="Path to config file"),
output_path: ("Output directory to store model in", "option", "o", Path) = None, output_path: Optional[Path] = Opt(None, "--output-path", "-o", help="Output directory to store model in"),
code_path: ("Path to Python file with additional code (registered functions) to be imported", "option", "c", Path) = None, code_path: Optional[Path] = Opt(None, "--code-path", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
init_tok2vec: ("Path to pretrained weights for the tok2vec components. See 'spacy pretrain'. Experimental.", "option", "t2v", Path) = None, init_tok2vec: Optional[Path] = Opt(None, "--init-tok2vec", "-t2v", help="Path to pretrained weights for the tok2vec components. See 'spacy pretrain'. Experimental."),
raw_text: ("Path to jsonl file with unlabelled text documents.", "option", "rt", Path) = None, raw_text: Optional[Path] = Opt(None, "--raw-text", "-rt", help="Path to jsonl file with unlabelled text documents."),
verbose: ("Display more information for debugging purposes", "flag", "VV", bool) = False, verbose: bool = Opt(False, "--verbose", "-VV", help="Display more information for debugging purposes"),
use_gpu: ("Use GPU", "option", "g", int) = -1, use_gpu: int = Opt(-1, "--use-gpu", "-g", help="Use GPU"),
tag_map_path: ("Location of JSON-formatted tag map", "option", "tm", Path) = None, tag_map_path: Optional[Path] = Opt(None, "--tag-map-path", "-tm", help="Location of JSON-formatted tag map"),
omit_extra_lookups: ("Don't include extra lookups in model", "flag", "OEL", bool) = False, omit_extra_lookups: bool = Opt(False, "--omit-extra-lookups", "-OEL", help="Don't include extra lookups in model"),
# fmt: on # fmt: on
): ):
""" """

View File

@ -3,11 +3,13 @@ import sys
import requests import requests
from wasabi import msg from wasabi import msg
from ._app import app
from .. import about from .. import about
from ..util import get_package_version, get_installed_models, get_base_version from ..util import get_package_version, get_installed_models, get_base_version
from ..util import get_package_path, get_model_meta, is_compatible_version from ..util import get_package_path, get_model_meta, is_compatible_version
@app.command("validate")
def validate(): def validate():
""" """
Validate that the currently installed version of spaCy is compatible Validate that the currently installed version of spaCy is compatible

View File

@ -1,8 +1,9 @@
from typing import Dict, List, Union, Optional from typing import Dict, List, Union, Optional, Sequence
from enum import Enum from enum import Enum
from pydantic import BaseModel, Field, ValidationError, validator from pydantic import BaseModel, Field, ValidationError, validator
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool, FilePath
from collections import defaultdict from collections import defaultdict
from thinc.api import Model
from .attrs import NAMES from .attrs import NAMES
@ -169,18 +170,42 @@ class ModelMetaSchema(BaseModel):
# fmt: on # fmt: on
# Training data object in "simple training style" # JSON training format
class SimpleTrainingSchema(BaseModel): class PipelineComponent(BaseModel):
# TODO: write factory: str
model: Model
class Config: class Config:
title = "Schema for training data dict in passed to nlp.update" arbitrary_types_allowed = True
extra = "forbid"
# JSON training format class ConfigSchema(BaseModel):
optimizer: Optional["Optimizer"]
class training(BaseModel):
patience: int = 10
eval_frequency: int = 100
dropout: float = 0.2
init_tok2vec: Optional[FilePath] = None
max_epochs: int = 100
orth_variant_level: float = 0.0
gold_preproc: bool = False
max_length: int = 0
use_gpu: int = 0
scores: List[str] = ["ents_p", "ents_r", "ents_f"]
score_weights: Dict[str, Union[int, float]] = {"ents_f": 1.0}
limit: int = 0
batch_size: Union[Sequence[int], int]
class nlp(BaseModel):
lang: str
vectors: Optional[str]
pipeline: Optional[Dict[str, PipelineComponent]]
class Config:
extra = "allow"
class TrainingSchema(BaseModel): class TrainingSchema(BaseModel):
@ -189,3 +214,34 @@ class TrainingSchema(BaseModel):
class Config: class Config:
title = "Schema for training data in spaCy's JSON format" title = "Schema for training data in spaCy's JSON format"
extra = "forbid" extra = "forbid"
# Project config Schema
class ProjectConfigAsset(BaseModel):
dest: StrictStr = Field(..., title="Destination of downloaded asset")
url: StrictStr = Field(..., title="URL of asset")
class ProjectConfigCommand(BaseModel):
# fmt: off
name: StrictStr = Field(..., title="Name of command")
help: Optional[StrictStr] = Field(None, title="Command description")
script: List[StrictStr] = Field([], title="List of CLI commands to run, in order")
dvc_deps: List[StrictStr] = Field([], title="Data Version Control dependencies")
dvc_outputs: List[StrictStr] = Field([], title="Data Version Control outputs")
dvc_outputs_no_cache: List[StrictStr] = Field([], title="Data Version Control outputs (no cache)")
# fmt: on
class ProjectConfigSchema(BaseModel):
# fmt: off
variables: Dict[StrictStr, Union[str, int, float, bool]] = Field({}, title="Optional variables to substitute in commands")
assets: List[ProjectConfigAsset] = Field([], title="Data assets")
run: List[StrictStr] = Field([], title="Names of project commands to execute, in order")
commands: List[ProjectConfigCommand] = Field([], title="Project command shortucts")
# fmt: on
class Config:
title = "Schema for project configuration file"