multiprocessing pipe (#1303) (#4371)

* refactor: separate formatting docs and golds in Language.update

* fix return typo

* add pipe test

* unpickleable object cannot be assigned to p.map

* passed test pipe

* passed test!

* pipe terminate

* try pipe

* passed test

* fix ch

* add comments

* fix len(texts)

* add comment

* add comment

* fix: multiprocessing of pipe is not supported in 2

* test: use assert_docs_equal

* fix: is_python3 -> is_python2

* fix: change _pipe arg to use functools.partial

* test: add vector modification test

* test: add sample ner_pipe and user_data pipe

* add warnings test

* test: fix user warnings

* test: fix warnings capture

* fix: remove islice import

* test: remove warnings test

* test: add stream test

* test: rename

* fix: multiproc stream

* fix: stream pipe

* add comment

* mp.Pipe seems to be able to use with relative small data

* test: skip stream test in python2

* sort imports

* test: add reason to skiptest

* fix: use pipe for docs communucation

* add comments

* add comment
This commit is contained in:
tamuhey 2019-10-08 19:20:55 +09:00 committed by Matthew Honnibal
parent 14841d0aa6
commit 650cbfe82d
3 changed files with 198 additions and 10 deletions

View File

@ -95,6 +95,7 @@ class Warnings(object):
"you can ignore this warning by setting SPACY_WARNING_IGNORE=W022. "
"If this is surprising, make sure you have the spacy-lookups-data "
"package installed.")
W023 = ("Multiprocessing of Language.pipe is not supported in Python2. 'n_process' will be set to 1.")
@add_codes

View File

