don't warn but raise

This commit is contained in:
kadarakos 2022-10-26 14:01:17 +00:00
parent 91e72f8abe
commit 9b404ea33c

View File

@ -3,7 +3,7 @@ import srsly
from itertools import chain from itertools import chain
from pathlib import Path from pathlib import Path
from typing import Optional, Generator, Union, List, Iterable, cast from typing import Optional, List, Iterable, cast, Union
from wasabi import msg from wasabi import msg
@ -25,6 +25,8 @@ code_help = ("Path to Python file with additional "
"code (registered functions) to be imported") "code (registered functions) to be imported")
gold_help = "Use gold preprocessing provided in the .spacy files" gold_help = "Use gold preprocessing provided in the .spacy files"
DocOrStrStream = Union[Iterable[str], Iterable[Doc]]
def _stream_docbin(path: Path, vocab: Vocab) -> Iterable[Doc]: def _stream_docbin(path: Path, vocab: Vocab) -> Iterable[Doc]:
""" """
@ -49,26 +51,14 @@ def _stream_jsonl(path: Path) -> Iterable[str]:
yield entry["text"] yield entry["text"]
def _maybe_read_text(path: Path) -> Union[str, None]: def _stream_texts(paths: Iterable[Path]) -> Iterable[str]:
""" """
Try to read the text file from the provided path. Yields strings from text files in paths.
When encoutering a decoding error just warn and pass.
"""
with open(path, 'r') as fin:
try:
text = fin.read()
return text
except UnicodeDecodeError as e:
msg.warn(f"Skipping file {path}")
return None
def _stream_texts(paths: List[Path]) -> Iterable[Union[str, None]]:
"""
Yields strings or None when decoding error is encountered.
""" """
for path in paths: for path in paths:
yield _maybe_read_text(path) with open(path, 'r') as fin:
text = fin.read()
yield text
@app.command("apply") @app.command("apply")
@ -114,7 +104,7 @@ def apply(
vocab = nlp.vocab vocab = nlp.vocab
docbin = DocBin() docbin = DocBin()
paths = walk_directory(data_path) paths = walk_directory(data_path)
streams: List[Union[Iterable[str], Iterable[Doc]]] = [] streams: List[DocOrStrStream] = []
text_files = [] text_files = []
for path in paths: for path in paths:
if path.suffix == ".spacy": if path.suffix == ".spacy":
@ -124,8 +114,8 @@ def apply(
else: else:
text_files.append(path) text_files.append(path)
if len(text_files) > 0: if len(text_files) > 0:
streams.append(filter(None, _stream_texts(text_files))) streams.append(_stream_texts(text_files))
datagen = cast(Iterable[Union[str, Doc]], chain(*streams)) datagen = cast(DocOrStrStream, chain(*streams))
for doc in tqdm.tqdm(nlp.pipe(datagen, batch_size=batch_size, n_process=n_process)): for doc in tqdm.tqdm(nlp.pipe(datagen, batch_size=batch_size, n_process=n_process)):
docbin.add(doc) docbin.add(doc)
if output_path.is_dir(): if output_path.is_dir():