update tests and bugfix

This commit is contained in:
kadarakos 2022-11-22 15:02:07 +00:00
parent 8cbf06c00e
commit 248388728b
2 changed files with 22 additions and 25 deletions

View File

@ -47,13 +47,15 @@ def _stream_jsonl(path: Path, field) -> Iterable[str]:
not found it raises error. not found it raises error.
""" """
for entry in srsly.read_jsonl(path): for entry in srsly.read_jsonl(path):
print(entry)
print(field)
if field not in entry: if field not in entry:
raise msg.fail( raise msg.fail(
f"{path} does not contain the required '{field}' field.", f"{path} does not contain the required '{field}' field.",
exits=1 exits=1
) )
else: else:
yield entry["text"] yield entry[field]
def _stream_texts(paths: Iterable[Path]) -> Iterable[str]: def _stream_texts(paths: Iterable[Path]) -> Iterable[str]:
@ -73,7 +75,7 @@ def apply_cli(
data_path: Path = Arg(..., help=path_help, exists=True), data_path: Path = Arg(..., help=path_help, exists=True),
output_file: Path = Arg(..., help=out_help, dir_okay=False), output_file: Path = Arg(..., help=out_help, dir_okay=False),
code_path: Optional[Path] = Opt(None, "--code", "-c", help=code_help), code_path: Optional[Path] = Opt(None, "--code", "-c", help=code_help),
field: str = Opt("text", "--field", "-f", help="Field to grab from .jsonl"), json_field: str = Opt("text", "--field", "-f", help="Field to grab from .jsonl"),
use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU."), use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU."),
batch_size: int = Opt(1, "--batch-size", "-b", help="Batch size."), batch_size: int = Opt(1, "--batch-size", "-b", help="Batch size."),
n_process: int = Opt(1, "--n-process", "-n", help="number of processors to use.") n_process: int = Opt(1, "--n-process", "-n", help="number of processors to use.")
@ -91,18 +93,19 @@ def apply_cli(
""" """
import_code(code_path) import_code(code_path)
setup_gpu(use_gpu) setup_gpu(use_gpu)
apply(data_path, output, model, batch_size, n_process) apply(data_path, output_file, model, json_field, batch_size, n_process)
def apply( def apply(
data_path: Path, data_path: Path,
output: Path, output_file: Path,
model: str, model: str,
json_field: str,
batch_size: int, batch_size: int,
n_process: int, n_process: int,
): ):
data_path = ensure_path(data_path) data_path = ensure_path(data_path)
output_path = ensure_path(output) output_file = ensure_path(output_file)
if not data_path.exists(): if not data_path.exists():
msg.fail("Couldn't find data path.", data_path, exits=1) msg.fail("Couldn't find data path.", data_path, exits=1)
nlp = load_model(model) nlp = load_model(model)
@ -116,7 +119,7 @@ def apply(
if path.suffix == ".spacy": if path.suffix == ".spacy":
streams.append(_stream_docbin(path, vocab)) streams.append(_stream_docbin(path, vocab))
elif path.suffix == ".jsonl": elif path.suffix == ".jsonl":
streams.append(_stream_jsonl(path, field)) streams.append(_stream_jsonl(path, json_field))
else: else:
text_files.append(path) text_files.append(path)
if len(text_files) > 0: if len(text_files) > 0:
@ -124,6 +127,6 @@ def apply(
datagen = cast(DocOrStrStream, chain(*streams)) datagen = cast(DocOrStrStream, chain(*streams))
for doc in tqdm.tqdm(nlp.pipe(datagen, batch_size=batch_size, n_process=n_process)): for doc in tqdm.tqdm(nlp.pipe(datagen, batch_size=batch_size, n_process=n_process)):
docbin.add(doc) docbin.add(doc)
if not output_file.endswith(".spacy"): if output_file.suffix == "":
output_file += ".spacy" output_file += ".spacy"
docbin.to_disk(output_file) docbin.to_disk(output_file)

View File

@ -864,7 +864,7 @@ def test_span_length_freq_dist_output_must_be_correct():
def test_applycli_empty_dir(): def test_applycli_empty_dir():
with make_tempdir() as data_path: with make_tempdir() as data_path:
output = data_path / "test.spacy" output = data_path / "test.spacy"
apply(data_path, output, "blank:en", 1, 1) apply(data_path, output, "blank:en", "text", 1, 1)
def test_applycli_docbin(): def test_applycli_docbin():
@ -875,35 +875,29 @@ def test_applycli_docbin():
# test empty DocBin case # test empty DocBin case
docbin = DocBin() docbin = DocBin()
docbin.to_disk(data_path / "testin.spacy") docbin.to_disk(data_path / "testin.spacy")
apply(data_path, output, "blank:en", 1, 1) apply(data_path, output, "blank:en", "text", 1, 1)
docbin.add(doc) docbin.add(doc)
docbin.to_disk(data_path / "testin.spacy") docbin.to_disk(data_path / "testin.spacy")
apply(data_path, output, "blank:en", 1, 1) apply(data_path, output, "blank:en", "text", 1, 1)
def test_applycli_jsonl(): def test_applycli_jsonl():
with make_tempdir() as data_path: with make_tempdir() as data_path:
output = data_path / "testout.spacy" output = data_path / "testout.spacy"
data = [{"text": "Testing apply cli.", "key": 234}] data = [{"field": "Testing apply cli.", "key": 234}]
srsly.write_jsonl(data_path / "test.jsonl", data) srsly.write_jsonl(data_path / "test.jsonl", data)
apply(data_path, output, "blank:en", 1, 1) apply(data_path, output, "blank:en", "field", 1, 1)
data = [{"key": 234}] data = [{"field": "234"}]
srsly.write_jsonl(data_path / "test2.jsonl", data) srsly.write_jsonl(data_path / "test2.jsonl", data)
# test no "text" field case apply(data_path, output, "blank:en", "field", 1, 1)
with pytest.raises(ValueError, match="test2.jsonl"):
apply(data_path, output, "blank:en", 1, 1)
def test_applycli_txt(): def test_applycli_txt():
with make_tempdir() as data_path: with make_tempdir() as data_path:
output = data_path / "testout.spacy" output = data_path / "testout.spacy"
data = [{"text": "Testing apply cli.", "key": 234}] with open(data_path / "test.foo", "w") as ftest:
srsly.write_jsonl(data_path / "test.jsonl", data) ftest.write("Testing apply cli.")
apply(data_path, output, "blank:en", 1, 1) apply(data_path, output, "blank:en", "text", 1, 1)
data = [{"key": 234}]
srsly.write_jsonl(data_path / "test2.jsonl", data)
with pytest.raises(ValueError, match="test2.jsonl"):
apply(data_path, output, "blank:en", 1, 1)
def test_applycli_mixed(): def test_applycli_mixed():
@ -919,7 +913,7 @@ def test_applycli_mixed():
docbin.to_disk(data_path / "testin.spacy") docbin.to_disk(data_path / "testin.spacy")
with open(data_path / "test.txt", "w") as ftest: with open(data_path / "test.txt", "w") as ftest:
ftest.write(text) ftest.write(text)
apply(data_path, output, "blank:en", 1, 1) apply(data_path, output, "blank:en", "text", 1, 1)
# Check whether it worked # Check whether it worked
result = list(DocBin().from_disk(output).get_docs(nlp.vocab)) result = list(DocBin().from_disk(output).get_docs(nlp.vocab))
assert len(result) == 3 assert len(result) == 3