typing fix

This commit is contained in:
kadarakos 2022-09-15 12:44:58 +00:00
parent 8030393ecc
commit cf2c073fcc

View File

@ -3,7 +3,7 @@ import srsly
from itertools import chain from itertools import chain
from pathlib import Path from pathlib import Path
from typing import Optional, Generator, Union, List from typing import Optional, Generator, Union, List, Iterable, cast
from wasabi import msg from wasabi import msg
@ -26,7 +26,7 @@ code_help = ("Path to Python file with additional "
gold_help = "Use gold preprocessing provided in the .spacy files" gold_help = "Use gold preprocessing provided in the .spacy files"
def _stream_docbin(path: Path, vocab: Vocab) -> Generator[Doc, None, None]: def _stream_docbin(path: Path, vocab: Vocab) -> Iterable[Doc]:
""" """
Stream Doc objects from DocBin. Stream Doc objects from DocBin.
""" """
@ -36,7 +36,7 @@ def _stream_docbin(path: Path, vocab: Vocab) -> Generator[Doc, None, None]:
yield doc yield doc
def _stream_jsonl(path: Path) -> Generator[str, None, None]: def _stream_jsonl(path: Path) -> Iterable[str]:
""" """
Stream "text" field from JSONL. If the field "text" is Stream "text" field from JSONL. If the field "text" is
not found it raises error. not found it raises error.
@ -65,7 +65,7 @@ def _maybe_read_text(path: Path) -> Union[str, None]:
return None return None
def _stream_texts(paths: List[Path]) -> Generator[Union[str, None], None, None]: def _stream_texts(paths: List[Path]) -> Iterable[Union[str, None]]:
""" """
Yields strings or None when decoding error is encountered. Yields strings or None when decoding error is encountered.
""" """
@ -115,26 +115,19 @@ def apply(
msg.good(f"Loaded model {model}") msg.good(f"Loaded model {model}")
vocab = nlp.vocab vocab = nlp.vocab
docbin = DocBin() docbin = DocBin()
datagen: Union[
Generator[Union[Doc, str], None, None],
chain[Union[Doc, str]],
filter[str]
]
paths = walk_directory(data_path) paths = walk_directory(data_path)
streams = [] streams: List[Union[Iterable[str], Iterable[Doc]]] = []
text_files = [] text_files = []
for path in paths: for path in paths:
if path.suffix == ".spacy": if path.suffix == ".spacy":
stream = _stream_docbin(path, vocab)
streams.append(_stream_docbin(path, vocab)) streams.append(_stream_docbin(path, vocab))
elif path.suffix == ".jsonl": elif path.suffix == ".jsonl":
streams.append(_stream_jsonl(path)) streams.append(_stream_jsonl(path))
else: else:
text_files.append(path) text_files.append(path)
if len(text_files) > 0: if len(text_files) > 0:
stream = filter(None, _stream_texts(text_files)) streams.append(filter(None, _stream_texts(text_files)))
streams.append(stream) datagen = cast(Iterable[Union[str, Doc]], chain(*streams))
datagen = chain(*streams)
for doc in tqdm.tqdm(nlp.pipe(datagen, batch_size=batch_size, n_process=n_process)): for doc in tqdm.tqdm(nlp.pipe(datagen, batch_size=batch_size, n_process=n_process)):
docbin.add(doc) docbin.add(doc)
if output_path.is_dir(): if output_path.is_dir():