From cf2c073fccb6b0c9ca658b71589c214acb0b9694 Mon Sep 17 00:00:00 2001 From: kadarakos Date: Thu, 15 Sep 2022 12:44:58 +0000 Subject: [PATCH] typing fix --- spacy/cli/apply.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/spacy/cli/apply.py b/spacy/cli/apply.py index 7acf82f2d..3c50b131b 100644 --- a/spacy/cli/apply.py +++ b/spacy/cli/apply.py @@ -3,7 +3,7 @@ import srsly from itertools import chain from pathlib import Path -from typing import Optional, Generator, Union, List +from typing import Optional, Generator, Union, List, Iterable, cast from wasabi import msg @@ -26,7 +26,7 @@ code_help = ("Path to Python file with additional " gold_help = "Use gold preprocessing provided in the .spacy files" -def _stream_docbin(path: Path, vocab: Vocab) -> Generator[Doc, None, None]: +def _stream_docbin(path: Path, vocab: Vocab) -> Iterable[Doc]: """ Stream Doc objects from DocBin. """ @@ -36,7 +36,7 @@ def _stream_docbin(path: Path, vocab: Vocab) -> Generator[Doc, None, None]: yield doc -def _stream_jsonl(path: Path) -> Generator[str, None, None]: +def _stream_jsonl(path: Path) -> Iterable[str]: """ Stream "text" field from JSONL. If the field "text" is not found it raises error. @@ -65,7 +65,7 @@ def _maybe_read_text(path: Path) -> Union[str, None]: return None -def _stream_texts(paths: List[Path]) -> Generator[Union[str, None], None, None]: +def _stream_texts(paths: List[Path]) -> Iterable[Union[str, None]]: """ Yields strings or None when decoding error is encountered. """ @@ -115,26 +115,19 @@ def apply( msg.good(f"Loaded model {model}") vocab = nlp.vocab docbin = DocBin() - datagen: Union[ - Generator[Union[Doc, str], None, None], - chain[Union[Doc, str]], - filter[str] - ] paths = walk_directory(data_path) - streams = [] + streams: List[Union[Iterable[str], Iterable[Doc]]] = [] text_files = [] for path in paths: if path.suffix == ".spacy": - stream = _stream_docbin(path, vocab) streams.append(_stream_docbin(path, vocab)) elif path.suffix == ".jsonl": streams.append(_stream_jsonl(path)) else: text_files.append(path) if len(text_files) > 0: - stream = filter(None, _stream_texts(text_files)) - streams.append(stream) - datagen = chain(*streams) + streams.append(filter(None, _stream_texts(text_files))) + datagen = cast(Iterable[Union[str, Doc]], chain(*streams)) for doc in tqdm.tqdm(nlp.pipe(datagen, batch_size=batch_size, n_process=n_process)): docbin.add(doc) if output_path.is_dir():