Update database channel backend to pass conformance

This commit is contained in:
Andrew Godwin 2016-02-05 14:16:20 -08:00
parent 5df99c9cfd
commit 17e9824f71
3 changed files with 63 additions and 11 deletions

View File

@ -3,6 +3,7 @@ import json
import random import random
import string import string
import time import time
import base64
from django.apps.registry import Apps from django.apps.registry import Apps
from django.db import DEFAULT_DB_ALIAS, IntegrityError, connections, models from django.db import DEFAULT_DB_ALIAS, IntegrityError, connections, models
@ -21,7 +22,8 @@ class DatabaseChannelLayer(object):
Also uses JSON for serialization, as we don't want to make Django depend Also uses JSON for serialization, as we don't want to make Django depend
on msgpack for the built-in backend. The JSON format uses \uffff as first on msgpack for the built-in backend. The JSON format uses \uffff as first
character to signify a byte string rather than a text string. character to signify a b64 byte string rather than a text string. Ugly, but
it's not a valid Unicode character, so it should be safe enough.
""" """
def __init__(self, db_alias=DEFAULT_DB_ALIAS, expiry=60): def __init__(self, db_alias=DEFAULT_DB_ALIAS, expiry=60):
@ -75,7 +77,7 @@ class DatabaseChannelLayer(object):
### ASGI Group extension ### ### ASGI Group extension ###
def group_add(self, group, channel, expiry=None): def group_add(self, group, channel):
""" """
Adds the channel to the named group for at least 'expiry' Adds the channel to the named group for at least 'expiry'
seconds (expiry defaults to message expiry if not provided). seconds (expiry defaults to message expiry if not provided).
@ -83,7 +85,6 @@ class DatabaseChannelLayer(object):
self.group_model.objects.update_or_create( self.group_model.objects.update_or_create(
group=group, group=group,
channel=channel, channel=channel,
defaults={"expiry": now() + datetime.timedelta(seconds=expiry or self.expiry)},
) )
def group_discard(self, group, channel): def group_discard(self, group, channel):
@ -110,10 +111,10 @@ class DatabaseChannelLayer(object):
### Serialization ### ### Serialization ###
def serialize(self, message): def serialize(self, message):
return json.dumps(message) return AsgiJsonEncoder().encode(message)
def deserialize(self, message): def deserialize(self, message):
return json.loads(message) return AsgiJsonDecoder().decode(message)
### Database state mgmt ### ### Database state mgmt ###
@ -157,7 +158,6 @@ class DatabaseChannelLayer(object):
class Group(models.Model): class Group(models.Model):
group = models.CharField(max_length=200) group = models.CharField(max_length=200)
channel = models.CharField(max_length=200) channel = models.CharField(max_length=200)
expiry = models.DateTimeField(db_index=True)
class Meta: class Meta:
apps = Apps() apps = Apps()
@ -174,9 +174,60 @@ class DatabaseChannelLayer(object):
""" """
Cleans out expired groups and messages. Cleans out expired groups and messages.
""" """
# Include a 10-second grace period because that solves some clock sync # Include a 1-second grace period for clock sync drift
self.channel_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete() target = now() - datetime.timedelta(seconds=1)
self.group_model.objects.filter(expiry__lt=now() - datetime.timedelta(seconds=10)).delete() # First, go through old messages and pick out channels that got expired
old_messages = self.channel_model.objects.filter(expiry__lt=target)
channels_to_ungroup = old_messages.values_list("channel", flat=True).distinct()
old_messages.delete()
# Now, remove channel membership from channels that expired
self.group_model.objects.filter(channel__in=channels_to_ungroup).delete()
def __str__(self): def __str__(self):
return "%s(alias=%s)" % (self.__class__.__name__, self.connection.alias) return "%s(alias=%s)" % (self.__class__.__name__, self.connection.alias)
class AsgiJsonEncoder(json.JSONEncoder):
"""
Special encoder that transforms bytestrings into unicode strings
prefixed with u+ffff
"""
def transform(self, o):
if isinstance(o, (list, tuple)):
return [self.transform(x) for x in o]
elif isinstance(o, dict):
return {
self.transform(k): self.transform(v)
for k, v in o.items()
}
elif isinstance(o, six.binary_type):
return u"\uffff" + base64.b64encode(o).decode("ascii")
else:
return o
def encode(self, o):
return super(AsgiJsonEncoder, self).encode(self.transform(o))
class AsgiJsonDecoder(json.JSONDecoder):
"""
Special encoder that transforms bytestrings into unicode strings
prefixed with u+ffff
"""
def transform(self, o):
if isinstance(o, (list, tuple)):
return [self.transform(x) for x in o]
elif isinstance(o, dict):
return {
self.transform(k): self.transform(v)
for k, v in o.items()
}
elif isinstance(o, six.text_type) and o and o[0] == u"\uffff":
return base64.b64decode(o[1:].encode("ascii"))
else:
return o
def decode(self, o):
return self.transform(super(AsgiJsonDecoder, self).decode(o))

View File

@ -3,4 +3,4 @@ from asgiref.conformance import make_tests
from ..database_layer import DatabaseChannelLayer from ..database_layer import DatabaseChannelLayer
channel_layer = DatabaseChannelLayer(expiry=1) channel_layer = DatabaseChannelLayer(expiry=1)
DatabaseLayerTests = make_tests(channel_layer, expiry_delay=1.1) DatabaseLayerTests = make_tests(channel_layer, expiry_delay=2.1)

View File

@ -367,7 +367,8 @@ clients reconnect will immediately resolve the problem.
If a channel layer implements the ``groups`` extension, it must persist group If a channel layer implements the ``groups`` extension, it must persist group
membership until at least the time when the member channel has a message membership until at least the time when the member channel has a message
expire due to non-consumption. expire due to non-consumption. It should drop membership after a while to
prevent collision of old messages with new clients with the same random ID.
Message Formats Message Formats