Rework database channel layer and tests, to use ASGI conformance suite

This commit is contained in:
Andrew Godwin 2016-01-17 14:28:01 -08:00
parent 5fa357e403
commit a9810014ca
5 changed files with 111 additions and 246 deletions

View File

@ -1,24 +1,122 @@
import datetime
import json
import random
import string
import time
from django.apps.registry import Apps
from django.db import DEFAULT_DB_ALIAS, IntegrityError, connections, models
from django.utils import six
from django.utils.functional import cached_property
from django.utils.timezone import now
from .base import BaseChannelBackend
class DatabaseChannelBackend(BaseChannelBackend):
class DatabaseChannelLayer(object):
"""
ORM-backed channel environment. For development use only; it will span
multiple processes fine, but it's going to be pretty bad at throughput.
ORM-backed ASGI channel layer.
For development use only; it will span multiple processes fine,
but it's going to be pretty bad at throughput. If you're reading this and
running it in production, PLEASE STOP.
Also uses JSON for serialization, as we don't want to make Django depend
on msgpack for the built-in backend. The JSON format uses \uffff as first
character to signify a byte string rather than a text string.
"""
def __init__(self, routing, expiry=60, db_alias=DEFAULT_DB_ALIAS):
super(DatabaseChannelBackend, self).__init__(routing=routing, expiry=expiry)
def __init__(self, db_alias=DEFAULT_DB_ALIAS, expiry=60):
self.expiry = expiry
self.db_alias = db_alias
### ASGI API ###
extensions = ["groups", "flush"]
def send(self, channel, message):
# Typecheck
assert isinstance(message, dict), "message is not a dict"
assert isinstance(channel, six.text_type), "%s is not unicode" % channel
# Write message to messages table
self.channel_model.objects.create(
channel=channel,
content=self.serialize(message),
expiry=now() + datetime.timedelta(seconds=self.expiry)
)
def receive_many(self, channels, block=False):
if not channels:
return None, None
assert all(isinstance(channel, six.text_type) for channel in channels)
# Shuffle channels
channels = list(channels)
random.shuffle(channels)
# Clean out expired messages
self._clean_expired()
# Get a message from one of our channels
while True:
message = self.channel_model.objects.filter(channel__in=channels).order_by("id").first()
if message:
self.channel_model.objects.filter(pk=message.pk).delete()
return message.channel, self.deserialize(message.content)
else:
if block:
time.sleep(1)
else:
return None, None
def new_channel(self, pattern):
assert isinstance(pattern, six.text_type)
# Keep making channel names till one isn't present.
while True:
random_string = "".join(random.choice(string.ascii_letters) for i in range(8))
new_name = pattern.replace(b"?", random_string)
if not self.channel_model.objects.filter(channel=new_name).exists():
return new_name
### ASGI Group extension ###
def group_add(self, group, channel, expiry=None):
"""
Adds the channel to the named group for at least 'expiry'
seconds (expiry defaults to message expiry if not provided).
"""
self.group_model.objects.update_or_create(
group=group,
channel=channel,
defaults={"expiry": now() + datetime.timedelta(seconds=expiry or self.expiry)},
)
def group_discard(self, group, channel):
"""
Removes the channel from the named group if it is in the group;
does nothing otherwise (does not error)
"""
self.group_model.objects.filter(group=group, channel=channel).delete()
def send_group(self, group, message):
"""
Sends a message to the entire group.
"""
self._clean_expired()
for channel in self.group_model.objects.filter(group=group).values_list("channel", flat=True):
self.send(channel, message)
### ASGI Flush extension ###
def flush(self):
self.channel_model.objects.all().delete()
self.group_model.objects.all().delete()
### Serialization ###
def serialize(self, message):
return json.dumps(message)
def deserialize(self, message):
return json.loads(message)
### Database state mgmt ###
@property
def connection(self):
"""
@ -72,46 +170,6 @@ class DatabaseChannelBackend(BaseChannelBackend):
editor.create_model(Group)
return Group
@cached_property
def lock_model(self):
"""
Initialises a new model to store groups; not done as part of a
models.py as we don't want to make it for most installs.
"""
# Make the model class
class Lock(models.Model):
channel = models.CharField(max_length=200, unique=True)
expiry = models.DateTimeField(db_index=True)
class Meta:
apps = Apps()
app_label = "channels"
db_table = "django_channel_locks"
# Ensure its table exists
if Lock._meta.db_table not in self.connection.introspection.table_names(self.connection.cursor()):
with self.connection.schema_editor() as editor:
editor.create_model(Lock)
return Lock
def send(self, channel, message):
self.channel_model.objects.create(
channel=channel,
content=json.dumps(message),
expiry=now() + datetime.timedelta(seconds=self.expiry)
)
def receive_many(self, channels):
if not channels:
raise ValueError("Cannot receive on empty channel list!")
self._clean_expired()
# Get a message from one of our channels
message = self.channel_model.objects.filter(channel__in=channels).order_by("id").first()
if message:
self.channel_model.objects.filter(pk=message.pk).delete()
return message.channel, json.loads(message.content)
else:
return None, None
def _clean_expired(self):
"""
Cleans out expired groups and messages.
@ -119,57 +177,6 @@ class DatabaseChannelBackend(BaseChannelBackend):
# Include a 10-second grace period because that solves some clock sync
self.channel_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete()
self.group_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete()
self.lock_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete()
def group_add(self, group, channel, expiry=None):
"""
Adds the channel to the named group for at least 'expiry'
seconds (expiry defaults to message expiry if not provided).
"""
self.group_model.objects.update_or_create(
group=group,
channel=channel,
defaults={"expiry": now() + datetime.timedelta(seconds=expiry or self.expiry)},
)
def group_discard(self, group, channel):
"""
Removes the channel from the named group if it is in the group;
does nothing otherwise (does not error)
"""
self.group_model.objects.filter(group=group, channel=channel).delete()
def group_channels(self, group):
"""
Returns an iterable of all channels in the group.
"""
self._clean_expired()
return list(self.group_model.objects.filter(group=group).values_list("channel", flat=True))
def lock_channel(self, channel, expiry=None):
"""
Attempts to get a lock on the named channel. Returns True if lock
obtained, False if lock not obtained.
"""
# We rely on the UNIQUE constraint for only-one-thread-wins on locks
try:
self.lock_model.objects.create(
channel=channel,
expiry=now() + datetime.timedelta(seconds=expiry or self.expiry),
)
except IntegrityError:
return False
else:
return True
def unlock_channel(self, channel):
"""
Unlocks the named channel. Always succeeds.
"""
self.lock_model.objects.filter(channel=channel).delete()
def __str__(self):
return "%s(alias=%s)" % (self.__class__.__name__, self.connection.alias)
def flush(self):
pass

