From 6649afce8e9f88c639db8e3eaa7385918f543117 Mon Sep 17 00:00:00 2001 From: Luke Hodkinson Date: Sun, 21 Aug 2016 10:53:54 +1000 Subject: [PATCH] Use a mixin for common test-case code. This way we can have both (#305) a regular channels test-case, and a transaction test-case, too. --- channels/tests/__init__.py | 2 +- channels/tests/base.py | 16 ++++++++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/channels/tests/__init__.py b/channels/tests/__init__.py index 36481f0..0c957f3 100644 --- a/channels/tests/__init__.py +++ b/channels/tests/__init__.py @@ -1,2 +1,2 @@ -from .base import ChannelTestCase, Client, apply_routes # NOQA isort:skip +from .base import TransactionChannelTestCase, ChannelTestCase, Client, apply_routes # NOQA isort:skip from .http import HttpClient # NOQA isort:skip diff --git a/channels/tests/base.py b/channels/tests/base.py index 248a837..5a90eca 100644 --- a/channels/tests/base.py +++ b/channels/tests/base.py @@ -5,7 +5,7 @@ import random import string from functools import wraps -from django.test.testcases import TestCase +from django.test.testcases import TestCase, TransactionTestCase from .. import DEFAULT_CHANNEL_LAYER from ..channel import Group from ..routing import Router, include @@ -14,7 +14,7 @@ from ..message import Message from asgiref.inmemory import ChannelLayer as InMemoryChannelLayer -class ChannelTestCase(TestCase): +class ChannelTestCaseMixin(object): """ TestCase subclass that provides easy methods for testing channels using an in-memory backend to capture messages, and assertion methods to allow @@ -31,7 +31,7 @@ class ChannelTestCase(TestCase): """ Initialises in memory channel layer for the duration of the test """ - super(ChannelTestCase, self)._pre_setup() + super(ChannelTestCaseMixin, self)._pre_setup() self._old_layers = {} for alias in self.test_channel_aliases: # Swap in an in memory layer wrapper and keep the old one around @@ -52,7 +52,7 @@ class ChannelTestCase(TestCase): # Swap in an in memory layer wrapper and keep the old one around channel_layers.set(alias, self._old_layers[alias]) del self._old_layers - super(ChannelTestCase, self)._post_teardown() + super(ChannelTestCaseMixin, self)._post_teardown() def get_next_message(self, channel, alias=DEFAULT_CHANNEL_LAYER, require=False): """ @@ -70,6 +70,14 @@ class ChannelTestCase(TestCase): return Message(content, recv_channel, channel_layers[alias]) +class ChannelTestCase(ChannelTestCaseMixin, TestCase): + pass + + +class TransactionChannelTestCase(ChannelTestCaseMixin, TransactionTestCase): + pass + + class Client(object): """ Channel client abstraction that provides easy methods for testing full live cycle of message in channels