Train CLI script fixes (#5931)

* fix dash replacement in overrides arguments

* perform interpolation on training config

* make sure only .spacy files are read
This commit is contained in:
Sofie Van Landeghem 2020-08-18 16:06:37 +02:00 committed by GitHub
parent 82f0e20318
commit 688e77562b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 17 additions and 7 deletions

View File

@ -68,11 +68,12 @@ def parse_config_overrides(args: List[str]) -> Dict[str, Any]:
opt = args.pop(0)
err = f"Invalid CLI argument '{opt}'"
if opt.startswith("--"): # new argument
opt = opt.replace("--", "").replace("-", "_")
opt = opt.replace("--", "")
if "." not in opt:
msg.fail(f"{err}: can't override top-level section", exits=1)
if "=" in opt: # we have --opt=value
opt, value = opt.split("=", 1)
opt = opt.replace("-", "_")
else:
if not args or args[0].startswith("--"): # flag with no value
value = "true"

View File

@ -75,7 +75,7 @@ def train(
msg.info("Using CPU")
msg.info(f"Loading config and nlp from: {config_path}")
with show_validation_error(config_path):
config = util.load_config(config_path, overrides=config_overrides)
config = util.load_config(config_path, overrides=config_overrides, interpolate=True)
if config.get("training", {}).get("seed") is not None:
fix_random_seed(config["training"]["seed"])
# Use original config here before it's resolved to functions

View File

@ -78,10 +78,11 @@ class Warnings:
"are currently: {langs}")
# TODO: fix numbering after merging develop into master
W090 = ("Could not locate any binary .spacy files in path '{path}'.")
W091 = ("Could not clean/remove the temp directory at {dir}: {msg}.")
W092 = ("Ignoring annotations for sentence starts, as dependency heads are set.")
W093 = ("Could not find any data to train the {name} on. Is your "
"input data correctly formatted ?")
"input data correctly formatted?")
W094 = ("Model '{model}' ({model_version}) specifies an under-constrained "
"spaCy version requirement: {version}. This can lead to compatibility "
"problems with older versions, or as new spaCy versions are "
@ -600,7 +601,8 @@ class Errors:
"\"en_core_web_sm\" will copy the component from that model.\n\n{config}")
E985 = ("Can't load model from config file: no 'nlp' section found.\n\n{config}")
E986 = ("Could not create any training batches: check your input. "
"Perhaps discard_oversize should be set to False ?")
"Are the train and dev paths defined? "
"Is 'discard_oversize' set appropriately? ")
E987 = ("The text of an example training instance is either a Doc or "
"a string, but found {type} instead.")
E988 = ("Could not parse any training examples. Ensure the data is "

View File

@ -1,8 +1,10 @@
import warnings
from typing import Union, List, Iterable, Iterator, TYPE_CHECKING, Callable
from pathlib import Path
from .. import util
from .example import Example
from ..errors import Warnings
from ..tokens import DocBin, Doc
from ..vocab import Vocab
@ -10,6 +12,8 @@ if TYPE_CHECKING:
# This lets us add type hints for mypy etc. without causing circular imports
from ..language import Language # noqa: F401
FILE_TYPE = ".spacy"
@util.registry.readers("spacy.Corpus.v1")
def create_docbin_reader(
@ -53,8 +57,9 @@ class Corpus:
@staticmethod
def walk_corpus(path: Union[str, Path]) -> List[Path]:
path = util.ensure_path(path)
if not path.is_dir():
if not path.is_dir() and path.parts[-1].endswith(FILE_TYPE):
return [path]
orig_path = path
paths = [path]
locs = []
seen = set()
@ -66,8 +71,10 @@ class Corpus:
continue
elif path.is_dir():
paths.extend(path.iterdir())
elif path.parts[-1].endswith(".spacy"):
elif path.parts[-1].endswith(FILE_TYPE):
locs.append(path)
if len(locs) == 0:
warnings.warn(Warnings.W090.format(path=orig_path))
return locs
def __call__(self, nlp: "Language") -> Iterator[Example]:
@ -135,7 +142,7 @@ class Corpus:
i = 0
for loc in locs:
loc = util.ensure_path(loc)
if loc.parts[-1].endswith(".spacy"):
if loc.parts[-1].endswith(FILE_TYPE):
doc_bin = DocBin().from_disk(loc)
docs = doc_bin.get_docs(vocab)
for doc in docs: