mirror of
https://github.com/django/daphne.git
synced 2025-07-04 11:53:06 +03:00
Rework database channel layer and tests, to use ASGI conformance suite
This commit is contained in:
parent
5fa357e403
commit
a9810014ca
|
@ -1,24 +1,122 @@
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
import time
|
||||||
|
|
||||||
from django.apps.registry import Apps
|
from django.apps.registry import Apps
|
||||||
from django.db import DEFAULT_DB_ALIAS, IntegrityError, connections, models
|
from django.db import DEFAULT_DB_ALIAS, IntegrityError, connections, models
|
||||||
|
from django.utils import six
|
||||||
from django.utils.functional import cached_property
|
from django.utils.functional import cached_property
|
||||||
from django.utils.timezone import now
|
from django.utils.timezone import now
|
||||||
|
|
||||||
from .base import BaseChannelBackend
|
|
||||||
|
|
||||||
|
class DatabaseChannelLayer(object):
|
||||||
class DatabaseChannelBackend(BaseChannelBackend):
|
|
||||||
"""
|
"""
|
||||||
ORM-backed channel environment. For development use only; it will span
|
ORM-backed ASGI channel layer.
|
||||||
multiple processes fine, but it's going to be pretty bad at throughput.
|
|
||||||
|
For development use only; it will span multiple processes fine,
|
||||||
|
but it's going to be pretty bad at throughput. If you're reading this and
|
||||||
|
running it in production, PLEASE STOP.
|
||||||
|
|
||||||
|
Also uses JSON for serialization, as we don't want to make Django depend
|
||||||
|
on msgpack for the built-in backend. The JSON format uses \uffff as first
|
||||||
|
character to signify a byte string rather than a text string.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, routing, expiry=60, db_alias=DEFAULT_DB_ALIAS):
|
def __init__(self, db_alias=DEFAULT_DB_ALIAS, expiry=60):
|
||||||
super(DatabaseChannelBackend, self).__init__(routing=routing, expiry=expiry)
|
self.expiry = expiry
|
||||||
self.db_alias = db_alias
|
self.db_alias = db_alias
|
||||||
|
|
||||||
|
### ASGI API ###
|
||||||
|
|
||||||
|
extensions = ["groups", "flush"]
|
||||||
|
|
||||||
|
def send(self, channel, message):
|
||||||
|
# Typecheck
|
||||||
|
assert isinstance(message, dict), "message is not a dict"
|
||||||
|
assert isinstance(channel, six.text_type), "%s is not unicode" % channel
|
||||||
|
# Write message to messages table
|
||||||
|
self.channel_model.objects.create(
|
||||||
|
channel=channel,
|
||||||
|
content=self.serialize(message),
|
||||||
|
expiry=now() + datetime.timedelta(seconds=self.expiry)
|
||||||
|
)
|
||||||
|
|
||||||
|
def receive_many(self, channels, block=False):
|
||||||
|
if not channels:
|
||||||
|
return None, None
|
||||||
|
assert all(isinstance(channel, six.text_type) for channel in channels)
|
||||||
|
# Shuffle channels
|
||||||
|
channels = list(channels)
|
||||||
|
random.shuffle(channels)
|
||||||
|
# Clean out expired messages
|
||||||
|
self._clean_expired()
|
||||||
|
# Get a message from one of our channels
|
||||||
|
while True:
|
||||||
|
message = self.channel_model.objects.filter(channel__in=channels).order_by("id").first()
|
||||||
|
if message:
|
||||||
|
self.channel_model.objects.filter(pk=message.pk).delete()
|
||||||
|
return message.channel, self.deserialize(message.content)
|
||||||
|
else:
|
||||||
|
if block:
|
||||||
|
time.sleep(1)
|
||||||
|
else:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def new_channel(self, pattern):
|
||||||
|
assert isinstance(pattern, six.text_type)
|
||||||
|
# Keep making channel names till one isn't present.
|
||||||
|
while True:
|
||||||
|
random_string = "".join(random.choice(string.ascii_letters) for i in range(8))
|
||||||
|
new_name = pattern.replace(b"?", random_string)
|
||||||
|
if not self.channel_model.objects.filter(channel=new_name).exists():
|
||||||
|
return new_name
|
||||||
|
|
||||||
|
### ASGI Group extension ###
|
||||||
|
|
||||||
|
def group_add(self, group, channel, expiry=None):
|
||||||
|
"""
|
||||||
|
Adds the channel to the named group for at least 'expiry'
|
||||||
|
seconds (expiry defaults to message expiry if not provided).
|
||||||
|
"""
|
||||||
|
self.group_model.objects.update_or_create(
|
||||||
|
group=group,
|
||||||
|
channel=channel,
|
||||||
|
defaults={"expiry": now() + datetime.timedelta(seconds=expiry or self.expiry)},
|
||||||
|
)
|
||||||
|
|
||||||
|
def group_discard(self, group, channel):
|
||||||
|
"""
|
||||||
|
Removes the channel from the named group if it is in the group;
|
||||||
|
does nothing otherwise (does not error)
|
||||||
|
"""
|
||||||
|
self.group_model.objects.filter(group=group, channel=channel).delete()
|
||||||
|
|
||||||
|
def send_group(self, group, message):
|
||||||
|
"""
|
||||||
|
Sends a message to the entire group.
|
||||||
|
"""
|
||||||
|
self._clean_expired()
|
||||||
|
for channel in self.group_model.objects.filter(group=group).values_list("channel", flat=True):
|
||||||
|
self.send(channel, message)
|
||||||
|
|
||||||
|
### ASGI Flush extension ###
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
self.channel_model.objects.all().delete()
|
||||||
|
self.group_model.objects.all().delete()
|
||||||
|
|
||||||
|
### Serialization ###
|
||||||
|
|
||||||
|
def serialize(self, message):
|
||||||
|
return json.dumps(message)
|
||||||
|
|
||||||
|
def deserialize(self, message):
|
||||||
|
return json.loads(message)
|
||||||
|
|
||||||
|
### Database state mgmt ###
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def connection(self):
|
def connection(self):
|
||||||
"""
|
"""
|
||||||
|
@ -72,46 +170,6 @@ class DatabaseChannelBackend(BaseChannelBackend):
|
||||||
editor.create_model(Group)
|
editor.create_model(Group)
|
||||||
return Group
|
return Group
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def lock_model(self):
|
|
||||||
"""
|
|
||||||
Initialises a new model to store groups; not done as part of a
|
|
||||||
models.py as we don't want to make it for most installs.
|
|
||||||
"""
|
|
||||||
# Make the model class
|
|
||||||
class Lock(models.Model):
|
|
||||||
channel = models.CharField(max_length=200, unique=True)
|
|
||||||
expiry = models.DateTimeField(db_index=True)
|
|
||||||
|
|
||||||
class Meta:
|
|
||||||
apps = Apps()
|
|
||||||
app_label = "channels"
|
|
||||||
db_table = "django_channel_locks"
|
|
||||||
# Ensure its table exists
|
|
||||||
if Lock._meta.db_table not in self.connection.introspection.table_names(self.connection.cursor()):
|
|
||||||
with self.connection.schema_editor() as editor:
|
|
||||||
editor.create_model(Lock)
|
|
||||||
return Lock
|
|
||||||
|
|
||||||
def send(self, channel, message):
|
|
||||||
self.channel_model.objects.create(
|
|
||||||
channel=channel,
|
|
||||||
content=json.dumps(message),
|
|
||||||
expiry=now() + datetime.timedelta(seconds=self.expiry)
|
|
||||||
)
|
|
||||||
|
|
||||||
def receive_many(self, channels):
|
|
||||||
if not channels:
|
|
||||||
raise ValueError("Cannot receive on empty channel list!")
|
|
||||||
self._clean_expired()
|
|
||||||
# Get a message from one of our channels
|
|
||||||
message = self.channel_model.objects.filter(channel__in=channels).order_by("id").first()
|
|
||||||
if message:
|
|
||||||
self.channel_model.objects.filter(pk=message.pk).delete()
|
|
||||||
return message.channel, json.loads(message.content)
|
|
||||||
else:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
def _clean_expired(self):
|
def _clean_expired(self):
|
||||||
"""
|
"""
|
||||||
Cleans out expired groups and messages.
|
Cleans out expired groups and messages.
|
||||||
|
@ -119,57 +177,6 @@ class DatabaseChannelBackend(BaseChannelBackend):
|
||||||
# Include a 10-second grace period because that solves some clock sync
|
# Include a 10-second grace period because that solves some clock sync
|
||||||
self.channel_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete()
|
self.channel_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete()
|
||||||
self.group_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete()
|
self.group_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete()
|
||||||
self.lock_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete()
|
|
||||||
|
|
||||||
def group_add(self, group, channel, expiry=None):
|
|
||||||
"""
|
|
||||||
Adds the channel to the named group for at least 'expiry'
|
|
||||||
seconds (expiry defaults to message expiry if not provided).
|
|
||||||
"""
|
|
||||||
self.group_model.objects.update_or_create(
|
|
||||||
group=group,
|
|
||||||
channel=channel,
|
|
||||||
defaults={"expiry": now() + datetime.timedelta(seconds=expiry or self.expiry)},
|
|
||||||
)
|
|
||||||
|
|
||||||
def group_discard(self, group, channel):
|
|
||||||
"""
|
|
||||||
Removes the channel from the named group if it is in the group;
|
|
||||||
does nothing otherwise (does not error)
|
|
||||||
"""
|
|
||||||
self.group_model.objects.filter(group=group, channel=channel).delete()
|
|
||||||
|
|
||||||
def group_channels(self, group):
|
|
||||||
"""
|
|
||||||
Returns an iterable of all channels in the group.
|
|
||||||
"""
|
|
||||||
self._clean_expired()
|
|
||||||
return list(self.group_model.objects.filter(group=group).values_list("channel", flat=True))
|
|
||||||
|
|
||||||
def lock_channel(self, channel, expiry=None):
|
|
||||||
"""
|
|
||||||
Attempts to get a lock on the named channel. Returns True if lock
|
|
||||||
obtained, False if lock not obtained.
|
|
||||||
"""
|
|
||||||
# We rely on the UNIQUE constraint for only-one-thread-wins on locks
|
|
||||||
try:
|
|
||||||
self.lock_model.objects.create(
|
|
||||||
channel=channel,
|
|
||||||
expiry=now() + datetime.timedelta(seconds=expiry or self.expiry),
|
|
||||||
)
|
|
||||||
except IntegrityError:
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def unlock_channel(self, channel):
|
|
||||||
"""
|
|
||||||
Unlocks the named channel. Always succeeds.
|
|
||||||
"""
|
|
||||||
self.lock_model.objects.filter(channel=channel).delete()
|
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "%s(alias=%s)" % (self.__class__.__name__, self.connection.alias)
|
return "%s(alias=%s)" % (self.__class__.__name__, self.connection.alias)
|
||||||
|
|
||||||
def flush(self):
|
|
||||||
pass
|
|
|
@ -1,95 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
from __future__ import unicode_literals
|
|
||||||
|
|
||||||
from django.test import TestCase
|
|
||||||
|
|
||||||
from ..backends.database import DatabaseChannelBackend
|
|
||||||
from ..backends.memory import InMemoryChannelBackend
|
|
||||||
from ..backends.redis_py import RedisChannelBackend
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryBackendTests(TestCase):
|
|
||||||
|
|
||||||
backend_class = InMemoryChannelBackend
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.backend = self.backend_class(routing={})
|
|
||||||
self.backend.flush()
|
|
||||||
|
|
||||||
def test_send_recv(self):
|
|
||||||
"""
|
|
||||||
Tests that channels can send and receive messages.
|
|
||||||
"""
|
|
||||||
self.backend.send("test", {"value": "blue"})
|
|
||||||
self.backend.send("test", {"value": "green"})
|
|
||||||
self.backend.send("test2", {"value": "red"})
|
|
||||||
# Get just one first
|
|
||||||
channel, message = self.backend.receive_many(["test"])
|
|
||||||
self.assertEqual(channel, "test")
|
|
||||||
self.assertEqual(message, {"value": "blue"})
|
|
||||||
# And the second
|
|
||||||
channel, message = self.backend.receive_many(["test"])
|
|
||||||
self.assertEqual(channel, "test")
|
|
||||||
self.assertEqual(message, {"value": "green"})
|
|
||||||
# And the other channel with multi select
|
|
||||||
channel, message = self.backend.receive_many(["test", "test2"])
|
|
||||||
self.assertEqual(channel, "test2")
|
|
||||||
self.assertEqual(message, {"value": "red"})
|
|
||||||
|
|
||||||
def test_message_expiry(self):
|
|
||||||
self.backend = self.backend_class(routing={}, expiry=-100)
|
|
||||||
self.backend.send("test", {"value": "blue"})
|
|
||||||
channel, message = self.backend.receive_many(["test"])
|
|
||||||
self.assertIs(channel, None)
|
|
||||||
self.assertIs(message, None)
|
|
||||||
|
|
||||||
def test_groups(self):
|
|
||||||
"""
|
|
||||||
Tests that group addition and removal and listing works
|
|
||||||
"""
|
|
||||||
self.backend.group_add("tgroup", "test")
|
|
||||||
self.backend.group_add("tgroup", "test2€")
|
|
||||||
self.backend.group_add("tgroup2", "test3")
|
|
||||||
self.assertEqual(
|
|
||||||
set(self.backend.group_channels("tgroup")),
|
|
||||||
{"test", "test2€"},
|
|
||||||
)
|
|
||||||
self.backend.group_discard("tgroup", "test2€")
|
|
||||||
self.backend.group_discard("tgroup", "test2€")
|
|
||||||
self.assertEqual(
|
|
||||||
list(self.backend.group_channels("tgroup")),
|
|
||||||
["test"],
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_group_send(self):
|
|
||||||
"""
|
|
||||||
Tests sending to groups.
|
|
||||||
"""
|
|
||||||
self.backend.group_add("tgroup", "test")
|
|
||||||
self.backend.group_add("tgroup", "test2")
|
|
||||||
self.backend.send_group("tgroup", {"value": "orange"})
|
|
||||||
channel, message = self.backend.receive_many(["test"])
|
|
||||||
self.assertEqual(channel, "test")
|
|
||||||
self.assertEqual(message, {"value": "orange"})
|
|
||||||
channel, message = self.backend.receive_many(["test2"])
|
|
||||||
self.assertEqual(channel, "test2")
|
|
||||||
self.assertEqual(message, {"value": "orange"})
|
|
||||||
|
|
||||||
def test_group_expiry(self):
|
|
||||||
self.backend = self.backend_class(routing={}, expiry=-100)
|
|
||||||
self.backend.group_add("tgroup", "test")
|
|
||||||
self.backend.group_add("tgroup", "test2")
|
|
||||||
self.assertEqual(
|
|
||||||
list(self.backend.group_channels("tgroup")),
|
|
||||||
[],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RedisBackendTests(MemoryBackendTests):
|
|
||||||
|
|
||||||
backend_class = RedisChannelBackend
|
|
||||||
|
|
||||||
|
|
||||||
class DatabaseBackendTests(MemoryBackendTests):
|
|
||||||
|
|
||||||
backend_class = DatabaseChannelBackend
|
|
6
channels/tests/test_database_layer.py
Normal file
6
channels/tests/test_database_layer.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
from asgiref.conformance import make_tests
|
||||||
|
from ..database_layer import DatabaseChannelLayer
|
||||||
|
|
||||||
|
channel_layer = DatabaseChannelLayer(expiry=1)
|
||||||
|
DatabaseLayerTests = make_tests(channel_layer, expiry_delay=1.1)
|
|
@ -1,53 +0,0 @@
|
||||||
from django.test import TestCase
|
|
||||||
|
|
||||||
from channels.interfaces.websocket_autobahn import get_protocol
|
|
||||||
|
|
||||||
try:
|
|
||||||
from unittest import mock
|
|
||||||
except ImportError:
|
|
||||||
import mock
|
|
||||||
|
|
||||||
|
|
||||||
def generate_connection_request(path, params, headers):
|
|
||||||
request = mock.Mock()
|
|
||||||
request.path = path
|
|
||||||
request.params = params
|
|
||||||
request.headers = headers
|
|
||||||
return request
|
|
||||||
|
|
||||||
|
|
||||||
class WebsocketAutobahnInterfaceProtocolTestCase(TestCase):
|
|
||||||
def test_on_connect_cookie(self):
|
|
||||||
protocol = get_protocol(object)()
|
|
||||||
session = "123cat"
|
|
||||||
cookie = "somethingelse=test; sessionid={0}".format(session)
|
|
||||||
headers = {
|
|
||||||
"cookie": cookie
|
|
||||||
}
|
|
||||||
|
|
||||||
test_request = generate_connection_request("path", {}, headers)
|
|
||||||
protocol.onConnect(test_request)
|
|
||||||
self.assertEqual(session, protocol.request_info["cookies"]["sessionid"])
|
|
||||||
|
|
||||||
def test_on_connect_no_cookie(self):
|
|
||||||
protocol = get_protocol(object)()
|
|
||||||
test_request = generate_connection_request("path", {}, {})
|
|
||||||
protocol.onConnect(test_request)
|
|
||||||
self.assertEqual({}, protocol.request_info["cookies"])
|
|
||||||
|
|
||||||
def test_on_connect_params(self):
|
|
||||||
protocol = get_protocol(object)()
|
|
||||||
params = {
|
|
||||||
"session_key": ["123cat"]
|
|
||||||
}
|
|
||||||
|
|
||||||
test_request = generate_connection_request("path", params, {})
|
|
||||||
protocol.onConnect(test_request)
|
|
||||||
self.assertEqual(params, protocol.request_info["get"])
|
|
||||||
|
|
||||||
def test_on_connect_path(self):
|
|
||||||
protocol = get_protocol(object)()
|
|
||||||
path = "path"
|
|
||||||
test_request = generate_connection_request(path, {}, {})
|
|
||||||
protocol.onConnect(test_request)
|
|
||||||
self.assertEqual(path, protocol.request_info["path"])
|
|
Loading…
Reference in New Issue
Block a user