mirror of
https://github.com/django/daphne.git
synced 2025-04-21 01:02:06 +03:00
Rework database channel layer and tests, to use ASGI conformance suite
This commit is contained in:
parent
5fa357e403
commit
a9810014ca
|
@ -1,24 +1,122 @@
|
|||
import datetime
|
||||
import json
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
|
||||
from django.apps.registry import Apps
|
||||
from django.db import DEFAULT_DB_ALIAS, IntegrityError, connections, models
|
||||
from django.utils import six
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.timezone import now
|
||||
|
||||
from .base import BaseChannelBackend
|
||||
|
||||
|
||||
class DatabaseChannelBackend(BaseChannelBackend):
|
||||
class DatabaseChannelLayer(object):
|
||||
"""
|
||||
ORM-backed channel environment. For development use only; it will span
|
||||
multiple processes fine, but it's going to be pretty bad at throughput.
|
||||
ORM-backed ASGI channel layer.
|
||||
|
||||
For development use only; it will span multiple processes fine,
|
||||
but it's going to be pretty bad at throughput. If you're reading this and
|
||||
running it in production, PLEASE STOP.
|
||||
|
||||
Also uses JSON for serialization, as we don't want to make Django depend
|
||||
on msgpack for the built-in backend. The JSON format uses \uffff as first
|
||||
character to signify a byte string rather than a text string.
|
||||
"""
|
||||
|
||||
def __init__(self, routing, expiry=60, db_alias=DEFAULT_DB_ALIAS):
|
||||
super(DatabaseChannelBackend, self).__init__(routing=routing, expiry=expiry)
|
||||
def __init__(self, db_alias=DEFAULT_DB_ALIAS, expiry=60):
|
||||
self.expiry = expiry
|
||||
self.db_alias = db_alias
|
||||
|
||||
### ASGI API ###
|
||||
|
||||
extensions = ["groups", "flush"]
|
||||
|
||||
def send(self, channel, message):
|
||||
# Typecheck
|
||||
assert isinstance(message, dict), "message is not a dict"
|
||||
assert isinstance(channel, six.text_type), "%s is not unicode" % channel
|
||||
# Write message to messages table
|
||||
self.channel_model.objects.create(
|
||||
channel=channel,
|
||||
content=self.serialize(message),
|
||||
expiry=now() + datetime.timedelta(seconds=self.expiry)
|
||||
)
|
||||
|
||||
def receive_many(self, channels, block=False):
|
||||
if not channels:
|
||||
return None, None
|
||||
assert all(isinstance(channel, six.text_type) for channel in channels)
|
||||
# Shuffle channels
|
||||
channels = list(channels)
|
||||
random.shuffle(channels)
|
||||
# Clean out expired messages
|
||||
self._clean_expired()
|
||||
# Get a message from one of our channels
|
||||
while True:
|
||||
message = self.channel_model.objects.filter(channel__in=channels).order_by("id").first()
|
||||
if message:
|
||||
self.channel_model.objects.filter(pk=message.pk).delete()
|
||||
return message.channel, self.deserialize(message.content)
|
||||
else:
|
||||
if block:
|
||||
time.sleep(1)
|
||||
else:
|
||||
return None, None
|
||||
|
||||
def new_channel(self, pattern):
|
||||
assert isinstance(pattern, six.text_type)
|
||||
# Keep making channel names till one isn't present.
|
||||
while True:
|
||||
random_string = "".join(random.choice(string.ascii_letters) for i in range(8))
|
||||
new_name = pattern.replace(b"?", random_string)
|
||||
if not self.channel_model.objects.filter(channel=new_name).exists():
|
||||
return new_name
|
||||
|
||||
### ASGI Group extension ###
|
||||
|
||||
def group_add(self, group, channel, expiry=None):
|
||||
"""
|
||||
Adds the channel to the named group for at least 'expiry'
|
||||
seconds (expiry defaults to message expiry if not provided).
|
||||
"""
|
||||
self.group_model.objects.update_or_create(
|
||||
group=group,
|
||||
channel=channel,
|
||||
defaults={"expiry": now() + datetime.timedelta(seconds=expiry or self.expiry)},
|
||||
)
|
||||
|
||||
def group_discard(self, group, channel):
|
||||
"""
|
||||
Removes the channel from the named group if it is in the group;
|
||||
does nothing otherwise (does not error)
|
||||
"""
|
||||
self.group_model.objects.filter(group=group, channel=channel).delete()
|
||||
|
||||
def send_group(self, group, message):
|
||||
"""
|
||||
Sends a message to the entire group.
|
||||
"""
|
||||
self._clean_expired()
|
||||
for channel in self.group_model.objects.filter(group=group).values_list("channel", flat=True):
|
||||
self.send(channel, message)
|
||||
|
||||
### ASGI Flush extension ###
|
||||
|
||||
def flush(self):
|
||||
self.channel_model.objects.all().delete()
|
||||
self.group_model.objects.all().delete()
|
||||
|
||||
### Serialization ###
|
||||
|
||||
def serialize(self, message):
|
||||
return json.dumps(message)
|
||||
|
||||
def deserialize(self, message):
|
||||
return json.loads(message)
|
||||
|
||||
### Database state mgmt ###
|
||||
|
||||
@property
|
||||
def connection(self):
|
||||
"""
|
||||
|
@ -72,46 +170,6 @@ class DatabaseChannelBackend(BaseChannelBackend):
|
|||
editor.create_model(Group)
|
||||
return Group
|
||||
|
||||
@cached_property
|
||||
def lock_model(self):
|
||||
"""
|
||||
Initialises a new model to store groups; not done as part of a
|
||||
models.py as we don't want to make it for most installs.
|
||||
"""
|
||||
# Make the model class
|
||||
class Lock(models.Model):
|
||||
channel = models.CharField(max_length=200, unique=True)
|
||||
expiry = models.DateTimeField(db_index=True)
|
||||
|
||||
class Meta:
|
||||
apps = Apps()
|
||||
app_label = "channels"
|
||||
db_table = "django_channel_locks"
|
||||
# Ensure its table exists
|
||||
if Lock._meta.db_table not in self.connection.introspection.table_names(self.connection.cursor()):
|
||||
with self.connection.schema_editor() as editor:
|
||||
editor.create_model(Lock)
|
||||
return Lock
|
||||
|
||||
def send(self, channel, message):
|
||||
self.channel_model.objects.create(
|
||||
channel=channel,
|
||||
content=json.dumps(message),
|
||||
expiry=now() + datetime.timedelta(seconds=self.expiry)
|
||||
)
|
||||
|
||||
def receive_many(self, channels):
|
||||
if not channels:
|
||||
raise ValueError("Cannot receive on empty channel list!")
|
||||
self._clean_expired()
|
||||
# Get a message from one of our channels
|
||||
message = self.channel_model.objects.filter(channel__in=channels).order_by("id").first()
|
||||
if message:
|
||||
self.channel_model.objects.filter(pk=message.pk).delete()
|
||||
return message.channel, json.loads(message.content)
|
||||
else:
|
||||
return None, None
|
||||
|
||||
def _clean_expired(self):
|
||||
"""
|
||||
Cleans out expired groups and messages.
|
||||
|
@ -119,57 +177,6 @@ class DatabaseChannelBackend(BaseChannelBackend):
|
|||
# Include a 10-second grace period because that solves some clock sync
|
||||
self.channel_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete()
|
||||
self.group_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete()
|
||||
self.lock_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete()
|
||||
|
||||
def group_add(self, group, channel, expiry=None):
|
||||
"""
|
||||
Adds the channel to the named group for at least 'expiry'
|
||||
seconds (expiry defaults to message expiry if not provided).
|
||||
"""
|
||||
self.group_model.objects.update_or_create(
|
||||
group=group,
|
||||
channel=channel,
|
||||
defaults={"expiry": now() + datetime.timedelta(seconds=expiry or self.expiry)},
|
||||
)
|
||||
|
||||
def group_discard(self, group, channel):
|
||||
"""
|
||||
Removes the channel from the named group if it is in the group;
|
||||
does nothing otherwise (does not error)
|
||||
"""
|
||||
self.group_model.objects.filter(group=group, channel=channel).delete()
|
||||
|
||||
def group_channels(self, group):
|
||||
"""
|
||||
Returns an iterable of all channels in the group.
|
||||
"""
|
||||
self._clean_expired()
|
||||
return list(self.group_model.objects.filter(group=group).values_list("channel", flat=True))
|
||||
|
||||
def lock_channel(self, channel, expiry=None):
|
||||
"""
|
||||
Attempts to get a lock on the named channel. Returns True if lock
|
||||
obtained, False if lock not obtained.
|
||||
"""
|
||||
# We rely on the UNIQUE constraint for only-one-thread-wins on locks
|
||||
try:
|
||||
self.lock_model.objects.create(
|
||||
channel=channel,
|
||||
expiry=now() + datetime.timedelta(seconds=expiry or self.expiry),
|
||||
)
|
||||
except IntegrityError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def unlock_channel(self, channel):
|
||||
"""
|
||||
Unlocks the named channel. Always succeeds.
|
||||
"""
|
||||
self.lock_model.objects.filter(channel=channel).delete()
|
||||
|
||||
def __str__(self):
|
||||
return "%s(alias=%s)" % (self.__class__.__name__, self.connection.alias)
|
||||
|
||||
def flush(self):
|
||||
pass
|
|
@ -1,95 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from ..backends.database import DatabaseChannelBackend
|
||||
from ..backends.memory import InMemoryChannelBackend
|
||||
from ..backends.redis_py import RedisChannelBackend
|
||||
|
||||
|
||||
class MemoryBackendTests(TestCase):
|
||||
|
||||
backend_class = InMemoryChannelBackend
|
||||
|
||||
def setUp(self):
|
||||
self.backend = self.backend_class(routing={})
|
||||
self.backend.flush()
|
||||
|
||||
def test_send_recv(self):
|
||||
"""
|
||||
Tests that channels can send and receive messages.
|
||||
"""
|
||||
self.backend.send("test", {"value": "blue"})
|
||||
self.backend.send("test", {"value": "green"})
|
||||
self.backend.send("test2", {"value": "red"})
|
||||
# Get just one first
|
||||
channel, message = self.backend.receive_many(["test"])
|
||||
self.assertEqual(channel, "test")
|
||||
self.assertEqual(message, {"value": "blue"})
|
||||
# And the second
|
||||
channel, message = self.backend.receive_many(["test"])
|
||||
self.assertEqual(channel, "test")
|
||||
self.assertEqual(message, {"value": "green"})
|
||||
# And the other channel with multi select
|
||||
channel, message = self.backend.receive_many(["test", "test2"])
|
||||
self.assertEqual(channel, "test2")
|
||||
self.assertEqual(message, {"value": "red"})
|
||||
|
||||
def test_message_expiry(self):
|
||||
self.backend = self.backend_class(routing={}, expiry=-100)
|
||||
self.backend.send("test", {"value": "blue"})
|
||||
channel, message = self.backend.receive_many(["test"])
|
||||
self.assertIs(channel, None)
|
||||
self.assertIs(message, None)
|
||||
|
||||
def test_groups(self):
|
||||
"""
|
||||
Tests that group addition and removal and listing works
|
||||
"""
|
||||
self.backend.group_add("tgroup", "test")
|
||||
self.backend.group_add("tgroup", "test2€")
|
||||
self.backend.group_add("tgroup2", "test3")
|
||||
self.assertEqual(
|
||||
set(self.backend.group_channels("tgroup")),
|
||||
{"test", "test2€"},
|
||||
)
|
||||
self.backend.group_discard("tgroup", "test2€")
|
||||
self.backend.group_discard("tgroup", "test2€")
|
||||
self.assertEqual(
|
||||
list(self.backend.group_channels("tgroup")),
|
||||
["test"],
|
||||
)
|
||||
|
||||
def test_group_send(self):
|
||||
"""
|
||||
Tests sending to groups.
|
||||
"""
|
||||
self.backend.group_add("tgroup", "test")
|
||||
self.backend.group_add("tgroup", "test2")
|
||||
self.backend.send_group("tgroup", {"value": "orange"})
|
||||
channel, message = self.backend.receive_many(["test"])
|
||||
self.assertEqual(channel, "test")
|
||||
self.assertEqual(message, {"value": "orange"})
|
||||
channel, message = self.backend.receive_many(["test2"])
|
||||
self.assertEqual(channel, "test2")
|
||||
self.assertEqual(message, {"value": "orange"})
|
||||
|
||||
def test_group_expiry(self):
|
||||
self.backend = self.backend_class(routing={}, expiry=-100)
|
||||
self.backend.group_add("tgroup", "test")
|
||||
self.backend.group_add("tgroup", "test2")
|
||||
self.assertEqual(
|
||||
list(self.backend.group_channels("tgroup")),
|
||||
[],
|
||||
)
|
||||
|
||||
|
||||
class RedisBackendTests(MemoryBackendTests):
|
||||
|
||||
backend_class = RedisChannelBackend
|
||||
|
||||
|
||||
class DatabaseBackendTests(MemoryBackendTests):
|
||||
|
||||
backend_class = DatabaseChannelBackend
|
6
channels/tests/test_database_layer.py
Normal file
6
channels/tests/test_database_layer.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from __future__ import unicode_literals
|
||||
from asgiref.conformance import make_tests
|
||||
from ..database_layer import DatabaseChannelLayer
|
||||
|
||||
channel_layer = DatabaseChannelLayer(expiry=1)
|
||||
DatabaseLayerTests = make_tests(channel_layer, expiry_delay=1.1)
|
|
@ -1,53 +0,0 @@
|
|||
from django.test import TestCase
|
||||
|
||||
from channels.interfaces.websocket_autobahn import get_protocol
|
||||
|
||||
try:
|
||||
from unittest import mock
|
||||
except ImportError:
|
||||
import mock
|
||||
|
||||
|
||||
def generate_connection_request(path, params, headers):
|
||||
request = mock.Mock()
|
||||
request.path = path
|
||||
request.params = params
|
||||
request.headers = headers
|
||||
return request
|
||||
|
||||
|
||||
class WebsocketAutobahnInterfaceProtocolTestCase(TestCase):
|
||||
def test_on_connect_cookie(self):
|
||||
protocol = get_protocol(object)()
|
||||
session = "123cat"
|
||||
cookie = "somethingelse=test; sessionid={0}".format(session)
|
||||
headers = {
|
||||
"cookie": cookie
|
||||
}
|
||||
|
||||
test_request = generate_connection_request("path", {}, headers)
|
||||
protocol.onConnect(test_request)
|
||||
self.assertEqual(session, protocol.request_info["cookies"]["sessionid"])
|
||||
|
||||
def test_on_connect_no_cookie(self):
|
||||
protocol = get_protocol(object)()
|
||||
test_request = generate_connection_request("path", {}, {})
|
||||
protocol.onConnect(test_request)
|
||||
self.assertEqual({}, protocol.request_info["cookies"])
|
||||
|
||||
def test_on_connect_params(self):
|
||||
protocol = get_protocol(object)()
|
||||
params = {
|
||||
"session_key": ["123cat"]
|
||||
}
|
||||
|
||||
test_request = generate_connection_request("path", params, {})
|
||||
protocol.onConnect(test_request)
|
||||
self.assertEqual(params, protocol.request_info["get"])
|
||||
|
||||
def test_on_connect_path(self):
|
||||
protocol = get_protocol(object)()
|
||||
path = "path"
|
||||
test_request = generate_connection_request(path, {}, {})
|
||||
protocol.onConnect(test_request)
|
||||
self.assertEqual(path, protocol.request_info["path"])
|
Loading…
Reference in New Issue
Block a user