Avoid pickling Doc inputs passed to Language.pipe() (#10864)

* `Language.pipe()`: Serialize `Doc` objects to bytes when using multiprocessing to avoid pickling overhead

* `Doc.to_dict()`: Serialize `_context` attribute (keeping in line with `(un)pickle_doc()`

* Correct type annotations

* Fix typo

* `Doc`: Do not serialize `_context`

* `Language.pipe`: Send context objects to child processes, Simplify `as_tuples` handling

* Fix type annotation

* `Language.pipe`: Simplify `as_tuple` multiprocessor handling

* Cleanup code, fix typos

* MyPy fixes

* Move doc preparation function into `_multiprocessing_pipe`
Whitespace changes

* Remove superfluous comma

* Rename `prepare_doc` to `prepare_input`

* Update spacy/errors.py

* Undo renaming for error

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
Madeesh Kannan 2022-06-02 20:06:49 +02:00 committed by GitHub
parent 430592b3ce
commit 41389ffe1e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 43 additions and 20 deletions

View File

@ -931,6 +931,7 @@ class Errors(metaclass=ErrorsWithCodes):
"could not be aligned to token boundaries.") "could not be aligned to token boundaries.")
E1040 = ("Doc.from_json requires all tokens to have the same attributes. " E1040 = ("Doc.from_json requires all tokens to have the same attributes. "
"Some tokens do not contain annotation for: {partial_attrs}") "Some tokens do not contain annotation for: {partial_attrs}")
E1041 = ("Expected a string, Doc, or bytes as input, but got: {type}")
# Deprecated model shortcuts, only used in errors and warnings # Deprecated model shortcuts, only used in errors and warnings

View File

