mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
Train CLI script fixes (#5931)
* fix dash replacement in overrides arguments * perform interpolation on training config * make sure only .spacy files are read
This commit is contained in:
parent
82f0e20318
commit
688e77562b
|
@ -68,11 +68,12 @@ def parse_config_overrides(args: List[str]) -> Dict[str, Any]:
|
|||
opt = args.pop(0)
|
||||
err = f"Invalid CLI argument '{opt}'"
|
||||
if opt.startswith("--"): # new argument
|
||||
opt = opt.replace("--", "").replace("-", "_")
|
||||
opt = opt.replace("--", "")
|
||||
if "." not in opt:
|
||||
msg.fail(f"{err}: can't override top-level section", exits=1)
|
||||
if "=" in opt: # we have --opt=value
|
||||
opt, value = opt.split("=", 1)
|
||||
opt = opt.replace("-", "_")
|
||||
else:
|
||||
if not args or args[0].startswith("--"): # flag with no value
|
||||
value = "true"
|
||||
|
|
|
@ -75,7 +75,7 @@ def train(
|
|||
msg.info("Using CPU")
|
||||
msg.info(f"Loading config and nlp from: {config_path}")
|
||||
with show_validation_error(config_path):
|
||||
config = util.load_config(config_path, overrides=config_overrides)
|
||||
config = util.load_config(config_path, overrides=config_overrides, interpolate=True)
|
||||
if config.get("training", {}).get("seed") is not None:
|
||||
fix_random_seed(config["training"]["seed"])
|
||||
# Use original config here before it's resolved to functions
|
||||
|
|
|
@ -78,6 +78,7 @@ class Warnings:
|
|||
"are currently: {langs}")
|
||||
|
||||
# TODO: fix numbering after merging develop into master
|
||||
W090 = ("Could not locate any binary .spacy files in path '{path}'.")
|
||||
W091 = ("Could not clean/remove the temp directory at {dir}: {msg}.")
|
||||
W092 = ("Ignoring annotations for sentence starts, as dependency heads are set.")
|
||||
W093 = ("Could not find any data to train the {name} on. Is your "
|
||||
|
@ -600,7 +601,8 @@ class Errors:
|
|||
"\"en_core_web_sm\" will copy the component from that model.\n\n{config}")
|
||||
E985 = ("Can't load model from config file: no 'nlp' section found.\n\n{config}")
|
||||
E986 = ("Could not create any training batches: check your input. "
|
||||
"Perhaps discard_oversize should be set to False ?")
|
||||
"Are the train and dev paths defined? "
|
||||
"Is 'discard_oversize' set appropriately? ")
|
||||
E987 = ("The text of an example training instance is either a Doc or "
|
||||
"a string, but found {type} instead.")
|
||||
E988 = ("Could not parse any training examples. Ensure the data is "
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import warnings
|
||||
from typing import Union, List, Iterable, Iterator, TYPE_CHECKING, Callable
|
||||
from pathlib import Path
|
||||
|
||||
from .. import util
|
||||
from .example import Example
|
||||
from ..errors import Warnings
|
||||
from ..tokens import DocBin, Doc
|
||||
from ..vocab import Vocab
|
||||
|
||||
|
@ -10,6 +12,8 @@ if TYPE_CHECKING:
|
|||
# This lets us add type hints for mypy etc. without causing circular imports
|
||||
from ..language import Language # noqa: F401
|
||||
|
||||
FILE_TYPE = ".spacy"
|
||||
|
||||
|
||||
@util.registry.readers("spacy.Corpus.v1")
|
||||
def create_docbin_reader(
|
||||
|
@ -53,8 +57,9 @@ class Corpus:
|
|||
@staticmethod
|
||||
def walk_corpus(path: Union[str, Path]) -> List[Path]:
|
||||
path = util.ensure_path(path)
|
||||
if not path.is_dir():
|
||||
if not path.is_dir() and path.parts[-1].endswith(FILE_TYPE):
|
||||
return [path]
|
||||
orig_path = path
|
||||
paths = [path]
|
||||
locs = []
|
||||
seen = set()
|
||||
|
@ -66,8 +71,10 @@ class Corpus:
|
|||
continue
|
||||
elif path.is_dir():
|
||||
paths.extend(path.iterdir())
|
||||
elif path.parts[-1].endswith(".spacy"):
|
||||
elif path.parts[-1].endswith(FILE_TYPE):
|
||||
locs.append(path)
|
||||
if len(locs) == 0:
|
||||
warnings.warn(Warnings.W090.format(path=orig_path))
|
||||
return locs
|
||||
|
||||
def __call__(self, nlp: "Language") -> Iterator[Example]:
|
||||
|
@ -135,7 +142,7 @@ class Corpus:
|
|||
i = 0
|
||||
for loc in locs:
|
||||
loc = util.ensure_path(loc)
|
||||
if loc.parts[-1].endswith(".spacy"):
|
||||
if loc.parts[-1].endswith(FILE_TYPE):
|
||||
doc_bin = DocBin().from_disk(loc)
|
||||
docs = doc_bin.get_docs(vocab)
|
||||
for doc in docs:
|
||||
|
|
Loading…
Reference in New Issue
Block a user