From a9810014caaf357d75545d1d7ebbaf88f2341307 Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Sun, 17 Jan 2016 14:28:01 -0800 Subject: [PATCH] Rework database channel layer and tests, to use ASGI conformance suite --- channels/backends/__init__.py | 0 .../database.py => database_layer.py} | 203 +++++++++--------- channels/tests/test_backends.py | 95 -------- channels/tests/test_database_layer.py | 6 + channels/tests/test_interfaces.py | 53 ----- 5 files changed, 111 insertions(+), 246 deletions(-) delete mode 100644 channels/backends/__init__.py rename channels/{backends/database.py => database_layer.py} (59%) delete mode 100644 channels/tests/test_backends.py create mode 100644 channels/tests/test_database_layer.py delete mode 100644 channels/tests/test_interfaces.py diff --git a/channels/backends/__init__.py b/channels/backends/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/channels/backends/database.py b/channels/database_layer.py similarity index 59% rename from channels/backends/database.py rename to channels/database_layer.py index d0b7f69..ea9e5e1 100644 --- a/channels/backends/database.py +++ b/channels/database_layer.py @@ -1,24 +1,122 @@ import datetime import json +import random +import string +import time from django.apps.registry import Apps from django.db import DEFAULT_DB_ALIAS, IntegrityError, connections, models +from django.utils import six from django.utils.functional import cached_property from django.utils.timezone import now -from .base import BaseChannelBackend - -class DatabaseChannelBackend(BaseChannelBackend): +class DatabaseChannelLayer(object): """ - ORM-backed channel environment. For development use only; it will span - multiple processes fine, but it's going to be pretty bad at throughput. + ORM-backed ASGI channel layer. + + For development use only; it will span multiple processes fine, + but it's going to be pretty bad at throughput. If you're reading this and + running it in production, PLEASE STOP. + + Also uses JSON for serialization, as we don't want to make Django depend + on msgpack for the built-in backend. The JSON format uses \uffff as first + character to signify a byte string rather than a text string. """ - def __init__(self, routing, expiry=60, db_alias=DEFAULT_DB_ALIAS): - super(DatabaseChannelBackend, self).__init__(routing=routing, expiry=expiry) + def __init__(self, db_alias=DEFAULT_DB_ALIAS, expiry=60): + self.expiry = expiry self.db_alias = db_alias + ### ASGI API ### + + extensions = ["groups", "flush"] + + def send(self, channel, message): + # Typecheck + assert isinstance(message, dict), "message is not a dict" + assert isinstance(channel, six.text_type), "%s is not unicode" % channel + # Write message to messages table + self.channel_model.objects.create( + channel=channel, + content=self.serialize(message), + expiry=now() + datetime.timedelta(seconds=self.expiry) + ) + + def receive_many(self, channels, block=False): + if not channels: + return None, None + assert all(isinstance(channel, six.text_type) for channel in channels) + # Shuffle channels + channels = list(channels) + random.shuffle(channels) + # Clean out expired messages + self._clean_expired() + # Get a message from one of our channels + while True: + message = self.channel_model.objects.filter(channel__in=channels).order_by("id").first() + if message: + self.channel_model.objects.filter(pk=message.pk).delete() + return message.channel, self.deserialize(message.content) + else: + if block: + time.sleep(1) + else: + return None, None + + def new_channel(self, pattern): + assert isinstance(pattern, six.text_type) + # Keep making channel names till one isn't present. + while True: + random_string = "".join(random.choice(string.ascii_letters) for i in range(8)) + new_name = pattern.replace(b"?", random_string) + if not self.channel_model.objects.filter(channel=new_name).exists(): + return new_name + + ### ASGI Group extension ### + + def group_add(self, group, channel, expiry=None): + """ + Adds the channel to the named group for at least 'expiry' + seconds (expiry defaults to message expiry if not provided). + """ + self.group_model.objects.update_or_create( + group=group, + channel=channel, + defaults={"expiry": now() + datetime.timedelta(seconds=expiry or self.expiry)}, + ) + + def group_discard(self, group, channel): + """ + Removes the channel from the named group if it is in the group; + does nothing otherwise (does not error) + """ + self.group_model.objects.filter(group=group, channel=channel).delete() + + def send_group(self, group, message): + """ + Sends a message to the entire group. + """ + self._clean_expired() + for channel in self.group_model.objects.filter(group=group).values_list("channel", flat=True): + self.send(channel, message) + + ### ASGI Flush extension ### + + def flush(self): + self.channel_model.objects.all().delete() + self.group_model.objects.all().delete() + + ### Serialization ### + + def serialize(self, message): + return json.dumps(message) + + def deserialize(self, message): + return json.loads(message) + + ### Database state mgmt ### + @property def connection(self): """ @@ -72,46 +170,6 @@ class DatabaseChannelBackend(BaseChannelBackend): editor.create_model(Group) return Group - @cached_property - def lock_model(self): - """ - Initialises a new model to store groups; not done as part of a - models.py as we don't want to make it for most installs. - """ - # Make the model class - class Lock(models.Model): - channel = models.CharField(max_length=200, unique=True) - expiry = models.DateTimeField(db_index=True) - - class Meta: - apps = Apps() - app_label = "channels" - db_table = "django_channel_locks" - # Ensure its table exists - if Lock._meta.db_table not in self.connection.introspection.table_names(self.connection.cursor()): - with self.connection.schema_editor() as editor: - editor.create_model(Lock) - return Lock - - def send(self, channel, message): - self.channel_model.objects.create( - channel=channel, - content=json.dumps(message), - expiry=now() + datetime.timedelta(seconds=self.expiry) - ) - - def receive_many(self, channels): - if not channels: - raise ValueError("Cannot receive on empty channel list!") - self._clean_expired() - # Get a message from one of our channels - message = self.channel_model.objects.filter(channel__in=channels).order_by("id").first() - if message: - self.channel_model.objects.filter(pk=message.pk).delete() - return message.channel, json.loads(message.content) - else: - return None, None - def _clean_expired(self): """ Cleans out expired groups and messages. @@ -119,57 +177,6 @@ class DatabaseChannelBackend(BaseChannelBackend): # Include a 10-second grace period because that solves some clock sync self.channel_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete() self.group_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete() - self.lock_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete() - - def group_add(self, group, channel, expiry=None): - """ - Adds the channel to the named group for at least 'expiry' - seconds (expiry defaults to message expiry if not provided). - """ - self.group_model.objects.update_or_create( - group=group, - channel=channel, - defaults={"expiry": now() + datetime.timedelta(seconds=expiry or self.expiry)}, - ) - - def group_discard(self, group, channel): - """ - Removes the channel from the named group if it is in the group; - does nothing otherwise (does not error) - """ - self.group_model.objects.filter(group=group, channel=channel).delete() - - def group_channels(self, group): - """ - Returns an iterable of all channels in the group. - """ - self._clean_expired() - return list(self.group_model.objects.filter(group=group).values_list("channel", flat=True)) - - def lock_channel(self, channel, expiry=None): - """ - Attempts to get a lock on the named channel. Returns True if lock - obtained, False if lock not obtained. - """ - # We rely on the UNIQUE constraint for only-one-thread-wins on locks - try: - self.lock_model.objects.create( - channel=channel, - expiry=now() + datetime.timedelta(seconds=expiry or self.expiry), - ) - except IntegrityError: - return False - else: - return True - - def unlock_channel(self, channel): - """ - Unlocks the named channel. Always succeeds. - """ - self.lock_model.objects.filter(channel=channel).delete() def __str__(self): return "%s(alias=%s)" % (self.__class__.__name__, self.connection.alias) - - def flush(self): - pass diff --git a/channels/tests/test_backends.py b/channels/tests/test_backends.py deleted file mode 100644 index b25750f..0000000 --- a/channels/tests/test_backends.py +++ /dev/null @@ -1,95 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals - -from django.test import TestCase - -from ..backends.database import DatabaseChannelBackend -from ..backends.memory import InMemoryChannelBackend -from ..backends.redis_py import RedisChannelBackend - - -class MemoryBackendTests(TestCase): - - backend_class = InMemoryChannelBackend - - def setUp(self): - self.backend = self.backend_class(routing={}) - self.backend.flush() - - def test_send_recv(self): - """ - Tests that channels can send and receive messages. - """ - self.backend.send("test", {"value": "blue"}) - self.backend.send("test", {"value": "green"}) - self.backend.send("test2", {"value": "red"}) - # Get just one first - channel, message = self.backend.receive_many(["test"]) - self.assertEqual(channel, "test") - self.assertEqual(message, {"value": "blue"}) - # And the second - channel, message = self.backend.receive_many(["test"]) - self.assertEqual(channel, "test") - self.assertEqual(message, {"value": "green"}) - # And the other channel with multi select - channel, message = self.backend.receive_many(["test", "test2"]) - self.assertEqual(channel, "test2") - self.assertEqual(message, {"value": "red"}) - - def test_message_expiry(self): - self.backend = self.backend_class(routing={}, expiry=-100) - self.backend.send("test", {"value": "blue"}) - channel, message = self.backend.receive_many(["test"]) - self.assertIs(channel, None) - self.assertIs(message, None) - - def test_groups(self): - """ - Tests that group addition and removal and listing works - """ - self.backend.group_add("tgroup", "test") - self.backend.group_add("tgroup", "test2€") - self.backend.group_add("tgroup2", "test3") - self.assertEqual( - set(self.backend.group_channels("tgroup")), - {"test", "test2€"}, - ) - self.backend.group_discard("tgroup", "test2€") - self.backend.group_discard("tgroup", "test2€") - self.assertEqual( - list(self.backend.group_channels("tgroup")), - ["test"], - ) - - def test_group_send(self): - """ - Tests sending to groups. - """ - self.backend.group_add("tgroup", "test") - self.backend.group_add("tgroup", "test2") - self.backend.send_group("tgroup", {"value": "orange"}) - channel, message = self.backend.receive_many(["test"]) - self.assertEqual(channel, "test") - self.assertEqual(message, {"value": "orange"}) - channel, message = self.backend.receive_many(["test2"]) - self.assertEqual(channel, "test2") - self.assertEqual(message, {"value": "orange"}) - - def test_group_expiry(self): - self.backend = self.backend_class(routing={}, expiry=-100) - self.backend.group_add("tgroup", "test") - self.backend.group_add("tgroup", "test2") - self.assertEqual( - list(self.backend.group_channels("tgroup")), - [], - ) - - -class RedisBackendTests(MemoryBackendTests): - - backend_class = RedisChannelBackend - - -class DatabaseBackendTests(MemoryBackendTests): - - backend_class = DatabaseChannelBackend diff --git a/channels/tests/test_database_layer.py b/channels/tests/test_database_layer.py new file mode 100644 index 0000000..146af36 --- /dev/null +++ b/channels/tests/test_database_layer.py @@ -0,0 +1,6 @@ +from __future__ import unicode_literals +from asgiref.conformance import make_tests +from ..database_layer import DatabaseChannelLayer + +channel_layer = DatabaseChannelLayer(expiry=1) +DatabaseLayerTests = make_tests(channel_layer, expiry_delay=1.1) diff --git a/channels/tests/test_interfaces.py b/channels/tests/test_interfaces.py deleted file mode 100644 index fd830d7..0000000 --- a/channels/tests/test_interfaces.py +++ /dev/null @@ -1,53 +0,0 @@ -from django.test import TestCase - -from channels.interfaces.websocket_autobahn import get_protocol - -try: - from unittest import mock -except ImportError: - import mock - - -def generate_connection_request(path, params, headers): - request = mock.Mock() - request.path = path - request.params = params - request.headers = headers - return request - - -class WebsocketAutobahnInterfaceProtocolTestCase(TestCase): - def test_on_connect_cookie(self): - protocol = get_protocol(object)() - session = "123cat" - cookie = "somethingelse=test; sessionid={0}".format(session) - headers = { - "cookie": cookie - } - - test_request = generate_connection_request("path", {}, headers) - protocol.onConnect(test_request) - self.assertEqual(session, protocol.request_info["cookies"]["sessionid"]) - - def test_on_connect_no_cookie(self): - protocol = get_protocol(object)() - test_request = generate_connection_request("path", {}, {}) - protocol.onConnect(test_request) - self.assertEqual({}, protocol.request_info["cookies"]) - - def test_on_connect_params(self): - protocol = get_protocol(object)() - params = { - "session_key": ["123cat"] - } - - test_request = generate_connection_request("path", params, {}) - protocol.onConnect(test_request) - self.assertEqual(params, protocol.request_info["get"]) - - def test_on_connect_path(self): - protocol = get_protocol(object)() - path = "path" - test_request = generate_connection_request(path, {}, {}) - protocol.onConnect(test_request) - self.assertEqual(path, protocol.request_info["path"])