mirror of
https://github.com/django/daphne.git
synced 2025-10-23 12:04:19 +03:00
128 lines
4.7 KiB
Python
128 lines
4.7 KiB
Python
import time
|
|
import json
|
|
import datetime
|
|
|
|
from django.apps.registry import Apps
|
|
from django.db import models, connections, DEFAULT_DB_ALIAS
|
|
from django.utils.functional import cached_property
|
|
from django.utils.timezone import now
|
|
|
|
from .base import BaseChannelBackend
|
|
|
|
|
|
class DatabaseChannelBackend(BaseChannelBackend):
|
|
"""
|
|
ORM-backed channel environment. For development use only; it will span
|
|
multiple processes fine, but it's going to be pretty bad at throughput.
|
|
"""
|
|
|
|
def __init__(self, expiry=60, db_alias=DEFAULT_DB_ALIAS):
|
|
super(DatabaseChannelBackend, self).__init__(expiry)
|
|
self.db_alias = db_alias
|
|
|
|
@property
|
|
def connection(self):
|
|
"""
|
|
Returns the correct connection for the current thread.
|
|
"""
|
|
return connections[self.db_alias]
|
|
|
|
@cached_property
|
|
def channel_model(self):
|
|
"""
|
|
Initialises a new model to store messages; not done as part of a
|
|
models.py as we don't want to make it for most installs.
|
|
"""
|
|
# Make the model class
|
|
class Message(models.Model):
|
|
# We assume an autoincrementing PK for message order
|
|
channel = models.CharField(max_length=200, db_index=True)
|
|
content = models.TextField()
|
|
expiry = models.DateTimeField(db_index=True)
|
|
class Meta:
|
|
apps = Apps()
|
|
app_label = "channels"
|
|
db_table = "django_channels"
|
|
# Ensure its table exists
|
|
if Message._meta.db_table not in self.connection.introspection.table_names(self.connection.cursor()):
|
|
with self.connection.schema_editor() as editor:
|
|
editor.create_model(Message)
|
|
return Message
|
|
|
|
@cached_property
|
|
def group_model(self):
|
|
"""
|
|
Initialises a new model to store groups; not done as part of a
|
|
models.py as we don't want to make it for most installs.
|
|
"""
|
|
# Make the model class
|
|
class Group(models.Model):
|
|
group = models.CharField(max_length=200)
|
|
channel = models.CharField(max_length=200)
|
|
expiry = models.DateTimeField(db_index=True)
|
|
class Meta:
|
|
apps = Apps()
|
|
app_label = "channels"
|
|
db_table = "django_channel_groups"
|
|
unique_together = [["group", "channel"]]
|
|
# Ensure its table exists
|
|
if Group._meta.db_table not in self.connection.introspection.table_names(self.connection.cursor()):
|
|
with self.connection.schema_editor() as editor:
|
|
editor.create_model(Group)
|
|
return Group
|
|
|
|
def send(self, channel, message):
|
|
self.channel_model.objects.create(
|
|
channel = channel,
|
|
content = json.dumps(message),
|
|
expiry = now() + datetime.timedelta(seconds=self.expiry)
|
|
)
|
|
|
|
def receive_many(self, channels):
|
|
if not channels:
|
|
raise ValueError("Cannot receive on empty channel list!")
|
|
self._clean_expired()
|
|
# Get a message from one of our channels
|
|
message = self.channel_model.objects.filter(channel__in=channels).order_by("id").first()
|
|
if message:
|
|
self.channel_model.objects.filter(pk=message.pk).delete()
|
|
return message.channel, json.loads(message.content)
|
|
else:
|
|
return None, None
|
|
|
|
def _clean_expired(self):
|
|
"""
|
|
Cleans out expired groups and messages.
|
|
"""
|
|
# Include a 10-second grace period because that solves some clock sync
|
|
self.channel_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete()
|
|
self.group_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete()
|
|
|
|
def group_add(self, group, channel, expiry=None):
|
|
"""
|
|
Adds the channel to the named group for at least 'expiry'
|
|
seconds (expiry defaults to message expiry if not provided).
|
|
"""
|
|
self.group_model.objects.update_or_create(
|
|
group = group,
|
|
channel = channel,
|
|
defaults = {"expiry": now() + datetime.timedelta(seconds=expiry or self.expiry)},
|
|
)
|
|
|
|
def group_discard(self, group, channel):
|
|
"""
|
|
Removes the channel from the named group if it is in the group;
|
|
does nothing otherwise (does not error)
|
|
"""
|
|
self.group_model.objects.filter(group=group, channel=channel).delete()
|
|
|
|
def group_channels(self, group):
|
|
"""
|
|
Returns an iterable of all channels in the group.
|
|
"""
|
|
self._clean_expired()
|
|
return list(self.group_model.objects.filter(group=group).values_list("channel", flat=True))
|
|
|
|
def __str__(self):
|
|
return "%s(alias=%s)" % (self.__class__.__name__, self.connection.alias)
|