Implement group_expiry on database channel backend

This commit is contained in:
Andrew Godwin 2016-03-28 11:59:25 +01:00
parent 59198ea93e
commit 1ab757fffb
2 changed files with 14 additions and 7 deletions

View File

@ -27,8 +27,9 @@ class DatabaseChannelLayer(object):
it's not a valid Unicode character, so it should be safe enough.
"""
def __init__(self, db_alias=DEFAULT_DB_ALIAS, expiry=60):
def __init__(self, db_alias=DEFAULT_DB_ALIAS, expiry=60, group_expiry=86400):
self.expiry = expiry
self.group_expiry = group_expiry
self.db_alias = db_alias
# ASGI API
@ -165,13 +166,14 @@ class DatabaseChannelLayer(object):
class Group(models.Model):
group = models.CharField(max_length=200)
channel = models.CharField(max_length=200)
created = models.DateTimeField(db_index=True, auto_now_add=True)
class Meta:
apps = Apps()
app_label = "channels"
db_table = "django_channel_groups"
unique_together = [["group", "channel"]]
# Ensure its table exists
# Ensure its table exists with the right schema
if Group._meta.db_table not in self.connection.introspection.table_names(self.connection.cursor()):
with self.connection.schema_editor() as editor:
editor.create_model(Group)
@ -187,8 +189,11 @@ class DatabaseChannelLayer(object):
old_messages = self.channel_model.objects.filter(expiry__lt=target)
channels_to_ungroup = old_messages.values_list("channel", flat=True).distinct()
old_messages.delete()
# Now, remove channel membership from channels that expired
self.group_model.objects.filter(channel__in=channels_to_ungroup).delete()
# Now, remove channel membership from channels that expired and ones that just expired
self.group_model.objects.filter(
models.Q(channel__in=channels_to_ungroup) |
models.Q(created__lte=target - datetime.timedelta(seconds=self.group_expiry))
).delete()
def __str__(self):
return "%s(alias=%s)" % (self.__class__.__name__, self.connection.alias)

View File

@ -1,6 +1,8 @@
from __future__ import unicode_literals
from asgiref.conformance import make_tests
from asgiref.conformance import ConformanceTestCase
from ..database_layer import DatabaseChannelLayer
channel_layer = DatabaseChannelLayer(expiry=1)
DatabaseLayerTests = make_tests(channel_layer, expiry_delay=2.1)
class DatabaseLayerTests(ConformanceTestCase):
channel_layer = DatabaseChannelLayer(expiry=1, group_expiry=3)
expiry_delay = 2.1