mirror of
https://github.com/django/daphne.git
synced 2025-06-13 17:43:25 +03:00
246 lines
8.5 KiB
Python
246 lines
8.5 KiB
Python
import base64
|
|
import datetime
|
|
import json
|
|
import random
|
|
import string
|
|
import time
|
|
|
|
from django.apps.registry import Apps
|
|
from django.db import DEFAULT_DB_ALIAS, connections, models, transaction
|
|
from django.db.utils import OperationalError
|
|
from django.utils import six
|
|
from django.utils.functional import cached_property
|
|
from django.utils.timezone import now
|
|
|
|
|
|
class DatabaseChannelLayer(object):
|
|
"""
|
|
ORM-backed ASGI channel layer.
|
|
|
|
For development use only; it will span multiple processes fine,
|
|
but it's going to be pretty bad at throughput. If you're reading this and
|
|
running it in production, PLEASE STOP.
|
|
|
|
Also uses JSON for serialization, as we don't want to make Django depend
|
|
on msgpack for the built-in backend. The JSON format uses \uffff as first
|
|
character to signify a b64 byte string rather than a text string. Ugly, but
|
|
it's not a valid Unicode character, so it should be safe enough.
|
|
"""
|
|
|
|
def __init__(self, db_alias=DEFAULT_DB_ALIAS, expiry=60, group_expiry=86400):
|
|
self.expiry = expiry
|
|
self.group_expiry = group_expiry
|
|
self.db_alias = db_alias
|
|
|
|
# ASGI API
|
|
|
|
extensions = ["groups", "flush"]
|
|
|
|
def send(self, channel, message):
|
|
# Typecheck
|
|
assert isinstance(message, dict), "message is not a dict"
|
|
assert isinstance(channel, six.text_type), "%s is not unicode" % channel
|
|
# Write message to messages table
|
|
self.channel_model.objects.create(
|
|
channel=channel,
|
|
content=self.serialize(message),
|
|
expiry=now() + datetime.timedelta(seconds=self.expiry)
|
|
)
|
|
|
|
def receive_many(self, channels, block=False):
|
|
if not channels:
|
|
return None, None
|
|
assert all(isinstance(channel, six.text_type) for channel in channels)
|
|
# Shuffle channels
|
|
channels = list(channels)
|
|
random.shuffle(channels)
|
|
# Clean out expired messages
|
|
self._clean_expired()
|
|
# Get a message from one of our channels
|
|
while True:
|
|
try:
|
|
with transaction.atomic():
|
|
message = self.channel_model.objects.filter(channel__in=channels).order_by("id").first()
|
|
if message:
|
|
self.channel_model.objects.filter(pk=message.pk).delete()
|
|
return message.channel, self.deserialize(message.content)
|
|
except OperationalError:
|
|
# The database is probably trying to prevent a deadlock
|
|
time.sleep(0.1)
|
|
continue
|
|
if block:
|
|
time.sleep(1)
|
|
else:
|
|
return None, None
|
|
|
|
def new_channel(self, pattern):
|
|
assert isinstance(pattern, six.text_type)
|
|
assert pattern.endswith("!")
|
|
# Keep making channel names till one isn't present.
|
|
while True:
|
|
random_string = "".join(random.choice(string.ascii_letters) for i in range(10))
|
|
new_name = pattern + random_string
|
|
if not self.channel_model.objects.filter(channel=new_name).exists():
|
|
return new_name
|
|
|
|
# ASGI Group extension
|
|
|
|
def group_add(self, group, channel):
|
|
"""
|
|
Adds the channel to the named group for at least 'expiry'
|
|
seconds (expiry defaults to message expiry if not provided).
|
|
"""
|
|
self.group_model.objects.update_or_create(
|
|
group=group,
|
|
channel=channel,
|
|
)
|
|
|
|
def group_discard(self, group, channel):
|
|
"""
|
|
Removes the channel from the named group if it is in the group;
|
|
does nothing otherwise (does not error)
|
|
"""
|
|
self.group_model.objects.filter(group=group, channel=channel).delete()
|
|
|
|
def send_group(self, group, message):
|
|
"""
|
|
Sends a message to the entire group.
|
|
"""
|
|
self._clean_expired()
|
|
for channel in self.group_model.objects.filter(group=group).values_list("channel", flat=True):
|
|
self.send(channel, message)
|
|
|
|
# ASGI Flush extension
|
|
|
|
def flush(self):
|
|
self.channel_model.objects.all().delete()
|
|
self.group_model.objects.all().delete()
|
|
|
|
# Serialization
|
|
|
|
def serialize(self, message):
|
|
return AsgiJsonEncoder().encode(message)
|
|
|
|
def deserialize(self, message):
|
|
return AsgiJsonDecoder().decode(message)
|
|
|
|
# Database state mgmt
|
|
|
|
@property
|
|
def connection(self):
|
|
"""
|
|
Returns the correct connection for the current thread.
|
|
"""
|
|
return connections[self.db_alias]
|
|
|
|
@cached_property
|
|
def channel_model(self):
|
|
"""
|
|
Initialises a new model to store messages; not done as part of a
|
|
models.py as we don't want to make it for most installs.
|
|
"""
|
|
# Make the model class
|
|
class Message(models.Model):
|
|
# We assume an autoincrementing PK for message order
|
|
channel = models.CharField(max_length=200, db_index=True)
|
|
content = models.TextField()
|
|
expiry = models.DateTimeField(db_index=True)
|
|
|
|
class Meta:
|
|
apps = Apps()
|
|
app_label = "channels"
|
|
db_table = "django_channels"
|
|
# Ensure its table exists
|
|
if Message._meta.db_table not in self.connection.introspection.table_names(self.connection.cursor()):
|
|
with self.connection.schema_editor() as editor:
|
|
editor.create_model(Message)
|
|
return Message
|
|
|
|
@cached_property
|
|
def group_model(self):
|
|
"""
|
|
Initialises a new model to store groups; not done as part of a
|
|
models.py as we don't want to make it for most installs.
|
|
"""
|
|
# Make the model class
|
|
class Group(models.Model):
|
|
group = models.CharField(max_length=200)
|
|
channel = models.CharField(max_length=200)
|
|
created = models.DateTimeField(db_index=True, auto_now_add=True)
|
|
|
|
class Meta:
|
|
apps = Apps()
|
|
app_label = "channels"
|
|
db_table = "django_channel_groups"
|
|
unique_together = [["group", "channel"]]
|
|
# Ensure its table exists with the right schema
|
|
if Group._meta.db_table not in self.connection.introspection.table_names(self.connection.cursor()):
|
|
with self.connection.schema_editor() as editor:
|
|
editor.create_model(Group)
|
|
return Group
|
|
|
|
def _clean_expired(self):
|
|
"""
|
|
Cleans out expired groups and messages.
|
|
"""
|
|
# Include a 1-second grace period for clock sync drift
|
|
target = now() - datetime.timedelta(seconds=1)
|
|
# First, go through old messages and pick out channels that got expired
|
|
old_messages = self.channel_model.objects.filter(expiry__lt=target)
|
|
channels_to_ungroup = old_messages.values_list("channel", flat=True).distinct()
|
|
old_messages.delete()
|
|
# Now, remove channel membership from channels that expired and ones that just expired
|
|
self.group_model.objects.filter(
|
|
models.Q(channel__in=channels_to_ungroup) |
|
|
models.Q(created__lte=target - datetime.timedelta(seconds=self.group_expiry))
|
|
).delete()
|
|
|
|
def __str__(self):
|
|
return "%s(alias=%s)" % (self.__class__.__name__, self.connection.alias)
|
|
|
|
|
|
class AsgiJsonEncoder(json.JSONEncoder):
|
|
"""
|
|
Special encoder that transforms bytestrings into unicode strings
|
|
prefixed with u+ffff
|
|
"""
|
|
|
|
def transform(self, o):
|
|
if isinstance(o, (list, tuple)):
|
|
return [self.transform(x) for x in o]
|
|
elif isinstance(o, dict):
|
|
return {
|
|
self.transform(k): self.transform(v)
|
|
for k, v in o.items()
|
|
}
|
|
elif isinstance(o, six.binary_type):
|
|
return u"\uffff" + base64.b64encode(o).decode("ascii")
|
|
else:
|
|
return o
|
|
|
|
def encode(self, o):
|
|
return super(AsgiJsonEncoder, self).encode(self.transform(o))
|
|
|
|
|
|
class AsgiJsonDecoder(json.JSONDecoder):
|
|
"""
|
|
Special encoder that transforms bytestrings into unicode strings
|
|
prefixed with u+ffff
|
|
"""
|
|
|
|
def transform(self, o):
|
|
if isinstance(o, (list, tuple)):
|
|
return [self.transform(x) for x in o]
|
|
elif isinstance(o, dict):
|
|
return {
|
|
self.transform(k): self.transform(v)
|
|
for k, v in o.items()
|
|
}
|
|
elif isinstance(o, six.text_type) and o and o[0] == u"\uffff":
|
|
return base64.b64decode(o[1:].encode("ascii"))
|
|
else:
|
|
return o
|
|
|
|
def decode(self, o):
|
|
return self.transform(super(AsgiJsonDecoder, self).decode(o))
|