mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Avoid pickling Doc
inputs passed to Language.pipe()
(#10864)
* `Language.pipe()`: Serialize `Doc` objects to bytes when using multiprocessing to avoid pickling overhead * `Doc.to_dict()`: Serialize `_context` attribute (keeping in line with `(un)pickle_doc()` * Correct type annotations * Fix typo * `Doc`: Do not serialize `_context` * `Language.pipe`: Send context objects to child processes, Simplify `as_tuples` handling * Fix type annotation * `Language.pipe`: Simplify `as_tuple` multiprocessor handling * Cleanup code, fix typos * MyPy fixes * Move doc preparation function into `_multiprocessing_pipe` Whitespace changes * Remove superfluous comma * Rename `prepare_doc` to `prepare_input` * Update spacy/errors.py * Undo renaming for error Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
parent
430592b3ce
commit
41389ffe1e
|
@ -931,6 +931,7 @@ class Errors(metaclass=ErrorsWithCodes):
|
||||||
"could not be aligned to token boundaries.")
|
"could not be aligned to token boundaries.")
|
||||||
E1040 = ("Doc.from_json requires all tokens to have the same attributes. "
|
E1040 = ("Doc.from_json requires all tokens to have the same attributes. "
|
||||||
"Some tokens do not contain annotation for: {partial_attrs}")
|
"Some tokens do not contain annotation for: {partial_attrs}")
|
||||||
|
E1041 = ("Expected a string, Doc, or bytes as input, but got: {type}")
|
||||||
|
|
||||||
|
|
||||||
# Deprecated model shortcuts, only used in errors and warnings
|
# Deprecated model shortcuts, only used in errors and warnings
|
||||||
|
|
|
@ -1090,16 +1090,21 @@ class Language:
|
||||||
)
|
)
|
||||||
return self.tokenizer(text)
|
return self.tokenizer(text)
|
||||||
|
|
||||||
def _ensure_doc(self, doc_like: Union[str, Doc]) -> Doc:
|
def _ensure_doc(self, doc_like: Union[str, Doc, bytes]) -> Doc:
|
||||||
"""Create a Doc if need be, or raise an error if the input is not a Doc or a string."""
|
"""Create a Doc if need be, or raise an error if the input is not
|
||||||
|
a Doc, string, or a byte array (generated by Doc.to_bytes())."""
|
||||||
if isinstance(doc_like, Doc):
|
if isinstance(doc_like, Doc):
|
||||||
return doc_like
|
return doc_like
|
||||||
if isinstance(doc_like, str):
|
if isinstance(doc_like, str):
|
||||||
return self.make_doc(doc_like)
|
return self.make_doc(doc_like)
|
||||||
raise ValueError(Errors.E866.format(type=type(doc_like)))
|
if isinstance(doc_like, bytes):
|
||||||
|
return Doc(self.vocab).from_bytes(doc_like)
|
||||||
|
raise ValueError(Errors.E1041.format(type=type(doc_like)))
|
||||||
|
|
||||||
def _ensure_doc_with_context(self, doc_like: Union[str, Doc], context: Any) -> Doc:
|
def _ensure_doc_with_context(
|
||||||
"""Create a Doc if need be and add as_tuples context, or raise an error if the input is not a Doc or a string."""
|
self, doc_like: Union[str, Doc, bytes], context: _AnyContext
|
||||||
|
) -> Doc:
|
||||||
|
"""Call _ensure_doc to generate a Doc and set its context object."""
|
||||||
doc = self._ensure_doc(doc_like)
|
doc = self._ensure_doc(doc_like)
|
||||||
doc._context = context
|
doc._context = context
|
||||||
return doc
|
return doc
|
||||||
|
@ -1519,7 +1524,6 @@ class Language:
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/language#pipe
|
DOCS: https://spacy.io/api/language#pipe
|
||||||
"""
|
"""
|
||||||
# Handle texts with context as tuples
|
|
||||||
if as_tuples:
|
if as_tuples:
|
||||||
texts = cast(Iterable[Tuple[Union[str, Doc], _AnyContext]], texts)
|
texts = cast(Iterable[Tuple[Union[str, Doc], _AnyContext]], texts)
|
||||||
docs_with_contexts = (
|
docs_with_contexts = (
|
||||||
|
@ -1597,8 +1601,21 @@ class Language:
|
||||||
n_process: int,
|
n_process: int,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
) -> Iterator[Doc]:
|
) -> Iterator[Doc]:
|
||||||
|
def prepare_input(
|
||||||
|
texts: Iterable[Union[str, Doc]]
|
||||||
|
) -> Iterable[Tuple[Union[str, bytes], _AnyContext]]:
|
||||||
|
# Serialize Doc inputs to bytes to avoid incurring pickling
|
||||||
|
# overhead when they are passed to child processes. Also yield
|
||||||
|
# any context objects they might have separately (as they are not serialized).
|
||||||
|
for doc_like in texts:
|
||||||
|
if isinstance(doc_like, Doc):
|
||||||
|
yield (doc_like.to_bytes(), cast(_AnyContext, doc_like._context))
|
||||||
|
else:
|
||||||
|
yield (doc_like, cast(_AnyContext, None))
|
||||||
|
|
||||||
|
serialized_texts_with_ctx = prepare_input(texts) # type: ignore
|
||||||
# raw_texts is used later to stop iteration.
|
# raw_texts is used later to stop iteration.
|
||||||
texts, raw_texts = itertools.tee(texts)
|
texts, raw_texts = itertools.tee(serialized_texts_with_ctx) # type: ignore
|
||||||
# for sending texts to worker
|
# for sending texts to worker
|
||||||
texts_q: List[mp.Queue] = [mp.Queue() for _ in range(n_process)]
|
texts_q: List[mp.Queue] = [mp.Queue() for _ in range(n_process)]
|
||||||
# for receiving byte-encoded docs from worker
|
# for receiving byte-encoded docs from worker
|
||||||
|
@ -1618,7 +1635,13 @@ class Language:
|
||||||
procs = [
|
procs = [
|
||||||
mp.Process(
|
mp.Process(
|
||||||
target=_apply_pipes,
|
target=_apply_pipes,
|
||||||
args=(self._ensure_doc, pipes, rch, sch, Underscore.get_state()),
|
args=(
|
||||||
|
self._ensure_doc_with_context,
|
||||||
|
pipes,
|
||||||
|
rch,
|
||||||
|
sch,
|
||||||
|
Underscore.get_state(),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
for rch, sch in zip(texts_q, bytedocs_send_ch)
|
for rch, sch in zip(texts_q, bytedocs_send_ch)
|
||||||
]
|
]
|
||||||
|
@ -1631,12 +1654,12 @@ class Language:
|
||||||
recv.recv() for recv in cycle(bytedocs_recv_ch)
|
recv.recv() for recv in cycle(bytedocs_recv_ch)
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
for i, (_, (byte_doc, byte_context, byte_error)) in enumerate(
|
for i, (_, (byte_doc, context, byte_error)) in enumerate(
|
||||||
zip(raw_texts, byte_tuples), 1
|
zip(raw_texts, byte_tuples), 1
|
||||||
):
|
):
|
||||||
if byte_doc is not None:
|
if byte_doc is not None:
|
||||||
doc = Doc(self.vocab).from_bytes(byte_doc)
|
doc = Doc(self.vocab).from_bytes(byte_doc)
|
||||||
doc._context = byte_context
|
doc._context = context
|
||||||
yield doc
|
yield doc
|
||||||
elif byte_error is not None:
|
elif byte_error is not None:
|
||||||
error = srsly.msgpack_loads(byte_error)
|
error = srsly.msgpack_loads(byte_error)
|
||||||
|
@ -2163,7 +2186,7 @@ def _copy_examples(examples: Iterable[Example]) -> List[Example]:
|
||||||
|
|
||||||
|
|
||||||
def _apply_pipes(
|
def _apply_pipes(
|
||||||
ensure_doc: Callable[[Union[str, Doc]], Doc],
|
ensure_doc: Callable[[Union[str, Doc, bytes], _AnyContext], Doc],
|
||||||
pipes: Iterable[Callable[..., Iterator[Doc]]],
|
pipes: Iterable[Callable[..., Iterator[Doc]]],
|
||||||
receiver,
|
receiver,
|
||||||
sender,
|
sender,
|
||||||
|
@ -2184,17 +2207,19 @@ def _apply_pipes(
|
||||||
Underscore.load_state(underscore_state)
|
Underscore.load_state(underscore_state)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
texts = receiver.get()
|
texts_with_ctx = receiver.get()
|
||||||
docs = (ensure_doc(text) for text in texts)
|
docs = (
|
||||||
|
ensure_doc(doc_like, context) for doc_like, context in texts_with_ctx
|
||||||
|
)
|
||||||
for pipe in pipes:
|
for pipe in pipes:
|
||||||
docs = pipe(docs) # type: ignore[arg-type, assignment]
|
docs = pipe(docs) # type: ignore[arg-type, assignment]
|
||||||
# Connection does not accept unpickable objects, so send list.
|
# Connection does not accept unpickable objects, so send list.
|
||||||
byte_docs = [(doc.to_bytes(), doc._context, None) for doc in docs]
|
byte_docs = [(doc.to_bytes(), doc._context, None) for doc in docs]
|
||||||
padding = [(None, None, None)] * (len(texts) - len(byte_docs))
|
padding = [(None, None, None)] * (len(texts_with_ctx) - len(byte_docs))
|
||||||
sender.send(byte_docs + padding) # type: ignore[operator]
|
sender.send(byte_docs + padding) # type: ignore[operator]
|
||||||
except Exception:
|
except Exception:
|
||||||
error_msg = [(None, None, srsly.msgpack_dumps(traceback.format_exc()))]
|
error_msg = [(None, None, srsly.msgpack_dumps(traceback.format_exc()))]
|
||||||
padding = [(None, None, None)] * (len(texts) - 1)
|
padding = [(None, None, None)] * (len(texts_with_ctx) - 1)
|
||||||
sender.send(error_msg + padding)
|
sender.send(error_msg + padding)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,11 +5,9 @@ from spacy.compat import pickle
|
||||||
def test_pickle_single_doc():
|
def test_pickle_single_doc():
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
doc = nlp("pickle roundtrip")
|
doc = nlp("pickle roundtrip")
|
||||||
doc._context = 3
|
|
||||||
data = pickle.dumps(doc, 1)
|
data = pickle.dumps(doc, 1)
|
||||||
doc2 = pickle.loads(data)
|
doc2 = pickle.loads(data)
|
||||||
assert doc2.text == "pickle roundtrip"
|
assert doc2.text == "pickle roundtrip"
|
||||||
assert doc2._context == 3
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_of_docs_pickles_efficiently():
|
def test_list_of_docs_pickles_efficiently():
|
||||||
|
|
|
@ -1880,18 +1880,17 @@ cdef int [:,:] _get_lca_matrix(Doc doc, int start, int end):
|
||||||
def pickle_doc(doc):
|
def pickle_doc(doc):
|
||||||
bytes_data = doc.to_bytes(exclude=["vocab", "user_data", "user_hooks"])
|
bytes_data = doc.to_bytes(exclude=["vocab", "user_data", "user_hooks"])
|
||||||
hooks_and_data = (doc.user_data, doc.user_hooks, doc.user_span_hooks,
|
hooks_and_data = (doc.user_data, doc.user_hooks, doc.user_span_hooks,
|
||||||
doc.user_token_hooks, doc._context)
|
doc.user_token_hooks)
|
||||||
return (unpickle_doc, (doc.vocab, srsly.pickle_dumps(hooks_and_data), bytes_data))
|
return (unpickle_doc, (doc.vocab, srsly.pickle_dumps(hooks_and_data), bytes_data))
|
||||||
|
|
||||||
|
|
||||||
def unpickle_doc(vocab, hooks_and_data, bytes_data):
|
def unpickle_doc(vocab, hooks_and_data, bytes_data):
|
||||||
user_data, doc_hooks, span_hooks, token_hooks, _context = srsly.pickle_loads(hooks_and_data)
|
user_data, doc_hooks, span_hooks, token_hooks = srsly.pickle_loads(hooks_and_data)
|
||||||
|
|
||||||
doc = Doc(vocab, user_data=user_data).from_bytes(bytes_data, exclude=["user_data"])
|
doc = Doc(vocab, user_data=user_data).from_bytes(bytes_data, exclude=["user_data"])
|
||||||
doc.user_hooks.update(doc_hooks)
|
doc.user_hooks.update(doc_hooks)
|
||||||
doc.user_span_hooks.update(span_hooks)
|
doc.user_span_hooks.update(span_hooks)
|
||||||
doc.user_token_hooks.update(token_hooks)
|
doc.user_token_hooks.update(token_hooks)
|
||||||
doc._context = _context
|
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user