mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
Train CLI script fixes (#5931)
* fix dash replacement in overrides arguments * perform interpolation on training config * make sure only .spacy files are read
This commit is contained in:
parent
82f0e20318
commit
688e77562b
|
@ -68,11 +68,12 @@ def parse_config_overrides(args: List[str]) -> Dict[str, Any]:
|
||||||
opt = args.pop(0)
|
opt = args.pop(0)
|
||||||
err = f"Invalid CLI argument '{opt}'"
|
err = f"Invalid CLI argument '{opt}'"
|
||||||
if opt.startswith("--"): # new argument
|
if opt.startswith("--"): # new argument
|
||||||
opt = opt.replace("--", "").replace("-", "_")
|
opt = opt.replace("--", "")
|
||||||
if "." not in opt:
|
if "." not in opt:
|
||||||
msg.fail(f"{err}: can't override top-level section", exits=1)
|
msg.fail(f"{err}: can't override top-level section", exits=1)
|
||||||
if "=" in opt: # we have --opt=value
|
if "=" in opt: # we have --opt=value
|
||||||
opt, value = opt.split("=", 1)
|
opt, value = opt.split("=", 1)
|
||||||
|
opt = opt.replace("-", "_")
|
||||||
else:
|
else:
|
||||||
if not args or args[0].startswith("--"): # flag with no value
|
if not args or args[0].startswith("--"): # flag with no value
|
||||||
value = "true"
|
value = "true"
|
||||||
|
|
|
@ -75,7 +75,7 @@ def train(
|
||||||
msg.info("Using CPU")
|
msg.info("Using CPU")
|
||||||
msg.info(f"Loading config and nlp from: {config_path}")
|
msg.info(f"Loading config and nlp from: {config_path}")
|
||||||
with show_validation_error(config_path):
|
with show_validation_error(config_path):
|
||||||
config = util.load_config(config_path, overrides=config_overrides)
|
config = util.load_config(config_path, overrides=config_overrides, interpolate=True)
|
||||||
if config.get("training", {}).get("seed") is not None:
|
if config.get("training", {}).get("seed") is not None:
|
||||||
fix_random_seed(config["training"]["seed"])
|
fix_random_seed(config["training"]["seed"])
|
||||||
# Use original config here before it's resolved to functions
|
# Use original config here before it's resolved to functions
|
||||||
|
|
|
@ -78,10 +78,11 @@ class Warnings:
|
||||||
"are currently: {langs}")
|
"are currently: {langs}")
|
||||||
|
|
||||||
# TODO: fix numbering after merging develop into master
|
# TODO: fix numbering after merging develop into master
|
||||||
|
W090 = ("Could not locate any binary .spacy files in path '{path}'.")
|
||||||
W091 = ("Could not clean/remove the temp directory at {dir}: {msg}.")
|
W091 = ("Could not clean/remove the temp directory at {dir}: {msg}.")
|
||||||
W092 = ("Ignoring annotations for sentence starts, as dependency heads are set.")
|
W092 = ("Ignoring annotations for sentence starts, as dependency heads are set.")
|
||||||
W093 = ("Could not find any data to train the {name} on. Is your "
|
W093 = ("Could not find any data to train the {name} on. Is your "
|
||||||
"input data correctly formatted ?")
|
"input data correctly formatted?")
|
||||||
W094 = ("Model '{model}' ({model_version}) specifies an under-constrained "
|
W094 = ("Model '{model}' ({model_version}) specifies an under-constrained "
|
||||||
"spaCy version requirement: {version}. This can lead to compatibility "
|
"spaCy version requirement: {version}. This can lead to compatibility "
|
||||||
"problems with older versions, or as new spaCy versions are "
|
"problems with older versions, or as new spaCy versions are "
|
||||||
|
@ -600,7 +601,8 @@ class Errors:
|
||||||
"\"en_core_web_sm\" will copy the component from that model.\n\n{config}")
|
"\"en_core_web_sm\" will copy the component from that model.\n\n{config}")
|
||||||
E985 = ("Can't load model from config file: no 'nlp' section found.\n\n{config}")
|
E985 = ("Can't load model from config file: no 'nlp' section found.\n\n{config}")
|
||||||
E986 = ("Could not create any training batches: check your input. "
|
E986 = ("Could not create any training batches: check your input. "
|
||||||
"Perhaps discard_oversize should be set to False ?")
|
"Are the train and dev paths defined? "
|
||||||
|
"Is 'discard_oversize' set appropriately? ")
|
||||||
E987 = ("The text of an example training instance is either a Doc or "
|
E987 = ("The text of an example training instance is either a Doc or "
|
||||||
"a string, but found {type} instead.")
|
"a string, but found {type} instead.")
|
||||||
E988 = ("Could not parse any training examples. Ensure the data is "
|
E988 = ("Could not parse any training examples. Ensure the data is "
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
|
import warnings
|
||||||
from typing import Union, List, Iterable, Iterator, TYPE_CHECKING, Callable
|
from typing import Union, List, Iterable, Iterator, TYPE_CHECKING, Callable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from .. import util
|
from .. import util
|
||||||
from .example import Example
|
from .example import Example
|
||||||
|
from ..errors import Warnings
|
||||||
from ..tokens import DocBin, Doc
|
from ..tokens import DocBin, Doc
|
||||||
from ..vocab import Vocab
|
from ..vocab import Vocab
|
||||||
|
|
||||||
|
@ -10,6 +12,8 @@ if TYPE_CHECKING:
|
||||||
# This lets us add type hints for mypy etc. without causing circular imports
|
# This lets us add type hints for mypy etc. without causing circular imports
|
||||||
from ..language import Language # noqa: F401
|
from ..language import Language # noqa: F401
|
||||||
|
|
||||||
|
FILE_TYPE = ".spacy"
|
||||||
|
|
||||||
|
|
||||||
@util.registry.readers("spacy.Corpus.v1")
|
@util.registry.readers("spacy.Corpus.v1")
|
||||||
def create_docbin_reader(
|
def create_docbin_reader(
|
||||||
|
@ -53,8 +57,9 @@ class Corpus:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def walk_corpus(path: Union[str, Path]) -> List[Path]:
|
def walk_corpus(path: Union[str, Path]) -> List[Path]:
|
||||||
path = util.ensure_path(path)
|
path = util.ensure_path(path)
|
||||||
if not path.is_dir():
|
if not path.is_dir() and path.parts[-1].endswith(FILE_TYPE):
|
||||||
return [path]
|
return [path]
|
||||||
|
orig_path = path
|
||||||
paths = [path]
|
paths = [path]
|
||||||
locs = []
|
locs = []
|
||||||
seen = set()
|
seen = set()
|
||||||
|
@ -66,8 +71,10 @@ class Corpus:
|
||||||
continue
|
continue
|
||||||
elif path.is_dir():
|
elif path.is_dir():
|
||||||
paths.extend(path.iterdir())
|
paths.extend(path.iterdir())
|
||||||
elif path.parts[-1].endswith(".spacy"):
|
elif path.parts[-1].endswith(FILE_TYPE):
|
||||||
locs.append(path)
|
locs.append(path)
|
||||||
|
if len(locs) == 0:
|
||||||
|
warnings.warn(Warnings.W090.format(path=orig_path))
|
||||||
return locs
|
return locs
|
||||||
|
|
||||||
def __call__(self, nlp: "Language") -> Iterator[Example]:
|
def __call__(self, nlp: "Language") -> Iterator[Example]:
|
||||||
|
@ -135,7 +142,7 @@ class Corpus:
|
||||||
i = 0
|
i = 0
|
||||||
for loc in locs:
|
for loc in locs:
|
||||||
loc = util.ensure_path(loc)
|
loc = util.ensure_path(loc)
|
||||||
if loc.parts[-1].endswith(".spacy"):
|
if loc.parts[-1].endswith(FILE_TYPE):
|
||||||
doc_bin = DocBin().from_disk(loc)
|
doc_bin = DocBin().from_disk(loc)
|
||||||
docs = doc_bin.get_docs(vocab)
|
docs = doc_bin.get_docs(vocab)
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user