mirror of
https://github.com/django/daphne.git
synced 2025-05-26 00:33:12 +03:00
230 lines
7.7 KiB
Python
230 lines
7.7 KiB
Python
from __future__ import unicode_literals
|
|
|
|
import copy
|
|
import random
|
|
import string
|
|
from functools import wraps
|
|
|
|
from asgiref.inmemory import ChannelLayer as InMemoryChannelLayer
|
|
from django.db import close_old_connections
|
|
from django.test.testcases import TestCase, TransactionTestCase
|
|
|
|
from .. import DEFAULT_CHANNEL_LAYER
|
|
from ..asgi import ChannelLayerWrapper, channel_layers
|
|
from ..channel import Group
|
|
from ..message import Message
|
|
from ..routing import Router, include
|
|
from ..signals import consumer_finished, consumer_started
|
|
|
|
|
|
class ChannelTestCaseMixin(object):
|
|
"""
|
|
TestCase subclass that provides easy methods for testing channels using
|
|
an in-memory backend to capture messages, and assertion methods to allow
|
|
checking of what was sent.
|
|
|
|
Inherits from TestCase, so provides per-test transactions as long as the
|
|
database backend supports it.
|
|
"""
|
|
|
|
# Customizable so users can test multi-layer setups
|
|
test_channel_aliases = [DEFAULT_CHANNEL_LAYER]
|
|
|
|
def _pre_setup(self):
|
|
"""
|
|
Initialises in memory channel layer for the duration of the test
|
|
"""
|
|
super(ChannelTestCaseMixin, self)._pre_setup()
|
|
self._old_layers = {}
|
|
for alias in self.test_channel_aliases:
|
|
# Swap in an in memory layer wrapper and keep the old one around
|
|
self._old_layers[alias] = channel_layers.set(
|
|
alias,
|
|
ChannelLayerWrapper(
|
|
InMemoryChannelLayer(),
|
|
alias,
|
|
channel_layers[alias].routing[:],
|
|
)
|
|
)
|
|
|
|
def _post_teardown(self):
|
|
"""
|
|
Undoes the channel rerouting
|
|
"""
|
|
for alias in self.test_channel_aliases:
|
|
# Swap in an in memory layer wrapper and keep the old one around
|
|
channel_layers.set(alias, self._old_layers[alias])
|
|
del self._old_layers
|
|
super(ChannelTestCaseMixin, self)._post_teardown()
|
|
|
|
def get_next_message(self, channel, alias=DEFAULT_CHANNEL_LAYER, require=False):
|
|
"""
|
|
Gets the next message that was sent to the channel during the test,
|
|
or None if no message is available.
|
|
|
|
If require is true, will fail the test if no message is received.
|
|
"""
|
|
recv_channel, content = channel_layers[alias].receive_many([channel])
|
|
if recv_channel is None:
|
|
if require:
|
|
self.fail("Expected a message on channel %s, got none" % channel)
|
|
else:
|
|
return None
|
|
return Message(content, recv_channel, channel_layers[alias])
|
|
|
|
|
|
class ChannelTestCase(ChannelTestCaseMixin, TestCase):
|
|
pass
|
|
|
|
|
|
class TransactionChannelTestCase(ChannelTestCaseMixin, TransactionTestCase):
|
|
pass
|
|
|
|
|
|
class Client(object):
|
|
"""
|
|
Channel client abstraction that provides easy methods for testing full live cycle of message in channels
|
|
with determined the reply channel
|
|
"""
|
|
|
|
def __init__(self, alias=DEFAULT_CHANNEL_LAYER):
|
|
self.reply_channel = alias + ''.join([random.choice(string.ascii_letters) for _ in range(5)])
|
|
self.alias = alias
|
|
|
|
@property
|
|
def channel_layer(self):
|
|
"""Channel layer as lazy property"""
|
|
return channel_layers[self.alias]
|
|
|
|
def get_next_message(self, channel):
|
|
"""
|
|
Gets the next message that was sent to the channel during the test,
|
|
or None if no message is available.
|
|
"""
|
|
recv_channel, content = channel_layers[self.alias].receive_many([channel])
|
|
if recv_channel is None:
|
|
return
|
|
return Message(content, recv_channel, channel_layers[self.alias])
|
|
|
|
def get_consumer_by_channel(self, channel):
|
|
message = Message({'text': ''}, channel, self.channel_layer)
|
|
match = self.channel_layer.router.match(message)
|
|
if match:
|
|
consumer, kwargs = match
|
|
return consumer
|
|
|
|
def send(self, to, content={}):
|
|
"""
|
|
Send a message to a channel.
|
|
Adds reply_channel name to the message.
|
|
"""
|
|
content = copy.deepcopy(content)
|
|
content.setdefault('reply_channel', self.reply_channel)
|
|
self.channel_layer.send(to, content)
|
|
|
|
def consume(self, channel, fail_on_none=True):
|
|
"""
|
|
Get next message for channel name and run appointed consumer
|
|
"""
|
|
message = self.get_next_message(channel)
|
|
if message:
|
|
match = self.channel_layer.router.match(message)
|
|
if match:
|
|
consumer, kwargs = match
|
|
try:
|
|
consumer_started.send(sender=self.__class__)
|
|
return consumer(message, **kwargs)
|
|
finally:
|
|
# Copy Django's workaround so we don't actually close DB conns
|
|
consumer_finished.disconnect(close_old_connections)
|
|
consumer_finished.send(sender=self.__class__)
|
|
consumer_finished.connect(close_old_connections)
|
|
elif fail_on_none:
|
|
raise AssertionError("Can't find consumer for message %s" % message)
|
|
elif fail_on_none:
|
|
raise AssertionError("No message for channel %s" % channel)
|
|
|
|
def send_and_consume(self, channel, content={}, fail_on_none=True):
|
|
"""
|
|
Reproduce full life cycle of the message
|
|
"""
|
|
self.send(channel, content)
|
|
return self.consume(channel, fail_on_none=fail_on_none)
|
|
|
|
def receive(self):
|
|
"""
|
|
Get content of next message for reply channel if message exists
|
|
"""
|
|
message = self.get_next_message(self.reply_channel)
|
|
if message:
|
|
return message.content
|
|
|
|
def join_group(self, group_name):
|
|
Group(group_name).add(self.reply_channel)
|
|
|
|
|
|
class apply_routes(object):
|
|
"""
|
|
Decorator/ContextManager for rewrite layers routes in context.
|
|
Helpful for testing group routes/consumers as isolated application
|
|
|
|
The applying routes can be list of instances of Route or list of this lists
|
|
"""
|
|
|
|
def __init__(self, routes, aliases=[DEFAULT_CHANNEL_LAYER]):
|
|
self._aliases = aliases
|
|
self.routes = routes
|
|
self._old_routing = {}
|
|
|
|
def enter(self):
|
|
"""
|
|
Store old routes and apply new one
|
|
"""
|
|
for alias in self._aliases:
|
|
channel_layer = channel_layers[DEFAULT_CHANNEL_LAYER]
|
|
self._old_routing[alias] = channel_layer.routing
|
|
if isinstance(self.routes, (list, tuple)):
|
|
if isinstance(self.routes[0], (list, tuple)):
|
|
routes = list(map(include, self.routes))
|
|
else:
|
|
routes = self.routes
|
|
else:
|
|
routes = [self.routes]
|
|
channel_layer.routing = routes
|
|
channel_layer.router = Router(routes)
|
|
|
|
def exit(self, exc_type=None, exc_val=None, exc_tb=None):
|
|
"""
|
|
Undoes rerouting
|
|
"""
|
|
for alias in self._aliases:
|
|
channel_layer = channel_layers[DEFAULT_CHANNEL_LAYER]
|
|
channel_layer.routing = self._old_routing[alias]
|
|
channel_layer.router = Router(self._old_routing[alias])
|
|
|
|
__enter__ = enter
|
|
__exit__ = exit
|
|
|
|
def __call__(self, test_func):
|
|
if isinstance(test_func, type):
|
|
old_setup = test_func.setUp
|
|
old_teardown = test_func.tearDown
|
|
|
|
def new_setup(this):
|
|
self.enter()
|
|
old_setup(this)
|
|
|
|
def new_teardown(this):
|
|
self.exit()
|
|
old_teardown(this)
|
|
|
|
test_func.setUp = new_setup
|
|
test_func.tearDown = new_teardown
|
|
return test_func
|
|
else:
|
|
@wraps(test_func)
|
|
def inner(*args, **kwargs):
|
|
with self:
|
|
return test_func(*args, **kwargs)
|
|
return inner
|