mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-06 21:30:22 +03:00
add force_overwrite
This commit is contained in:
parent
248388728b
commit
55f97aec4e
|
@ -28,6 +28,9 @@ out_help = "Path where to save the result .spacy file"
|
||||||
code_help = ("Path to Python file with additional "
|
code_help = ("Path to Python file with additional "
|
||||||
"code (registered functions) to be imported")
|
"code (registered functions) to be imported")
|
||||||
gold_help = "Use gold preprocessing provided in the .spacy files"
|
gold_help = "Use gold preprocessing provided in the .spacy files"
|
||||||
|
force_msg = ("The provided output file already exists. "
|
||||||
|
"To force overwriting the config file, set the --force or -F flag.")
|
||||||
|
|
||||||
|
|
||||||
DocOrStrStream = Union[Iterable[str], Iterable[Doc]]
|
DocOrStrStream = Union[Iterable[str], Iterable[Doc]]
|
||||||
|
|
||||||
|
@ -47,8 +50,6 @@ def _stream_jsonl(path: Path, field) -> Iterable[str]:
|
||||||
not found it raises error.
|
not found it raises error.
|
||||||
"""
|
"""
|
||||||
for entry in srsly.read_jsonl(path):
|
for entry in srsly.read_jsonl(path):
|
||||||
print(entry)
|
|
||||||
print(field)
|
|
||||||
if field not in entry:
|
if field not in entry:
|
||||||
raise msg.fail(
|
raise msg.fail(
|
||||||
f"{path} does not contain the required '{field}' field.",
|
f"{path} does not contain the required '{field}' field.",
|
||||||
|
@ -76,6 +77,7 @@ def apply_cli(
|
||||||
output_file: Path = Arg(..., help=out_help, dir_okay=False),
|
output_file: Path = Arg(..., help=out_help, dir_okay=False),
|
||||||
code_path: Optional[Path] = Opt(None, "--code", "-c", help=code_help),
|
code_path: Optional[Path] = Opt(None, "--code", "-c", help=code_help),
|
||||||
json_field: str = Opt("text", "--field", "-f", help="Field to grab from .jsonl"),
|
json_field: str = Opt("text", "--field", "-f", help="Field to grab from .jsonl"),
|
||||||
|
force_overwrite: bool = Opt(False, "--force", "-F", help="Force overwriting the output file"),
|
||||||
use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU."),
|
use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU."),
|
||||||
batch_size: int = Opt(1, "--batch-size", "-b", help="Batch size."),
|
batch_size: int = Opt(1, "--batch-size", "-b", help="Batch size."),
|
||||||
n_process: int = Opt(1, "--n-process", "-n", help="number of processors to use.")
|
n_process: int = Opt(1, "--n-process", "-n", help="number of processors to use.")
|
||||||
|
@ -91,6 +93,13 @@ def apply_cli(
|
||||||
a single document.
|
a single document.
|
||||||
DOCS: https://spacy.io/api/cli#tba
|
DOCS: https://spacy.io/api/cli#tba
|
||||||
"""
|
"""
|
||||||
|
data_path = ensure_path(data_path)
|
||||||
|
output_file = ensure_path(output_file)
|
||||||
|
code_path = ensure_path(code_path)
|
||||||
|
if output_file.exists() and not force_overwrite:
|
||||||
|
msg.fail(force_msg, exits=1)
|
||||||
|
if not data_path.exists():
|
||||||
|
msg.fail(f"Couldn't find data path: {data_path}", exits=1)
|
||||||
import_code(code_path)
|
import_code(code_path)
|
||||||
setup_gpu(use_gpu)
|
setup_gpu(use_gpu)
|
||||||
apply(data_path, output_file, model, json_field, batch_size, n_process)
|
apply(data_path, output_file, model, json_field, batch_size, n_process)
|
||||||
|
@ -104,10 +113,6 @@ def apply(
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
n_process: int,
|
n_process: int,
|
||||||
):
|
):
|
||||||
data_path = ensure_path(data_path)
|
|
||||||
output_file = ensure_path(output_file)
|
|
||||||
if not data_path.exists():
|
|
||||||
msg.fail("Couldn't find data path.", data_path, exits=1)
|
|
||||||
nlp = load_model(model)
|
nlp = load_model(model)
|
||||||
msg.good(f"Loaded model {model}")
|
msg.good(f"Loaded model {model}")
|
||||||
vocab = nlp.vocab
|
vocab = nlp.vocab
|
||||||
|
|
Loading…
Reference in New Issue
Block a user