Avoid pickling Doc inputs passed to Language.pipe() (#10864)

* `Language.pipe()`: Serialize `Doc` objects to bytes when using multiprocessing to avoid pickling overhead

* `Doc.to_dict()`: Serialize `_context` attribute (keeping in line with `(un)pickle_doc()`

* Correct type annotations

* Fix typo

* `Doc`: Do not serialize `_context`

* `Language.pipe`: Send context objects to child processes, Simplify `as_tuples` handling

* Fix type annotation

* `Language.pipe`: Simplify `as_tuple` multiprocessor handling

* Cleanup code, fix typos

* MyPy fixes

* Move doc preparation function into `_multiprocessing_pipe`
Whitespace changes

* Remove superfluous comma

* Rename `prepare_doc` to `prepare_input`

* Update spacy/errors.py

* Undo renaming for error

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
Madeesh Kannan 2022-06-02 20:06:49 +02:00 committed by Adriane Boyd
parent 0bf367dfc2
commit 0926a0993a
4 changed files with 43 additions and 20 deletions

View File

@ -927,6 +927,7 @@ class Errors(metaclass=ErrorsWithCodes):
"could not be aligned to token boundaries.")
E1040 = ("Doc.from_json requires all tokens to have the same attributes. "
"Some tokens do not contain annotation for: {partial_attrs}")
E1041 = ("Expected a string, Doc, or bytes as input, but got: {type}")
# Deprecated model shortcuts, only used in errors and warnings

View File

@ -1090,16 +1090,21 @@ class Language:
)
return self.tokenizer(text)
def _ensure_doc(self, doc_like: Union[str, Doc]) -> Doc:
"""Create a Doc if need be, or raise an error if the input is not a Doc or a string."""
def _ensure_doc(self, doc_like: Union[str, Doc, bytes]) -> Doc:
"""Create a Doc if need be, or raise an error if the input is not
a Doc, string, or a byte array (generated by Doc.to_bytes())."""
if isinstance(doc_like, Doc):
return doc_like
if isinstance(doc_like, str):
return self.make_doc(doc_like)
raise ValueError(Errors.E866.format(type=type(doc_like)))
if isinstance(doc_like, bytes):
return Doc(self.vocab).from_bytes(doc_like)
raise ValueError(Errors.E1041.format(type=type(doc_like)))
def _ensure_doc_with_context(self, doc_like: Union[str, Doc], context: Any) -> Doc:
"""Create a Doc if need be and add as_tuples context, or raise an error if the input is not a Doc or a string."""
def _ensure_doc_with_context(
self, doc_like: Union[str, Doc, bytes], context: _AnyContext
) -> Doc:
"""Call _ensure_doc to generate a Doc and set its context object."""
doc = self._ensure_doc(doc_like)
doc._context = context
return doc
@ -1519,7 +1524,6 @@ class Language:
DOCS: https://spacy.io/api/language#pipe
"""
# Handle texts with context as tuples
if as_tuples:
texts = cast(Iterable[Tuple[Union[str, Doc], _AnyContext]], texts)
docs_with_contexts = (
@ -1597,8 +1601,21 @@ class Language:
n_process: int,
batch_size: int,
) -> Iterator[Doc]:
def prepare_input(
texts: Iterable[Union[str, Doc]]
) -> Iterable[Tuple[Union[str, bytes], _AnyContext]]:
# Serialize Doc inputs to bytes to avoid incurring pickling
# overhead when they are passed to child processes. Also yield
# any context objects they might have separately (as they are not serialized).
for doc_like in texts:
if isinstance(doc_like, Doc):
yield (doc_like.to_bytes(), cast(_AnyContext, doc_like._context))
else:
yield (doc_like, cast(_AnyContext, None))
serialized_texts_with_ctx = prepare_input(texts) # type: ignore
# raw_texts is used later to stop iteration.
texts, raw_texts = itertools.tee(texts)
texts, raw_texts = itertools.tee(serialized_texts_with_ctx) # type: ignore
# for sending texts to worker
texts_q: List[mp.Queue] = [mp.Queue() for _ in range(n_process)]
# for receiving byte-encoded docs from worker
@ -1618,7 +1635,13 @@ class Language:
procs = [
mp.Process(
target=_apply_pipes,
args=(self._ensure_doc, pipes, rch, sch, Underscore.get_state()),
args=(
self._ensure_doc_with_context,
pipes,
rch,
sch,
Underscore.get_state(),
),
)
for rch, sch in zip(texts_q, bytedocs_send_ch)
]
@ -1631,12 +1654,12 @@ class Language:
recv.recv() for recv in cycle(bytedocs_recv_ch)
)
try:
for i, (_, (byte_doc, byte_context, byte_error)) in enumerate(
for i, (_, (byte_doc, context, byte_error)) in enumerate(
zip(raw_texts, byte_tuples), 1
):
if byte_doc is not None:
doc = Doc(self.vocab).from_bytes(byte_doc)
doc._context = byte_context
doc._context = context
yield doc
elif byte_error is not None:
error = srsly.msgpack_loads(byte_error)
@ -2163,7 +2186,7 @@ def _copy_examples(examples: Iterable[Example]) -> List[Example]:
def _apply_pipes(
ensure_doc: Callable[[Union[str, Doc]], Doc],
ensure_doc: Callable[[Union[str, Doc, bytes], _AnyContext], Doc],
pipes: Iterable[Callable[..., Iterator[Doc]]],
receiver,
sender,
@ -2184,17 +2207,19 @@ def _apply_pipes(
Underscore.load_state(underscore_state)
while True:
try:
texts = receiver.get()
docs = (ensure_doc(text) for text in texts)
texts_with_ctx = receiver.get()
docs = (
ensure_doc(doc_like, context) for doc_like, context in texts_with_ctx
)
for pipe in pipes:
docs = pipe(docs) # type: ignore[arg-type, assignment]
# Connection does not accept unpickable objects, so send list.
byte_docs = [(doc.to_bytes(), doc._context, None) for doc in docs]
padding = [(None, None, None)] * (len(texts) - len(byte_docs))
padding = [(None, None, None)] * (len(texts_with_ctx) - len(byte_docs))
sender.send(byte_docs + padding) # type: ignore[operator]
except Exception:
error_msg = [(None, None, srsly.msgpack_dumps(traceback.format_exc()))]
padding = [(None, None, None)] * (len(texts) - 1)
padding = [(None, None, None)] * (len(texts_with_ctx) - 1)
sender.send(error_msg + padding)

View File

@ -5,11 +5,9 @@ from spacy.compat import pickle
def test_pickle_single_doc():
nlp = Language()
doc = nlp("pickle roundtrip")
doc._context = 3
data = pickle.dumps(doc, 1)
doc2 = pickle.loads(data)
assert doc2.text == "pickle roundtrip"
assert doc2._context == 3
def test_list_of_docs_pickles_efficiently():

View File

@ -1880,18 +1880,17 @@ cdef int [:,:] _get_lca_matrix(Doc doc, int start, int end):
def pickle_doc(doc):
bytes_data = doc.to_bytes(exclude=["vocab", "user_data", "user_hooks"])
hooks_and_data = (doc.user_data, doc.user_hooks, doc.user_span_hooks,
doc.user_token_hooks, doc._context)
doc.user_token_hooks)
return (unpickle_doc, (doc.vocab, srsly.pickle_dumps(hooks_and_data), bytes_data))
def unpickle_doc(vocab, hooks_and_data, bytes_data):
user_data, doc_hooks, span_hooks, token_hooks, _context = srsly.pickle_loads(hooks_and_data)
user_data, doc_hooks, span_hooks, token_hooks = srsly.pickle_loads(hooks_and_data)
doc = Doc(vocab, user_data=user_data).from_bytes(bytes_data, exclude=["user_data"])
doc.user_hooks.update(doc_hooks)
doc.user_span_hooks.update(span_hooks)
doc.user_token_hooks.update(token_hooks)
doc._context = _context
return doc