Accept Doc input in pipelines (#9069)

* Accept Doc input in pipelines

Allow `Doc` input to `Language.__call__` and `Language.pipe`, which
skips `Language.make_doc` and passes the doc directly to the pipeline.

* ensure_doc helper function

* avoid running multiple processes on GPU

* Update spacy/tests/test_language.py

Co-authored-by: svlandeg <svlandeg@github.com>
This commit is contained in:
Adriane Boyd 2021-09-22 09:41:05 +02:00 committed by GitHub
parent cd75f96501
commit 2f0bb77920
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 50 additions and 11 deletions

View File

@ -521,6 +521,7 @@ class Errors:
E202 = ("Unsupported alignment mode '{mode}'. Supported modes: {modes}.") E202 = ("Unsupported alignment mode '{mode}'. Supported modes: {modes}.")
# New errors added in v3.x # New errors added in v3.x
E866 = ("Expected a string or 'Doc' as input, but got: {type}.")
E867 = ("The 'textcat' component requires at least two labels because it " E867 = ("The 'textcat' component requires at least two labels because it "
"uses mutually exclusive classes where exactly one label is True " "uses mutually exclusive classes where exactly one label is True "
"for each doc. For binary classification tasks, you can use two " "for each doc. For binary classification tasks, you can use two "

View File

@ -968,7 +968,7 @@ class Language:
def __call__( def __call__(
self, self,
text: str, text: Union[str, Doc],
*, *,
disable: Iterable[str] = SimpleFrozenList(), disable: Iterable[str] = SimpleFrozenList(),
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None, component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
@ -977,7 +977,9 @@ class Language:
and can contain arbitrary whitespace. Alignment into the original string and can contain arbitrary whitespace. Alignment into the original string
is preserved. is preserved.
text (str): The text to be processed. text (Union[str, Doc]): If `str`, the text to be processed. If `Doc`,
the doc will be passed directly to the pipeline, skipping
`Language.make_doc`.
disable (list): Names of the pipeline components to disable. disable (list): Names of the pipeline components to disable.
component_cfg (Dict[str, dict]): An optional dictionary with extra component_cfg (Dict[str, dict]): An optional dictionary with extra
keyword arguments for specific components. keyword arguments for specific components.
@ -985,7 +987,7 @@ class Language:
DOCS: https://spacy.io/api/language#call DOCS: https://spacy.io/api/language#call
""" """
doc = self.make_doc(text) doc = self._ensure_doc(text)
if component_cfg is None: if component_cfg is None:
component_cfg = {} component_cfg = {}
for name, proc in self.pipeline: for name, proc in self.pipeline:
@ -1069,6 +1071,14 @@ class Language:
) )
return self.tokenizer(text) return self.tokenizer(text)
def _ensure_doc(self, doc_like: Union[str, Doc]) -> Doc:
"""Create a Doc if need be, or raise an error if the input is not a Doc or a string."""
if isinstance(doc_like, Doc):
return doc_like
if isinstance(doc_like, str):
return self.make_doc(doc_like)
raise ValueError(Errors.E866.format(type=type(doc_like)))
def update( def update(
self, self,
examples: Iterable[Example], examples: Iterable[Example],
@ -1437,7 +1447,7 @@ class Language:
@overload @overload
def pipe( def pipe(
self, self,
texts: Iterable[Tuple[str, _AnyContext]], texts: Iterable[Tuple[Union[str, Doc], _AnyContext]],
*, *,
as_tuples: bool = ..., as_tuples: bool = ...,
batch_size: Optional[int] = ..., batch_size: Optional[int] = ...,
@ -1449,7 +1459,7 @@ class Language:
def pipe( # noqa: F811 def pipe( # noqa: F811
self, self,
texts: Iterable[str], texts: Iterable[Union[str, Doc]],
*, *,
as_tuples: bool = False, as_tuples: bool = False,
batch_size: Optional[int] = None, batch_size: Optional[int] = None,
@ -1459,7 +1469,8 @@ class Language:
) -> Iterator[Doc]: ) -> Iterator[Doc]:
"""Process texts as a stream, and yield `Doc` objects in order. """Process texts as a stream, and yield `Doc` objects in order.
texts (Iterable[str]): A sequence of texts to process. texts (Iterable[Union[str, Doc]]): A sequence of texts or docs to
process.
as_tuples (bool): If set to True, inputs should be a sequence of as_tuples (bool): If set to True, inputs should be a sequence of
(text, context) tuples. Output will then be a sequence of (text, context) tuples. Output will then be a sequence of
(doc, context) tuples. Defaults to False. (doc, context) tuples. Defaults to False.
@ -1515,7 +1526,7 @@ class Language:
docs = self._multiprocessing_pipe(texts, pipes, n_process, batch_size) docs = self._multiprocessing_pipe(texts, pipes, n_process, batch_size)
else: else:
# if n_process == 1, no processes are forked. # if n_process == 1, no processes are forked.
docs = (self.make_doc(text) for text in texts) docs = (self._ensure_doc(text) for text in texts)
for pipe in pipes: for pipe in pipes:
docs = pipe(docs) docs = pipe(docs)
for doc in docs: for doc in docs:
@ -1549,7 +1560,7 @@ class Language:
procs = [ procs = [
mp.Process( mp.Process(
target=_apply_pipes, target=_apply_pipes,
args=(self.make_doc, pipes, rch, sch, Underscore.get_state()), args=(self._ensure_doc, pipes, rch, sch, Underscore.get_state()),
) )
for rch, sch in zip(texts_q, bytedocs_send_ch) for rch, sch in zip(texts_q, bytedocs_send_ch)
] ]
@ -2084,7 +2095,7 @@ def _copy_examples(examples: Iterable[Example]) -> List[Example]:
def _apply_pipes( def _apply_pipes(
make_doc: Callable[[str], Doc], ensure_doc: Callable[[Union[str, Doc]], Doc],
pipes: Iterable[Callable[[Doc], Doc]], pipes: Iterable[Callable[[Doc], Doc]],
receiver, receiver,
sender, sender,
@ -2092,7 +2103,8 @@ def _apply_pipes(
) -> None: ) -> None:
"""Worker for Language.pipe """Worker for Language.pipe
make_doc (Callable[[str,] Doc]): Function to create Doc from text. ensure_doc (Callable[[Union[str, Doc]], Doc]): Function to create Doc from text
or raise an error if the input is neither a Doc nor a string.
pipes (Iterable[Callable[[Doc], Doc]]): The components to apply. pipes (Iterable[Callable[[Doc], Doc]]): The components to apply.
receiver (multiprocessing.Connection): Pipe to receive text. Usually receiver (multiprocessing.Connection): Pipe to receive text. Usually
created by `multiprocessing.Pipe()` created by `multiprocessing.Pipe()`
@ -2105,7 +2117,7 @@ def _apply_pipes(
while True: while True:
try: try:
texts = receiver.get() texts = receiver.get()
docs = (make_doc(text) for text in texts) docs = (ensure_doc(text) for text in texts)
for pipe in pipes: for pipe in pipes:
docs = pipe(docs) docs = pipe(docs)
# Connection does not accept unpickable objects, so send list. # Connection does not accept unpickable objects, so send list.

View File

@ -528,3 +528,29 @@ def test_language_source_and_vectors(nlp2):
assert long_string in nlp2.vocab.strings assert long_string in nlp2.vocab.strings
# vectors should remain unmodified # vectors should remain unmodified
assert nlp.vocab.vectors.to_bytes() == vectors_bytes assert nlp.vocab.vectors.to_bytes() == vectors_bytes
@pytest.mark.parametrize("n_process", [1, 2])
def test_pass_doc_to_pipeline(nlp, n_process):
texts = ["cats", "dogs", "guinea pigs"]
docs = [nlp.make_doc(text) for text in texts]
assert not any(len(doc.cats) for doc in docs)
doc = nlp(docs[0])
assert doc.text == texts[0]
assert len(doc.cats) > 0
if isinstance(get_current_ops(), NumpyOps) or n_process < 2:
docs = nlp.pipe(docs, n_process=n_process)
assert [doc.text for doc in docs] == texts
assert all(len(doc.cats) for doc in docs)
def test_invalid_arg_to_pipeline(nlp):
str_list = ["This is a text.", "This is another."]
with pytest.raises(ValueError):
nlp(str_list) # type: ignore
assert len(list(nlp.pipe(str_list))) == 2
int_list = [1, 2, 3]
with pytest.raises(ValueError):
list(nlp.pipe(int_list)) # type: ignore
with pytest.raises(ValueError):
nlp(int_list) # type: ignore