Added a decorator that checks origin headers (#593)

Adds a new allowed_hosts_only decorator and extensible base class to allow for checking the incoming Origin header on WebSocket requests, using the Django `ALLOWED_HOSTS` setting by default.
This commit is contained in:
Daniel Hepper 2017-04-07 15:05:34 +02:00 committed by Andrew Godwin
parent 4063ac03ed
commit 7383280776
5 changed files with 187 additions and 1 deletions

View File

View File

@ -0,0 +1,90 @@
from functools import update_wrapper
from django.conf import settings
from django.http.request import validate_host
from ..exceptions import DenyConnection
try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse
class BaseOriginValidator(object):
"""
Base class-based decorator for origin validation of WebSocket connect
messages.
This base class handles parsing of the origin header. When the origin header
is missing, empty or contains non-ascii characters, it raises a
DenyConnection exception to reject the connection.
Subclasses must overwrite the method validate_origin(self, message, origin)
to return True when a message should be accepted, False otherwise.
"""
def __init__(self, func):
update_wrapper(self, func)
self.func = func
def __call__(self, message, *args, **kwargs):
origin = self.get_origin(message)
if not self.validate_origin(message, origin):
raise DenyConnection
return self.func(message, *args, **kwargs)
def get_header(self, message, name):
headers = message.content['headers']
for header in headers:
try:
if header[0] == name:
return header[1:]
except IndexError:
continue
raise KeyError('No header named "{}"'.format(name))
def get_origin(self, message):
"""
Returns the origin of a WebSocket connect message.
Raises DenyConnection for messages with missing or non-ascii Origin
header.
"""
try:
header = self.get_header(message, b'origin')[0]
except (IndexError, KeyError):
raise DenyConnection
try:
origin = header.decode('ascii')
except UnicodeDecodeError:
raise DenyConnection
return origin
def validate_origin(self, message, origin):
"""
Validates the origin of a WebSocket connect message.
Must be overwritten by subclasses.
"""
raise NotImplemented('You must overwrite this method.')
class AllowedHostsOnlyOriginValidator(BaseOriginValidator):
"""
Class-based decorator for websocket consumers that checks that
the origin is allowed according to the ALLOWED_HOSTS settings.
"""
def validate_origin(self, message, origin):
allowed_hosts = settings.ALLOWED_HOSTS
if settings.DEBUG and not allowed_hosts:
allowed_hosts = ['localhost', '127.0.0.1', '[::1]']
origin_hostname = urlparse(origin).hostname
valid = (origin_hostname and
validate_host(origin_hostname, allowed_hosts))
return valid
allowed_hosts_only = AllowedHostsOnlyOriginValidator

View File

@ -518,6 +518,51 @@ responses can set cookies, it needs a backend it can write to to separately
store state.
Security
--------
Unlike AJAX requests, WebSocket requests are not limited by the Same-Origin
policy. This means you don't have to take any extra steps when you have an HTML
page served by host A containing JavaScript code wanting to connect to a
WebSocket on Host B.
While this can be convenient, it also implies that by default any third-party
site can connect to your WebSocket application. When you are using the
``http_session_user`` or the ``channel_session_user_from_http`` decorator, this
connection would be authenticated.
The WebSocket specification requires browsers to send the origin of a WebSocket
request in the HTTP header named ``Origin``, but validating that header is left
to the server.
You can use the decorator ``channels.security.websockets.allowed_hosts_only``
on a ``websocket.connect`` consumer to only allow requests originating
from hosts listed in the ``ALLOWED_HOSTS`` setting::
# In consumers.py
from channels import Channel, Group
from channels.sessions import channel_session
from channels.auth import channel_session_user, channel_session_user_from_http
from channels.security.websockets import allowed_hosts_only.
# Connected to websocket.connect
@allowed_hosts_only
@channel_session_user_from_http
def ws_add(message):
# Accept connection
...
Requests from other hosts or requests with missing or invalid origin header
are now rejected.
The name ``allowed_hosts_only`` is an alias for the class-based decorator
``AllowedHostsOnlyOriginValidator``, which inherits from
``BaseOriginValidator``. If you have custom requirements for origin validation,
create a subclass and overwrite the method
``validate_origin(self, message, origin)``. It must return True when a message
should be accepted, False otherwise.
Routing
-------

View File

@ -172,7 +172,7 @@ directly, but there are two useful ways you can call it:
Decorators
----------
Channels provides decorators to assist with persisting data.
Channels provides decorators to assist with persisting data and security.
* ``channel_session``: Provides a session-like object called "channel_session" to consumers
as a message attribute that will auto-persist across consumers with
@ -200,3 +200,10 @@ Channels provides decorators to assist with persisting data.
Stores the http session key in the channel_session on websocket.connect messages.
It will then hydrate the http_session from that same key on subsequent messages.
* ``allowed_hosts_only``: Wraps a WebSocket connect consumer and ensures the
request originates from an allowed host.
Reads the Origin header and only passes request originating from a host
listed in ``ALLOWED_HOSTS`` to the consumer. Requests from other hosts or
with a missing or invalid Origin headers are rejected.

44
tests/test_security.py Normal file
View File

@ -0,0 +1,44 @@
from __future__ import unicode_literals
from django.test import override_settings
from channels.exceptions import DenyConnection
from channels.security.websockets import allowed_hosts_only
from channels.message import Message
from channels.test import ChannelTestCase
@allowed_hosts_only
def connect(message):
return True
class OriginValidationTestCase(ChannelTestCase):
@override_settings(ALLOWED_HOSTS=['example.com'])
def test_valid_origin(self):
content = {
'headers': [[b'origin', b'http://example.com']]
}
message = Message(content, 'websocket.connect', None)
self.assertTrue(connect(message))
@override_settings(ALLOWED_HOSTS=['example.com'])
def test_invalid_origin(self):
content = {
'headers': [[b'origin', b'http://example.org']]
}
message = Message(content, 'websocket.connect', None)
self.assertRaises(DenyConnection, connect, message)
def test_invalid_origin_header(self):
invalid_headers = [
[], # origin header missing
[b'origin', b''], # origin header empty
[b'origin', b'\xc3\xa4'] # non-ascii
]
for headers in invalid_headers:
content = {
'headers': [headers]
}
message = Message(content, 'websocket.connect', None)
self.assertRaises(DenyConnection, connect, message)