mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-11 04:08:09 +03:00
Avoid pickling Doc
inputs passed to Language.pipe()
(#10864)
* `Language.pipe()`: Serialize `Doc` objects to bytes when using multiprocessing to avoid pickling overhead * `Doc.to_dict()`: Serialize `_context` attribute (keeping in line with `(un)pickle_doc()` * Correct type annotations * Fix typo * `Doc`: Do not serialize `_context` * `Language.pipe`: Send context objects to child processes, Simplify `as_tuples` handling * Fix type annotation * `Language.pipe`: Simplify `as_tuple` multiprocessor handling * Cleanup code, fix typos * MyPy fixes * Move doc preparation function into `_multiprocessing_pipe` Whitespace changes * Remove superfluous comma * Rename `prepare_doc` to `prepare_input` * Update spacy/errors.py * Undo renaming for error Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
parent
0bf367dfc2
commit
0926a0993a
|
@ -927,6 +927,7 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
"could not be aligned to token boundaries.")
|
||||
E1040 = ("Doc.from_json requires all tokens to have the same attributes. "
|
||||
"Some tokens do not contain annotation for: {partial_attrs}")
|
||||
E1041 = ("Expected a string, Doc, or bytes as input, but got: {type}")
|
||||
|
||||
|
||||
# Deprecated model shortcuts, only used in errors and warnings
|
||||
|
|
|
@ -1090,16 +1090,21 @@ class Language:
|
|||
)
|
||||
return self.tokenizer(text)
|
||||
|
||||
def _ensure_doc(self, doc_like: Union[str, Doc]) -> Doc:
|
||||
"""Create a Doc if need be, or raise an error if the input is not a Doc or a string."""
|
||||
def _ensure_doc(self, doc_like: Union[str, Doc, bytes]) -> Doc:
|
||||
"""Create a Doc if need be, or raise an error if the input is not
|
||||
a Doc, string, or a byte array (generated by Doc.to_bytes())."""
|
||||
if isinstance(doc_like, Doc):
|
||||
return doc_like
|
||||
if isinstance(doc_like, str):
|
||||
return self.make_doc(doc_like)
|
||||
raise ValueError(Errors.E866.format(type=type(doc_like)))
|
||||
if isinstance(doc_like, bytes):
|
||||
return Doc(self.vocab).from_bytes(doc_like)
|
||||
raise ValueError(Errors.E1041.format(type=type(doc_like)))
|
||||
|
||||
def _ensure_doc_with_context(self, doc_like: Union[str, Doc], context: Any) -> Doc:
|
||||
"""Create a Doc if need be and add as_tuples context, or raise an error if the input is not a Doc or a string."""
|
||||
def _ensure_doc_with_context(
|
||||
self, doc_like: Union[str, Doc, bytes], context: _AnyContext
|
||||
) -> Doc:
|
||||
"""Call _ensure_doc to generate a Doc and set its context object."""
|
||||
doc = self._ensure_doc(doc_like)
|
||||
doc._context = context
|
||||
return doc
|
||||
|
@ -1519,7 +1524,6 @@ class Language:
|
|||
|
||||
DOCS: https://spacy.io/api/language#pipe
|
||||
"""
|
||||
# Handle texts with context as tuples
|
||||
if as_tuples:
|
||||
texts = cast(Iterable[Tuple[Union[str, Doc], _AnyContext]], texts)
|
||||
docs_with_contexts = (
|
||||
|
@ -1597,8 +1601,21 @@ class Language:
|
|||
n_process: int,
|
||||
batch_size: int,
|
||||
) -> Iterator[Doc]:
|
||||
def prepare_input(
|
||||
texts: Iterable[Union[str, Doc]]
|
||||
) -> Iterable[Tuple[Union[str, bytes], _AnyContext]]:
|
||||
# Serialize Doc inputs to bytes to avoid incurring pickling
|
||||
# overhead when they are passed to child processes. Also yield
|
||||
# any context objects they might have separately (as they are not serialized).
|
||||
for doc_like in texts:
|
||||
if isinstance(doc_like, Doc):
|
||||
yield (doc_like.to_bytes(), cast(_AnyContext, doc_like._context))
|
||||
else:
|
||||
yield (doc_like, cast(_AnyContext, None))
|
||||
|
||||
serialized_texts_with_ctx = prepare_input(texts) # type: ignore
|
||||
# raw_texts is used later to stop iteration.
|
||||
texts, raw_texts = itertools.tee(texts)
|
||||
texts, raw_texts = itertools.tee(serialized_texts_with_ctx) # type: ignore
|
||||
# for sending texts to worker
|
||||
texts_q: List[mp.Queue] = [mp.Queue() for _ in range(n_process)]
|
||||
# for receiving byte-encoded docs from worker
|
||||
|
@ -1618,7 +1635,13 @@ class Language:
|
|||
procs = [
|
||||
mp.Process(
|
||||
target=_apply_pipes,
|
||||
args=(self._ensure_doc, pipes, rch, sch, Underscore.get_state()),
|
||||
args=(
|
||||
self._ensure_doc_with_context,
|
||||
pipes,
|
||||
rch,
|
||||
sch,
|
||||
Underscore.get_state(),
|
||||
),
|
||||
)
|
||||
for rch, sch in zip(texts_q, bytedocs_send_ch)
|
||||
]
|
||||
|
@ -1631,12 +1654,12 @@ class Language:
|
|||
recv.recv() for recv in cycle(bytedocs_recv_ch)
|
||||
)
|
||||
try:
|
||||
for i, (_, (byte_doc, byte_context, byte_error)) in enumerate(
|
||||
for i, (_, (byte_doc, context, byte_error)) in enumerate(
|
||||
zip(raw_texts, byte_tuples), 1
|
||||
):
|
||||
if byte_doc is not None:
|
||||
doc = Doc(self.vocab).from_bytes(byte_doc)
|
||||
doc._context = byte_context
|
||||
doc._context = context
|
||||
yield doc
|
||||
elif byte_error is not None:
|
||||
error = srsly.msgpack_loads(byte_error)
|
||||
|
@ -2163,7 +2186,7 @@ def _copy_examples(examples: Iterable[Example]) -> List[Example]:
|
|||
|
||||
|
||||
def _apply_pipes(
|
||||
ensure_doc: Callable[[Union[str, Doc]], Doc],
|
||||
ensure_doc: Callable[[Union[str, Doc, bytes], _AnyContext], Doc],
|
||||
pipes: Iterable[Callable[..., Iterator[Doc]]],
|
||||
receiver,
|
||||
sender,
|
||||
|
@ -2184,17 +2207,19 @@ def _apply_pipes(
|
|||
Underscore.load_state(underscore_state)
|
||||
while True:
|
||||
try:
|
||||
texts = receiver.get()
|
||||
docs = (ensure_doc(text) for text in texts)
|
||||
texts_with_ctx = receiver.get()
|
||||
docs = (
|
||||
ensure_doc(doc_like, context) for doc_like, context in texts_with_ctx
|
||||
)
|
||||
for pipe in pipes:
|
||||
docs = pipe(docs) # type: ignore[arg-type, assignment]
|
||||
# Connection does not accept unpickable objects, so send list.
|
||||
byte_docs = [(doc.to_bytes(), doc._context, None) for doc in docs]
|
||||
padding = [(None, None, None)] * (len(texts) - len(byte_docs))
|
||||
padding = [(None, None, None)] * (len(texts_with_ctx) - len(byte_docs))
|
||||
sender.send(byte_docs + padding) # type: ignore[operator]
|
||||
except Exception:
|
||||
error_msg = [(None, None, srsly.msgpack_dumps(traceback.format_exc()))]
|
||||
padding = [(None, None, None)] * (len(texts) - 1)
|
||||
padding = [(None, None, None)] * (len(texts_with_ctx) - 1)
|
||||
sender.send(error_msg + padding)
|
||||
|
||||
|
||||
|
|
|
@ -5,11 +5,9 @@ from spacy.compat import pickle
|
|||
def test_pickle_single_doc():
|
||||
nlp = Language()
|
||||
doc = nlp("pickle roundtrip")
|
||||
doc._context = 3
|
||||
data = pickle.dumps(doc, 1)
|
||||
doc2 = pickle.loads(data)
|
||||
assert doc2.text == "pickle roundtrip"
|
||||
assert doc2._context == 3
|
||||
|
||||
|
||||
def test_list_of_docs_pickles_efficiently():
|
||||
|
|
|
@ -1880,18 +1880,17 @@ cdef int [:,:] _get_lca_matrix(Doc doc, int start, int end):
|
|||
def pickle_doc(doc):
|
||||
bytes_data = doc.to_bytes(exclude=["vocab", "user_data", "user_hooks"])
|
||||
hooks_and_data = (doc.user_data, doc.user_hooks, doc.user_span_hooks,
|
||||
doc.user_token_hooks, doc._context)
|
||||
doc.user_token_hooks)
|
||||
return (unpickle_doc, (doc.vocab, srsly.pickle_dumps(hooks_and_data), bytes_data))
|
||||
|
||||
|
||||
def unpickle_doc(vocab, hooks_and_data, bytes_data):
|
||||
user_data, doc_hooks, span_hooks, token_hooks, _context = srsly.pickle_loads(hooks_and_data)
|
||||
user_data, doc_hooks, span_hooks, token_hooks = srsly.pickle_loads(hooks_and_data)
|
||||
|
||||
doc = Doc(vocab, user_data=user_data).from_bytes(bytes_data, exclude=["user_data"])
|
||||
doc.user_hooks.update(doc_hooks)
|
||||
doc.user_span_hooks.update(span_hooks)
|
||||
doc.user_token_hooks.update(token_hooks)
|
||||
doc._context = _context
|
||||
return doc
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user