mirror of
https://github.com/django/daphne.git
synced 2025-05-08 09:33:49 +03:00
138 lines
4.3 KiB
Python
138 lines
4.3 KiB
Python
from __future__ import unicode_literals
|
|
|
|
import threading
|
|
|
|
from channels import DEFAULT_CHANNEL_LAYER, Channel, route
|
|
from channels.asgi import channel_layers
|
|
from channels.exceptions import ConsumeLater
|
|
from channels.signals import worker_ready
|
|
from channels.tests import ChannelTestCase
|
|
from channels.worker import Worker, WorkerGroup
|
|
|
|
try:
|
|
from unittest import mock
|
|
except ImportError:
|
|
import mock
|
|
|
|
|
|
class PatchedWorker(Worker):
|
|
"""Worker with specific numbers of loops"""
|
|
def get_termed(self):
|
|
if not self.__iters:
|
|
return True
|
|
self.__iters -= 1
|
|
return False
|
|
|
|
def set_termed(self, value):
|
|
self.__iters = value
|
|
|
|
termed = property(get_termed, set_termed)
|
|
|
|
|
|
class WorkerTests(ChannelTestCase):
|
|
"""
|
|
Tests that the router's routing code works correctly.
|
|
"""
|
|
|
|
def test_channel_filters(self):
|
|
"""
|
|
Tests that the include/exclude logic works
|
|
"""
|
|
# Include
|
|
worker = Worker(None, only_channels=["yes.*", "maybe.*"])
|
|
self.assertEqual(
|
|
worker.apply_channel_filters(["yes.1", "no.1"]),
|
|
["yes.1"],
|
|
)
|
|
self.assertEqual(
|
|
worker.apply_channel_filters(["yes.1", "no.1", "maybe.2", "yes"]),
|
|
["yes.1", "maybe.2"],
|
|
)
|
|
# Exclude
|
|
worker = Worker(None, exclude_channels=["no.*", "maybe.*"])
|
|
self.assertEqual(
|
|
worker.apply_channel_filters(["yes.1", "no.1", "maybe.2", "yes"]),
|
|
["yes.1", "yes"],
|
|
)
|
|
# Both
|
|
worker = Worker(None, exclude_channels=["no.*"], only_channels=["yes.*"])
|
|
self.assertEqual(
|
|
worker.apply_channel_filters(["yes.1", "no.1", "maybe.2", "yes"]),
|
|
["yes.1"],
|
|
)
|
|
|
|
def test_run_with_consume_later_error(self):
|
|
|
|
# consumer with ConsumeLater error at first call
|
|
def _consumer(message, **kwargs):
|
|
_consumer._call_count = getattr(_consumer, '_call_count', 0) + 1
|
|
if _consumer._call_count == 1:
|
|
raise ConsumeLater()
|
|
|
|
Channel('test').send({'test': 'test'}, immediately=True)
|
|
channel_layer = channel_layers[DEFAULT_CHANNEL_LAYER]
|
|
channel_layer.router.add_route(route('test', _consumer))
|
|
old_send = channel_layer.send
|
|
channel_layer.send = mock.Mock(side_effect=old_send) # proxy 'send' for counting
|
|
|
|
worker = PatchedWorker(channel_layer)
|
|
worker.termed = 2 # first loop with error, second with sending
|
|
|
|
worker.run()
|
|
self.assertEqual(getattr(_consumer, '_call_count', None), 2)
|
|
self.assertEqual(channel_layer.send.call_count, 1)
|
|
|
|
def test_normal_run(self):
|
|
consumer = mock.Mock()
|
|
Channel('test').send({'test': 'test'}, immediately=True)
|
|
channel_layer = channel_layers[DEFAULT_CHANNEL_LAYER]
|
|
channel_layer.router.add_route(route('test', consumer))
|
|
old_send = channel_layer.send
|
|
channel_layer.send = mock.Mock(side_effect=old_send) # proxy 'send' for counting
|
|
|
|
worker = PatchedWorker(channel_layer)
|
|
worker.termed = 2
|
|
|
|
worker.run()
|
|
self.assertEqual(consumer.call_count, 1)
|
|
self.assertEqual(channel_layer.send.call_count, 0)
|
|
|
|
|
|
class WorkerGroupTests(ChannelTestCase):
|
|
"""
|
|
Test threaded workers.
|
|
"""
|
|
|
|
def setUp(self):
|
|
self.channel_layer = channel_layers[DEFAULT_CHANNEL_LAYER]
|
|
self.worker = WorkerGroup(self.channel_layer, n_threads=4)
|
|
self.subworkers = self.worker.workers
|
|
|
|
def test_subworkers_created(self):
|
|
self.assertEqual(len(self.subworkers), 3)
|
|
|
|
def test_subworkers_no_sigterm(self):
|
|
for wrk in self.subworkers:
|
|
self.assertFalse(wrk.signal_handlers)
|
|
|
|
def test_ready_signals_sent(self):
|
|
self.in_signal = 0
|
|
|
|
def handle_signal(sender, *args, **kwargs):
|
|
self.in_signal += 1
|
|
|
|
worker_ready.connect(handle_signal)
|
|
WorkerGroup(self.channel_layer, n_threads=4)
|
|
self.worker.ready()
|
|
self.assertEqual(self.in_signal, 4)
|
|
|
|
def test_sigterm_handler(self):
|
|
threads = []
|
|
for wkr in self.subworkers:
|
|
t = threading.Thread(target=wkr.run)
|
|
t.start()
|
|
threads.append(t)
|
|
self.worker.sigterm_handler(None, None)
|
|
for t in threads:
|
|
t.join()
|