support jsonl

This commit is contained in:
kadarakos 2022-09-13 13:43:25 +00:00
parent fd7e299967
commit 8030393ecc

View File

@ -1,7 +1,9 @@
import tqdm
import srsly
from itertools import chain
from pathlib import Path
from typing import Optional, Generator, Union
from typing import Optional, Generator, Union, List
from wasabi import msg
@ -24,34 +26,31 @@ code_help = ("Path to Python file with additional "
gold_help = "Use gold preprocessing provided in the .spacy files"
def _stream_file(path: Path, vocab: Vocab) -> Generator[Union[Doc, str], None, None]:
def _stream_docbin(path: Path, vocab: Vocab) -> Generator[Doc, None, None]:
"""
Stream data from a single file. If the path points to
a .spacy file then yield from the DocBin otherwise
yield each line of a text file. If a decoding error
is encountered during reading the file exit.
Stream Doc objects from DocBin.
"""
if not path.is_dir():
# Yield from DocBin.
if path.suffix == ".spacy":
docbin = DocBin().from_disk(path)
for doc in docbin.get_docs(vocab):
yield doc
# Yield from text file
input(path)
docbin = DocBin().from_disk(path)
for doc in docbin.get_docs(vocab):
yield doc
def _stream_jsonl(path: Path) -> Generator[str, None, None]:
"""
Stream "text" field from JSONL. If the field "text" is
not found it raises error.
"""
for entry in srsly.read_jsonl(path):
if "text" not in entry:
raise ValueError(
"JSONL files have to contain 'text' field."
)
else:
try:
with open(path, 'r') as fin:
for line in fin:
yield line
except UnicodeDecodeError as e:
print(e)
msg.warn(
f"{path} could not be decoded.",
exits=True
)
yield entry["text"]
def _maybe_read(path: Path) -> Union[str, None]:
def _maybe_read_text(path: Path) -> Union[str, None]:
"""
Try to read the text file from the provided path.
When encoutering a decoding error just warn and pass.
@ -66,6 +65,14 @@ def _maybe_read(path: Path) -> Union[str, None]:
return None
def _stream_texts(paths: List[Path]) -> Generator[Union[str, None], None, None]:
"""
Yields strings or None when decoding error is encountered.
"""
for path in paths:
yield _maybe_read_text(path)
@app.command("apply")
def apply_cli(
# fmt: off
@ -75,8 +82,7 @@ def apply_cli(
code_path: Optional[Path] = Opt(None, "--code", "-c", help=code_help),
use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU."),
batch_size: int = Opt(1, "--batch-size", "-b", help="Batch size."),
n_process: int = Opt(1, "--n-process", "-n", help="number of processors to use."),
suffix: str = Opt("", "--suffix", "-n", help="Only read files with file.suffix.")
n_process: int = Opt(1, "--n-process", "-n", help="number of processors to use.")
):
"""
Apply a trained pipeline to documents to get predictions.
@ -89,17 +95,9 @@ def apply_cli(
DOCS: https://spacy.io/api/cli#tba
"""
if data_path.is_dir() and suffix == "":
raise ValueError(
"When the provided 'data_path' is a directory "
"the --suffix argument has to be provided as well."
)
if suffix is not None:
if not suffix.startswith("."):
suffix = "." + suffix
import_code(code_path)
setup_gpu(use_gpu)
apply(data_path, output, model, batch_size, n_process, suffix)
apply(data_path, output, model, batch_size, n_process)
def apply(
@ -108,7 +106,6 @@ def apply(
model: str,
batch_size: int,
n_process: int,
suffix: str
):
data_path = ensure_path(data_path)
output_path = ensure_path(output)
@ -123,14 +120,21 @@ def apply(
chain[Union[Doc, str]],
filter[str]
]
if not data_path.is_dir():
datagen = _stream_file(data_path, vocab)
else:
paths = walk_directory(data_path, suffix)
if suffix == ".spacy":
datagen = chain(*[_stream_file(path, vocab) for path in paths])
paths = walk_directory(data_path)
streams = []
text_files = []
for path in paths:
if path.suffix == ".spacy":
stream = _stream_docbin(path, vocab)
streams.append(_stream_docbin(path, vocab))
elif path.suffix == ".jsonl":
streams.append(_stream_jsonl(path))
else:
datagen = filter(None, (_maybe_read(path) for path in paths))
text_files.append(path)
if len(text_files) > 0:
stream = filter(None, _stream_texts(text_files))
streams.append(stream)
datagen = chain(*streams)
for doc in tqdm.tqdm(nlp.pipe(datagen, batch_size=batch_size, n_process=n_process)):
docbin.add(doc)
if output_path.is_dir():