diff --git a/channels/database_layer.py b/channels/database_layer.py index 9caef7e..4f23142 100644 --- a/channels/database_layer.py +++ b/channels/database_layer.py @@ -27,8 +27,9 @@ class DatabaseChannelLayer(object): it's not a valid Unicode character, so it should be safe enough. """ - def __init__(self, db_alias=DEFAULT_DB_ALIAS, expiry=60): + def __init__(self, db_alias=DEFAULT_DB_ALIAS, expiry=60, group_expiry=86400): self.expiry = expiry + self.group_expiry = group_expiry self.db_alias = db_alias # ASGI API @@ -165,13 +166,14 @@ class DatabaseChannelLayer(object): class Group(models.Model): group = models.CharField(max_length=200) channel = models.CharField(max_length=200) + created = models.DateTimeField(db_index=True, auto_now_add=True) class Meta: apps = Apps() app_label = "channels" db_table = "django_channel_groups" unique_together = [["group", "channel"]] - # Ensure its table exists + # Ensure its table exists with the right schema if Group._meta.db_table not in self.connection.introspection.table_names(self.connection.cursor()): with self.connection.schema_editor() as editor: editor.create_model(Group) @@ -187,8 +189,11 @@ class DatabaseChannelLayer(object): old_messages = self.channel_model.objects.filter(expiry__lt=target) channels_to_ungroup = old_messages.values_list("channel", flat=True).distinct() old_messages.delete() - # Now, remove channel membership from channels that expired - self.group_model.objects.filter(channel__in=channels_to_ungroup).delete() + # Now, remove channel membership from channels that expired and ones that just expired + self.group_model.objects.filter( + models.Q(channel__in=channels_to_ungroup) | + models.Q(created__lte=target - datetime.timedelta(seconds=self.group_expiry)) + ).delete() def __str__(self): return "%s(alias=%s)" % (self.__class__.__name__, self.connection.alias) diff --git a/channels/tests/test_database_layer.py b/channels/tests/test_database_layer.py index c501c15..4878157 100644 --- a/channels/tests/test_database_layer.py +++ b/channels/tests/test_database_layer.py @@ -1,6 +1,8 @@ from __future__ import unicode_literals -from asgiref.conformance import make_tests +from asgiref.conformance import ConformanceTestCase from ..database_layer import DatabaseChannelLayer -channel_layer = DatabaseChannelLayer(expiry=1) -DatabaseLayerTests = make_tests(channel_layer, expiry_delay=2.1) + +class DatabaseLayerTests(ConformanceTestCase): + channel_layer = DatabaseChannelLayer(expiry=1, group_expiry=3) + expiry_delay = 2.1