View File

@ -1,95 +0,0 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from django.test import TestCase
from ..backends.database import DatabaseChannelBackend
from ..backends.memory import InMemoryChannelBackend
from ..backends.redis_py import RedisChannelBackend
class MemoryBackendTests(TestCase):
backend_class = InMemoryChannelBackend
def setUp(self):
self.backend = self.backend_class(routing={})
self.backend.flush()
def test_send_recv(self):
"""
Tests that channels can send and receive messages.
"""
self.backend.send("test", {"value": "blue"})
self.backend.send("test", {"value": "green"})
self.backend.send("test2", {"value": "red"})
# Get just one first
channel, message = self.backend.receive_many(["test"])
self.assertEqual(channel, "test")
self.assertEqual(message, {"value": "blue"})
# And the second
channel, message = self.backend.receive_many(["test"])
self.assertEqual(channel, "test")
self.assertEqual(message, {"value": "green"})
# And the other channel with multi select
channel, message = self.backend.receive_many(["test", "test2"])
self.assertEqual(channel, "test2")
self.assertEqual(message, {"value": "red"})
def test_message_expiry(self):
self.backend = self.backend_class(routing={}, expiry=-100)
self.backend.send("test", {"value": "blue"})
channel, message = self.backend.receive_many(["test"])
self.assertIs(channel, None)
self.assertIs(message, None)
def test_groups(self):
"""
Tests that group addition and removal and listing works
"""
self.backend.group_add("tgroup", "test")
self.backend.group_add("tgroup", "test2€")
self.backend.group_add("tgroup2", "test3")
self.assertEqual(
set(self.backend.group_channels("tgroup")),
{"test", "test2€"},
)
self.backend.group_discard("tgroup", "test2€")
self.backend.group_discard("tgroup", "test2€")
self.assertEqual(
list(self.backend.group_channels("tgroup")),
["test"],
)
def test_group_send(self):
"""
Tests sending to groups.
"""
self.backend.group_add("tgroup", "test")
self.backend.group_add("tgroup", "test2")
self.backend.send_group("tgroup", {"value": "orange"})
channel, message = self.backend.receive_many(["test"])
self.assertEqual(channel, "test")
self.assertEqual(message, {"value": "orange"})
channel, message = self.backend.receive_many(["test2"])
self.assertEqual(channel, "test2")
self.assertEqual(message, {"value": "orange"})
def test_group_expiry(self):
self.backend = self.backend_class(routing={}, expiry=-100)
self.backend.group_add("tgroup", "test")
self.backend.group_add("tgroup", "test2")
self.assertEqual(
list(self.backend.group_channels("tgroup")),
[],
)
class RedisBackendTests(MemoryBackendTests):
backend_class = RedisChannelBackend
class DatabaseBackendTests(MemoryBackendTests):
backend_class = DatabaseChannelBackend

View File

@ -0,0 +1,6 @@
from __future__ import unicode_literals
from asgiref.conformance import make_tests
from ..database_layer import DatabaseChannelLayer
channel_layer = DatabaseChannelLayer(expiry=1)
DatabaseLayerTests = make_tests(channel_layer, expiry_delay=1.1)

View File

@ -1,53 +0,0 @@
from django.test import TestCase
from channels.interfaces.websocket_autobahn import get_protocol
try:
from unittest import mock
except ImportError:
import mock
def generate_connection_request(path, params, headers):
request = mock.Mock()
request.path = path
request.params = params
request.headers = headers
return request
class WebsocketAutobahnInterfaceProtocolTestCase(TestCase):
def test_on_connect_cookie(self):
protocol = get_protocol(object)()
session = "123cat"
cookie = "somethingelse=test; sessionid={0}".format(session)
headers = {
"cookie": cookie
}
test_request = generate_connection_request("path", {}, headers)
protocol.onConnect(test_request)
self.assertEqual(session, protocol.request_info["cookies"]["sessionid"])
def test_on_connect_no_cookie(self):
protocol = get_protocol(object)()
test_request = generate_connection_request("path", {}, {})
protocol.onConnect(test_request)
self.assertEqual({}, protocol.request_info["cookies"])
def test_on_connect_params(self):
protocol = get_protocol(object)()
params = {
"session_key": ["123cat"]
}
test_request = generate_connection_request("path", params, {})
protocol.onConnect(test_request)
self.assertEqual(params, protocol.request_info["get"])
def test_on_connect_path(self):
protocol = get_protocol(object)()
path = "path"
test_request = generate_connection_request(path, {}, {})
protocol.onConnect(test_request)
self.assertEqual(path, protocol.request_info["path"])