mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-07 05:40:20 +03:00
don't warn but raise
This commit is contained in:
parent
91e72f8abe
commit
9b404ea33c
|
@ -3,7 +3,7 @@ import srsly
|
|||
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Optional, Generator, Union, List, Iterable, cast
|
||||
from typing import Optional, List, Iterable, cast, Union
|
||||
|
||||
from wasabi import msg
|
||||
|
||||
|
@ -25,6 +25,8 @@ code_help = ("Path to Python file with additional "
|
|||
"code (registered functions) to be imported")
|
||||
gold_help = "Use gold preprocessing provided in the .spacy files"
|
||||
|
||||
DocOrStrStream = Union[Iterable[str], Iterable[Doc]]
|
||||
|
||||
|
||||
def _stream_docbin(path: Path, vocab: Vocab) -> Iterable[Doc]:
|
||||
"""
|
||||
|
@ -49,26 +51,14 @@ def _stream_jsonl(path: Path) -> Iterable[str]:
|
|||
yield entry["text"]
|
||||
|
||||
|
||||
def _maybe_read_text(path: Path) -> Union[str, None]:
|
||||
def _stream_texts(paths: Iterable[Path]) -> Iterable[str]:
|
||||
"""
|
||||
Try to read the text file from the provided path.
|
||||
When encoutering a decoding error just warn and pass.
|
||||
"""
|
||||
with open(path, 'r') as fin:
|
||||
try:
|
||||
text = fin.read()
|
||||
return text
|
||||
except UnicodeDecodeError as e:
|
||||
msg.warn(f"Skipping file {path}")
|
||||
return None
|
||||
|
||||
|
||||
def _stream_texts(paths: List[Path]) -> Iterable[Union[str, None]]:
|
||||
"""
|
||||
Yields strings or None when decoding error is encountered.
|
||||
Yields strings from text files in paths.
|
||||
"""
|
||||
for path in paths:
|
||||
yield _maybe_read_text(path)
|
||||
with open(path, 'r') as fin:
|
||||
text = fin.read()
|
||||
yield text
|
||||
|
||||
|
||||
@app.command("apply")
|
||||
|
@ -114,7 +104,7 @@ def apply(
|
|||
vocab = nlp.vocab
|
||||
docbin = DocBin()
|
||||
paths = walk_directory(data_path)
|
||||
streams: List[Union[Iterable[str], Iterable[Doc]]] = []
|
||||
streams: List[DocOrStrStream] = []
|
||||
text_files = []
|
||||
for path in paths:
|
||||
if path.suffix == ".spacy":
|
||||
|
@ -124,8 +114,8 @@ def apply(
|
|||
else:
|
||||
text_files.append(path)
|
||||
if len(text_files) > 0:
|
||||
streams.append(filter(None, _stream_texts(text_files)))
|
||||
datagen = cast(Iterable[Union[str, Doc]], chain(*streams))
|
||||
streams.append(_stream_texts(text_files))
|
||||
datagen = cast(DocOrStrStream, chain(*streams))
|
||||
for doc in tqdm.tqdm(nlp.pipe(datagen, batch_size=batch_size, n_process=n_process)):
|
||||
docbin.add(doc)
|
||||
if output_path.is_dir():
|
||||
|
|
Loading…
Reference in New Issue
Block a user