diff --git a/channels/management/commands/runserver.py b/channels/management/commands/runserver.py index 45debb1..3b68220 100644 --- a/channels/management/commands/runserver.py +++ b/channels/management/commands/runserver.py @@ -151,5 +151,6 @@ class WorkerThread(threading.Thread): def run(self): self.logger.debug("Worker thread running") worker = Worker(channel_layer=self.channel_layer, signal_handlers=False) + worker.ready() worker.run() self.logger.debug("Worker thread exited") diff --git a/channels/management/commands/runworker.py b/channels/management/commands/runworker.py index e082b4a..84454a1 100644 --- a/channels/management/commands/runworker.py +++ b/channels/management/commands/runworker.py @@ -6,8 +6,8 @@ from django.core.management import BaseCommand, CommandError from channels import DEFAULT_CHANNEL_LAYER, channel_layers from channels.log import setup_logger from channels.staticfiles import StaticFilesConsumer -from channels.worker import Worker -from channels.signals import worker_ready +from channels.worker import Worker, WorkerGroup +from channels.signals import worker_process_ready class Command(BaseCommand): @@ -28,12 +28,18 @@ class Command(BaseCommand): '--exclude-channels', action='append', dest='exclude_channels', help='Prevents this worker from listening on the provided channels (supports globbing).', ) + parser.add_argument( + '--threads', action='store', dest='threads', + default=1, type=int, + help='Number of threads to execute.' + ) def handle(self, *args, **options): # Get the backend to use self.verbosity = options.get("verbosity", 1) self.logger = setup_logger('django.channels', self.verbosity) self.channel_layer = channel_layers[options.get("layer", DEFAULT_CHANNEL_LAYER)] + self.n_threads = options.get('threads', 1) # Check that handler isn't inmemory if self.channel_layer.local_only(): raise CommandError( @@ -46,21 +52,30 @@ class Command(BaseCommand): self.channel_layer.router.check_default(http_consumer=StaticFilesConsumer()) else: self.channel_layer.router.check_default() - # Launch a worker - self.logger.info("Running worker against channel layer %s", self.channel_layer) # Optionally provide an output callback callback = None if self.verbosity > 1: callback = self.consumer_called + self.callback = callback + self.options = options + # Choose an appropriate worker. + if self.n_threads == 1: + self.logger.info("Using single-threaded worker.") + worker_cls = Worker + else: + self.logger.info("Using multi-threaded worker, {} thread(s).".format(self.n_threads)) + worker_cls = WorkerGroup # Run the worker + self.logger.info("Running worker against channel layer %s", self.channel_layer) try: - worker = Worker( + worker = worker_cls( channel_layer=self.channel_layer, - callback=callback, - only_channels=options.get("only_channels", None), - exclude_channels=options.get("exclude_channels", None), + callback=self.callback, + only_channels=self.options.get("only_channels", None), + exclude_channels=self.options.get("exclude_channels", None), ) - worker_ready.send(sender=worker) + worker_process_ready.send(sender=worker) + worker.ready() worker.run() except KeyboardInterrupt: pass diff --git a/channels/signals.py b/channels/signals.py index 8c33b96..dc83b94 100644 --- a/channels/signals.py +++ b/channels/signals.py @@ -5,6 +5,7 @@ from django.dispatch import Signal consumer_started = Signal(providing_args=["environ"]) consumer_finished = Signal() worker_ready = Signal() +worker_process_ready = Signal() # Connect connection closer to consumer finished as well consumer_finished.connect(close_old_connections) diff --git a/channels/tests/test_worker.py b/channels/tests/test_worker.py index 5cff5b7..bc6b5d4 100644 --- a/channels/tests/test_worker.py +++ b/channels/tests/test_worker.py @@ -4,12 +4,14 @@ try: from unittest import mock except ImportError: import mock +import threading from channels import Channel, route, DEFAULT_CHANNEL_LAYER from channels.asgi import channel_layers from channels.tests import ChannelTestCase -from channels.worker import Worker +from channels.worker import Worker, WorkerGroup from channels.exceptions import ConsumeLater +from channels.signals import worker_ready class PatchedWorker(Worker): @@ -93,3 +95,42 @@ class WorkerTests(ChannelTestCase): worker.run() self.assertEqual(consumer.call_count, 1) self.assertEqual(channel_layer.send.call_count, 0) + + +class WorkerGroupTests(ChannelTestCase): + """ + Test threaded workers. + """ + + def setUp(self): + self.channel_layer = channel_layers[DEFAULT_CHANNEL_LAYER] + self.worker = WorkerGroup(self.channel_layer, n_threads=4) + self.subworkers = self.worker.workers + + def test_subworkers_created(self): + self.assertEqual(len(self.subworkers), 3) + + def test_subworkers_no_sigterm(self): + for wrk in self.subworkers: + self.assertFalse(wrk.signal_handlers) + + def test_ready_signals_sent(self): + self.in_signal = 0 + + def handle_signal(sender, *args, **kwargs): + self.in_signal += 1 + + worker_ready.connect(handle_signal) + WorkerGroup(self.channel_layer, n_threads=4) + self.worker.ready() + self.assertEqual(self.in_signal, 4) + + def test_sigterm_handler(self): + threads = [] + for wkr in self.subworkers: + t = threading.Thread(target=wkr.run) + t.start() + threads.append(t) + self.worker.sigterm_handler(None, None) + for t in threads: + t.join() diff --git a/channels/worker.py b/channels/worker.py index f6d93d7..3b67e92 100644 --- a/channels/worker.py +++ b/channels/worker.py @@ -5,11 +5,14 @@ import logging import signal import sys import time +import multiprocessing +import threading from .signals import consumer_started, consumer_finished from .exceptions import ConsumeLater from .message import Message from .utils import name_that_thing +from .signals import worker_ready logger = logging.getLogger('django.channels') @@ -66,6 +69,12 @@ class Worker(object): ] return channels + def ready(self): + """ + Called once worker setup is complete. + """ + worker_ready.send(sender=self) + def run(self): """ Tries to continually dispatch messages to consumers. @@ -134,3 +143,41 @@ class Worker(object): else: # Send consumer finished so DB conns close etc. consumer_finished.send(sender=self.__class__) + + +class WorkerGroup(Worker): + """ + Group several workers together in threads. Manages the sub-workers, + terminating them if a signal is received. + """ + + def __init__(self, *args, **kwargs): + n_threads = kwargs.pop('n_threads', multiprocessing.cpu_count()) - 1 + super(WorkerGroup, self).__init__(*args, **kwargs) + kwargs['signal_handlers'] = False + self.workers = [Worker(*args, **kwargs) for ii in range(n_threads)] + + def sigterm_handler(self, signo, stack_frame): + self.termed = True + for wkr in self.workers: + wkr.termed = True + logger.info("Shutdown signal received while busy, waiting for " + "loop termination") + + def ready(self): + super(WorkerGroup, self).ready() + for wkr in self.workers: + wkr.ready() + + def run(self): + """ + Launch sub-workers before running. + """ + self.threads = [threading.Thread(target=self.workers[ii].run) + for ii in range(len(self.workers))] + for t in self.threads: + t.start() + super(WorkerGroup, self).run() + # Join threads once completed. + for t in self.threads: + t.join()