@ -1,8 +1,11 @@
# coding: utf8
from __future__ import absolute_import, unicode_literals
import atexit
import random
import itertools
from warnings import warn
from spacy.util import minibatch
import weakref
import functools
from collections import OrderedDict
@ -10,6 +13,8 @@ from contextlib import contextmanager
from copy import copy, deepcopy
from thinc.neural import Model
import srsly
import multiprocessing as mp
from itertools import chain, cycle
from .tokenizer import Tokenizer
from .vocab import Vocab
@ -21,7 +26,7 @@ from .pipeline import SimilarityHook, TextCategorizer, Sentencizer
from .pipeline import merge_noun_chunks, merge_entities, merge_subtokens
from .pipeline import EntityRuler
from .pipeline import Morphologizer
from .compat import izip, basestring_
from .compat import izip, basestring_, is_python2
from .gold import GoldParse
from .scorer import Scorer
from ._ml import link_vectors_to_models, create_default_optimizer
@ -30,8 +35,9 @@ from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES
from .lang.punctuation import TOKENIZER_INFIXES
from .lang.tokenizer_exceptions import TOKEN_MATCH
from .lang.tag_map import TAG_MAP
from .tokens import Doc
from .lang.lex_attrs import LEX_ATTRS, is_stop
from .errors import Errors, Warnings, deprecation_warning
from .errors import Errors, Warnings, deprecation_warning, user_warning
from . import util
from . import about
@ -733,6 +739,7 @@ class Language(object):
disable=[],
cleanup=False,
component_cfg=None,
n_process=1,
):
"""Process texts as a stream, and yield `Doc` objects in order.
@ -746,12 +753,20 @@ class Language(object):
use. Experimental.
component_cfg (dict): An optional dictionary with extra keyword
arguments for specific components.
n_process (int): Number of processors to process texts, only supported in Python3. If -1, set `multiprocessing.cpu_count()`.
YIELDS (Doc): Documents in the order of the original text.
DOCS: https://spacy.io/api/language#pipe
"""
# raw_texts will be used later to stop iterator.
texts, raw_texts = itertools.tee(texts)
if is_python2 and n_process != 1:
user_warning(Warnings.W023)
n_process = 1
if n_threads != -1:
deprecation_warning(Warnings.W016)
if n_process == -1:
n_process = mp.cpu_count()
if as_tuples:
text_context1, text_context2 = itertools.tee(texts)
texts = (tc[0] for tc in text_context1)
@ -765,9 +780,12 @@ class Language(object):
for doc, context in izip(docs, contexts):
yield (doc, context)
return
docs = (self.make_doc(text) for text in texts)
if component_cfg is None:
component_cfg = {}
pipes = (
[]
) # contains functools.partial objects so that easily create multiprocess worker.
for name, proc in self.pipeline:
if name in disable:
continue
@ -775,10 +793,20 @@ class Language(object):
# Allow component_cfg to overwrite the top-level kwargs.
kwargs.setdefault("batch_size", batch_size)
if hasattr(proc, "pipe"):
docs = proc.pipe(docs, **kwargs)
f = functools.partial(proc.pipe, **kwargs)
else:
# Apply the function, but yield the doc
docs = _pipe(proc, docs, kwargs)
f = functools.partial(_pipe, proc=proc, kwargs=kwargs)
pipes.append(f)
if n_process != 1:
docs = self._multiprocessing_pipe(texts, pipes, n_process, batch_size)
else:
# if n_process == 1, no processes are forked.
docs = (self.make_doc(text) for text in texts)
for pipe in pipes:
docs = pipe(docs)
# Track weakrefs of "recent" documents, so that we can see when they
# expire from memory. When they do, we know we don't need old strings.
# This way, we avoid maintaining an unbounded growth in string entries
@ -809,6 +837,46 @@ class Language(object):
self.tokenizer._reset_cache(keys)
nr_seen = 0
def _multiprocessing_pipe(self, texts, pipes, n_process, batch_size):
# raw_texts is used later to stop iteration.
texts, raw_texts = itertools.tee(texts)
# for sending texts to worker
texts_q = [mp.Queue() for _ in range(n_process)]
# for receiving byte encoded docs from worker
bytedocs_recv_ch, bytedocs_send_ch = zip(
*[mp.Pipe(False) for _ in range(n_process)]
)
batch_texts = minibatch(texts, batch_size)
# Sender sends texts to the workers.
# This is necessary to properly handle infinite length of texts.
# (In this case, all data cannot be sent to the workers at once)
sender = _Sender(batch_texts, texts_q, chunk_size=n_process)
# send twice so that make process busy
sender.send()
sender.send()
procs = [
mp.Process(target=_apply_pipes, args=(self.make_doc, pipes, rch, sch))
for rch, sch in zip(texts_q, bytedocs_send_ch)
]
for proc in procs:
proc.start()
# Cycle channels not to break the order of docs.
# The received object is batch of byte encoded docs, so flatten them with chain.from_iterable.
byte_docs = chain.from_iterable(recv.recv() for recv in cycle(bytedocs_recv_ch))
docs = (Doc(self.vocab).from_bytes(byte_doc) for byte_doc in byte_docs)
try:
for i, (_, doc) in enumerate(zip(raw_texts, docs), 1):
yield doc
if i % batch_size == 0:
# tell `sender` that one batch was consumed.
sender.step()
finally:
for proc in procs:
proc.terminate()
def to_disk(self, path, exclude=tuple(), disable=None):
"""Save the current state to a directory. If a model is loaded, this
will include the model.
@ -987,12 +1055,55 @@ class DisabledPipes(list):
self[:] = []
def _pipe(func, docs, kwargs):
def _pipe(docs, proc, kwargs):
# We added some args for pipe that __call__ doesn't expect.
kwargs = dict(kwargs)
for arg in ["n_threads", "batch_size"]:
if arg in kwargs:
kwargs.pop(arg)
for doc in docs:
doc = func(doc, **kwargs)
doc = proc(doc, **kwargs)
yield doc
def _apply_pipes(make_doc, pipes, reciever, sender):
"""Worker for Language.pipe
Args:
receiver (multiprocessing.Connection): Pipe to receive text. Usually created by `multiprocessing.Pipe()`
sender (multiprocessing.Connection): Pipe to send doc. Usually created by `multiprocessing.Pipe()`
"""
while True:
texts = reciever.get()
docs = (make_doc(text) for text in texts)
for pipe in pipes:
docs = pipe(docs)
# Connection does not accept unpickable objects, so send list.
sender.send([doc.to_bytes() for doc in docs])
class _Sender:
"""Util for sending data to multiprocessing workers in Language.pipe"""
def __init__(self, data, queues, chunk_size):
self.data = iter(data)
self.queues = iter(cycle(queues))
self.chunk_size = chunk_size
self.count = 0
def send(self):
"""Send chunk_size items from self.data to channels."""
for item, q in itertools.islice(
zip(self.data, cycle(self.queues)), self.chunk_size
):
# cycle channels so that distribute the texts evenly
q.put(item)
def step(self):
"""Tell sender that comsumed one item.
Data is sent to the workers after every chunk_size calls."""
self.count += 1
if self.count >= self.chunk_size:
self.count = 0
self.send()

View File

@ -1,11 +1,16 @@
# coding: utf-8
from __future__ import unicode_literals
import itertools
import pytest
from spacy.vocab import Vocab
from spacy.language import Language
from spacy.tokens import Doc
from spacy.compat import is_python2
from spacy.gold import GoldParse
from spacy.language import Language
from spacy.tokens import Doc, Span
from spacy.vocab import Vocab
from .util import add_vecs_to_vocab, assert_docs_equal
@pytest.fixture
@ -58,3 +63,74 @@ def test_language_evaluate(nlp):
# Evaluate badly
with pytest.raises(Exception):
nlp.evaluate([text, gold])
def vector_modification_pipe(doc):
doc.vector += 1
return doc
def userdata_pipe(doc):
doc.user_data["foo"] = "bar"
return doc
def ner_pipe(doc):
span = Span(doc, 0, 1, label="FIRST")
doc.ents += (span,)
return doc
@pytest.fixture
def sample_vectors():
return [
("spacy", [-0.1, -0.2, -0.3]),
("world", [-0.2, -0.3, -0.4]),
("pipe", [0.7, 0.8, 0.9]),
]
@pytest.fixture
def nlp2(nlp, sample_vectors):
add_vecs_to_vocab(nlp.vocab, sample_vectors)
nlp.add_pipe(vector_modification_pipe)
nlp.add_pipe(ner_pipe)
nlp.add_pipe(userdata_pipe)
return nlp
@pytest.fixture
def texts():
data = [
"Hello world.",
"This is spacy.",
"You can use multiprocessing with pipe method.",
"Please try!",
]
return data
@pytest.mark.parametrize("n_process", [1, 2])
def test_language_pipe(nlp2, n_process, texts):
texts = texts * 10
expecteds = [nlp2(text) for text in texts]
docs = nlp2.pipe(texts, n_process=n_process, batch_size=2)
for doc, expected_doc in zip(docs, expecteds):
assert_docs_equal(doc, expected_doc)
@pytest.mark.skipif(
is_python2, reason="python2 seems to be unable to handle iterator properly"
)
@pytest.mark.parametrize("n_process", [1, 2])
def test_language_pipe_stream(nlp2, n_process, texts):
# check if nlp.pipe can handle infinite length iterator properly.
stream_texts = itertools.cycle(texts)
texts0, texts1 = itertools.tee(stream_texts)
expecteds = (nlp2(text) for text in texts0)
docs = nlp2.pipe(texts1, n_process=n_process, batch_size=2)
n_fetch = 20
for doc, expected_doc in itertools.islice(zip(docs, expecteds), n_fetch):
assert_docs_equal(doc, expected_doc)