mirror of
https://github.com/django/daphne.git
synced 2025-07-11 08:22:17 +03:00
Update database channel backend to pass conformance
This commit is contained in:
parent
5df99c9cfd
commit
17e9824f71
|
@ -3,6 +3,7 @@ import json
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
import time
|
import time
|
||||||
|
import base64
|
||||||
|
|
||||||
from django.apps.registry import Apps
|
from django.apps.registry import Apps
|
||||||
from django.db import DEFAULT_DB_ALIAS, IntegrityError, connections, models
|
from django.db import DEFAULT_DB_ALIAS, IntegrityError, connections, models
|
||||||
|
@ -21,7 +22,8 @@ class DatabaseChannelLayer(object):
|
||||||
|
|
||||||
Also uses JSON for serialization, as we don't want to make Django depend
|
Also uses JSON for serialization, as we don't want to make Django depend
|
||||||
on msgpack for the built-in backend. The JSON format uses \uffff as first
|
on msgpack for the built-in backend. The JSON format uses \uffff as first
|
||||||
character to signify a byte string rather than a text string.
|
character to signify a b64 byte string rather than a text string. Ugly, but
|
||||||
|
it's not a valid Unicode character, so it should be safe enough.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, db_alias=DEFAULT_DB_ALIAS, expiry=60):
|
def __init__(self, db_alias=DEFAULT_DB_ALIAS, expiry=60):
|
||||||
|
@ -75,7 +77,7 @@ class DatabaseChannelLayer(object):
|
||||||
|
|
||||||
### ASGI Group extension ###
|
### ASGI Group extension ###
|
||||||
|
|
||||||
def group_add(self, group, channel, expiry=None):
|
def group_add(self, group, channel):
|
||||||
"""
|
"""
|
||||||
Adds the channel to the named group for at least 'expiry'
|
Adds the channel to the named group for at least 'expiry'
|
||||||
seconds (expiry defaults to message expiry if not provided).
|
seconds (expiry defaults to message expiry if not provided).
|
||||||
|
@ -83,7 +85,6 @@ class DatabaseChannelLayer(object):
|
||||||
self.group_model.objects.update_or_create(
|
self.group_model.objects.update_or_create(
|
||||||
group=group,
|
group=group,
|
||||||
channel=channel,
|
channel=channel,
|
||||||
defaults={"expiry": now() + datetime.timedelta(seconds=expiry or self.expiry)},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def group_discard(self, group, channel):
|
def group_discard(self, group, channel):
|
||||||
|
@ -110,10 +111,10 @@ class DatabaseChannelLayer(object):
|
||||||
### Serialization ###
|
### Serialization ###
|
||||||
|
|
||||||
def serialize(self, message):
|
def serialize(self, message):
|
||||||
return json.dumps(message)
|
return AsgiJsonEncoder().encode(message)
|
||||||
|
|
||||||
def deserialize(self, message):
|
def deserialize(self, message):
|
||||||
return json.loads(message)
|
return AsgiJsonDecoder().decode(message)
|
||||||
|
|
||||||
### Database state mgmt ###
|
### Database state mgmt ###
|
||||||
|
|
||||||
|
@ -157,7 +158,6 @@ class DatabaseChannelLayer(object):
|
||||||
class Group(models.Model):
|
class Group(models.Model):
|
||||||
group = models.CharField(max_length=200)
|
group = models.CharField(max_length=200)
|
||||||
channel = models.CharField(max_length=200)
|
channel = models.CharField(max_length=200)
|
||||||
expiry = models.DateTimeField(db_index=True)
|
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
apps = Apps()
|
apps = Apps()
|
||||||
|
@ -174,9 +174,60 @@ class DatabaseChannelLayer(object):
|
||||||
"""
|
"""
|
||||||
Cleans out expired groups and messages.
|
Cleans out expired groups and messages.
|
||||||
"""
|
"""
|
||||||
# Include a 10-second grace period because that solves some clock sync
|
# Include a 1-second grace period for clock sync drift
|
||||||
self.channel_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete()
|
target = now() - datetime.timedelta(seconds=1)
|
||||||
self.group_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete()
|
# First, go through old messages and pick out channels that got expired
|
||||||
|
old_messages = self.channel_model.objects.filter(expiry__lt=target)
|
||||||
|
channels_to_ungroup = old_messages.values_list("channel", flat=True).distinct()
|
||||||
|
old_messages.delete()
|
||||||
|
# Now, remove channel membership from channels that expired
|
||||||
|
self.group_model.objects.filter(channel__in=channels_to_ungroup).delete()
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "%s(alias=%s)" % (self.__class__.__name__, self.connection.alias)
|
return "%s(alias=%s)" % (self.__class__.__name__, self.connection.alias)
|
||||||
|
|
||||||
|
|
||||||
|
class AsgiJsonEncoder(json.JSONEncoder):
|
||||||
|
"""
|
||||||
|
Special encoder that transforms bytestrings into unicode strings
|
||||||
|
prefixed with u+ffff
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform(self, o):
|
||||||
|
if isinstance(o, (list, tuple)):
|
||||||
|
return [self.transform(x) for x in o]
|
||||||
|
elif isinstance(o, dict):
|
||||||
|
return {
|
||||||
|
self.transform(k): self.transform(v)
|
||||||
|
for k, v in o.items()
|
||||||
|
}
|
||||||
|
elif isinstance(o, six.binary_type):
|
||||||
|
return u"\uffff" + base64.b64encode(o).decode("ascii")
|
||||||
|
else:
|
||||||
|
return o
|
||||||
|
|
||||||
|
def encode(self, o):
|
||||||
|
return super(AsgiJsonEncoder, self).encode(self.transform(o))
|
||||||
|
|
||||||
|
|
||||||
|
class AsgiJsonDecoder(json.JSONDecoder):
|
||||||
|
"""
|
||||||
|
Special encoder that transforms bytestrings into unicode strings
|
||||||
|
prefixed with u+ffff
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transform(self, o):
|
||||||
|
if isinstance(o, (list, tuple)):
|
||||||
|
return [self.transform(x) for x in o]
|
||||||
|
elif isinstance(o, dict):
|
||||||
|
return {
|
||||||
|
self.transform(k): self.transform(v)
|
||||||
|
for k, v in o.items()
|
||||||
|
}
|
||||||
|
elif isinstance(o, six.text_type) and o and o[0] == u"\uffff":
|
||||||
|
return base64.b64decode(o[1:].encode("ascii"))
|
||||||
|
else:
|
||||||
|
return o
|
||||||
|
|
||||||
|
def decode(self, o):
|
||||||
|
return self.transform(super(AsgiJsonDecoder, self).decode(o))
|
||||||
|
|
|
@ -3,4 +3,4 @@ from asgiref.conformance import make_tests
|
||||||
from ..database_layer import DatabaseChannelLayer
|
from ..database_layer import DatabaseChannelLayer
|
||||||
|
|
||||||
channel_layer = DatabaseChannelLayer(expiry=1)
|
channel_layer = DatabaseChannelLayer(expiry=1)
|
||||||
DatabaseLayerTests = make_tests(channel_layer, expiry_delay=1.1)
|
DatabaseLayerTests = make_tests(channel_layer, expiry_delay=2.1)
|
||||||
|
|
|
@ -367,7 +367,8 @@ clients reconnect will immediately resolve the problem.
|
||||||
|
|
||||||
If a channel layer implements the ``groups`` extension, it must persist group
|
If a channel layer implements the ``groups`` extension, it must persist group
|
||||||
membership until at least the time when the member channel has a message
|
membership until at least the time when the member channel has a message
|
||||||
expire due to non-consumption.
|
expire due to non-consumption. It should drop membership after a while to
|
||||||
|
prevent collision of old messages with new clients with the same random ID.
|
||||||
|
|
||||||
|
|
||||||
Message Formats
|
Message Formats
|
||||||
|
|
Loading…
Reference in New Issue
Block a user