Train CLI script fixes (#5931)

* fix dash replacement in overrides arguments

* perform interpolation on training config

* make sure only .spacy files are read
This commit is contained in:
Sofie Van Landeghem 2020-08-18 16:06:37 +02:00 committed by GitHub
parent 82f0e20318
commit 688e77562b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 17 additions and 7 deletions

View File

@ -68,11 +68,12 @@ def parse_config_overrides(args: List[str]) -> Dict[str, Any]:
opt = args.pop(0) opt = args.pop(0)
err = f"Invalid CLI argument '{opt}'" err = f"Invalid CLI argument '{opt}'"
if opt.startswith("--"): # new argument if opt.startswith("--"): # new argument
opt = opt.replace("--", "").replace("-", "_") opt = opt.replace("--", "")
if "." not in opt: if "." not in opt:
msg.fail(f"{err}: can't override top-level section", exits=1) msg.fail(f"{err}: can't override top-level section", exits=1)
if "=" in opt: # we have --opt=value if "=" in opt: # we have --opt=value
opt, value = opt.split("=", 1) opt, value = opt.split("=", 1)
opt = opt.replace("-", "_")
else: else:
if not args or args[0].startswith("--"): # flag with no value if not args or args[0].startswith("--"): # flag with no value
value = "true" value = "true"

View File

@ -75,7 +75,7 @@ def train(
msg.info("Using CPU") msg.info("Using CPU")
msg.info(f"Loading config and nlp from: {config_path}") msg.info(f"Loading config and nlp from: {config_path}")
with show_validation_error(config_path): with show_validation_error(config_path):
config = util.load_config(config_path, overrides=config_overrides) config = util.load_config(config_path, overrides=config_overrides, interpolate=True)
if config.get("training", {}).get("seed") is not None: if config.get("training", {}).get("seed") is not None:
fix_random_seed(config["training"]["seed"]) fix_random_seed(config["training"]["seed"])
# Use original config here before it's resolved to functions # Use original config here before it's resolved to functions

View File

@ -78,10 +78,11 @@ class Warnings:
"are currently: {langs}") "are currently: {langs}")
# TODO: fix numbering after merging develop into master # TODO: fix numbering after merging develop into master
W090 = ("Could not locate any binary .spacy files in path '{path}'.")
W091 = ("Could not clean/remove the temp directory at {dir}: {msg}.") W091 = ("Could not clean/remove the temp directory at {dir}: {msg}.")
W092 = ("Ignoring annotations for sentence starts, as dependency heads are set.") W092 = ("Ignoring annotations for sentence starts, as dependency heads are set.")
W093 = ("Could not find any data to train the {name} on. Is your " W093 = ("Could not find any data to train the {name} on. Is your "
"input data correctly formatted ?") "input data correctly formatted?")
W094 = ("Model '{model}' ({model_version}) specifies an under-constrained " W094 = ("Model '{model}' ({model_version}) specifies an under-constrained "
"spaCy version requirement: {version}. This can lead to compatibility " "spaCy version requirement: {version}. This can lead to compatibility "
"problems with older versions, or as new spaCy versions are " "problems with older versions, or as new spaCy versions are "
@ -600,7 +601,8 @@ class Errors:
"\"en_core_web_sm\" will copy the component from that model.\n\n{config}") "\"en_core_web_sm\" will copy the component from that model.\n\n{config}")
E985 = ("Can't load model from config file: no 'nlp' section found.\n\n{config}") E985 = ("Can't load model from config file: no 'nlp' section found.\n\n{config}")
E986 = ("Could not create any training batches: check your input. " E986 = ("Could not create any training batches: check your input. "
"Perhaps discard_oversize should be set to False ?") "Are the train and dev paths defined? "
"Is 'discard_oversize' set appropriately? ")
E987 = ("The text of an example training instance is either a Doc or " E987 = ("The text of an example training instance is either a Doc or "
"a string, but found {type} instead.") "a string, but found {type} instead.")
E988 = ("Could not parse any training examples. Ensure the data is " E988 = ("Could not parse any training examples. Ensure the data is "

View File

@ -1,8 +1,10 @@
import warnings
from typing import Union, List, Iterable, Iterator, TYPE_CHECKING, Callable from typing import Union, List, Iterable, Iterator, TYPE_CHECKING, Callable
from pathlib import Path from pathlib import Path
from .. import util from .. import util
from .example import Example from .example import Example
from ..errors import Warnings
from ..tokens import DocBin, Doc from ..tokens import DocBin, Doc
from ..vocab import Vocab from ..vocab import Vocab
@ -10,6 +12,8 @@ if TYPE_CHECKING:
# This lets us add type hints for mypy etc. without causing circular imports # This lets us add type hints for mypy etc. without causing circular imports
from ..language import Language # noqa: F401 from ..language import Language # noqa: F401
FILE_TYPE = ".spacy"
@util.registry.readers("spacy.Corpus.v1") @util.registry.readers("spacy.Corpus.v1")
def create_docbin_reader( def create_docbin_reader(
@ -53,8 +57,9 @@ class Corpus:
@staticmethod @staticmethod
def walk_corpus(path: Union[str, Path]) -> List[Path]: def walk_corpus(path: Union[str, Path]) -> List[Path]:
path = util.ensure_path(path) path = util.ensure_path(path)
if not path.is_dir(): if not path.is_dir() and path.parts[-1].endswith(FILE_TYPE):
return [path] return [path]
orig_path = path
paths = [path] paths = [path]
locs = [] locs = []
seen = set() seen = set()
@ -66,8 +71,10 @@ class Corpus:
continue continue
elif path.is_dir(): elif path.is_dir():
paths.extend(path.iterdir()) paths.extend(path.iterdir())
elif path.parts[-1].endswith(".spacy"): elif path.parts[-1].endswith(FILE_TYPE):
locs.append(path) locs.append(path)
if len(locs) == 0:
warnings.warn(Warnings.W090.format(path=orig_path))
return locs return locs
def __call__(self, nlp: "Language") -> Iterator[Example]: def __call__(self, nlp: "Language") -> Iterator[Example]:
@ -135,7 +142,7 @@ class Corpus:
i = 0 i = 0
for loc in locs: for loc in locs:
loc = util.ensure_path(loc) loc = util.ensure_path(loc)
if loc.parts[-1].endswith(".spacy"): if loc.parts[-1].endswith(FILE_TYPE):
doc_bin = DocBin().from_disk(loc) doc_bin = DocBin().from_disk(loc)
docs = doc_bin.get_docs(vocab) docs = doc_bin.get_docs(vocab)
for doc in docs: for doc in docs: