From fd7e2999674a48857b3de526e8847cf5c0df5882 Mon Sep 17 00:00:00 2001 From: kadarakos Date: Wed, 7 Sep 2022 11:41:22 +0000 Subject: [PATCH] walk directories --- spacy/cli/_util.py | 28 ++++++++++- spacy/cli/apply.py | 110 ++++++++++++++++++++++++------------------- spacy/cli/convert.py | 31 +----------- 3 files changed, 91 insertions(+), 78 deletions(-) diff --git a/spacy/cli/_util.py b/spacy/cli/_util.py index ae43b991b..84e611841 100644 --- a/spacy/cli/_util.py +++ b/spacy/cli/_util.py @@ -404,7 +404,10 @@ def git_checkout( if not is_subpath_of(tmp_dir, source_path): err = f"'{subpath}' is a path outside of the cloned repository." msg.fail(err, repo, exits=1) - shutil.copytree(str(source_path), str(dest)) + if source_path.is_dir(): + shutil.copytree(str(source_path), str(dest)) + else: + shutil.copy(str(source_path), str(dest)) except FileNotFoundError: err = f"Can't clone {subpath}. Make sure the directory exists in the repo (branch '{branch}')" msg.fail(err, repo, exits=1) @@ -573,3 +576,26 @@ def setup_gpu(use_gpu: int, silent=None) -> None: local_msg.info("Using CPU") if gpu_is_available(): local_msg.info("To switch to GPU 0, use the option: --gpu-id 0") + + +def walk_directory(path: Path, suffix: str) -> List[Path]: + if not path.is_dir(): + return [path] + paths = [path] + locs = [] + seen = set() + for path in paths: + if str(path) in seen: + continue + seen.add(str(path)) + if path.parts[-1].startswith("."): + continue + elif path.is_dir(): + paths.extend(path.iterdir()) + elif not path.parts[-1].endswith(suffix): + continue + else: + locs.append(path) + # It's good to sort these, in case the ordering messes up cache. + locs.sort() + return locs diff --git a/spacy/cli/apply.py b/spacy/cli/apply.py index 692f78e9e..26f4e2907 100644 --- a/spacy/cli/apply.py +++ b/spacy/cli/apply.py @@ -1,14 +1,15 @@ import tqdm -import sys - -from ._util import app, Arg, Opt, setup_gpu, import_code -from typing import Optional, Generator, Union +from itertools import chain from pathlib import Path +from typing import Optional, Generator, Union + from wasabi import msg +from ._util import app, Arg, Opt, setup_gpu, import_code, walk_directory + from ..tokens import Doc, DocBin from ..vocab import Vocab -from .. import util +from ..util import ensure_path, load_model path_help = ("Location of the documents to predict on." @@ -23,52 +24,46 @@ code_help = ("Path to Python file with additional " gold_help = "Use gold preprocessing provided in the .spacy files" -def _stream_data( - data_path: Path, - vocab: Vocab, - suffix: Optional[str] = None -) -> Generator[Union[str, Doc], None, None]: +def _stream_file(path: Path, vocab: Vocab) -> Generator[Union[Doc, str], None, None]: """ - Load data which is either in a single file - in .spacy or plain text format or multiple - text files in a directory. If a directory - is provided skip subdirectories and undecodeable - files. + Stream data from a single file. If the path points to + a .spacy file then yield from the DocBin otherwise + yield each line of a text file. If a decoding error + is encountered during reading the file exit. """ - if not data_path.is_dir(): + if not path.is_dir(): # Yield from DocBin. - if data_path.suffix == ".spacy": - docbin = DocBin().from_disk(data_path) + if path.suffix == ".spacy": + docbin = DocBin().from_disk(path) for doc in docbin.get_docs(vocab): yield doc # Yield from text file else: try: - with open(data_path, 'r') as fin: + with open(path, 'r') as fin: for line in fin: yield line except UnicodeDecodeError as e: print(e) msg.warn( - f"{data_path} could not be decoded.", + f"{path} could not be decoded.", exits=True ) - else: - # Yield per one file in directory - for path in data_path.iterdir(): - if path.is_dir(): - msg.warn(f"Skipping directory {path}") - elif suffix is not None and path.suffix != suffix: - print(suffix, path.suffix) - msg.warn(f"Skipping file {path}") - else: - with open(path, 'r') as fin: - try: - text = fin.read() - yield text - except UnicodeDecodeError as e: - msg.warn(f"Skipping file {path}") - print(e) + + +def _maybe_read(path: Path) -> Union[str, None]: + """ + Try to read the text file from the provided path. + When encoutering a decoding error just warn and pass. + """ + with open(path, 'r') as fin: + try: + text = fin.read() + return text + except UnicodeDecodeError as e: + msg.warn(f"Skipping file {path}") + print(e) + return None @app.command("apply") @@ -78,20 +73,27 @@ def apply_cli( data_path: Path = Arg(..., help=path_help, exists=True), output: Path = Arg(..., help=out_help, dir_okay=False), code_path: Optional[Path] = Opt(None, "--code", "-c", help=code_help), - use_gpu: Optional[int] = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU."), - batch_size: Optional[int] = Opt(1, "--batch-size", "-b", help="Batch size."), - n_process: Optional[int] = Opt(1, "--n-process", "-n", help="number of processors to use."), - suffix: Optional[str] = Opt(None, "--suffix", "-n", help="Only read files with file.suffix.") + use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU."), + batch_size: int = Opt(1, "--batch-size", "-b", help="Batch size."), + n_process: int = Opt(1, "--n-process", "-n", help="number of processors to use."), + suffix: str = Opt("", "--suffix", "-n", help="Only read files with file.suffix.") ): """ Apply a trained pipeline to documents to get predictions. Expects a loadable spaCy pipeline and some data as input. - The input can be provided multiple formats. It can be a .spacy - file, a single text file with one document per line or a directory - where each file is assumed to be plain text document. + The data can be provided multiple formats. It can be a single + .spacy file or a single text file with one document per line. + A directory can also be provided in which case the 'suffix' + argument is required. All paths pointing to files with the + provided suffix will be recursively collected and processed. DOCS: https://spacy.io/api/cli#tba """ + if data_path.is_dir() and suffix == "": + raise ValueError( + "When the provided 'data_path' is a directory " + "the --suffix argument has to be provided as well." + ) if suffix is not None: if not suffix.startswith("."): suffix = "." + suffix @@ -106,17 +108,29 @@ def apply( model: str, batch_size: int, n_process: int, - suffix: Optional[str] + suffix: str ): - data_path = util.ensure_path(data_path) - output_path = util.ensure_path(output) + data_path = ensure_path(data_path) + output_path = ensure_path(output) if not data_path.exists(): msg.fail("Couldn't find data path.", data_path, exits=1) - nlp = util.load_model(model) + nlp = load_model(model) msg.good(f"Loaded model {model}") vocab = nlp.vocab docbin = DocBin() - datagen = _stream_data(data_path, vocab, suffix) + datagen: Union[ + Generator[Union[Doc, str], None, None], + chain[Union[Doc, str]], + filter[str] + ] + if not data_path.is_dir(): + datagen = _stream_file(data_path, vocab) + else: + paths = walk_directory(data_path, suffix) + if suffix == ".spacy": + datagen = chain(*[_stream_file(path, vocab) for path in paths]) + else: + datagen = filter(None, (_maybe_read(path) for path in paths)) for doc in tqdm.tqdm(nlp.pipe(datagen, batch_size=batch_size, n_process=n_process)): docbin.add(doc) if output_path.is_dir(): diff --git a/spacy/cli/convert.py b/spacy/cli/convert.py index 04eb7078f..7f365ae2c 100644 --- a/spacy/cli/convert.py +++ b/spacy/cli/convert.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, Mapping, Optional, Any, List, Union +from typing import Callable, Iterable, Mapping, Optional, Any, Union from enum import Enum from pathlib import Path from wasabi import Printer @@ -7,7 +7,7 @@ import re import sys import itertools -from ._util import app, Arg, Opt +from ._util import app, Arg, Opt, walk_directory from ..training import docs_to_json from ..tokens import Doc, DocBin from ..training.converters import iob_to_docs, conll_ner_to_docs, json_to_docs @@ -189,33 +189,6 @@ def autodetect_ner_format(input_data: str) -> Optional[str]: return None -def walk_directory(path: Path, converter: str) -> List[Path]: - if not path.is_dir(): - return [path] - paths = [path] - locs = [] - seen = set() - for path in paths: - if str(path) in seen: - continue - seen.add(str(path)) - if path.parts[-1].startswith("."): - continue - elif path.is_dir(): - paths.extend(path.iterdir()) - elif converter == "json" and not path.parts[-1].endswith("json"): - continue - elif converter == "conll" and not path.parts[-1].endswith("conll"): - continue - elif converter == "iob" and not path.parts[-1].endswith("iob"): - continue - else: - locs.append(path) - # It's good to sort these, in case the ordering messes up cache. - locs.sort() - return locs - - def verify_cli_args( msg: Printer, input_path: Path,