handle file suffixes

This commit is contained in:
kadarakos 2022-09-06 15:24:39 +00:00
parent 7a72491b78
commit e84b295279

View File

@ -25,51 +25,50 @@ gold_help = "Use gold preprocessing provided in the .spacy files"
def _stream_data( def _stream_data(
data_path: Path, data_path: Path,
vocab: Vocab vocab: Vocab,
suffix: Optional[str] = None
) -> Generator[Union[str, Doc], None, None]: ) -> Generator[Union[str, Doc], None, None]:
""" """
Load data which is either in a single file Load data which is either in a single file
in .spacy or plain text format or multiple in .spacy or plain text format or multiple
text files in a directory. text files in a directory. If a directory
is provided skip subdirectories and undecodeable
files.
""" """
# XXX I know that we have it in the developer guidelines
# to don't try/except, but I thought its appropriate here.
# because we are not sure exactly what input we are getting.
if not data_path.is_dir(): if not data_path.is_dir():
# Yield from DocBin. # Yield from DocBin.
try: if data_path.suffix == ".spacy":
docbin = DocBin().from_disk(data_path) docbin = DocBin().from_disk(data_path)
for doc in docbin.get_docs(vocab): for doc in docbin.get_docs(vocab):
yield doc yield doc
# Yield from text file. # Yield from text file
except ValueError: else:
try: try:
with open(data_path, 'r') as fin: with open(data_path, 'r') as fin:
for line in fin: for line in fin:
yield line yield line
except UnicodeDecodeError: except UnicodeDecodeError as e:
print( print(e)
f"file {data_path} does not seem " msg.warn(
"to be a plain text file" f"{data_path} could not be decoded.",
exits=True
) )
sys.exit()
else: else:
# Yield per one file in directory # Yield per one file in directory
for path in data_path.iterdir(): for path in data_path.iterdir():
if path.is_dir(): if path.is_dir():
raise ValueError( msg.warn(f"Skipping directory {path}")
"All files should be text files." elif suffix is not None and path.suffix != suffix:
) print(suffix, path.suffix)
with open(path, 'r') as fin: msg.warn(f"Skipping file {path}")
try: else:
text = fin.read() with open(path, 'r') as fin:
yield text try:
except UnicodeDecodeError: text = fin.read()
print( yield text
f"file {path} does not seem " except UnicodeDecodeError as e:
"to be a plain text file" msg.warn(f"Skipping file {path}")
) print(e)
sys.exit()
@app.command("apply") @app.command("apply")
@ -79,9 +78,10 @@ def apply_cli(
data_path: Path = Arg(..., help=path_help, exists=True), data_path: Path = Arg(..., help=path_help, exists=True),
output: Path = Arg(..., help=out_help, dir_okay=False), output: Path = Arg(..., help=out_help, dir_okay=False),
code_path: Optional[Path] = Opt(None, "--code", "-c", help=code_help), code_path: Optional[Path] = Opt(None, "--code", "-c", help=code_help),
use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU"), use_gpu: Optional[int] = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU."),
batch_size: int = Opt(1, "--batch-size", "-b", help="Batch size"), batch_size: Optional[int] = Opt(1, "--batch-size", "-b", help="Batch size."),
n_process: int = Opt(1, "--n-process", "-n", help="Number of processors to use") n_process: Optional[int] = Opt(1, "--n-process", "-n", help="number of processors to use."),
suffix: Optional[str] = Opt(None, "--suffix", "-n", help="Only read files with file.suffix.")
): ):
""" """
Apply a trained pipeline to documents to get predictions. Apply a trained pipeline to documents to get predictions.
@ -92,9 +92,12 @@ def apply_cli(
DOCS: https://spacy.io/api/cli#tba DOCS: https://spacy.io/api/cli#tba
""" """
if suffix is not None:
if not suffix.startswith("."):
suffix = "." + suffix
import_code(code_path) import_code(code_path)
setup_gpu(use_gpu) setup_gpu(use_gpu)
apply(data_path, output, model, batch_size, n_process) apply(data_path, output, model, batch_size, n_process, suffix)
def apply( def apply(
@ -102,7 +105,8 @@ def apply(
output: Path, output: Path,
model: str, model: str,
batch_size: int, batch_size: int,
n_process: int n_process: int,
suffix: Optional[str]
): ):
data_path = util.ensure_path(data_path) data_path = util.ensure_path(data_path)
output_path = util.ensure_path(output) output_path = util.ensure_path(output)
@ -112,7 +116,7 @@ def apply(
msg.good(f"Loaded model {model}") msg.good(f"Loaded model {model}")
vocab = nlp.vocab vocab = nlp.vocab
docbin = DocBin() docbin = DocBin()
datagen = _stream_data(data_path, vocab) datagen = _stream_data(data_path, vocab, suffix)
for doc in tqdm.tqdm(nlp.pipe(datagen, batch_size=batch_size, n_process=n_process)): for doc in tqdm.tqdm(nlp.pipe(datagen, batch_size=batch_size, n_process=n_process)):
docbin.add(doc) docbin.add(doc)
if output_path.is_dir(): if output_path.is_dir():