mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-06 05:10:21 +03:00
update tests and bugfix
This commit is contained in:
parent
8cbf06c00e
commit
248388728b
|
@ -47,13 +47,15 @@ def _stream_jsonl(path: Path, field) -> Iterable[str]:
|
||||||
not found it raises error.
|
not found it raises error.
|
||||||
"""
|
"""
|
||||||
for entry in srsly.read_jsonl(path):
|
for entry in srsly.read_jsonl(path):
|
||||||
|
print(entry)
|
||||||
|
print(field)
|
||||||
if field not in entry:
|
if field not in entry:
|
||||||
raise msg.fail(
|
raise msg.fail(
|
||||||
f"{path} does not contain the required '{field}' field.",
|
f"{path} does not contain the required '{field}' field.",
|
||||||
exits=1
|
exits=1
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield entry["text"]
|
yield entry[field]
|
||||||
|
|
||||||
|
|
||||||
def _stream_texts(paths: Iterable[Path]) -> Iterable[str]:
|
def _stream_texts(paths: Iterable[Path]) -> Iterable[str]:
|
||||||
|
@ -73,7 +75,7 @@ def apply_cli(
|
||||||
data_path: Path = Arg(..., help=path_help, exists=True),
|
data_path: Path = Arg(..., help=path_help, exists=True),
|
||||||
output_file: Path = Arg(..., help=out_help, dir_okay=False),
|
output_file: Path = Arg(..., help=out_help, dir_okay=False),
|
||||||
code_path: Optional[Path] = Opt(None, "--code", "-c", help=code_help),
|
code_path: Optional[Path] = Opt(None, "--code", "-c", help=code_help),
|
||||||
field: str = Opt("text", "--field", "-f", help="Field to grab from .jsonl"),
|
json_field: str = Opt("text", "--field", "-f", help="Field to grab from .jsonl"),
|
||||||
use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU."),
|
use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU."),
|
||||||
batch_size: int = Opt(1, "--batch-size", "-b", help="Batch size."),
|
batch_size: int = Opt(1, "--batch-size", "-b", help="Batch size."),
|
||||||
n_process: int = Opt(1, "--n-process", "-n", help="number of processors to use.")
|
n_process: int = Opt(1, "--n-process", "-n", help="number of processors to use.")
|
||||||
|
@ -91,18 +93,19 @@ def apply_cli(
|
||||||
"""
|
"""
|
||||||
import_code(code_path)
|
import_code(code_path)
|
||||||
setup_gpu(use_gpu)
|
setup_gpu(use_gpu)
|
||||||
apply(data_path, output, model, batch_size, n_process)
|
apply(data_path, output_file, model, json_field, batch_size, n_process)
|
||||||
|
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
data_path: Path,
|
data_path: Path,
|
||||||
output: Path,
|
output_file: Path,
|
||||||
model: str,
|
model: str,
|
||||||
|
json_field: str,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
n_process: int,
|
n_process: int,
|
||||||
):
|
):
|
||||||
data_path = ensure_path(data_path)
|
data_path = ensure_path(data_path)
|
||||||
output_path = ensure_path(output)
|
output_file = ensure_path(output_file)
|
||||||
if not data_path.exists():
|
if not data_path.exists():
|
||||||
msg.fail("Couldn't find data path.", data_path, exits=1)
|
msg.fail("Couldn't find data path.", data_path, exits=1)
|
||||||
nlp = load_model(model)
|
nlp = load_model(model)
|
||||||
|
@ -116,7 +119,7 @@ def apply(
|
||||||
if path.suffix == ".spacy":
|
if path.suffix == ".spacy":
|
||||||
streams.append(_stream_docbin(path, vocab))
|
streams.append(_stream_docbin(path, vocab))
|
||||||
elif path.suffix == ".jsonl":
|
elif path.suffix == ".jsonl":
|
||||||
streams.append(_stream_jsonl(path, field))
|
streams.append(_stream_jsonl(path, json_field))
|
||||||
else:
|
else:
|
||||||
text_files.append(path)
|
text_files.append(path)
|
||||||
if len(text_files) > 0:
|
if len(text_files) > 0:
|
||||||
|
@ -124,6 +127,6 @@ def apply(
|
||||||
datagen = cast(DocOrStrStream, chain(*streams))
|
datagen = cast(DocOrStrStream, chain(*streams))
|
||||||
for doc in tqdm.tqdm(nlp.pipe(datagen, batch_size=batch_size, n_process=n_process)):
|
for doc in tqdm.tqdm(nlp.pipe(datagen, batch_size=batch_size, n_process=n_process)):
|
||||||
docbin.add(doc)
|
docbin.add(doc)
|
||||||
if not output_file.endswith(".spacy"):
|
if output_file.suffix == "":
|
||||||
output_file += ".spacy"
|
output_file += ".spacy"
|
||||||
docbin.to_disk(output_file)
|
docbin.to_disk(output_file)
|
||||||
|
|
|
@ -864,7 +864,7 @@ def test_span_length_freq_dist_output_must_be_correct():
|
||||||
def test_applycli_empty_dir():
|
def test_applycli_empty_dir():
|
||||||
with make_tempdir() as data_path:
|
with make_tempdir() as data_path:
|
||||||
output = data_path / "test.spacy"
|
output = data_path / "test.spacy"
|
||||||
apply(data_path, output, "blank:en", 1, 1)
|
apply(data_path, output, "blank:en", "text", 1, 1)
|
||||||
|
|
||||||
|
|
||||||
def test_applycli_docbin():
|
def test_applycli_docbin():
|
||||||
|
@ -875,35 +875,29 @@ def test_applycli_docbin():
|
||||||
# test empty DocBin case
|
# test empty DocBin case
|
||||||
docbin = DocBin()
|
docbin = DocBin()
|
||||||
docbin.to_disk(data_path / "testin.spacy")
|
docbin.to_disk(data_path / "testin.spacy")
|
||||||
apply(data_path, output, "blank:en", 1, 1)
|
apply(data_path, output, "blank:en", "text", 1, 1)
|
||||||
docbin.add(doc)
|
docbin.add(doc)
|
||||||
docbin.to_disk(data_path / "testin.spacy")
|
docbin.to_disk(data_path / "testin.spacy")
|
||||||
apply(data_path, output, "blank:en", 1, 1)
|
apply(data_path, output, "blank:en", "text", 1, 1)
|
||||||
|
|
||||||
|
|
||||||
def test_applycli_jsonl():
|
def test_applycli_jsonl():
|
||||||
with make_tempdir() as data_path:
|
with make_tempdir() as data_path:
|
||||||
output = data_path / "testout.spacy"
|
output = data_path / "testout.spacy"
|
||||||
data = [{"text": "Testing apply cli.", "key": 234}]
|
data = [{"field": "Testing apply cli.", "key": 234}]
|
||||||
srsly.write_jsonl(data_path / "test.jsonl", data)
|
srsly.write_jsonl(data_path / "test.jsonl", data)
|
||||||
apply(data_path, output, "blank:en", 1, 1)
|
apply(data_path, output, "blank:en", "field", 1, 1)
|
||||||
data = [{"key": 234}]
|
data = [{"field": "234"}]
|
||||||
srsly.write_jsonl(data_path / "test2.jsonl", data)
|
srsly.write_jsonl(data_path / "test2.jsonl", data)
|
||||||
# test no "text" field case
|
apply(data_path, output, "blank:en", "field", 1, 1)
|
||||||
with pytest.raises(ValueError, match="test2.jsonl"):
|
|
||||||
apply(data_path, output, "blank:en", 1, 1)
|
|
||||||
|
|
||||||
|
|
||||||
def test_applycli_txt():
|
def test_applycli_txt():
|
||||||
with make_tempdir() as data_path:
|
with make_tempdir() as data_path:
|
||||||
output = data_path / "testout.spacy"
|
output = data_path / "testout.spacy"
|
||||||
data = [{"text": "Testing apply cli.", "key": 234}]
|
with open(data_path / "test.foo", "w") as ftest:
|
||||||
srsly.write_jsonl(data_path / "test.jsonl", data)
|
ftest.write("Testing apply cli.")
|
||||||
apply(data_path, output, "blank:en", 1, 1)
|
apply(data_path, output, "blank:en", "text", 1, 1)
|
||||||
data = [{"key": 234}]
|
|
||||||
srsly.write_jsonl(data_path / "test2.jsonl", data)
|
|
||||||
with pytest.raises(ValueError, match="test2.jsonl"):
|
|
||||||
apply(data_path, output, "blank:en", 1, 1)
|
|
||||||
|
|
||||||
|
|
||||||
def test_applycli_mixed():
|
def test_applycli_mixed():
|
||||||
|
@ -919,7 +913,7 @@ def test_applycli_mixed():
|
||||||
docbin.to_disk(data_path / "testin.spacy")
|
docbin.to_disk(data_path / "testin.spacy")
|
||||||
with open(data_path / "test.txt", "w") as ftest:
|
with open(data_path / "test.txt", "w") as ftest:
|
||||||
ftest.write(text)
|
ftest.write(text)
|
||||||
apply(data_path, output, "blank:en", 1, 1)
|
apply(data_path, output, "blank:en", "text", 1, 1)
|
||||||
# Check whether it worked
|
# Check whether it worked
|
||||||
result = list(DocBin().from_disk(output).get_docs(nlp.vocab))
|
result = list(DocBin().from_disk(output).get_docs(nlp.vocab))
|
||||||
assert len(result) == 3
|
assert len(result) == 3
|
||||||
|
|
Loading…
Reference in New Issue
Block a user