From 73832807766811b4b21c0eb4835f6708a984fb72 Mon Sep 17 00:00:00 2001 From: Daniel Hepper Date: Fri, 7 Apr 2017 15:05:34 +0200 Subject: [PATCH] Added a decorator that checks origin headers (#593) Adds a new allowed_hosts_only decorator and extensible base class to allow for checking the incoming Origin header on WebSocket requests, using the Django `ALLOWED_HOSTS` setting by default. --- channels/security/__init__.py | 0 channels/security/websockets.py | 90 +++++++++++++++++++++++++++++++++ docs/getting-started.rst | 45 +++++++++++++++++ docs/reference.rst | 9 +++- tests/test_security.py | 44 ++++++++++++++++ 5 files changed, 187 insertions(+), 1 deletion(-) create mode 100644 channels/security/__init__.py create mode 100644 channels/security/websockets.py create mode 100644 tests/test_security.py diff --git a/channels/security/__init__.py b/channels/security/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/channels/security/websockets.py b/channels/security/websockets.py new file mode 100644 index 0000000..d1dac88 --- /dev/null +++ b/channels/security/websockets.py @@ -0,0 +1,90 @@ +from functools import update_wrapper + +from django.conf import settings +from django.http.request import validate_host + +from ..exceptions import DenyConnection + +try: + from urllib.parse import urlparse +except ImportError: + from urlparse import urlparse + + +class BaseOriginValidator(object): + """ + Base class-based decorator for origin validation of WebSocket connect + messages. + + This base class handles parsing of the origin header. When the origin header + is missing, empty or contains non-ascii characters, it raises a + DenyConnection exception to reject the connection. + + Subclasses must overwrite the method validate_origin(self, message, origin) + to return True when a message should be accepted, False otherwise. + """ + + def __init__(self, func): + update_wrapper(self, func) + self.func = func + + def __call__(self, message, *args, **kwargs): + origin = self.get_origin(message) + if not self.validate_origin(message, origin): + raise DenyConnection + return self.func(message, *args, **kwargs) + + def get_header(self, message, name): + headers = message.content['headers'] + for header in headers: + try: + if header[0] == name: + return header[1:] + except IndexError: + continue + raise KeyError('No header named "{}"'.format(name)) + + def get_origin(self, message): + """ + Returns the origin of a WebSocket connect message. + + Raises DenyConnection for messages with missing or non-ascii Origin + header. + """ + try: + header = self.get_header(message, b'origin')[0] + except (IndexError, KeyError): + raise DenyConnection + try: + origin = header.decode('ascii') + except UnicodeDecodeError: + raise DenyConnection + return origin + + def validate_origin(self, message, origin): + """ + Validates the origin of a WebSocket connect message. + + Must be overwritten by subclasses. + """ + raise NotImplemented('You must overwrite this method.') + + +class AllowedHostsOnlyOriginValidator(BaseOriginValidator): + """ + Class-based decorator for websocket consumers that checks that + the origin is allowed according to the ALLOWED_HOSTS settings. + """ + + def validate_origin(self, message, origin): + allowed_hosts = settings.ALLOWED_HOSTS + if settings.DEBUG and not allowed_hosts: + allowed_hosts = ['localhost', '127.0.0.1', '[::1]'] + + origin_hostname = urlparse(origin).hostname + valid = (origin_hostname and + validate_host(origin_hostname, allowed_hosts)) + return valid + + +allowed_hosts_only = AllowedHostsOnlyOriginValidator diff --git a/docs/getting-started.rst b/docs/getting-started.rst index acef373..95ebf4d 100644 --- a/docs/getting-started.rst +++ b/docs/getting-started.rst @@ -518,6 +518,51 @@ responses can set cookies, it needs a backend it can write to to separately store state. +Security +-------- + +Unlike AJAX requests, WebSocket requests are not limited by the Same-Origin +policy. This means you don't have to take any extra steps when you have an HTML +page served by host A containing JavaScript code wanting to connect to a +WebSocket on Host B. + +While this can be convenient, it also implies that by default any third-party +site can connect to your WebSocket application. When you are using the +``http_session_user`` or the ``channel_session_user_from_http`` decorator, this +connection would be authenticated. + +The WebSocket specification requires browsers to send the origin of a WebSocket +request in the HTTP header named ``Origin``, but validating that header is left +to the server. + +You can use the decorator ``channels.security.websockets.allowed_hosts_only`` +on a ``websocket.connect`` consumer to only allow requests originating +from hosts listed in the ``ALLOWED_HOSTS`` setting:: + + # In consumers.py + from channels import Channel, Group + from channels.sessions import channel_session + from channels.auth import channel_session_user, channel_session_user_from_http + from channels.security.websockets import allowed_hosts_only. + + # Connected to websocket.connect + @allowed_hosts_only + @channel_session_user_from_http + def ws_add(message): + # Accept connection + ... + +Requests from other hosts or requests with missing or invalid origin header +are now rejected. + +The name ``allowed_hosts_only`` is an alias for the class-based decorator +``AllowedHostsOnlyOriginValidator``, which inherits from +``BaseOriginValidator``. If you have custom requirements for origin validation, +create a subclass and overwrite the method +``validate_origin(self, message, origin)``. It must return True when a message +should be accepted, False otherwise. + + Routing ------- diff --git a/docs/reference.rst b/docs/reference.rst index 3a1f7e7..c90e8f4 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -172,7 +172,7 @@ directly, but there are two useful ways you can call it: Decorators ---------- -Channels provides decorators to assist with persisting data. +Channels provides decorators to assist with persisting data and security. * ``channel_session``: Provides a session-like object called "channel_session" to consumers as a message attribute that will auto-persist across consumers with @@ -200,3 +200,10 @@ Channels provides decorators to assist with persisting data. Stores the http session key in the channel_session on websocket.connect messages. It will then hydrate the http_session from that same key on subsequent messages. + +* ``allowed_hosts_only``: Wraps a WebSocket connect consumer and ensures the + request originates from an allowed host. + + Reads the Origin header and only passes request originating from a host + listed in ``ALLOWED_HOSTS`` to the consumer. Requests from other hosts or + with a missing or invalid Origin headers are rejected. diff --git a/tests/test_security.py b/tests/test_security.py new file mode 100644 index 0000000..2805481 --- /dev/null +++ b/tests/test_security.py @@ -0,0 +1,44 @@ +from __future__ import unicode_literals + +from django.test import override_settings +from channels.exceptions import DenyConnection +from channels.security.websockets import allowed_hosts_only +from channels.message import Message +from channels.test import ChannelTestCase + + +@allowed_hosts_only +def connect(message): + return True + + +class OriginValidationTestCase(ChannelTestCase): + + @override_settings(ALLOWED_HOSTS=['example.com']) + def test_valid_origin(self): + content = { + 'headers': [[b'origin', b'http://example.com']] + } + message = Message(content, 'websocket.connect', None) + self.assertTrue(connect(message)) + + @override_settings(ALLOWED_HOSTS=['example.com']) + def test_invalid_origin(self): + content = { + 'headers': [[b'origin', b'http://example.org']] + } + message = Message(content, 'websocket.connect', None) + self.assertRaises(DenyConnection, connect, message) + + def test_invalid_origin_header(self): + invalid_headers = [ + [], # origin header missing + [b'origin', b''], # origin header empty + [b'origin', b'\xc3\xa4'] # non-ascii + ] + for headers in invalid_headers: + content = { + 'headers': [headers] + } + message = Message(content, 'websocket.connect', None) + self.assertRaises(DenyConnection, connect, message)