mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Accept Doc input in pipelines (#9069)
* Accept Doc input in pipelines Allow `Doc` input to `Language.__call__` and `Language.pipe`, which skips `Language.make_doc` and passes the doc directly to the pipeline. * ensure_doc helper function * avoid running multiple processes on GPU * Update spacy/tests/test_language.py Co-authored-by: svlandeg <svlandeg@github.com>
This commit is contained in:
parent
cd75f96501
commit
2f0bb77920
|
@ -521,6 +521,7 @@ class Errors:
|
|||
E202 = ("Unsupported alignment mode '{mode}'. Supported modes: {modes}.")
|
||||
|
||||
# New errors added in v3.x
|
||||
E866 = ("Expected a string or 'Doc' as input, but got: {type}.")
|
||||
E867 = ("The 'textcat' component requires at least two labels because it "
|
||||
"uses mutually exclusive classes where exactly one label is True "
|
||||
"for each doc. For binary classification tasks, you can use two "
|
||||
|
|
|
@ -968,7 +968,7 @@ class Language:
|
|||
|
||||
def __call__(
|
||||
self,
|
||||
text: str,
|
||||
text: Union[str, Doc],
|
||||
*,
|
||||
disable: Iterable[str] = SimpleFrozenList(),
|
||||
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
|
@ -977,7 +977,9 @@ class Language:
|
|||
and can contain arbitrary whitespace. Alignment into the original string
|
||||
is preserved.
|
||||
|
||||
text (str): The text to be processed.
|
||||
text (Union[str, Doc]): If `str`, the text to be processed. If `Doc`,
|
||||
the doc will be passed directly to the pipeline, skipping
|
||||
`Language.make_doc`.
|
||||
disable (list): Names of the pipeline components to disable.
|
||||
component_cfg (Dict[str, dict]): An optional dictionary with extra
|
||||
keyword arguments for specific components.
|
||||
|
@ -985,7 +987,7 @@ class Language:
|
|||
|
||||
DOCS: https://spacy.io/api/language#call
|
||||
"""
|
||||
doc = self.make_doc(text)
|
||||
doc = self._ensure_doc(text)
|
||||
if component_cfg is None:
|
||||
component_cfg = {}
|
||||
for name, proc in self.pipeline:
|
||||
|
@ -1069,6 +1071,14 @@ class Language:
|
|||
)
|
||||
return self.tokenizer(text)
|
||||
|
||||
def _ensure_doc(self, doc_like: Union[str, Doc]) -> Doc:
|
||||
"""Create a Doc if need be, or raise an error if the input is not a Doc or a string."""
|
||||
if isinstance(doc_like, Doc):
|
||||
return doc_like
|
||||
if isinstance(doc_like, str):
|
||||
return self.make_doc(doc_like)
|
||||
raise ValueError(Errors.E866.format(type=type(doc_like)))
|
||||
|
||||
def update(
|
||||
self,
|
||||
examples: Iterable[Example],
|
||||
|
@ -1437,7 +1447,7 @@ class Language:
|
|||
@overload
|
||||
def pipe(
|
||||
self,
|
||||
texts: Iterable[Tuple[str, _AnyContext]],
|
||||
texts: Iterable[Tuple[Union[str, Doc], _AnyContext]],
|
||||
*,
|
||||
as_tuples: bool = ...,
|
||||
batch_size: Optional[int] = ...,
|
||||
|
@ -1449,7 +1459,7 @@ class Language:
|
|||
|
||||
def pipe( # noqa: F811
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
texts: Iterable[Union[str, Doc]],
|
||||
*,
|
||||
as_tuples: bool = False,
|
||||
batch_size: Optional[int] = None,
|
||||
|
@ -1459,7 +1469,8 @@ class Language:
|
|||
) -> Iterator[Doc]:
|
||||
"""Process texts as a stream, and yield `Doc` objects in order.
|
||||
|
||||
texts (Iterable[str]): A sequence of texts to process.
|
||||
texts (Iterable[Union[str, Doc]]): A sequence of texts or docs to
|
||||
process.
|
||||
as_tuples (bool): If set to True, inputs should be a sequence of
|
||||
(text, context) tuples. Output will then be a sequence of
|
||||
(doc, context) tuples. Defaults to False.
|
||||
|
@ -1515,7 +1526,7 @@ class Language:
|
|||
docs = self._multiprocessing_pipe(texts, pipes, n_process, batch_size)
|
||||
else:
|
||||
# if n_process == 1, no processes are forked.
|
||||
docs = (self.make_doc(text) for text in texts)
|
||||
docs = (self._ensure_doc(text) for text in texts)
|
||||
for pipe in pipes:
|
||||
docs = pipe(docs)
|
||||
for doc in docs:
|
||||
|
@ -1549,7 +1560,7 @@ class Language:
|
|||
procs = [
|
||||
mp.Process(
|
||||
target=_apply_pipes,
|
||||
args=(self.make_doc, pipes, rch, sch, Underscore.get_state()),
|
||||
args=(self._ensure_doc, pipes, rch, sch, Underscore.get_state()),
|
||||
)
|
||||
for rch, sch in zip(texts_q, bytedocs_send_ch)
|
||||
]
|
||||
|
@ -2084,7 +2095,7 @@ def _copy_examples(examples: Iterable[Example]) -> List[Example]:
|
|||
|
||||
|
||||
def _apply_pipes(
|
||||
make_doc: Callable[[str], Doc],
|
||||
ensure_doc: Callable[[Union[str, Doc]], Doc],
|
||||
pipes: Iterable[Callable[[Doc], Doc]],
|
||||
receiver,
|
||||
sender,
|
||||
|
@ -2092,7 +2103,8 @@ def _apply_pipes(
|
|||
) -> None:
|
||||
"""Worker for Language.pipe
|
||||
|
||||
make_doc (Callable[[str,] Doc]): Function to create Doc from text.
|
||||
ensure_doc (Callable[[Union[str, Doc]], Doc]): Function to create Doc from text
|
||||
or raise an error if the input is neither a Doc nor a string.
|
||||
pipes (Iterable[Callable[[Doc], Doc]]): The components to apply.
|
||||
receiver (multiprocessing.Connection): Pipe to receive text. Usually
|
||||
created by `multiprocessing.Pipe()`
|
||||
|
@ -2105,7 +2117,7 @@ def _apply_pipes(
|
|||
while True:
|
||||
try:
|
||||
texts = receiver.get()
|
||||
docs = (make_doc(text) for text in texts)
|
||||
docs = (ensure_doc(text) for text in texts)
|
||||
for pipe in pipes:
|
||||
docs = pipe(docs)
|
||||
# Connection does not accept unpickable objects, so send list.
|
||||
|
|
|
@ -528,3 +528,29 @@ def test_language_source_and_vectors(nlp2):
|
|||
assert long_string in nlp2.vocab.strings
|
||||
# vectors should remain unmodified
|
||||
assert nlp.vocab.vectors.to_bytes() == vectors_bytes
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_process", [1, 2])
|
||||
def test_pass_doc_to_pipeline(nlp, n_process):
|
||||
texts = ["cats", "dogs", "guinea pigs"]
|
||||
docs = [nlp.make_doc(text) for text in texts]
|
||||
assert not any(len(doc.cats) for doc in docs)
|
||||
doc = nlp(docs[0])
|
||||
assert doc.text == texts[0]
|
||||
assert len(doc.cats) > 0
|
||||
if isinstance(get_current_ops(), NumpyOps) or n_process < 2:
|
||||
docs = nlp.pipe(docs, n_process=n_process)
|
||||
assert [doc.text for doc in docs] == texts
|
||||
assert all(len(doc.cats) for doc in docs)
|
||||
|
||||
|
||||
def test_invalid_arg_to_pipeline(nlp):
|
||||
str_list = ["This is a text.", "This is another."]
|
||||
with pytest.raises(ValueError):
|
||||
nlp(str_list) # type: ignore
|
||||
assert len(list(nlp.pipe(str_list))) == 2
|
||||
int_list = [1, 2, 3]
|
||||
with pytest.raises(ValueError):
|
||||
list(nlp.pipe(int_list)) # type: ignore
|
||||
with pytest.raises(ValueError):
|
||||
nlp(int_list) # type: ignore
|
||||
|
|
Loading…
Reference in New Issue
Block a user