@ -1090,16 +1090,21 @@ class Language:
) )
return self.tokenizer(text) return self.tokenizer(text)
def _ensure_doc(self, doc_like: Union[str, Doc]) -> Doc: def _ensure_doc(self, doc_like: Union[str, Doc, bytes]) -> Doc:
"""Create a Doc if need be, or raise an error if the input is not a Doc or a string.""" """Create a Doc if need be, or raise an error if the input is not
a Doc, string, or a byte array (generated by Doc.to_bytes())."""
if isinstance(doc_like, Doc): if isinstance(doc_like, Doc):
return doc_like return doc_like
if isinstance(doc_like, str): if isinstance(doc_like, str):
return self.make_doc(doc_like) return self.make_doc(doc_like)
raise ValueError(Errors.E866.format(type=type(doc_like))) if isinstance(doc_like, bytes):
return Doc(self.vocab).from_bytes(doc_like)
raise ValueError(Errors.E1041.format(type=type(doc_like)))
def _ensure_doc_with_context(self, doc_like: Union[str, Doc], context: Any) -> Doc: def _ensure_doc_with_context(
"""Create a Doc if need be and add as_tuples context, or raise an error if the input is not a Doc or a string.""" self, doc_like: Union[str, Doc, bytes], context: _AnyContext
) -> Doc:
"""Call _ensure_doc to generate a Doc and set its context object."""
doc = self._ensure_doc(doc_like) doc = self._ensure_doc(doc_like)
doc._context = context doc._context = context
return doc return doc
@ -1519,7 +1524,6 @@ class Language:
DOCS: https://spacy.io/api/language#pipe DOCS: https://spacy.io/api/language#pipe
""" """
# Handle texts with context as tuples
if as_tuples: if as_tuples:
texts = cast(Iterable[Tuple[Union[str, Doc], _AnyContext]], texts) texts = cast(Iterable[Tuple[Union[str, Doc], _AnyContext]], texts)
docs_with_contexts = ( docs_with_contexts = (
@ -1597,8 +1601,21 @@ class Language:
n_process: int, n_process: int,
batch_size: int, batch_size: int,
) -> Iterator[Doc]: ) -> Iterator[Doc]:
def prepare_input(
texts: Iterable[Union[str, Doc]]
) -> Iterable[Tuple[Union[str, bytes], _AnyContext]]:
# Serialize Doc inputs to bytes to avoid incurring pickling
# overhead when they are passed to child processes. Also yield
# any context objects they might have separately (as they are not serialized).
for doc_like in texts:
if isinstance(doc_like, Doc):
yield (doc_like.to_bytes(), cast(_AnyContext, doc_like._context))
else:
yield (doc_like, cast(_AnyContext, None))
serialized_texts_with_ctx = prepare_input(texts) # type: ignore
# raw_texts is used later to stop iteration. # raw_texts is used later to stop iteration.
texts, raw_texts = itertools.tee(texts) texts, raw_texts = itertools.tee(serialized_texts_with_ctx) # type: ignore
# for sending texts to worker # for sending texts to worker
texts_q: List[mp.Queue] = [mp.Queue() for _ in range(n_process)] texts_q: List[mp.Queue] = [mp.Queue() for _ in range(n_process)]
# for receiving byte-encoded docs from worker # for receiving byte-encoded docs from worker
@ -1618,7 +1635,13 @@ class Language:
procs = [ procs = [
mp.Process( mp.Process(
target=_apply_pipes, target=_apply_pipes,
args=(self._ensure_doc, pipes, rch, sch, Underscore.get_state()), args=(
self._ensure_doc_with_context,
pipes,
rch,
sch,
Underscore.get_state(),
),
) )
for rch, sch in zip(texts_q, bytedocs_send_ch) for rch, sch in zip(texts_q, bytedocs_send_ch)
] ]
@ -1631,12 +1654,12 @@ class Language:
recv.recv() for recv in cycle(bytedocs_recv_ch) recv.recv() for recv in cycle(bytedocs_recv_ch)
) )
try: try:
for i, (_, (byte_doc, byte_context, byte_error)) in enumerate( for i, (_, (byte_doc, context, byte_error)) in enumerate(
zip(raw_texts, byte_tuples), 1 zip(raw_texts, byte_tuples), 1
): ):
if byte_doc is not None: if byte_doc is not None:
doc = Doc(self.vocab).from_bytes(byte_doc) doc = Doc(self.vocab).from_bytes(byte_doc)
doc._context = byte_context doc._context = context
yield doc yield doc
elif byte_error is not None: elif byte_error is not None:
error = srsly.msgpack_loads(byte_error) error = srsly.msgpack_loads(byte_error)
@ -2163,7 +2186,7 @@ def _copy_examples(examples: Iterable[Example]) -> List[Example]:
def _apply_pipes( def _apply_pipes(
ensure_doc: Callable[[Union[str, Doc]], Doc], ensure_doc: Callable[[Union[str, Doc, bytes], _AnyContext], Doc],
pipes: Iterable[Callable[..., Iterator[Doc]]], pipes: Iterable[Callable[..., Iterator[Doc]]],
receiver, receiver,
sender, sender,
@ -2184,17 +2207,19 @@ def _apply_pipes(
Underscore.load_state(underscore_state) Underscore.load_state(underscore_state)
while True: while True:
try: try:
texts = receiver.get() texts_with_ctx = receiver.get()
docs = (ensure_doc(text) for text in texts) docs = (
ensure_doc(doc_like, context) for doc_like, context in texts_with_ctx
)
for pipe in pipes: for pipe in pipes:
docs = pipe(docs) # type: ignore[arg-type, assignment] docs = pipe(docs) # type: ignore[arg-type, assignment]
# Connection does not accept unpickable objects, so send list. # Connection does not accept unpickable objects, so send list.
byte_docs = [(doc.to_bytes(), doc._context, None) for doc in docs] byte_docs = [(doc.to_bytes(), doc._context, None) for doc in docs]
padding = [(None, None, None)] * (len(texts) - len(byte_docs)) padding = [(None, None, None)] * (len(texts_with_ctx) - len(byte_docs))
sender.send(byte_docs + padding) # type: ignore[operator] sender.send(byte_docs + padding) # type: ignore[operator]
except Exception: except Exception:
error_msg = [(None, None, srsly.msgpack_dumps(traceback.format_exc()))] error_msg = [(None, None, srsly.msgpack_dumps(traceback.format_exc()))]
padding = [(None, None, None)] * (len(texts) - 1) padding = [(None, None, None)] * (len(texts_with_ctx) - 1)
sender.send(error_msg + padding) sender.send(error_msg + padding)

View File

@ -5,11 +5,9 @@ from spacy.compat import pickle
def test_pickle_single_doc(): def test_pickle_single_doc():
nlp = Language() nlp = Language()
doc = nlp("pickle roundtrip") doc = nlp("pickle roundtrip")
doc._context = 3
data = pickle.dumps(doc, 1) data = pickle.dumps(doc, 1)
doc2 = pickle.loads(data) doc2 = pickle.loads(data)
assert doc2.text == "pickle roundtrip" assert doc2.text == "pickle roundtrip"
assert doc2._context == 3
def test_list_of_docs_pickles_efficiently(): def test_list_of_docs_pickles_efficiently():

View File

@ -1880,18 +1880,17 @@ cdef int [:,:] _get_lca_matrix(Doc doc, int start, int end):
def pickle_doc(doc): def pickle_doc(doc):
bytes_data = doc.to_bytes(exclude=["vocab", "user_data", "user_hooks"]) bytes_data = doc.to_bytes(exclude=["vocab", "user_data", "user_hooks"])
hooks_and_data = (doc.user_data, doc.user_hooks, doc.user_span_hooks, hooks_and_data = (doc.user_data, doc.user_hooks, doc.user_span_hooks,
doc.user_token_hooks, doc._context) doc.user_token_hooks)
return (unpickle_doc, (doc.vocab, srsly.pickle_dumps(hooks_and_data), bytes_data)) return (unpickle_doc, (doc.vocab, srsly.pickle_dumps(hooks_and_data), bytes_data))
def unpickle_doc(vocab, hooks_and_data, bytes_data): def unpickle_doc(vocab, hooks_and_data, bytes_data):
user_data, doc_hooks, span_hooks, token_hooks, _context = srsly.pickle_loads(hooks_and_data) user_data, doc_hooks, span_hooks, token_hooks = srsly.pickle_loads(hooks_and_data)
doc = Doc(vocab, user_data=user_data).from_bytes(bytes_data, exclude=["user_data"]) doc = Doc(vocab, user_data=user_data).from_bytes(bytes_data, exclude=["user_data"])
doc.user_hooks.update(doc_hooks) doc.user_hooks.update(doc_hooks)
doc.user_span_hooks.update(span_hooks) doc.user_span_hooks.update(span_hooks)
doc.user_token_hooks.update(token_hooks) doc.user_token_hooks.update(token_hooks)
doc._context = _context
return doc return doc