diff --git a/channels/tests/base.py b/channels/tests/base.py index 7e7ddc9..46faf85 100644 --- a/channels/tests/base.py +++ b/channels/tests/base.py @@ -7,6 +7,7 @@ from functools import wraps from asgiref.inmemory import ChannelLayer as InMemoryChannelLayer from django.test.testcases import TestCase, TransactionTestCase +from django.db import close_old_connections from .. import DEFAULT_CHANNEL_LAYER from ..asgi import ChannelLayerWrapper, channel_layers @@ -134,7 +135,10 @@ class Client(object): consumer_started.send(sender=self.__class__) return consumer(message, **kwargs) finally: + # Copy Django's workaround so we don't actually close DB conns + consumer_finished.disconnect(close_old_connections) consumer_finished.send(sender=self.__class__) + consumer_finished.connect(close_old_connections) elif fail_on_none: raise AssertionError("Can't find consumer for message %s" % message) elif fail_on_none: