move AUTO check to convert and fix verification of args

This commit is contained in:
svlandeg 2023-01-04 14:08:08 +01:00
parent d56139b082
commit eaf74b24d8
3 changed files with 25 additions and 15 deletions

View File

@ -29,8 +29,6 @@ if TYPE_CHECKING:
SDIST_SUFFIX = ".tar.gz"
WHEEL_SUFFIX = "-py3-none-any.whl"
AUTO = "auto"
PROJECT_FILE = "project.yml"
PROJECT_LOCK = "project.lock"
COMMAND = "python -m spacy"
@ -585,6 +583,10 @@ def setup_gpu(use_gpu: int, silent=None) -> None:
def walk_directory(path: Path, suffix: Optional[str] = None) -> List[Path]:
"""Given a directory and a suffix, recursively find all files matching the suffix.
Directories or files with names beginning with a . are ignored, but hidden flags on
filesystems are not checked.
When provided with a suffix `None`, all files are returned without filtering."""
if not path.is_dir():
return [path]
paths = [path]
@ -598,8 +600,6 @@ def walk_directory(path: Path, suffix: Optional[str] = None) -> List[Path]:
continue
elif path.is_dir():
paths.extend(path.iterdir())
elif suffix == AUTO:
locs.append(path)
elif suffix is not None and not path.parts[-1].endswith(suffix):
continue
else:

View File

@ -7,7 +7,7 @@ import re
import sys
import itertools
from ._util import app, Arg, Opt, walk_directory, AUTO
from ._util import app, Arg, Opt, walk_directory
from ..training import docs_to_json
from ..tokens import Doc, DocBin
from ..training.converters import iob_to_docs, conll_ner_to_docs, json_to_docs
@ -28,6 +28,8 @@ CONVERTERS: Mapping[str, Callable[..., Iterable[Doc]]] = {
"json": json_to_docs,
}
AUTO = "auto"
# File types that can be written to stdout
FILE_TYPES_STDOUT = ("json",)
@ -66,12 +68,16 @@ def convert_cli(
DOCS: https://spacy.io/api/cli#convert
"""
print("CONVERT")
input_path = Path(input_path)
output_dir: Union[str, Path] = "-" if output_dir == Path("-") else output_dir
silent = output_dir == "-"
msg = Printer(no_print=silent)
verify_cli_args(msg, input_path, output_dir, file_type.value, converter, ner_map)
print("BEFORE", converter)
converter = _get_converter(msg, converter, input_path)
print("AFTER", converter)
print("VERIFYING")
verify_cli_args(msg, input_path, output_dir, file_type.value, converter, ner_map)
convert(
input_path,
output_dir,
@ -100,7 +106,7 @@ def convert(
model: Optional[str] = None,
morphology: bool = False,
merge_subtokens: bool = False,
converter: str = AUTO,
converter: str,
ner_map: Optional[Path] = None,
lang: Optional[str] = None,
concatenate: bool = False,
@ -212,17 +218,21 @@ def verify_cli_args(
input_locs = walk_directory(input_path, converter)
if len(input_locs) == 0:
msg.fail("No input files in directory", input_path, exits=1)
file_types = list(set([loc.suffix[1:] for loc in input_locs]))
if converter == AUTO and len(file_types) >= 2:
file_types_str = ",".join(file_types)
msg.fail("All input files must be same type", file_types_str, exits=1)
if converter != AUTO and converter not in CONVERTERS:
if converter not in CONVERTERS:
msg.fail(f"Can't find converter for {converter}", exits=1)
def _get_converter(msg, converter, input_path: Path):
if input_path.is_dir():
input_path = walk_directory(input_path, converter)[0]
if converter == AUTO:
input_locs = walk_directory(input_path, suffix=None)
file_types = list(set([loc.suffix[1:] for loc in input_locs]))
if len(file_types) >= 2:
file_types_str = ",".join(file_types)
msg.fail("All input files must be same type", file_types_str, exits=1)
input_path = input_locs[0]
else:
input_path = walk_directory(input_path, suffix=converter)[0]
if converter == AUTO:
converter = input_path.suffix[1:]
if converter == "ner" or converter == "iob":
@ -241,4 +251,5 @@ def _get_converter(msg, converter, input_path: Path):
"Conversion may not succeed. "
"See https://spacy.io/api/cli#convert"
)
print("got convertor", converter)
return converter

View File

@ -20,7 +20,7 @@ from spacy.cli._util import is_subpath_of, load_project_config, walk_directory
from spacy.cli._util import parse_config_overrides, string_to_list
from spacy.cli._util import substitute_project_variables
from spacy.cli._util import validate_project_commands
from spacy.cli._util import upload_file, download_file, AUTO
from spacy.cli._util import upload_file, download_file
from spacy.cli.debug_data import _compile_gold, _get_labels_from_model
from spacy.cli.debug_data import _get_labels_from_spancat
from spacy.cli.debug_data import _get_distribution, _get_kl_divergence
@ -1209,4 +1209,3 @@ def test_walk_directory():
assert (len(walk_directory(d, suffix="iob"))) == 2
assert (len(walk_directory(d, suffix="conll"))) == 3
assert (len(walk_directory(d, suffix="pdf"))) == 0
assert (len(walk_directory(d, suffix=AUTO))) == 7