mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-07 05:40:20 +03:00
typing fix
This commit is contained in:
parent
8030393ecc
commit
cf2c073fcc
|
@ -3,7 +3,7 @@ import srsly
|
|||
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Optional, Generator, Union, List
|
||||
from typing import Optional, Generator, Union, List, Iterable, cast
|
||||
|
||||
from wasabi import msg
|
||||
|
||||
|
@ -26,7 +26,7 @@ code_help = ("Path to Python file with additional "
|
|||
gold_help = "Use gold preprocessing provided in the .spacy files"
|
||||
|
||||
|
||||
def _stream_docbin(path: Path, vocab: Vocab) -> Generator[Doc, None, None]:
|
||||
def _stream_docbin(path: Path, vocab: Vocab) -> Iterable[Doc]:
|
||||
"""
|
||||
Stream Doc objects from DocBin.
|
||||
"""
|
||||
|
@ -36,7 +36,7 @@ def _stream_docbin(path: Path, vocab: Vocab) -> Generator[Doc, None, None]:
|
|||
yield doc
|
||||
|
||||
|
||||
def _stream_jsonl(path: Path) -> Generator[str, None, None]:
|
||||
def _stream_jsonl(path: Path) -> Iterable[str]:
|
||||
"""
|
||||
Stream "text" field from JSONL. If the field "text" is
|
||||
not found it raises error.
|
||||
|
@ -65,7 +65,7 @@ def _maybe_read_text(path: Path) -> Union[str, None]:
|
|||
return None
|
||||
|
||||
|
||||
def _stream_texts(paths: List[Path]) -> Generator[Union[str, None], None, None]:
|
||||
def _stream_texts(paths: List[Path]) -> Iterable[Union[str, None]]:
|
||||
"""
|
||||
Yields strings or None when decoding error is encountered.
|
||||
"""
|
||||
|
@ -115,26 +115,19 @@ def apply(
|
|||
msg.good(f"Loaded model {model}")
|
||||
vocab = nlp.vocab
|
||||
docbin = DocBin()
|
||||
datagen: Union[
|
||||
Generator[Union[Doc, str], None, None],
|
||||
chain[Union[Doc, str]],
|
||||
filter[str]
|
||||
]
|
||||
paths = walk_directory(data_path)
|
||||
streams = []
|
||||
streams: List[Union[Iterable[str], Iterable[Doc]]] = []
|
||||
text_files = []
|
||||
for path in paths:
|
||||
if path.suffix == ".spacy":
|
||||
stream = _stream_docbin(path, vocab)
|
||||
streams.append(_stream_docbin(path, vocab))
|
||||
elif path.suffix == ".jsonl":
|
||||
streams.append(_stream_jsonl(path))
|
||||
else:
|
||||
text_files.append(path)
|
||||
if len(text_files) > 0:
|
||||
stream = filter(None, _stream_texts(text_files))
|
||||
streams.append(stream)
|
||||
datagen = chain(*streams)
|
||||
streams.append(filter(None, _stream_texts(text_files)))
|
||||
datagen = cast(Iterable[Union[str, Doc]], chain(*streams))
|
||||
for doc in tqdm.tqdm(nlp.pipe(datagen, batch_size=batch_size, n_process=n_process)):
|
||||
docbin.add(doc)
|
||||
if output_path.is_dir():
|
||||
|
|
Loading…
Reference in New Issue
Block a user