mirror of
https://github.com/django/daphne.git
synced 2025-07-29 16:39:46 +03:00
Added a decorator that checks origin headers (#593)
Adds a new allowed_hosts_only decorator and extensible base class to allow for checking the incoming Origin header on WebSocket requests, using the Django `ALLOWED_HOSTS` setting by default.
This commit is contained in:
parent
4063ac03ed
commit
7383280776
0
channels/security/__init__.py
Normal file
0
channels/security/__init__.py
Normal file
90
channels/security/websockets.py
Normal file
90
channels/security/websockets.py
Normal file
|
@ -0,0 +1,90 @@
|
|||
from functools import update_wrapper
|
||||
|
||||
from django.conf import settings
|
||||
from django.http.request import validate_host
|
||||
|
||||
from ..exceptions import DenyConnection
|
||||
|
||||
try:
|
||||
from urllib.parse import urlparse
|
||||
except ImportError:
|
||||
from urlparse import urlparse
|
||||
|
||||
|
||||
class BaseOriginValidator(object):
|
||||
"""
|
||||
Base class-based decorator for origin validation of WebSocket connect
|
||||
messages.
|
||||
|
||||
This base class handles parsing of the origin header. When the origin header
|
||||
is missing, empty or contains non-ascii characters, it raises a
|
||||
DenyConnection exception to reject the connection.
|
||||
|
||||
Subclasses must overwrite the method validate_origin(self, message, origin)
|
||||
to return True when a message should be accepted, False otherwise.
|
||||
"""
|
||||
|
||||
def __init__(self, func):
|
||||
update_wrapper(self, func)
|
||||
self.func = func
|
||||
|
||||
def __call__(self, message, *args, **kwargs):
|
||||
origin = self.get_origin(message)
|
||||
if not self.validate_origin(message, origin):
|
||||
raise DenyConnection
|
||||
return self.func(message, *args, **kwargs)
|
||||
|
||||
def get_header(self, message, name):
|
||||
headers = message.content['headers']
|
||||
for header in headers:
|
||||
try:
|
||||
if header[0] == name:
|
||||
return header[1:]
|
||||
except IndexError:
|
||||
continue
|
||||
raise KeyError('No header named "{}"'.format(name))
|
||||
|
||||
def get_origin(self, message):
|
||||
"""
|
||||
Returns the origin of a WebSocket connect message.
|
||||
|
||||
Raises DenyConnection for messages with missing or non-ascii Origin
|
||||
header.
|
||||
"""
|
||||
try:
|
||||
header = self.get_header(message, b'origin')[0]
|
||||
except (IndexError, KeyError):
|
||||
raise DenyConnection
|
||||
try:
|
||||
origin = header.decode('ascii')
|
||||
except UnicodeDecodeError:
|
||||
raise DenyConnection
|
||||
return origin
|
||||
|
||||
def validate_origin(self, message, origin):
|
||||
"""
|
||||
Validates the origin of a WebSocket connect message.
|
||||
|
||||
Must be overwritten by subclasses.
|
||||
"""
|
||||
raise NotImplemented('You must overwrite this method.')
|
||||
|
||||
|
||||
class AllowedHostsOnlyOriginValidator(BaseOriginValidator):
|
||||
"""
|
||||
Class-based decorator for websocket consumers that checks that
|
||||
the origin is allowed according to the ALLOWED_HOSTS settings.
|
||||
"""
|
||||
|
||||
def validate_origin(self, message, origin):
|
||||
allowed_hosts = settings.ALLOWED_HOSTS
|
||||
if settings.DEBUG and not allowed_hosts:
|
||||
allowed_hosts = ['localhost', '127.0.0.1', '[::1]']
|
||||
|
||||
origin_hostname = urlparse(origin).hostname
|
||||
valid = (origin_hostname and
|
||||
validate_host(origin_hostname, allowed_hosts))
|
||||
return valid
|
||||
|
||||
|
||||
allowed_hosts_only = AllowedHostsOnlyOriginValidator
|
|
@ -518,6 +518,51 @@ responses can set cookies, it needs a backend it can write to to separately
|
|||
store state.
|
||||
|
||||
|
||||
Security
|
||||
--------
|
||||
|
||||
Unlike AJAX requests, WebSocket requests are not limited by the Same-Origin
|
||||
policy. This means you don't have to take any extra steps when you have an HTML
|
||||
page served by host A containing JavaScript code wanting to connect to a
|
||||
WebSocket on Host B.
|
||||
|
||||
While this can be convenient, it also implies that by default any third-party
|
||||
site can connect to your WebSocket application. When you are using the
|
||||
``http_session_user`` or the ``channel_session_user_from_http`` decorator, this
|
||||
connection would be authenticated.
|
||||
|
||||
The WebSocket specification requires browsers to send the origin of a WebSocket
|
||||
request in the HTTP header named ``Origin``, but validating that header is left
|
||||
to the server.
|
||||
|
||||
You can use the decorator ``channels.security.websockets.allowed_hosts_only``
|
||||
on a ``websocket.connect`` consumer to only allow requests originating
|
||||
from hosts listed in the ``ALLOWED_HOSTS`` setting::
|
||||
|
||||
# In consumers.py
|
||||
from channels import Channel, Group
|
||||
from channels.sessions import channel_session
|
||||
from channels.auth import channel_session_user, channel_session_user_from_http
|
||||
from channels.security.websockets import allowed_hosts_only.
|
||||
|
||||
# Connected to websocket.connect
|
||||
@allowed_hosts_only
|
||||
@channel_session_user_from_http
|
||||
def ws_add(message):
|
||||
# Accept connection
|
||||
...
|
||||
|
||||
Requests from other hosts or requests with missing or invalid origin header
|
||||
are now rejected.
|
||||
|
||||
The name ``allowed_hosts_only`` is an alias for the class-based decorator
|
||||
``AllowedHostsOnlyOriginValidator``, which inherits from
|
||||
``BaseOriginValidator``. If you have custom requirements for origin validation,
|
||||
create a subclass and overwrite the method
|
||||
``validate_origin(self, message, origin)``. It must return True when a message
|
||||
should be accepted, False otherwise.
|
||||
|
||||
|
||||
Routing
|
||||
-------
|
||||
|
||||
|
|
|
@ -172,7 +172,7 @@ directly, but there are two useful ways you can call it:
|
|||
Decorators
|
||||
----------
|
||||
|
||||
Channels provides decorators to assist with persisting data.
|
||||
Channels provides decorators to assist with persisting data and security.
|
||||
|
||||
* ``channel_session``: Provides a session-like object called "channel_session" to consumers
|
||||
as a message attribute that will auto-persist across consumers with
|
||||
|
@ -200,3 +200,10 @@ Channels provides decorators to assist with persisting data.
|
|||
|
||||
Stores the http session key in the channel_session on websocket.connect messages.
|
||||
It will then hydrate the http_session from that same key on subsequent messages.
|
||||
|
||||
* ``allowed_hosts_only``: Wraps a WebSocket connect consumer and ensures the
|
||||
request originates from an allowed host.
|
||||
|
||||
Reads the Origin header and only passes request originating from a host
|
||||
listed in ``ALLOWED_HOSTS`` to the consumer. Requests from other hosts or
|
||||
with a missing or invalid Origin headers are rejected.
|
||||
|
|
44
tests/test_security.py
Normal file
44
tests/test_security.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from django.test import override_settings
|
||||
from channels.exceptions import DenyConnection
|
||||
from channels.security.websockets import allowed_hosts_only
|
||||
from channels.message import Message
|
||||
from channels.test import ChannelTestCase
|
||||
|
||||
|
||||
@allowed_hosts_only
|
||||
def connect(message):
|
||||
return True
|
||||
|
||||
|
||||
class OriginValidationTestCase(ChannelTestCase):
|
||||
|
||||
@override_settings(ALLOWED_HOSTS=['example.com'])
|
||||
def test_valid_origin(self):
|
||||
content = {
|
||||
'headers': [[b'origin', b'http://example.com']]
|
||||
}
|
||||
message = Message(content, 'websocket.connect', None)
|
||||
self.assertTrue(connect(message))
|
||||
|
||||
@override_settings(ALLOWED_HOSTS=['example.com'])
|
||||
def test_invalid_origin(self):
|
||||
content = {
|
||||
'headers': [[b'origin', b'http://example.org']]
|
||||
}
|
||||
message = Message(content, 'websocket.connect', None)
|
||||
self.assertRaises(DenyConnection, connect, message)
|
||||
|
||||
def test_invalid_origin_header(self):
|
||||
invalid_headers = [
|
||||
[], # origin header missing
|
||||
[b'origin', b''], # origin header empty
|
||||
[b'origin', b'\xc3\xa4'] # non-ascii
|
||||
]
|
||||
for headers in invalid_headers:
|
||||
content = {
|
||||
'headers': [headers]
|
||||
}
|
||||
message = Message(content, 'websocket.connect', None)
|
||||
self.assertRaises(DenyConnection, connect, message)
|
Loading…
Reference in New Issue
Block a user