Accept Doc input in pipelines (#9069)

* Accept Doc input in pipelines

Allow `Doc` input to `Language.__call__` and `Language.pipe`, which
skips `Language.make_doc` and passes the doc directly to the pipeline.

* ensure_doc helper function

* avoid running multiple processes on GPU

* Update spacy/tests/test_language.py

Co-authored-by: svlandeg <svlandeg@github.com>
This commit is contained in:
Adriane Boyd 2021-09-22 09:41:05 +02:00 committed by GitHub
parent cd75f96501
commit 2f0bb77920
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 50 additions and 11 deletions

View File

@ -521,6 +521,7 @@ class Errors:
E202 = ("Unsupported alignment mode '{mode}'. Supported modes: {modes}.")
# New errors added in v3.x
E866 = ("Expected a string or 'Doc' as input, but got: {type}.")
E867 = ("The 'textcat' component requires at least two labels because it "
"uses mutually exclusive classes where exactly one label is True "
"for each doc. For binary classification tasks, you can use two "

View File

@ -968,7 +968,7 @@ class Language:
def __call__(
self,
text: str,
text: Union[str, Doc],
*,
disable: Iterable[str] = SimpleFrozenList(),
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
@ -977,7 +977,9 @@ class Language:
and can contain arbitrary whitespace. Alignment into the original string
is preserved.
text (str): The text to be processed.
text (Union[str, Doc]): If `str`, the text to be processed. If `Doc`,
the doc will be passed directly to the pipeline, skipping
`Language.make_doc`.
disable (list): Names of the pipeline components to disable.
component_cfg (Dict[str, dict]): An optional dictionary with extra
keyword arguments for specific components.
@ -985,7 +987,7 @@ class Language:
DOCS: https://spacy.io/api/language#call
"""
doc = self.make_doc(text)
doc = self._ensure_doc(text)
if component_cfg is None:
component_cfg = {}
for name, proc in self.pipeline:
@ -1069,6 +1071,14 @@ class Language:
)
return self.tokenizer(text)
def _ensure_doc(self, doc_like: Union[str, Doc]) -> Doc:
"""Create a Doc if need be, or raise an error if the input is not a Doc or a string."""
if isinstance(doc_like, Doc):
return doc_like
if isinstance(doc_like, str):
return self.make_doc(doc_like)
raise ValueError(Errors.E866.format(type=type(doc_like)))
def update(
self,
examples: Iterable[Example],
@ -1437,7 +1447,7 @@ class Language:
@overload
def pipe(
self,
texts: Iterable[Tuple[str, _AnyContext]],
texts: Iterable[Tuple[Union[str, Doc], _AnyContext]],
*,
as_tuples: bool = ...,
batch_size: Optional[int] = ...,
@ -1449,7 +1459,7 @@ class Language:
def pipe( # noqa: F811
self,
texts: Iterable[str],
texts: Iterable[Union[str, Doc]],
*,
as_tuples: bool = False,
batch_size: Optional[int] = None,
@ -1459,7 +1469,8 @@ class Language:
) -> Iterator[Doc]:
"""Process texts as a stream, and yield `Doc` objects in order.
texts (Iterable[str]): A sequence of texts to process.
texts (Iterable[Union[str, Doc]]): A sequence of texts or docs to
process.
as_tuples (bool): If set to True, inputs should be a sequence of
(text, context) tuples. Output will then be a sequence of
(doc, context) tuples. Defaults to False.
@ -1515,7 +1526,7 @@ class Language:
docs = self._multiprocessing_pipe(texts, pipes, n_process, batch_size)
else:
# if n_process == 1, no processes are forked.
docs = (self.make_doc(text) for text in texts)
docs = (self._ensure_doc(text) for text in texts)
for pipe in pipes:
docs = pipe(docs)
for doc in docs:
@ -1549,7 +1560,7 @@ class Language:
procs = [
mp.Process(
target=_apply_pipes,
args=(self.make_doc, pipes, rch, sch, Underscore.get_state()),
args=(self._ensure_doc, pipes, rch, sch, Underscore.get_state()),
)
for rch, sch in zip(texts_q, bytedocs_send_ch)
]
@ -2084,7 +2095,7 @@ def _copy_examples(examples: Iterable[Example]) -> List[Example]:
def _apply_pipes(
make_doc: Callable[[str], Doc],
ensure_doc: Callable[[Union[str, Doc]], Doc],
pipes: Iterable[Callable[[Doc], Doc]],
receiver,
sender,
@ -2092,7 +2103,8 @@ def _apply_pipes(
) -> None:
"""Worker for Language.pipe
make_doc (Callable[[str,] Doc]): Function to create Doc from text.
ensure_doc (Callable[[Union[str, Doc]], Doc]): Function to create Doc from text
or raise an error if the input is neither a Doc nor a string.
pipes (Iterable[Callable[[Doc], Doc]]): The components to apply.
receiver (multiprocessing.Connection): Pipe to receive text. Usually
created by `multiprocessing.Pipe()`
@ -2105,7 +2117,7 @@ def _apply_pipes(
while True:
try:
texts = receiver.get()
docs = (make_doc(text) for text in texts)
docs = (ensure_doc(text) for text in texts)
for pipe in pipes:
docs = pipe(docs)
# Connection does not accept unpickable objects, so send list.

View File

@ -528,3 +528,29 @@ def test_language_source_and_vectors(nlp2):
assert long_string in nlp2.vocab.strings
# vectors should remain unmodified
assert nlp.vocab.vectors.to_bytes() == vectors_bytes
@pytest.mark.parametrize("n_process", [1, 2])
def test_pass_doc_to_pipeline(nlp, n_process):
texts = ["cats", "dogs", "guinea pigs"]
docs = [nlp.make_doc(text) for text in texts]
assert not any(len(doc.cats) for doc in docs)
doc = nlp(docs[0])
assert doc.text == texts[0]
assert len(doc.cats) > 0
if isinstance(get_current_ops(), NumpyOps) or n_process < 2:
docs = nlp.pipe(docs, n_process=n_process)
assert [doc.text for doc in docs] == texts
assert all(len(doc.cats) for doc in docs)
def test_invalid_arg_to_pipeline(nlp):
str_list = ["This is a text.", "This is another."]
with pytest.raises(ValueError):
nlp(str_list) # type: ignore
assert len(list(nlp.pipe(str_list))) == 2
int_list = [1, 2, 3]
with pytest.raises(ValueError):
list(nlp.pipe(int_list)) # type: ignore
with pytest.raises(ValueError):
nlp(int_list) # type: ignore