typing fix

This commit is contained in:
kadarakos 2022-09-15 12:44:58 +00:00
parent 8030393ecc
commit cf2c073fcc

View File

@ -3,7 +3,7 @@ import srsly
from itertools import chain
from pathlib import Path
from typing import Optional, Generator, Union, List
from typing import Optional, Generator, Union, List, Iterable, cast
from wasabi import msg
@ -26,7 +26,7 @@ code_help = ("Path to Python file with additional "
gold_help = "Use gold preprocessing provided in the .spacy files"
def _stream_docbin(path: Path, vocab: Vocab) -> Generator[Doc, None, None]:
def _stream_docbin(path: Path, vocab: Vocab) -> Iterable[Doc]:
"""
Stream Doc objects from DocBin.
"""
@ -36,7 +36,7 @@ def _stream_docbin(path: Path, vocab: Vocab) -> Generator[Doc, None, None]:
yield doc
def _stream_jsonl(path: Path) -> Generator[str, None, None]:
def _stream_jsonl(path: Path) -> Iterable[str]:
"""
Stream "text" field from JSONL. If the field "text" is
not found it raises error.
@ -65,7 +65,7 @@ def _maybe_read_text(path: Path) -> Union[str, None]:
return None
def _stream_texts(paths: List[Path]) -> Generator[Union[str, None], None, None]:
def _stream_texts(paths: List[Path]) -> Iterable[Union[str, None]]:
"""
Yields strings or None when decoding error is encountered.
"""
@ -115,26 +115,19 @@ def apply(
msg.good(f"Loaded model {model}")
vocab = nlp.vocab
docbin = DocBin()
datagen: Union[
Generator[Union[Doc, str], None, None],
chain[Union[Doc, str]],
filter[str]
]
paths = walk_directory(data_path)
streams = []
streams: List[Union[Iterable[str], Iterable[Doc]]] = []
text_files = []
for path in paths:
if path.suffix == ".spacy":
stream = _stream_docbin(path, vocab)
streams.append(_stream_docbin(path, vocab))
elif path.suffix == ".jsonl":
streams.append(_stream_jsonl(path))
else:
text_files.append(path)
if len(text_files) > 0:
stream = filter(None, _stream_texts(text_files))
streams.append(stream)
datagen = chain(*streams)
streams.append(filter(None, _stream_texts(text_files)))
datagen = cast(Iterable[Union[str, Doc]], chain(*streams))
for doc in tqdm.tqdm(nlp.pipe(datagen, batch_size=batch_size, n_process=n_process)):
docbin.add(doc)
if output_path.is_dir():