diff --git a/channels/interfaces/websocket_autobahn.py b/channels/interfaces/websocket_autobahn.py index 1445ad1..eff1bc6 100644 --- a/channels/interfaces/websocket_autobahn.py +++ b/channels/interfaces/websocket_autobahn.py @@ -1,5 +1,7 @@ import time +from django.http import parse_cookie + from channels import DEFAULT_CHANNEL_BACKEND, Channel, channel_backends @@ -16,6 +18,7 @@ def get_protocol(base): self.request_info = { "path": request.path, "get": request.params, + "cookies": parse_cookie(request.headers.get('cookie', '')) } def onOpen(self): diff --git a/channels/tests/test_interfaces.py b/channels/tests/test_interfaces.py new file mode 100644 index 0000000..fd830d7 --- /dev/null +++ b/channels/tests/test_interfaces.py @@ -0,0 +1,53 @@ +from django.test import TestCase + +from channels.interfaces.websocket_autobahn import get_protocol + +try: + from unittest import mock +except ImportError: + import mock + + +def generate_connection_request(path, params, headers): + request = mock.Mock() + request.path = path + request.params = params + request.headers = headers + return request + + +class WebsocketAutobahnInterfaceProtocolTestCase(TestCase): + def test_on_connect_cookie(self): + protocol = get_protocol(object)() + session = "123cat" + cookie = "somethingelse=test; sessionid={0}".format(session) + headers = { + "cookie": cookie + } + + test_request = generate_connection_request("path", {}, headers) + protocol.onConnect(test_request) + self.assertEqual(session, protocol.request_info["cookies"]["sessionid"]) + + def test_on_connect_no_cookie(self): + protocol = get_protocol(object)() + test_request = generate_connection_request("path", {}, {}) + protocol.onConnect(test_request) + self.assertEqual({}, protocol.request_info["cookies"]) + + def test_on_connect_params(self): + protocol = get_protocol(object)() + params = { + "session_key": ["123cat"] + } + + test_request = generate_connection_request("path", params, {}) + protocol.onConnect(test_request) + self.assertEqual(params, protocol.request_info["get"]) + + def test_on_connect_path(self): + protocol = get_protocol(object)() + path = "path" + test_request = generate_connection_request(path, {}, {}) + protocol.onConnect(test_request) + self.assertEqual(path, protocol.request_info["path"]) diff --git a/tox.ini b/tox.ini index f42ff19..4cde18b 100644 --- a/tox.ini +++ b/tox.ini @@ -10,8 +10,10 @@ envlist = setenv = PYTHONPATH = {toxinidir}:{toxinidir} deps = + autobahn six redis==2.10.5 + py27: mock flake8: flake8 isort: isort django-16: Django>=1.6,<1.7