Compare commits

..

No commits in common. "main" and "0.15.0" have entirely different histories.
main ... 0.15.0

42 changed files with 619 additions and 4115 deletions

11
.flake8
View File

@ -1,11 +0,0 @@
[flake8]
exclude =
.venv,
.tox,
docs,
testproject,
js_client,
.eggs
extend-ignore = E123, E128, E266, E402, W503, E731, W601, B036
max-line-length = 120

View File

@ -1,14 +0,0 @@
Issues are for **concrete, actionable bugs and feature requests** only - if you're just asking for debugging help or technical support we have to direct you elsewhere. If you just have questions or support requests please use:
- Stack Overflow
- The Django Users mailing list django-users@googlegroups.com (https://groups.google.com/forum/#!forum/django-users)
We have to limit this because of limited volunteer time to respond to issues!
Please also try and include, if you can:
- Your OS and runtime environment, and browser if applicable
- A `pip freeze` output showing your package versions
- What you expected to happen vs. what actually happened
- How you're running Channels (runserver? daphne/runworker? Nginx/Apache in front?)
- Console logs and full tracebacks of any errors

View File

@ -1,6 +0,0 @@
version: 2
updates:
- package-ecosystem: github-actions
directory: "/"
schedule:
interval: weekly

View File

@ -1,43 +0,0 @@
name: Tests
on:
push:
branches:
- main
pull_request:
workflow_dispatch:
permissions:
contents: read
jobs:
tests:
runs-on: ${{ matrix.os }}-latest
strategy:
fail-fast: false
matrix:
os:
- ubuntu
- windows
python-version:
- "3.9"
- "3.10"
- "3.11"
- "3.12"
- "3.13"
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip setuptools wheel
python -m pip install --upgrade tox
- name: Run tox targets for ${{ matrix.python-version }}
run: tox run -f py$(echo ${{ matrix.python-version }} | tr -d .)

11
.gitignore vendored
View File

@ -1,16 +1,5 @@
.idea/
*.egg-info *.egg-info
*.pyc *.pyc
__pycache__ __pycache__
dist/ dist/
build/ build/
/.tox
.hypothesis
.cache
.eggs
test_layer*
test_consumer*
.python-version
.pytest_cache/
.vscode
.coverage

View File

@ -1,23 +0,0 @@
repos:
- repo: https://github.com/asottile/pyupgrade
rev: v3.19.1
hooks:
- id: pyupgrade
args: [--py39-plus]
- repo: https://github.com/psf/black
rev: 25.1.0
hooks:
- id: black
language_version: python3
- repo: https://github.com/pycqa/isort
rev: 6.0.1
hooks:
- id: isort
- repo: https://github.com/PyCQA/flake8
rev: 7.2.0
hooks:
- id: flake8
additional_dependencies:
- flake8-bugbear
ci:
autoupdate_schedule: quarterly

9
.travis.yml Normal file
View File

@ -0,0 +1,9 @@
sudo: false
language: python
python:
- "2.7"
- "3.5"
install:
- if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pip install unittest2; fi
- pip install asgiref twisted autobahn
script: if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then python -m unittest2; else python -m unittest; fi

View File

@ -1,382 +1,3 @@
4.2.0 (to be released)
------------------
* Added support for Python 3.13.
* Drop support for EOL Python 3.8.
* Removed unused pytest-runner
* Fixed sdist file to ensure it includes all tests
4.1.2 (2024-04-11)
------------------
* Fixed a setuptools configuration error in 4.1.1.
4.1.1 (2024-04-10)
------------------
* Fixed a twisted.plugin packaging error in 4.1.0.
Thanks to sdc50.
4.1.0 (2024-02-10)
------------------
* Added support for Python 3.12.
* Dropped support for EOL Python 3.7.
* Handled root path for websocket scopes.
* Validate HTTP header names as per RFC 9110.
4.0.0 (2022-10-07)
------------------
Major versioning targeting use with Channels 4.0 and beyond. Except where
noted should remain usable with Channels v3 projects, but updating Channels to the latest version is recommended.
* Added a ``runserver`` command to run an ASGI Django development server.
Added ``"daphne"`` to the ``INSTALLED_APPS`` setting, before
``"django.contrib.staticfiles"`` to enable:
INSTALLED_APPS = [
"daphne",
...
]
This replaces the Channels implementation of ``runserver``, which is removed
in Channels 4.0.
* Made the ``DaphneProcess`` tests helper class compatible with the ``spawn``
process start method, which is used on macOS and Windows.
Note that requires Channels v4 if using with ``ChannelsLiveServerTestCase``.
* Dropped support for Python 3.6.
* Updated dependencies to the latest versions.
Previously a range of Twisted versions have been supported. Recent Twisted
releases (22.2, 22.4) have issued security fixes, so those are now the
minimum supported version. Given the stability of Twisted, supporting a
range of versions does not represent a good use of maintainer time. Going
forward the latest Twisted version will be required.
* Set ``daphne`` as default ``Server`` header.
This can be configured with the ``--server-name`` CLI argument.
Added the new ``--no-server-name`` CLI argument to disable the ``Server``
header, which is equivalent to ``--server-name=` (an empty name).
* Added ``--log-fmt`` CLI argument.
* Added support for ``ASGI_THREADS`` environment variable, setting the maximum
number of workers used by a ``SyncToAsync`` thread-pool executor.
Set e.g. ``ASGI_THREADS=4 daphne ...`` when running to limit the number of
workers.
* Removed deprecated ``--ws_protocols`` CLI option.
3.0.2 (2021-04-07)
------------------
* Fixed a bug where ``send`` passed to applications wasn't a true async
function but a lambda wrapper, preventing it from being used with
``asgiref.sync.async_to_sync()``.
3.0.1 (2020-11-12)
------------------
* Fixed a bug where ``asyncio.CancelledError`` was not correctly handled on
Python 3.8+, resulting in incorrect protocol application cleanup.
3.0.0 (2020-10-28)
------------------
* Updates internals to use ASGI v3 throughout. ``asgiref.compatibility`` is
used for older applications.
* Consequently, the `--asgi-protocol` command-line option is removed.
* HTTP request bodies are now read, and passed to the application, in chunks.
* Added support for Python 3.9.
* Dropped support for Python 3.5.
2.5.0 (2020-04-15)
------------------
* Fixes compatability for twisted when running Python 3.8+ on Windows, by
setting ``asyncio.WindowsSelectorEventLoopPolicy`` as the event loop policy
in this case.
* The internal ``daphne.testing.TestApplication`` now requires an addition
``lock`` argument to ``__init__()``. This is expected to be an instance of
``multiprocessing.Lock``.
2.4.1 (2019-12-18)
------------------
* Avoids Twisted using the default event loop, for compatibility with Django
3.0's ``async_unsafe()`` decorator in threaded contexts, such as using the
auto-reloader.
2.4.0 (2019-11-20)
------------------
* Adds CI testing against and support for Python 3.8.
* Adds support for ``raw_path`` in ASGI scope.
* Ensures an error response is sent to the client if the application sends
malformed headers.
* Resolves an asyncio + multiprocessing problem when testing that would cause
the test suite to fail/hang on macOS.
* Requires installing Twisted's TLS extras, via ``install_requires``.
* Adds missing LICENSE to distribution.
2.3.0 (2019-04-09)
------------------
* Added support for ASGI v3.
2.2.5 (2019-01-31)
------------------
* WebSocket handshakes are now affected by the websocket connect timeout, so
you can limit them from the command line.
* Server name can now be set using --server-name
2.2.4 (2018-12-15)
------------------
* No longer listens on port 8000 when a file descriptor is provided with --fd
* Fixed a memory leak with WebSockets
2.2.3 (2018-11-06)
------------------
* Enforce that response headers are only bytestrings, rather than allowing
unicode strings and coercing them into bytes.
* New command-line options to set proxy header names: --proxy-headers-host and
--proxy-headers-port.
2.2.2 (2018-08-16)
------------------
* X-Forwarded-Proto support is now present and enabled if you turn on the
--proxy-headers flag
* ASGI applications are no longer instantiated in a thread (the ASGI spec
was finalised to say all constructors must be non-blocking on the main thread)
2.2.1 (2018-07-22)
------------------
* Python 3.7 compatability is flagged and ensured by using Twisted 18.7 and
above as a dependency.
* The send() awaitable in applications no longer blocks if the connection is
closed.
* Fixed a race condition where applications would be cleaned up before they
had even started.
2.2.0 (2018-06-13)
------------------
* HTTP timeouts have been removed by default, as they were only needed
with ASGI/Channels 1. You can re-enable them with the --http-timeout
argument to Daphne.
* Occasional errors on application timeout for non-fully-opened sockets
and for trying to read closed requests under high load are fixed.
* X-Forwarded-For headers are now correctly decoded in all environments
and no longer have unicode matching issues.
2.1.2 (2018-05-24)
------------------
* Fixed spurious errors caused by websockets disconnecting before their
application was instantiated.
* Stronger checking for type-safety of headers as bytestrings
2.1.1 (2018-04-18)
------------------
* ASGI application constructors are now run in a threadpool as they might
contain blocking synchronous code.
2.1.0 (2018-03-05)
------------------
* Removed subprotocol support from server, as it never really worked. Subprotocols
can instead be negotiated by ASGI applications now.
* Non-ASCII query strings now raise a 400 Bad Request error rather than silently
breaking the logger
2.0.4 (2018-02-21)
------------------
* Ping timeouts no longer reset on outgoing data, only incoming data
* No more errors when connections close prematurely
2.0.3 (2018-02-07)
------------------
* Unix socket listening no longer errors during startup (introduced in 2.0.2)
* ASGI Applications are now not immediately killed on disconnection but instead
given --application-close-timeout seconds to exit (defaults to 10)
2.0.2 (2018-02-04)
------------------
* WebSockets are no longer closed after the duration of http_timeout
2.0.1 (2018-02-03)
------------------
* Updated logging to correctly route exceptions through the main Daphne logger
2.0.0 (2018-02-01)
------------------
* Major rewrite to the new async-based ASGI specification and to support
Channels 2. Not backwards compatible.
1.3.0 (2017-06-16)
------------------
* Ability to set the websocket connection timeout
* Server no longer reveals the exact Autobahn version number for security
* A few unicode fixes for Python 2/3 compatability
* Stopped logging messages to already-closed connections as ERROR
1.2.0 (2017-04-01)
------------------
* The new process-specific channel support is now implemented, resulting in
significantly less traffic to your channel backend.
* Native twisted blocking support for channel layers that support it is now
used. While it is a lot more efficient, it is also sometimes slightly more
latent; you can disable it using --force-sync.
* Native SSL termination is now correctly reflected in the ASGI-HTTP `scheme`
key.
* accept: False is now a valid way to deny a connection, as well as close: True.
* HTTP version is now correctly sent as one of "1.0", "1.1" or "2".
* More command line options for websocket timeouts
1.1.0 (2017-03-18)
------------------
* HTTP/2 termination is now supported natively. The Twisted dependency has been
increased to at least 17.1 as a result; for more information about setting up
HTTP/2, see the README.
* X-Forwarded-For decoding support understands IPv6 addresses, and picks the
most remote (leftmost) entry if there are multiple relay hosts.
* Fixed an error where `disconnect` messages would still try and get sent even
if the client never finished a request.
1.0.3 (2017-02-12)
------------------
* IPv6 addresses are correctly accepted as bind targets on the command line
* Twisted 17.1 compatability fixes for WebSocket receiving/keepalive and
proxy header detection.
1.0.2 (2017-02-01)
------------------
* The "null" WebSocket origin (including file:// and no value) is now accepted
by Daphne and passed onto the application to accept/deny.
* Listening on file descriptors works properly again.
* The DeprecationError caused by not passing endpoints into a Server class
directly is now a warning instead.
1.0.1 (2017-01-09)
------------------
* Endpoint unicode strings now work correctly on Python 2 and Python 3
1.0.0 (2017-01-08)
------------------
* BREAKING CHANGE: Daphne now requires acceptance of WebSocket connections
before it finishes the socket handshake and relays incoming packets.
You must upgrade to at least Channels 1.0.0 as well; see
http://channels.readthedocs.io/en/latest/releases/1.0.0.html for more.
* http.disconnect now has a `path` key
* WebSockets can now be closed with a specific code
* X-Forwarded-For header support; defaults to X-Forwarded-For, override with
--proxy-headers on the commandline.
* Twisted endpoint description string support with `-e` on the command line
(allowing for SNI/ACME support, among other things)
* Logging/error verbosity fixes and access log flushes properly
0.15.0 (2016-08-28) 0.15.0 (2016-08-28)
------------------- -------------------
@ -566,4 +187,4 @@ noted should remain usable with Channels v3 projects, but updating Channels to t
* http.disconnect messages are now sent * http.disconnect messages are now sent
* Request handling speed significantly improved * Request handling speed significantly improved

27
LICENSE
View File

@ -1,27 +0,0 @@
Copyright (c) Django Software Foundation and individual contributors.
All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the name of Django nor the names of its contributors may be used
to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -1,2 +0,0 @@
include LICENSE
recursive-include tests *.py

View File

@ -13,4 +13,4 @@ endif
git tag $(version) git tag $(version)
git push git push
git push --tags git push --tags
#python setup.py sdist bdist_wheel upload python setup.py sdist bdist_wheel upload

View File

@ -1,13 +1,15 @@
daphne daphne
====== ======
.. image:: https://api.travis-ci.org/andrewgodwin/daphne.svg
:target: https://travis-ci.org/andrewgodwin/daphne
.. image:: https://img.shields.io/pypi/v/daphne.svg .. image:: https://img.shields.io/pypi/v/daphne.svg
:target: https://pypi.python.org/pypi/daphne :target: https://pypi.python.org/pypi/daphne
Daphne is a HTTP, HTTP2 and WebSocket protocol server for Daphne is a HTTP, HTTP2 and WebSocket protocol server for
`ASGI <https://github.com/django/asgiref/blob/main/specs/asgi.rst>`_ and `ASGI <http://channels.readthedocs.org/en/latest/asgi.html>`_, and developed
`ASGI-HTTP <https://github.com/django/asgiref/blob/main/specs/www.rst>`_, to power Django Channels.
developed to power Django Channels.
It supports automatic negotiation of protocols; there's no need for URL It supports automatic negotiation of protocols; there's no need for URL
prefixing to determine WebSocket endpoints versus HTTP endpoints. prefixing to determine WebSocket endpoints versus HTTP endpoints.
@ -16,76 +18,26 @@ prefixing to determine WebSocket endpoints versus HTTP endpoints.
Running Running
------- -------
Simply point Daphne to your ASGI application, and optionally Simply point Daphne to your ASGI channel layer instance, and optionally
set a bind address and port (defaults to localhost, port 8000):: set a bind address and port (defaults to localhost, port 8000)::
daphne -b 0.0.0.0 -p 8001 django_project.asgi:application daphne -b 0.0.0.0 -p 8001 django_project.asgi:channel_layer
If you intend to run daphne behind a proxy server you can use UNIX If you intend to run daphne behind a proxy server you can use UNIX
sockets to communicate between the two:: sockets to communicate between the two::
daphne -u /tmp/daphne.sock django_project.asgi:application daphne -u /tmp/daphne.sock django_project.asgi:channel_layer
If daphne is being run inside a process manager, you might
If daphne is being run inside a process manager such as
`Circus <https://github.com/circus-tent/circus/>`_ you might
want it to bind to a file descriptor passed down from a parent process. want it to bind to a file descriptor passed down from a parent process.
To achieve this you can use the --fd flag:: To achieve this you can use the --fd flag::
daphne --fd 5 django_project.asgi:application daphne --fd 5 django_project.asgi:channel_layer
If you want more control over the port/socket bindings you can fall back to
using `twisted's endpoint description strings
<http://twistedmatrix.com/documents/current/api/twisted.internet.endpoints.html#serverFromString>`_
by using the `--endpoint (-e)` flag, which can be used multiple times.
This line would start a SSL server on port 443, assuming that `key.pem` and `crt.pem`
exist in the current directory (requires pyopenssl to be installed)::
daphne -e ssl:443:privateKey=key.pem:certKey=crt.pem django_project.asgi:application
Endpoints even let you use the ``txacme`` endpoint syntax to get automatic certificates
from Let's Encrypt, which you can read more about at http://txacme.readthedocs.io/en/stable/.
To see all available command line options run daphne with the ``-h`` flag.
HTTP/2 Support
--------------
Daphne supports terminating HTTP/2 connections natively. You'll
need to do a couple of things to get it working, though. First, you need to
make sure you install the Twisted ``http2`` and ``tls`` extras::
pip install -U "Twisted[tls,http2]"
Next, because all current browsers only support HTTP/2 when using TLS, you will
need to start Daphne with TLS turned on, which can be done using the Twisted endpoint syntax::
daphne -e ssl:443:privateKey=key.pem:certKey=crt.pem django_project.asgi:application
Alternatively, you can use the ``txacme`` endpoint syntax or anything else that
enables TLS under the hood.
You will also need to be on a system that has **OpenSSL 1.0.2 or greater**; if you are
using Ubuntu, this means you need at least Ubuntu 16.04.
Now, when you start up Daphne, it should tell you this in the log::
2017-03-18 19:14:02,741 INFO Starting server at ssl:port=8000:privateKey=privkey.pem:certKey=cert.pem, channel layer django_project.asgi:channel_layer.
2017-03-18 19:14:02,742 INFO HTTP/2 support enabled
Then, connect with a browser that supports HTTP/2, and everything should be
working. It's often hard to tell that HTTP/2 is working, as the log Daphne gives you
will be identical (it's HTTP, after all), and most browsers don't make it obvious
in their network inspector windows. There are browser extensions that will let
you know clearly if it's working or not.
Daphne only supports "normal" requests over HTTP/2 at this time; there is not
yet support for extended features like Server Push. It will, however, result in
much faster connections and lower overheads.
If you have a reverse proxy in front of your site to serve static files or
similar, HTTP/2 will only work if that proxy understands and passes through the
connection correctly.
To see all available command line options run daphne with the *-h* flag.
Root Path (SCRIPT_NAME) Root Path (SCRIPT_NAME)
----------------------- -----------------------
@ -102,36 +54,4 @@ WSGI ``SCRIPT_NAME`` setting, you have two options:
The header takes precedence if both are set. As with ``SCRIPT_ALIAS``, the value The header takes precedence if both are set. As with ``SCRIPT_ALIAS``, the value
should start with a slash, but not end with one; for example:: should start with a slash, but not end with one; for example::
daphne --root-path=/forum django_project.asgi:application daphne --root-path=/forum django_project.asgi:channel_layer
Python Support
--------------
Daphne requires Python 3.9 or later.
Contributing
------------
Please refer to the
`main Channels contributing docs <https://github.com/django/channels/blob/main/CONTRIBUTING.rst>`_.
To run tests, make sure you have installed the ``tests`` extra with the package::
cd daphne/
pip install -e '.[tests]'
pytest
Maintenance and Security
------------------------
To report security issues, please contact security@djangoproject.com. For GPG
signatures and more security process information, see
https://docs.djangoproject.com/en/dev/internals/security/.
To report bugs or request new features, please open a new GitHub issue.
This repository is part of the Channels project. For the shepherd and maintenance team, please see the
`main Channels readme <https://github.com/django/channels/blob/main/README.rst>`_.

View File

@ -1,14 +1 @@
import sys __version__ = "0.15.0"
__version__ = "4.1.3"
# Windows on Python 3.8+ uses ProactorEventLoop, which is not compatible with
# Twisted. Does not implement add_writer/add_reader.
# See https://bugs.python.org/issue37373
# and https://twistedmatrix.com/trac/ticket/9766
PY38_WIN = sys.version_info >= (3, 8) and sys.platform == "win32"
if PY38_WIN:
import asyncio
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())

View File

@ -1,3 +0,0 @@
from daphne.cli import CommandLineInterface
CommandLineInterface.entrypoint()

View File

@ -1,7 +1,7 @@
import datetime import datetime
class AccessLogGenerator: class AccessLogGenerator(object):
""" """
Object that implements the Daphne "action logger" internal interface in Object that implements the Daphne "action logger" internal interface in
order to provide an access log in something resembling NCSA format. order to provide an access log in something resembling NCSA format.
@ -17,48 +17,33 @@ class AccessLogGenerator:
# HTTP requests # HTTP requests
if protocol == "http" and action == "complete": if protocol == "http" and action == "complete":
self.write_entry( self.write_entry(
host=details["client"], host=details['client'],
date=datetime.datetime.now(), date=datetime.datetime.now(),
request="%(method)s %(path)s" % details, request="%(method)s %(path)s" % details,
status=details["status"], status=details['status'],
length=details["size"], length=details['size'],
) )
# Websocket requests # Websocket requests
elif protocol == "websocket" and action == "connecting":
self.write_entry(
host=details["client"],
date=datetime.datetime.now(),
request="WSCONNECTING %(path)s" % details,
)
elif protocol == "websocket" and action == "rejected":
self.write_entry(
host=details["client"],
date=datetime.datetime.now(),
request="WSREJECT %(path)s" % details,
)
elif protocol == "websocket" and action == "connected": elif protocol == "websocket" and action == "connected":
self.write_entry( self.write_entry(
host=details["client"], host=details['client'],
date=datetime.datetime.now(), date=datetime.datetime.now(),
request="WSCONNECT %(path)s" % details, request="WSCONNECT %(path)s" % details,
) )
elif protocol == "websocket" and action == "disconnected": elif protocol == "websocket" and action == "disconnected":
self.write_entry( self.write_entry(
host=details["client"], host=details['client'],
date=datetime.datetime.now(), date=datetime.datetime.now(),
request="WSDISCONNECT %(path)s" % details, request="WSDISCONNECT %(path)s" % details,
) )
def write_entry( def write_entry(self, host, date, request, status=None, length=None, ident=None, user=None):
self, host, date, request, status=None, length=None, ident=None, user=None
):
""" """
Writes an NCSA-style entry to the log file (some liberty is taken with Writes an NCSA-style entry to the log file (some liberty is taken with
what the entries are for non-HTTP) what the entries are for non-HTTP)
""" """
self.stream.write( self.stream.write(
'%s %s %s [%s] "%s" %s %s\n' "%s %s %s [%s] \"%s\" %s %s\n" % (
% (
host, host,
ident or "-", ident or "-",
user or "-", user or "-",

View File

@ -1,16 +0,0 @@
# Import the server here to ensure the reactor is installed very early on in case other
# packages import twisted.internet.reactor (e.g. raven does this).
from django.apps import AppConfig
from django.core import checks
import daphne.server # noqa: F401
from .checks import check_daphne_installed
class DaphneConfig(AppConfig):
name = "daphne"
verbose_name = "Daphne"
def ready(self):
checks.register(check_daphne_installed, checks.Tags.staticfiles)

View File

@ -1,21 +0,0 @@
# Django system check to ensure daphne app is listed in INSTALLED_APPS before django.contrib.staticfiles.
from django.core.checks import Error, register
@register()
def check_daphne_installed(app_configs, **kwargs):
from django.apps import apps
from django.contrib.staticfiles.apps import StaticFilesConfig
from daphne.apps import DaphneConfig
for app in apps.get_app_configs():
if isinstance(app, DaphneConfig):
return []
if isinstance(app, StaticFilesConfig):
return [
Error(
"Daphne must be listed before django.contrib.staticfiles in INSTALLED_APPS.",
id="daphne.E001",
)
]

View File

@ -1,167 +1,101 @@
import sys
import argparse import argparse
import logging import logging
import sys import importlib
from argparse import ArgumentError, Namespace
from asgiref.compatibility import guarantee_single_callable
from .access import AccessLogGenerator
from .endpoints import build_endpoint_description_strings
from .server import Server from .server import Server
from .utils import import_by_path from .access import AccessLogGenerator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 8000
class CommandLineInterface(object):
class CommandLineInterface:
""" """
Acts as the main CLI entry point for running the server. Acts as the main CLI entry point for running the server.
""" """
description = "Django HTTP/WebSocket server" description = "Django HTTP/WebSocket server"
server_class = Server
def __init__(self): def __init__(self):
self.parser = argparse.ArgumentParser(description=self.description) self.parser = argparse.ArgumentParser(
self.parser.add_argument( description=self.description,
"-p", "--port", type=int, help="Port number to listen on", default=None
) )
self.parser.add_argument( self.parser.add_argument(
"-b", '-p',
"--bind", '--port',
dest="host", type=int,
help="The host/address to bind to", help='Port number to listen on',
default=8000,
)
self.parser.add_argument(
'-b',
'--bind',
dest='host',
help='The host/address to bind to',
default="127.0.0.1",
)
self.parser.add_argument(
'-u',
'--unix-socket',
dest='unix_socket',
help='Bind to a UNIX socket rather than a TCP host/port',
default=None, default=None,
) )
self.parser.add_argument( self.parser.add_argument(
"--websocket_timeout", '--fd',
type=int, type=int,
help="Maximum time to allow a websocket to be connected. -1 for infinite.", dest='file_descriptor',
default=86400, help='Bind to a file descriptor rather than a TCP host/port or named unix socket',
)
self.parser.add_argument(
"--websocket_connect_timeout",
type=int,
help="Maximum time to allow a connection to handshake. -1 for infinite",
default=5,
)
self.parser.add_argument(
"-u",
"--unix-socket",
dest="unix_socket",
help="Bind to a UNIX socket rather than a TCP host/port",
default=None, default=None,
) )
self.parser.add_argument( self.parser.add_argument(
"--fd", '-v',
'--verbosity',
type=int, type=int,
dest="file_descriptor", help='How verbose to make the output',
help="Bind to a file descriptor rather than a TCP host/port or named unix socket",
default=None,
)
self.parser.add_argument(
"-e",
"--endpoint",
dest="socket_strings",
action="append",
help="Use raw server strings passed directly to twisted",
default=[],
)
self.parser.add_argument(
"-v",
"--verbosity",
type=int,
help="How verbose to make the output",
default=1, default=1,
) )
self.parser.add_argument( self.parser.add_argument(
"-t", '-t',
"--http-timeout", '--http-timeout',
type=int, type=int,
help="How long to wait for worker before timing out HTTP connections", help='How long to wait for worker server before timing out HTTP connections',
default=120,
)
self.parser.add_argument(
'--access-log',
help='Where to write the access log (- for stdout, the default for verbosity=1)',
default=None, default=None,
) )
self.parser.add_argument( self.parser.add_argument(
"--access-log", '--ping-interval',
help="Where to write the access log (- for stdout, the default for verbosity=1)",
default=None,
)
self.parser.add_argument(
"--log-fmt",
help="Log format to use",
default="%(asctime)-15s %(levelname)-8s %(message)s",
)
self.parser.add_argument(
"--ping-interval",
type=int, type=int,
help="The number of seconds a WebSocket must be idle before a keepalive ping is sent", help='The number of seconds a WebSocket must be idle before a keepalive ping is sent',
default=20, default=20,
) )
self.parser.add_argument( self.parser.add_argument(
"--ping-timeout", '--ping-timeout',
type=int, type=int,
help="The number of seconds before a WebSocket is closed if no response to a keepalive ping", help='The number of seconds before a WeSocket is closed if no response to a keepalive ping',
default=30, default=30,
) )
self.parser.add_argument( self.parser.add_argument(
"--application-close-timeout", 'channel_layer',
type=int, help='The ASGI channel layer instance to use as path.to.module:instance.path',
help="The number of seconds an ASGI application has to exit after client disconnect before it is killed",
default=10,
) )
self.parser.add_argument( self.parser.add_argument(
"--root-path", '--ws-protocol',
dest="root_path", nargs='*',
help="The setting for the ASGI root_path variable", dest='ws_protocols',
help='The WebSocket protocols you wish to support',
default=None,
)
self.parser.add_argument(
'--root-path',
dest='root_path',
help='The setting for the ASGI root_path variable',
default="", default="",
) )
self.parser.add_argument(
"--proxy-headers",
dest="proxy_headers",
help="Enable parsing and using of X-Forwarded-For and X-Forwarded-Port headers and using that as the "
"client address",
default=False,
action="store_true",
)
self.arg_proxy_host = self.parser.add_argument(
"--proxy-headers-host",
dest="proxy_headers_host",
help="Specify which header will be used for getting the host "
"part. Can be omitted, requires --proxy-headers to be specified "
'when passed. "X-Real-IP" (when passed by your webserver) is a '
"good candidate for this.",
default=False,
action="store",
)
self.arg_proxy_port = self.parser.add_argument(
"--proxy-headers-port",
dest="proxy_headers_port",
help="Specify which header will be used for getting the port "
"part. Can be omitted, requires --proxy-headers to be specified "
"when passed.",
default=False,
action="store",
)
self.parser.add_argument(
"application",
help="The application to dispatch to as path.to.module:instance.path",
)
self.parser.add_argument(
"-s",
"--server-name",
dest="server_name",
help="specify which value should be passed to response header Server attribute",
default="daphne",
)
self.parser.add_argument(
"--no-server-name", dest="server_name", action="store_const", const=""
)
self.server = None
@classmethod @classmethod
def entrypoint(cls): def entrypoint(cls):
@ -170,37 +104,6 @@ class CommandLineInterface:
""" """
cls().run(sys.argv[1:]) cls().run(sys.argv[1:])
def _check_proxy_headers_passed(self, argument: str, args: Namespace):
"""Raise if the `--proxy-headers` weren't specified."""
if args.proxy_headers:
return
raise ArgumentError(
argument=argument,
message="--proxy-headers has to be passed for this parameter.",
)
def _get_forwarded_host(self, args: Namespace):
"""
Return the default host header from which the remote hostname/ip
will be extracted.
"""
if args.proxy_headers_host:
self._check_proxy_headers_passed(argument=self.arg_proxy_host, args=args)
return args.proxy_headers_host
if args.proxy_headers:
return "X-Forwarded-For"
def _get_forwarded_port(self, args: Namespace):
"""
Return the default host header from which the remote hostname/ip
will be extracted.
"""
if args.proxy_headers_port:
self._check_proxy_headers_passed(argument=self.arg_proxy_port, args=args)
return args.proxy_headers_port
if args.proxy_headers:
return "X-Forwarded-Port"
def run(self, args): def run(self, args):
""" """
Pass in raw argument list and it will decode them Pass in raw argument list and it will decode them
@ -210,13 +113,12 @@ class CommandLineInterface:
args = self.parser.parse_args(args) args = self.parser.parse_args(args)
# Set up logging # Set up logging
logging.basicConfig( logging.basicConfig(
level={ level = {
0: logging.WARN, 0: logging.WARN,
1: logging.INFO, 1: logging.INFO,
2: logging.DEBUG, 2: logging.DEBUG,
3: logging.DEBUG, # Also turns on asyncio debug
}[args.verbosity], }[args.verbosity],
format=args.log_fmt, format = "%(asctime)-15s %(levelname)-8s %(message)s" ,
) )
# If verbosity is 1 or greater, or they told us explicitly, set up access log # If verbosity is 1 or greater, or they told us explicitly, set up access log
access_log_stream = None access_log_stream = None
@ -224,62 +126,31 @@ class CommandLineInterface:
if args.access_log == "-": if args.access_log == "-":
access_log_stream = sys.stdout access_log_stream = sys.stdout
else: else:
access_log_stream = open(args.access_log, "a", 1) access_log_stream = open(args.access_log, "a")
elif args.verbosity >= 1: elif args.verbosity >= 1:
access_log_stream = sys.stdout access_log_stream = sys.stdout
# Import channel layer
# Import application
sys.path.insert(0, ".") sys.path.insert(0, ".")
application = import_by_path(args.application) module_path, object_path = args.channel_layer.split(":", 1)
application = guarantee_single_callable(application) channel_layer = importlib.import_module(module_path)
for bit in object_path.split("."):
# Set up port/host bindings channel_layer = getattr(channel_layer, bit)
if not any( # Run server
[ logger.info(
args.host, "Starting server at %s, channel layer %s",
args.port is not None, (args.unix_socket if args.unix_socket else "%s:%s" % (args.host, args.port)),
args.unix_socket, args.channel_layer,
args.file_descriptor is not None, )
args.socket_strings, Server(
] channel_layer=channel_layer,
):
# no advanced binding options passed, patch in defaults
args.host = DEFAULT_HOST
args.port = DEFAULT_PORT
elif args.host and args.port is None:
args.port = DEFAULT_PORT
elif args.port is not None and not args.host:
args.host = DEFAULT_HOST
# Build endpoint description strings from (optional) cli arguments
endpoints = build_endpoint_description_strings(
host=args.host, host=args.host,
port=args.port, port=args.port,
unix_socket=args.unix_socket, unix_socket=args.unix_socket,
file_descriptor=args.file_descriptor, file_descriptor=args.file_descriptor,
)
endpoints = sorted(args.socket_strings + endpoints)
# Start the server
logger.info("Starting server at {}".format(", ".join(endpoints)))
self.server = self.server_class(
application=application,
endpoints=endpoints,
http_timeout=args.http_timeout, http_timeout=args.http_timeout,
ping_interval=args.ping_interval, ping_interval=args.ping_interval,
ping_timeout=args.ping_timeout, ping_timeout=args.ping_timeout,
websocket_timeout=args.websocket_timeout, action_logger=AccessLogGenerator(access_log_stream) if access_log_stream else None,
websocket_connect_timeout=args.websocket_connect_timeout, ws_protocols=args.ws_protocols,
websocket_handshake_timeout=args.websocket_connect_timeout,
application_close_timeout=args.application_close_timeout,
action_logger=(
AccessLogGenerator(access_log_stream) if access_log_stream else None
),
root_path=args.root_path, root_path=args.root_path,
verbosity=args.verbosity, ).run()
proxy_forwarded_address_header=self._get_forwarded_host(args=args),
proxy_forwarded_port_header=self._get_forwarded_port(args=args),
proxy_forwarded_proto_header=(
"X-Forwarded-Proto" if args.proxy_headers else None
),
server_name=args.server_name,
)
self.server.run()

View File

@ -1,22 +0,0 @@
def build_endpoint_description_strings(
host=None, port=None, unix_socket=None, file_descriptor=None
):
"""
Build a list of twisted endpoint description strings that the server will listen on.
This is to streamline the generation of twisted endpoint description strings from easier
to use command line args such as host, port, unix sockets etc.
"""
socket_descriptions = []
if host and port is not None:
host = host.strip("[]").replace(":", r"\:")
socket_descriptions.append("tcp:port=%d:interface=%s" % (int(port), host))
elif any([host, port]):
raise ValueError("TCP binding requires both port and host kwargs.")
if unix_socket:
socket_descriptions.append("unix:%s" % unix_socket)
if file_descriptor is not None:
socket_descriptions.append("fd:fileno=%d" % int(file_descriptor))
return socket_descriptions

View File

@ -1,15 +1,15 @@
from __future__ import unicode_literals
import logging import logging
import six
import time import time
import traceback import traceback
from urllib.parse import unquote
from twisted.internet.defer import inlineCallbacks, maybeDeferred from six.moves.urllib_parse import unquote, unquote_plus
from twisted.internet.interfaces import IProtocolNegotiationFactory
from twisted.protocols.policies import ProtocolWrapper from twisted.protocols.policies import ProtocolWrapper
from twisted.web import http from twisted.web import http
from zope.interface import implementer
from .utils import HEADER_NAME_RE, parse_x_forwarded_for from .ws_protocol import WebSocketProtocol, WebSocketFactory
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -23,8 +23,7 @@ class WebRequest(http.Request):
GET and POST out. GET and POST out.
""" """
error_template = ( error_template = """
"""
<html> <html>
<head> <head>
<title>%(title)s</title> <title>%(title)s</title>
@ -41,64 +40,33 @@ class WebRequest(http.Request):
<footer>Daphne</footer> <footer>Daphne</footer>
</body> </body>
</html> </html>
""".replace( """.replace("\n", "").replace(" ", " ").replace(" ", " ").replace(" ", " ") # Shorten it a bit, bytes wise
"\n", ""
)
.replace(" ", " ")
.replace(" ", " ")
.replace(" ", " ")
) # Shorten it a bit, bytes wise
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.client_addr = None http.Request.__init__(self, *args, **kwargs)
self.server_addr = None # Easy factory link
try: self.factory = self.channel.factory
http.Request.__init__(self, *args, **kwargs) # Make a name for our reply channel
# Easy server link self.reply_channel = self.factory.channel_layer.new_channel("http.response!")
self.server = self.channel.factory.server # Tell factory we're that channel's client
self.application_queue = None self.last_keepalive = time.time()
self._response_started = False self.factory.reply_protocols[self.reply_channel] = self
self.server.protocol_connected(self) self._got_response_start = False
except Exception:
logger.error(traceback.format_exc())
raise
### Twisted progress callbacks
@inlineCallbacks
def process(self): def process(self):
try: try:
self.request_start = time.time() self.request_start = time.time()
# Validate header names.
for name, _ in self.requestHeaders.getAllRawHeaders():
if not HEADER_NAME_RE.fullmatch(name):
self.basic_error(400, b"Bad Request", "Invalid header name")
return
# Get upgrade header # Get upgrade header
upgrade_header = None upgrade_header = None
if self.requestHeaders.hasHeader(b"Upgrade"): if self.requestHeaders.hasHeader(b"Upgrade"):
upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0] upgrade_header = self.requestHeaders.getRawHeaders(b"Upgrade")[0]
# Get client address if possible # Get client address if possible
if hasattr(self.client, "host") and hasattr(self.client, "port"): if hasattr(self.client, "host") and hasattr(self.client, "port"):
# client.host and host.host are byte strings in Python 2, but spec self.client_addr = [self.client.host, self.client.port]
# requires unicode string. self.server_addr = [self.host.host, self.host.port]
self.client_addr = [str(self.client.host), self.client.port] else:
self.server_addr = [str(self.host.host), self.host.port] self.client_addr = None
self.server_addr = None
self.client_scheme = "https" if self.isSecure() else "http"
# See if we need to get the address from a proxy header instead
if self.server.proxy_forwarded_address_header:
self.client_addr, self.client_scheme = parse_x_forwarded_for(
self.requestHeaders,
self.server.proxy_forwarded_address_header,
self.server.proxy_forwarded_port_header,
self.server.proxy_forwarded_proto_header,
self.client_addr,
self.client_scheme,
)
# Check for unicodeish path (or it'll crash when trying to parse) # Check for unicodeish path (or it'll crash when trying to parse)
try: try:
self.path.decode("ascii") self.path.decode("ascii")
@ -110,25 +78,17 @@ class WebRequest(http.Request):
self.query_string = b"" self.query_string = b""
if b"?" in self.uri: if b"?" in self.uri:
self.query_string = self.uri.split(b"?", 1)[1] self.query_string = self.uri.split(b"?", 1)[1]
try:
self.query_string.decode("ascii")
except UnicodeDecodeError:
self.basic_error(400, b"Bad Request", "Invalid query string")
return
# Is it WebSocket? IS IT?! # Is it WebSocket? IS IT?!
if upgrade_header and upgrade_header.lower() == b"websocket": if upgrade_header and upgrade_header.lower() == b"websocket":
# Make WebSocket protocol to hand off to # Make WebSocket protocol to hand off to
protocol = self.server.ws_factory.buildProtocol( protocol = self.factory.ws_factory.buildProtocol(self.transport.getPeer())
self.transport.getPeer()
)
if not protocol: if not protocol:
# If protocol creation fails, we signal "internal server error" # If protocol creation fails, we signal "internal server error"
self.setResponseCode(500) self.setResponseCode(500)
logger.warn("Could not make WebSocket protocol") logger.warn("Could not make WebSocket protocol")
self.finish() self.finish()
# Give it the raw query string
protocol._raw_query_string = self.query_string
# Port across transport # Port across transport
protocol.set_main_factory(self.factory)
transport, self.transport = self.transport, None transport, self.transport = self.transport, None
if isinstance(transport, ProtocolWrapper): if isinstance(transport, ProtocolWrapper):
# i.e. TLS is a wrapping protocol # i.e. TLS is a wrapping protocol
@ -137,203 +97,147 @@ class WebRequest(http.Request):
transport.protocol = protocol transport.protocol = protocol
protocol.makeConnection(transport) protocol.makeConnection(transport)
# Re-inject request # Re-inject request
data = self.method + b" " + self.uri + b" HTTP/1.1\x0d\x0a" data = self.method + b' ' + self.uri + b' HTTP/1.1\x0d\x0a'
for h in self.requestHeaders.getAllRawHeaders(): for h in self.requestHeaders.getAllRawHeaders():
data += h[0] + b": " + b",".join(h[1]) + b"\x0d\x0a" data += h[0] + b': ' + b",".join(h[1]) + b'\x0d\x0a'
data += b"\x0d\x0a" data += b"\x0d\x0a"
data += self.content.read() data += self.content.read()
protocol.dataReceived(data) protocol.dataReceived(data)
# Remove our HTTP reply channel association # Remove our HTTP reply channel association
logger.debug("Upgraded connection %s to WebSocket", self.client_addr) if hasattr(protocol, "reply_channel"):
self.server.protocol_disconnected(self) logger.debug("Upgraded connection %s to WebSocket %s", self.reply_channel, protocol.reply_channel)
else:
logger.debug("Connection %s did not get successful WS handshake.", self.reply_channel)
del self.factory.reply_protocols[self.reply_channel]
self.reply_channel = None
# Resume the producer so we keep getting data, if it's available as a method # Resume the producer so we keep getting data, if it's available as a method
self.channel._networkProducer.resumeProducing() if hasattr(self.channel, "resumeProducing"):
self.channel.resumeProducing()
# Boring old HTTP. # Boring old HTTP.
else: else:
# Sanitize and decode headers, potentially extracting root path # Sanitize and decode headers, potentially extracting root path
self.clean_headers = [] self.clean_headers = []
self.root_path = self.server.root_path self.root_path = self.factory.root_path
for name, values in self.requestHeaders.getAllRawHeaders(): for name, values in self.requestHeaders.getAllRawHeaders():
# Prevent CVE-2015-0219 # Prevent CVE-2015-0219
if b"_" in name: if b"_" in name:
continue continue
for value in values: for value in values:
if name.lower() == b"daphne-root-path": if name.lower() == b"daphne-root-path":
self.root_path = unquote(value.decode("ascii")) self.root_path = self.unquote(value)
else: else:
self.clean_headers.append((name.lower(), value)) self.clean_headers.append((name.lower(), value))
logger.debug("HTTP %s request for %s", self.method, self.client_addr) logger.debug("HTTP %s request for %s", self.method, self.reply_channel)
self.content.seek(0, 0) self.content.seek(0, 0)
# Work out the application scope and create application # Send message
self.application_queue = yield maybeDeferred( try:
self.server.create_application, self.factory.channel_layer.send("http.request", {
self, "reply_channel": self.reply_channel,
{
"type": "http",
# TODO: Correctly say if it's 1.1 or 1.0 # TODO: Correctly say if it's 1.1 or 1.0
"http_version": self.clientproto.split(b"/")[-1].decode( "http_version": "1.1",
"ascii"
),
"method": self.method.decode("ascii"), "method": self.method.decode("ascii"),
"path": unquote(self.path.decode("ascii")), "path": self.unquote(self.path),
"raw_path": self.path,
"root_path": self.root_path, "root_path": self.root_path,
"scheme": self.client_scheme, "scheme": "http",
"query_string": self.query_string, "query_string": self.query_string,
"headers": self.clean_headers, "headers": self.clean_headers,
"body": self.content.read(),
"client": self.client_addr, "client": self.client_addr,
"server": self.server_addr, "server": self.server_addr,
}, })
) except self.factory.channel_layer.ChannelFull:
# Check they didn't close an unfinished request # Channel is too full; reject request with 503
if self.application_queue is None or self.content.closed: self.basic_error(503, b"Service Unavailable", "Request queue full.")
# Not much we can do, the request is prematurely abandoned.
return
# Run application against request
buffer_size = self.server.request_buffer_size
while True:
chunk = self.content.read(buffer_size)
more_body = not (len(chunk) < buffer_size)
payload = {
"type": "http.request",
"body": chunk,
"more_body": more_body,
}
self.application_queue.put_nowait(payload)
if not more_body:
break
except Exception: except Exception:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
self.basic_error( self.basic_error(500, b"Internal Server Error", "HTTP processing error")
500, b"Internal Server Error", "Daphne HTTP processing error"
) @classmethod
def unquote(cls, value, plus_as_space=False):
"""
Python 2 and 3 compat layer for utf-8 unquoting
"""
if six.PY2:
if plus_as_space:
return unquote_plus(value).decode("utf8")
else:
return unquote(value).decode("utf8")
else:
if plus_as_space:
return unquote_plus(value.decode("ascii"))
else:
return unquote(value.decode("ascii"))
def send_disconnect(self):
"""
Sends a disconnect message on the http.disconnect channel.
Useful only really for long-polling.
"""
try:
self.factory.channel_layer.send("http.disconnect", {
"reply_channel": self.reply_channel,
})
except self.factory.channel_layer.ChannelFull:
pass
def connectionLost(self, reason): def connectionLost(self, reason):
""" """
Cleans up reply channel on close. Cleans up reply channel on close.
""" """
if self.application_queue: if self.reply_channel and self.reply_channel in self.channel.factory.reply_protocols:
self.send_disconnect() self.send_disconnect()
logger.debug("HTTP disconnect for %s", self.client_addr) del self.channel.factory.reply_protocols[self.reply_channel]
logger.debug("HTTP disconnect for %s", self.reply_channel)
http.Request.connectionLost(self, reason) http.Request.connectionLost(self, reason)
self.server.protocol_disconnected(self)
def finish(self): def finish(self):
""" """
Cleans up reply channel on close. Cleans up reply channel on close.
""" """
if self.application_queue: if self.reply_channel and self.reply_channel in self.channel.factory.reply_protocols:
self.send_disconnect() self.send_disconnect()
logger.debug("HTTP close for %s", self.client_addr) del self.channel.factory.reply_protocols[self.reply_channel]
logger.debug("HTTP close for %s", self.reply_channel)
http.Request.finish(self) http.Request.finish(self)
self.server.protocol_disconnected(self)
### Server reply callbacks def serverResponse(self, message):
def handle_reply(self, message):
""" """
Handles a reply from the client Writes a received HTTP response back out to the transport.
""" """
# Handle connections that are already closed if "status" in message:
if self.finished or self.channel is None: if self._got_response_start:
return raise ValueError("Got multiple Response messages for %s!" % self.reply_channel)
# Check message validity self._got_response_start = True
if "type" not in message: # Write code
raise ValueError("Message has no type defined") self.setResponseCode(message['status'])
# Handle message
if message["type"] == "http.response.start":
if self._response_started:
raise ValueError("HTTP response has already been started")
self._response_started = True
if "status" not in message:
raise ValueError(
"Specifying a status code is required for a Response message."
)
# Set HTTP status code
self.setResponseCode(message["status"])
# Write headers # Write headers
for header, value in message.get("headers", {}): for header, value in message.get("headers", {}):
# Shim code from old ASGI version, can be removed after a while
if isinstance(header, six.text_type):
header = header.encode("latin1")
self.responseHeaders.addRawHeader(header, value) self.responseHeaders.addRawHeader(header, value)
if self.server.server_name and not self.responseHeaders.hasHeader("server"): logger.debug("HTTP %s response started for %s", message['status'], self.reply_channel)
self.setHeader(b"server", self.server.server_name.encode()) # Write out body
logger.debug( if "content" in message:
"HTTP %s response started for %s", message["status"], self.client_addr http.Request.write(self, message['content'])
) # End if there's no more content
elif message["type"] == "http.response.body": if not message.get("more_content", False):
if not self._response_started: self.finish()
raise ValueError( logger.debug("HTTP response complete for %s", self.reply_channel)
"HTTP response has not yet been started but got %s" try:
% message["type"] self.factory.log_action("http", "complete", {
) "path": self.path.decode("ascii"),
# Write out body "status": self.code,
http.Request.write(self, message.get("body", b"")) "method": self.method.decode("ascii"),
# End if there's no more content "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
if not message.get("more_body", False): "time_taken": self.duration(),
self.finish() "size": self.sentLength,
logger.debug("HTTP response complete for %s", self.client_addr) })
try: except Exception as e:
uri = self.uri.decode("ascii") logging.error(traceback.format_exc())
except UnicodeDecodeError:
# The path is malformed somehow - do our best to log something
uri = repr(self.uri)
try:
self.server.log_action(
"http",
"complete",
{
"path": uri,
"status": self.code,
"method": self.method.decode("ascii", "replace"),
"client": (
"%s:%s" % tuple(self.client_addr)
if self.client_addr
else None
),
"time_taken": self.duration(),
"size": self.sentLength,
},
)
except Exception:
logger.error(traceback.format_exc())
else:
logger.debug("HTTP response chunk for %s", self.client_addr)
else: else:
raise ValueError("Cannot handle message type %s!" % message["type"]) logger.debug("HTTP response chunk for %s", self.reply_channel)
def handle_exception(self, exception):
"""
Called by the server when our application tracebacks
"""
self.basic_error(500, b"Internal Server Error", "Exception inside application.")
def check_timeouts(self):
"""
Called periodically to see if we should timeout something
"""
# Web timeout checking
if self.server.http_timeout and self.duration() > self.server.http_timeout:
if self._response_started:
logger.warning("Application timed out while sending response")
self.finish()
else:
self.basic_error(
503,
b"Service Unavailable",
"Application failed to respond within time limit.",
)
### Utility functions
def send_disconnect(self):
"""
Sends a http.disconnect message.
Useful only really for long-polling.
"""
# If we don't yet have a path, then don't send as we never opened.
if self.path:
self.application_queue.put_nowait({"type": "http.disconnect"})
def duration(self): def duration(self):
""" """
@ -347,34 +251,25 @@ class WebRequest(http.Request):
""" """
Responds with a server-level error page (very basic) Responds with a server-level error page (very basic)
""" """
self.handle_reply( self.serverResponse({
{ "status": status,
"type": "http.response.start", "status_text": status_text,
"status": status, "headers": [
"headers": [(b"Content-Type", b"text/html; charset=utf-8")], (b"Content-Type", b"text/html; charset=utf-8"),
} ],
) "content": (self.error_template % {
self.handle_reply( "title": str(status) + " " + status_text.decode("ascii"),
{ "body": body,
"type": "http.response.body", }).encode("utf8"),
"body": ( })
self.error_template
% {
"title": str(status) + " " + status_text.decode("ascii"),
"body": body, class HTTPProtocol(http.HTTPChannel):
}
).encode("utf8"), requestFactory = WebRequest
}
)
def __hash__(self):
return hash(id(self))
def __eq__(self, other):
return id(self) == id(other)
@implementer(IProtocolNegotiationFactory)
class HTTPFactory(http.HTTPFactory): class HTTPFactory(http.HTTPFactory):
""" """
Factory which takes care of tracking which protocol Factory which takes care of tracking which protocol
@ -383,34 +278,70 @@ class HTTPFactory(http.HTTPFactory):
routed appropriately. routed appropriately.
""" """
def __init__(self, server): protocol = HTTPProtocol
def __init__(self, channel_layer, action_logger=None, timeout=120, websocket_timeout=86400, ping_interval=20, ping_timeout=30, ws_protocols=None, root_path=""):
http.HTTPFactory.__init__(self) http.HTTPFactory.__init__(self)
self.server = server self.channel_layer = channel_layer
self.action_logger = action_logger
self.timeout = timeout
self.websocket_timeout = websocket_timeout
self.ping_interval = ping_interval
# We track all sub-protocols for response channel mapping
self.reply_protocols = {}
# Make a factory for WebSocket protocols
self.ws_factory = WebSocketFactory(self, protocols=ws_protocols)
self.ws_factory.setProtocolOptions(autoPingTimeout=ping_timeout)
self.ws_factory.protocol = WebSocketProtocol
self.ws_factory.reply_protocols = self.reply_protocols
self.root_path = root_path
def buildProtocol(self, addr): def reply_channels(self):
""" return self.reply_protocols.keys()
Builds protocol instances. This override is used to ensure we use our
own Request object instead of the default.
"""
try:
protocol = http.HTTPFactory.buildProtocol(self, addr)
protocol.requestFactory = WebRequest
return protocol
except Exception:
logger.error("Cannot build protocol: %s" % traceback.format_exc())
raise
# IProtocolNegotiationFactory def dispatch_reply(self, channel, message):
def acceptableProtocols(self): if channel.startswith("http") and isinstance(self.reply_protocols[channel], WebRequest):
""" self.reply_protocols[channel].serverResponse(message)
Protocols this server can speak after ALPN negotiation. Currently that elif channel.startswith("websocket") and isinstance(self.reply_protocols[channel], WebSocketProtocol):
is HTTP/1.1 and optionally HTTP/2. Websockets cannot be negotiated # Ensure the message is a valid WebSocket one
using ALPN, so that doesn't go here: anyone wanting websockets will unknown_message_keys = set(message.keys()) - {"bytes", "text", "close"}
negotiate HTTP/1.1 and then do the upgrade dance. if unknown_message_keys:
""" raise ValueError(
baseProtocols = [b"http/1.1"] "Got invalid WebSocket reply message on %s - contains unknown keys %s" % (
channel,
unknown_message_keys,
)
)
if message.get("bytes", None):
self.reply_protocols[channel].serverSend(message["bytes"], True)
if message.get("text", None):
self.reply_protocols[channel].serverSend(message["text"], False)
if message.get("close", False):
self.reply_protocols[channel].serverClose()
else:
raise ValueError("Cannot dispatch message on channel %r" % channel)
if http.H2_ENABLED: def log_action(self, protocol, action, details):
baseProtocols.insert(0, b"h2") """
Dispatches to any registered action logger, if there is one.
"""
if self.action_logger:
self.action_logger(protocol, action, details)
return baseProtocols def check_timeouts(self):
"""
Runs through all HTTP protocol instances and times them out if they've
taken too long (and so their message is probably expired)
"""
for protocol in list(self.reply_protocols.values()):
# Web timeout checking
if isinstance(protocol, WebRequest) and protocol.duration() > self.timeout:
protocol.basic_error(503, b"Service Unavailable", "Worker server failed to respond within time limit.")
# WebSocket timeout checking and keepalive ping sending
elif isinstance(protocol, WebSocketProtocol):
# Timeout check
if protocol.duration() > self.websocket_timeout:
protocol.serverClose()
# Ping check
else:
protocol.check_ping()

View File

@ -1,203 +0,0 @@
import datetime
import importlib
import logging
import sys
from django.apps import apps
from django.conf import settings
from django.contrib.staticfiles.handlers import ASGIStaticFilesHandler
from django.core.exceptions import ImproperlyConfigured
from django.core.management import CommandError
from django.core.management.commands.runserver import Command as RunserverCommand
from daphne import __version__
from daphne.endpoints import build_endpoint_description_strings
from daphne.server import Server
logger = logging.getLogger("django.channels.server")
def get_default_application():
"""
Gets the default application, set in the ASGI_APPLICATION setting.
"""
try:
path, name = settings.ASGI_APPLICATION.rsplit(".", 1)
except (ValueError, AttributeError):
raise ImproperlyConfigured("Cannot find ASGI_APPLICATION setting.")
try:
module = importlib.import_module(path)
except ImportError:
raise ImproperlyConfigured("Cannot import ASGI_APPLICATION module %r" % path)
try:
value = getattr(module, name)
except AttributeError:
raise ImproperlyConfigured(
f"Cannot find {name!r} in ASGI_APPLICATION module {path}"
)
return value
class Command(RunserverCommand):
protocol = "http"
server_cls = Server
def add_arguments(self, parser):
super().add_arguments(parser)
parser.add_argument(
"--noasgi",
action="store_false",
dest="use_asgi",
default=True,
help="Run the old WSGI-based runserver rather than the ASGI-based one",
)
parser.add_argument(
"--http_timeout",
action="store",
dest="http_timeout",
type=int,
default=None,
help=(
"Specify the daphne http_timeout interval in seconds "
"(default: no timeout)"
),
)
parser.add_argument(
"--websocket_handshake_timeout",
action="store",
dest="websocket_handshake_timeout",
type=int,
default=5,
help=(
"Specify the daphne websocket_handshake_timeout interval in "
"seconds (default: 5)"
),
)
parser.add_argument(
"--nostatic",
action="store_false",
dest="use_static_handler",
help="Tells Django to NOT automatically serve static files at STATIC_URL.",
)
parser.add_argument(
"--insecure",
action="store_true",
dest="insecure_serving",
help="Allows serving static files even if DEBUG is False.",
)
def handle(self, *args, **options):
self.http_timeout = options.get("http_timeout", None)
self.websocket_handshake_timeout = options.get("websocket_handshake_timeout", 5)
# Check Channels is installed right
if options["use_asgi"] and not hasattr(settings, "ASGI_APPLICATION"):
raise CommandError(
"You have not set ASGI_APPLICATION, which is needed to run the server."
)
# Dispatch upward
super().handle(*args, **options)
def inner_run(self, *args, **options):
# Maybe they want the wsgi one?
if not options.get("use_asgi", True):
if hasattr(RunserverCommand, "server_cls"):
self.server_cls = RunserverCommand.server_cls
return RunserverCommand.inner_run(self, *args, **options)
# Run checks
self.stdout.write("Performing system checks...\n\n")
self.check(display_num_errors=True)
self.check_migrations()
# Print helpful text
quit_command = "CTRL-BREAK" if sys.platform == "win32" else "CONTROL-C"
now = datetime.datetime.now().strftime("%B %d, %Y - %X")
self.stdout.write(now)
self.stdout.write(
(
"Django version %(version)s, using settings %(settings)r\n"
"Starting ASGI/Daphne version %(daphne_version)s development server"
" at %(protocol)s://%(addr)s:%(port)s/\n"
"Quit the server with %(quit_command)s.\n"
)
% {
"version": self.get_version(),
"daphne_version": __version__,
"settings": settings.SETTINGS_MODULE,
"protocol": self.protocol,
"addr": "[%s]" % self.addr if self._raw_ipv6 else self.addr,
"port": self.port,
"quit_command": quit_command,
}
)
# Launch server in 'main' thread. Signals are disabled as it's still
# actually a subthread under the autoreloader.
logger.debug("Daphne running, listening on %s:%s", self.addr, self.port)
# build the endpoint description string from host/port options
endpoints = build_endpoint_description_strings(host=self.addr, port=self.port)
try:
self.server_cls(
application=self.get_application(options),
endpoints=endpoints,
signal_handlers=not options["use_reloader"],
action_logger=self.log_action,
http_timeout=self.http_timeout,
root_path=getattr(settings, "FORCE_SCRIPT_NAME", "") or "",
websocket_handshake_timeout=self.websocket_handshake_timeout,
).run()
logger.debug("Daphne exited")
except KeyboardInterrupt:
shutdown_message = options.get("shutdown_message", "")
if shutdown_message:
self.stdout.write(shutdown_message)
return
def get_application(self, options):
"""
Returns the static files serving application wrapping the default application,
if static files should be served. Otherwise just returns the default
handler.
"""
staticfiles_installed = apps.is_installed("django.contrib.staticfiles")
use_static_handler = options.get("use_static_handler", staticfiles_installed)
insecure_serving = options.get("insecure_serving", False)
if use_static_handler and (settings.DEBUG or insecure_serving):
return ASGIStaticFilesHandler(get_default_application())
else:
return get_default_application()
def log_action(self, protocol, action, details):
"""
Logs various different kinds of requests to the console.
"""
# HTTP requests
if protocol == "http" and action == "complete":
msg = "HTTP %(method)s %(path)s %(status)s [%(time_taken).2f, %(client)s]"
# Utilize terminal colors, if available
if 200 <= details["status"] < 300:
# Put 2XX first, since it should be the common case
logger.info(self.style.HTTP_SUCCESS(msg), details)
elif 100 <= details["status"] < 200:
logger.info(self.style.HTTP_INFO(msg), details)
elif details["status"] == 304:
logger.info(self.style.HTTP_NOT_MODIFIED(msg), details)
elif 300 <= details["status"] < 400:
logger.info(self.style.HTTP_REDIRECT(msg), details)
elif details["status"] == 404:
logger.warning(self.style.HTTP_NOT_FOUND(msg), details)
elif 400 <= details["status"] < 500:
logger.warning(self.style.HTTP_BAD_REQUEST(msg), details)
else:
# Any 5XX, or any other response
logger.error(self.style.HTTP_SERVER_ERROR(msg), details)
# Websocket requests
elif protocol == "websocket" and action == "connected":
logger.info("WebSocket CONNECT %(path)s [%(client)s]", details)
elif protocol == "websocket" and action == "disconnected":
logger.info("WebSocket DISCONNECT %(path)s [%(client)s]", details)
elif protocol == "websocket" and action == "connecting":
logger.info("WebSocket HANDSHAKING %(path)s [%(client)s]", details)
elif protocol == "websocket" and action == "rejected":
logger.info("WebSocket REJECT %(path)s [%(client)s]", details)

View File

@ -1,342 +1,137 @@
# This has to be done first as Twisted is import-order-sensitive with reactors
import asyncio # isort:skip
import os # isort:skip
import sys # isort:skip
import warnings # isort:skip
from concurrent.futures import ThreadPoolExecutor # isort:skip
from twisted.internet import asyncioreactor # isort:skip
twisted_loop = asyncio.new_event_loop()
if "ASGI_THREADS" in os.environ:
twisted_loop.set_default_executor(
ThreadPoolExecutor(max_workers=int(os.environ["ASGI_THREADS"]))
)
current_reactor = sys.modules.get("twisted.internet.reactor", None)
if current_reactor is not None:
if not isinstance(current_reactor, asyncioreactor.AsyncioSelectorReactor):
warnings.warn(
"Something has already installed a non-asyncio Twisted reactor. Attempting to uninstall it; "
+ "you can fix this warning by importing daphne.server early in your codebase or "
+ "finding the package that imports Twisted and importing it later on.",
UserWarning,
stacklevel=2,
)
del sys.modules["twisted.internet.reactor"]
asyncioreactor.install(twisted_loop)
else:
asyncioreactor.install(twisted_loop)
import logging import logging
import time import socket
from concurrent.futures import CancelledError
from functools import partial
from twisted.internet import defer, reactor from twisted.internet import reactor, defer
from twisted.internet.endpoints import serverFromString from twisted.logger import globalLogBeginner
from twisted.logger import STDLibLogObserver, globalLogBeginner
from twisted.web import http
from .http_protocol import HTTPFactory from .http_protocol import HTTPFactory
from .ws_protocol import WebSocketFactory
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Server: class Server(object):
def __init__( def __init__(
self, self,
application, channel_layer,
endpoints=None, host="127.0.0.1",
port=8000,
unix_socket=None,
file_descriptor=None,
signal_handlers=True, signal_handlers=True,
action_logger=None, action_logger=None,
http_timeout=None, http_timeout=120,
request_buffer_size=8192, websocket_timeout=None,
websocket_timeout=86400,
websocket_connect_timeout=20,
ping_interval=20, ping_interval=20,
ping_timeout=30, ping_timeout=30,
ws_protocols=None,
root_path="", root_path="",
proxy_forwarded_address_header=None,
proxy_forwarded_port_header=None,
proxy_forwarded_proto_header=None,
verbosity=1,
websocket_handshake_timeout=5,
application_close_timeout=10,
ready_callable=None,
server_name="daphne",
): ):
self.application = application self.channel_layer = channel_layer
self.endpoints = endpoints or [] self.host = host
self.listeners = [] self.port = port
self.listening_addresses = [] self.unix_socket = unix_socket
self.file_descriptor = file_descriptor
self.signal_handlers = signal_handlers self.signal_handlers = signal_handlers
self.action_logger = action_logger self.action_logger = action_logger
self.http_timeout = http_timeout self.http_timeout = http_timeout
self.ping_interval = ping_interval self.ping_interval = ping_interval
self.ping_timeout = ping_timeout self.ping_timeout = ping_timeout
self.request_buffer_size = request_buffer_size # If they did not provide a websocket timeout, default it to the
self.proxy_forwarded_address_header = proxy_forwarded_address_header # channel layer's group_expiry value if present, or one day if not.
self.proxy_forwarded_port_header = proxy_forwarded_port_header self.websocket_timeout = websocket_timeout or getattr(channel_layer, "group_expiry", 86400)
self.proxy_forwarded_proto_header = proxy_forwarded_proto_header self.ws_protocols = ws_protocols
self.websocket_timeout = websocket_timeout
self.websocket_connect_timeout = websocket_connect_timeout
self.websocket_handshake_timeout = websocket_handshake_timeout
self.application_close_timeout = application_close_timeout
self.root_path = root_path self.root_path = root_path
self.verbosity = verbosity
self.abort_start = False
self.ready_callable = ready_callable
self.server_name = server_name
# Check our construction is actually sensible
if not self.endpoints:
logger.error("No endpoints. This server will not listen on anything.")
sys.exit(1)
def run(self): def run(self):
# A dict of protocol: {"application_instance":, "connected":, "disconnected":} dicts self.factory = HTTPFactory(
self.connections = {} self.channel_layer,
# Make the factory self.action_logger,
self.http_factory = HTTPFactory(self) timeout=self.http_timeout,
self.ws_factory = WebSocketFactory(self, server=self.server_name) websocket_timeout=self.websocket_timeout,
self.ws_factory.setProtocolOptions( ping_interval=self.ping_interval,
autoPingTimeout=self.ping_timeout, ping_timeout=self.ping_timeout,
allowNullOrigin=True, ws_protocols=self.ws_protocols,
openHandshakeTimeout=self.websocket_handshake_timeout, root_path=self.root_path,
) )
if self.verbosity <= 1: # Redirect the Twisted log to nowhere
# Redirect the Twisted log to nowhere globalLogBeginner.beginLoggingTo([lambda _: None], redirectStandardIO=False, discardBuffer=True)
globalLogBeginner.beginLoggingTo( # Listen on a socket
[lambda _: None], redirectStandardIO=False, discardBuffer=True if self.unix_socket:
) reactor.listenUNIX(self.unix_socket, self.factory)
elif self.file_descriptor:
# socket returns the same socket if supplied with a fileno
sock = socket.socket(fileno=self.file_descriptor)
reactor.adoptStreamPort(self.file_descriptor, sock.family, self.factory)
else: else:
globalLogBeginner.beginLoggingTo([STDLibLogObserver(__name__)]) reactor.listenTCP(self.port, self.factory, interface=self.host)
# Detect what Twisted features are enabled if "twisted" in self.channel_layer.extensions and False:
if http.H2_ENABLED: logger.info("Using native Twisted mode on channel layer")
logger.info("HTTP/2 support enabled") reactor.callLater(0, self.backend_reader_twisted)
else: else:
logger.info( logger.info("Using busy-loop synchronous mode on channel layer")
"HTTP/2 support not enabled (install the http2 and tls Twisted extras)" reactor.callLater(0, self.backend_reader_sync)
)
# Kick off the timeout loop
reactor.callLater(1, self.application_checker)
reactor.callLater(2, self.timeout_checker) reactor.callLater(2, self.timeout_checker)
reactor.run(installSignalHandlers=self.signal_handlers)
for socket_description in self.endpoints: def backend_reader_sync(self):
logger.info("Configuring endpoint %s", socket_description)
ep = serverFromString(reactor, str(socket_description))
listener = ep.listen(self.http_factory)
listener.addCallback(self.listen_success)
listener.addErrback(self.listen_error)
self.listeners.append(listener)
# Set the asyncio reactor's event loop as global
# TODO: Should we instead pass the global one into the reactor?
asyncio.set_event_loop(reactor._asyncioEventloop)
# Verbosity 3 turns on asyncio debug to find those blocking yields
if self.verbosity >= 3:
asyncio.get_event_loop().set_debug(True)
reactor.addSystemEventTrigger("before", "shutdown", self.kill_all_applications)
if not self.abort_start:
# Trigger the ready flag if we had one
if self.ready_callable:
self.ready_callable()
# Run the reactor
reactor.run(installSignalHandlers=self.signal_handlers)
def listen_success(self, port):
""" """
Called when a listen succeeds so we can store port details (if there are any) Runs as an-often-as-possible task with the reactor, unless there was
no result previously in which case we add a small delay.
""" """
if hasattr(port, "getHost"): channels = self.factory.reply_channels()
host = port.getHost() delay = 0.05
if hasattr(host, "host") and hasattr(host, "port"): # Quit if reactor is stopping
self.listening_addresses.append((host.host, host.port)) if not reactor.running:
logger.info( logger.debug("Backend reader quitting due to reactor stop")
"Listening on TCP address %s:%s",
port.getHost().host,
port.getHost().port,
)
def listen_error(self, failure):
logger.critical("Listen failure: %s", failure.getErrorMessage())
self.stop()
def stop(self):
"""
Force-stops the server.
"""
if reactor.running:
reactor.stop()
else:
self.abort_start = True
### Protocol handling
def protocol_connected(self, protocol):
"""
Adds a protocol as a current connection.
"""
if protocol in self.connections:
raise RuntimeError("Protocol %r was added to main list twice!" % protocol)
self.connections[protocol] = {"connected": time.time()}
def protocol_disconnected(self, protocol):
# Set its disconnected time (the loops will come and clean it up)
# Do not set it if it is already set. Overwriting it might
# cause it to never be cleaned up.
# See https://github.com/django/channels/issues/1181
if "disconnected" not in self.connections[protocol]:
self.connections[protocol]["disconnected"] = time.time()
### Internal event/message handling
def create_application(self, protocol, scope):
"""
Creates a new application instance that fronts a Protocol instance
for one of our supported protocols. Pass it the protocol,
and it will work out the type, supply appropriate callables, and
return you the application's input queue
"""
# Make sure the protocol has not had another application made for it
assert "application_instance" not in self.connections[protocol]
# Make an instance of the application
input_queue = asyncio.Queue()
scope.setdefault("asgi", {"version": "3.0"})
application_instance = self.application(
scope=scope,
receive=input_queue.get,
send=partial(self.handle_reply, protocol),
)
# Run it, and stash the future for later checking
if protocol not in self.connections:
return None
self.connections[protocol]["application_instance"] = asyncio.ensure_future(
application_instance,
loop=asyncio.get_event_loop(),
)
return input_queue
async def handle_reply(self, protocol, message):
"""
Coroutine that jumps the reply message from asyncio to Twisted
"""
# Don't do anything if the connection is closed or does not exist
if protocol not in self.connections or self.connections[protocol].get(
"disconnected", None
):
return return
try: # Don't do anything if there's no channels to listen on
self.check_headers_type(message) if channels:
except ValueError: delay = 0.01
# Ensure to send SOME reply. channel, message = self.channel_layer.receive_many(channels, block=False)
protocol.basic_error(500, b"Server Error", "Server Error") if channel:
raise delay = 0.00
# Let the protocol handle it # Deal with the message
protocol.handle_reply(message)
@staticmethod
def check_headers_type(message):
if not message["type"] == "http.response.start":
return
for k, v in message.get("headers", []):
if not isinstance(k, bytes):
raise ValueError(
"Header name '{}' expected to be `bytes`, but got `{}`".format(
k, type(k)
)
)
if not isinstance(v, bytes):
raise ValueError(
"Header value '{}' expected to be `bytes`, but got `{}`".format(
v, type(v)
)
)
### Utility
def application_checker(self):
"""
Goes through the set of current application Futures and cleans up
any that are done/prints exceptions for any that errored.
"""
for protocol, details in list(self.connections.items()):
disconnected = details.get("disconnected", None)
application_instance = details.get("application_instance", None)
# First, see if the protocol disconnected and the app has taken
# too long to close up
if (
disconnected
and time.time() - disconnected > self.application_close_timeout
):
if application_instance and not application_instance.done():
logger.warning(
"Application instance %r for connection %s took too long to shut down and was killed.",
application_instance,
repr(protocol),
)
application_instance.cancel()
# Then see if the app is done and we should reap it
if application_instance and application_instance.done():
try: try:
exception = application_instance.exception() self.factory.dispatch_reply(channel, message)
except (CancelledError, asyncio.CancelledError): except Exception as e:
# Future cancellation. We can ignore this. logger.error("HTTP/WS send decode error: %s" % e)
pass reactor.callLater(delay, self.backend_reader_sync)
else:
if exception:
if isinstance(exception, KeyboardInterrupt):
# Protocol is asking the server to exit (likely during test)
self.stop()
else:
logger.error(
"Exception inside application: %s",
exception,
exc_info=exception,
)
if not disconnected:
protocol.handle_exception(exception)
del self.connections[protocol]["application_instance"]
application_instance = None
# Check to see if protocol is closed and app is closed so we can remove it
if not application_instance and disconnected:
del self.connections[protocol]
reactor.callLater(1, self.application_checker)
def kill_all_applications(self): @defer.inlineCallbacks
def backend_reader_twisted(self):
""" """
Kills all application coroutines before reactor exit. Runs as an-often-as-possible task with the reactor, unless there was
no result previously in which case we add a small delay.
""" """
# Send cancel to all coroutines while True:
wait_for = [] if not reactor.running:
for details in self.connections.values(): logging.debug("Backend reader quitting due to reactor stop")
application_instance = details["application_instance"] return
if not application_instance.done(): channels = self.factory.reply_channels()
application_instance.cancel() if channels:
wait_for.append(application_instance) channel, message = yield self.channel_layer.receive_many_twisted(channels)
logger.info("Killed %i pending application instances", len(wait_for)) # Deal with the message
# Make Twisted wait until they're all dead if channel:
wait_deferred = defer.Deferred.fromFuture(asyncio.gather(*wait_for)) try:
wait_deferred.addErrback(lambda x: None) self.factory.dispatch_reply(channel, message)
return wait_deferred except Exception as e:
logger.error("HTTP/WS send decode error: %s" % e)
else:
yield self.sleep(0.01)
else:
yield self.sleep(0.05)
def sleep(self, delay):
d = defer.Deferred()
reactor.callLater(delay, d.callback, None)
return d
def timeout_checker(self): def timeout_checker(self):
""" """
Called periodically to enforce timeout rules on all connections. Called periodically to enforce timeout rules on all connections.
Also checks pings at the same time. Also checks pings at the same time.
""" """
for protocol in list(self.connections.keys()): self.factory.check_timeouts()
protocol.check_timeouts()
reactor.callLater(2, self.timeout_checker) reactor.callLater(2, self.timeout_checker)
def log_action(self, protocol, action, details):
"""
Dispatches to any registered action logger, if there is one.
"""
if self.action_logger:
self.action_logger(protocol, action, details)

View File

@ -1,309 +0,0 @@
import logging
import multiprocessing
import os
import pickle
import tempfile
import traceback
from concurrent.futures import CancelledError
class BaseDaphneTestingInstance:
"""
Launches an instance of Daphne in a subprocess, with a host and port
attribute allowing you to call it.
Works as a context manager.
"""
startup_timeout = 2
def __init__(
self, xff=False, http_timeout=None, request_buffer_size=None, *, application
):
self.xff = xff
self.http_timeout = http_timeout
self.host = "127.0.0.1"
self.request_buffer_size = request_buffer_size
self.application = application
def get_application(self):
return self.application
def __enter__(self):
# Option Daphne features
kwargs = {}
if self.request_buffer_size:
kwargs["request_buffer_size"] = self.request_buffer_size
# Optionally enable X-Forwarded-For support.
if self.xff:
kwargs["proxy_forwarded_address_header"] = "X-Forwarded-For"
kwargs["proxy_forwarded_port_header"] = "X-Forwarded-Port"
kwargs["proxy_forwarded_proto_header"] = "X-Forwarded-Proto"
if self.http_timeout:
kwargs["http_timeout"] = self.http_timeout
# Start up process
self.process = DaphneProcess(
host=self.host,
get_application=self.get_application,
kwargs=kwargs,
setup=self.process_setup,
teardown=self.process_teardown,
)
self.process.start()
# Wait for the port
if self.process.ready.wait(self.startup_timeout):
self.port = self.process.port.value
return self
else:
if self.process.errors.empty():
raise RuntimeError("Daphne did not start up, no error caught")
else:
error, traceback = self.process.errors.get(False)
raise RuntimeError("Daphne did not start up:\n%s" % traceback)
def __exit__(self, exc_type, exc_value, traceback):
# Shut down the process
self.process.terminate()
del self.process
def process_setup(self):
"""
Called by the process just before it starts serving.
"""
pass
def process_teardown(self):
"""
Called by the process just after it stops serving
"""
pass
def get_received(self):
pass
class DaphneTestingInstance(BaseDaphneTestingInstance):
def __init__(self, *args, **kwargs):
self.lock = multiprocessing.Lock()
super().__init__(*args, **kwargs, application=TestApplication(lock=self.lock))
def __enter__(self):
# Clear result storage
TestApplication.delete_setup()
TestApplication.delete_result()
return super().__enter__()
def get_received(self):
"""
Returns the scope and messages the test application has received
so far. Note you'll get all messages since scope start, not just any
new ones since the last call.
Also checks for any exceptions in the application. If there are,
raises them.
"""
try:
with self.lock:
inner_result = TestApplication.load_result()
except FileNotFoundError:
raise ValueError("No results available yet.")
# Check for exception
if "exception" in inner_result:
raise inner_result["exception"]
return inner_result["scope"], inner_result["messages"]
def add_send_messages(self, messages):
"""
Adds messages for the application to send back.
The next time it receives an incoming message, it will reply with these.
"""
TestApplication.save_setup(response_messages=messages)
class DaphneProcess(multiprocessing.Process):
"""
Process subclass that launches and runs a Daphne instance, communicating the
port it ends up listening on back to the parent process.
"""
def __init__(self, host, get_application, kwargs=None, setup=None, teardown=None):
super().__init__()
self.host = host
self.get_application = get_application
self.kwargs = kwargs or {}
self.setup = setup
self.teardown = teardown
self.port = multiprocessing.Value("i")
self.ready = multiprocessing.Event()
self.errors = multiprocessing.Queue()
def run(self):
# OK, now we are in a forked child process, and want to use the reactor.
# However, FreeBSD systems like MacOS do not fork the underlying Kqueue,
# which asyncio (hence asyncioreactor) is built on.
# Therefore, we should uninstall the broken reactor and install a new one.
_reinstall_reactor()
from twisted.internet import reactor
from .endpoints import build_endpoint_description_strings
from .server import Server
application = self.get_application()
try:
# Create the server class
endpoints = build_endpoint_description_strings(host=self.host, port=0)
self.server = Server(
application=application,
endpoints=endpoints,
signal_handlers=False,
**self.kwargs,
)
# Set up a poller to look for the port
reactor.callLater(0.1, self.resolve_port)
# Run with setup/teardown
if self.setup is not None:
self.setup()
try:
self.server.run()
finally:
if self.teardown is not None:
self.teardown()
except BaseException as e:
# Put the error on our queue so the parent gets it
self.errors.put((e, traceback.format_exc()))
def resolve_port(self):
from twisted.internet import reactor
if self.server.listening_addresses:
self.port.value = self.server.listening_addresses[0][1]
self.ready.set()
else:
reactor.callLater(0.1, self.resolve_port)
class TestApplication:
"""
An application that receives one or more messages, sends a response,
and then quits the server. For testing.
"""
setup_storage = os.path.join(tempfile.gettempdir(), "setup.testio")
result_storage = os.path.join(tempfile.gettempdir(), "result.testio")
def __init__(self, lock):
self.lock = lock
self.messages = []
async def __call__(self, scope, receive, send):
self.scope = scope
# Receive input and send output
logging.debug("test app coroutine alive")
try:
while True:
# Receive a message and save it into the result store
self.messages.append(await receive())
self.lock.acquire()
logging.debug("test app received %r", self.messages[-1])
self.save_result(self.scope, self.messages)
self.lock.release()
# See if there are any messages to send back
setup = self.load_setup()
self.delete_setup()
for message in setup["response_messages"]:
await send(message)
logging.debug("test app sent %r", message)
except Exception as e:
if isinstance(e, CancelledError):
# Don't catch task-cancelled errors!
raise
else:
self.save_exception(e)
@classmethod
def save_setup(cls, response_messages):
"""
Stores setup information.
"""
with open(cls.setup_storage, "wb") as fh:
pickle.dump({"response_messages": response_messages}, fh)
@classmethod
def load_setup(cls):
"""
Returns setup details.
"""
try:
with open(cls.setup_storage, "rb") as fh:
return pickle.load(fh)
except FileNotFoundError:
return {"response_messages": []}
@classmethod
def save_result(cls, scope, messages):
"""
Saves details of what happened to the result storage.
We could use pickle here, but that seems wrong, still, somehow.
"""
with open(cls.result_storage, "wb") as fh:
pickle.dump({"scope": scope, "messages": messages}, fh)
@classmethod
def save_exception(cls, exception):
"""
Saves details of what happened to the result storage.
We could use pickle here, but that seems wrong, still, somehow.
"""
with open(cls.result_storage, "wb") as fh:
pickle.dump({"exception": exception}, fh)
@classmethod
def load_result(cls):
"""
Returns result details.
"""
with open(cls.result_storage, "rb") as fh:
return pickle.load(fh)
@classmethod
def delete_setup(cls):
"""
Clears setup storage files.
"""
try:
os.unlink(cls.setup_storage)
except OSError:
pass
@classmethod
def delete_result(cls):
"""
Clears result storage files.
"""
try:
os.unlink(cls.result_storage)
except OSError:
pass
def _reinstall_reactor():
import asyncio
import sys
from twisted.internet import asyncioreactor
# Uninstall the reactor.
if "twisted.internet.reactor" in sys.modules:
del sys.modules["twisted.internet.reactor"]
# The daphne.server module may have already installed the reactor.
# If so, using this module will use uninstalled one, thus we should
# reimport this module too.
if "daphne.server" in sys.modules:
del sys.modules["daphne.server"]
event_loop = asyncio.new_event_loop()
asyncioreactor.install(event_loop)
asyncio.set_event_loop(event_loop)

68
daphne/tests/test_http.py Normal file
View File

@ -0,0 +1,68 @@
# coding: utf8
from __future__ import unicode_literals
from unittest import TestCase
from asgiref.inmemory import ChannelLayer
from twisted.test import proto_helpers
from ..http_protocol import HTTPFactory
class TestHTTPProtocol(TestCase):
"""
Tests that the HTTP protocol class correctly generates and parses messages.
"""
def setUp(self):
self.channel_layer = ChannelLayer()
self.factory = HTTPFactory(self.channel_layer)
self.proto = self.factory.buildProtocol(('127.0.0.1', 0))
self.tr = proto_helpers.StringTransport()
self.proto.makeConnection(self.tr)
def test_basic(self):
"""
Tests basic HTTP parsing
"""
# Send a simple request to the protocol
self.proto.dataReceived(
b"GET /te%20st-%C3%A0/?foo=+bar HTTP/1.1\r\n" +
b"Host: somewhere.com\r\n" +
b"\r\n"
)
# Get the resulting message off of the channel layer
_, message = self.channel_layer.receive_many(["http.request"])
self.assertEqual(message['http_version'], "1.1")
self.assertEqual(message['method'], "GET")
self.assertEqual(message['scheme'], "http")
self.assertEqual(message['path'], "/te st-à/")
self.assertEqual(message['query_string'], b"foo=+bar")
self.assertEqual(message['headers'], [(b"host", b"somewhere.com")])
self.assertFalse(message.get("body", None))
self.assertTrue(message['reply_channel'])
# Send back an example response
self.factory.dispatch_reply(
message['reply_channel'],
{
"status": 201,
"status_text": b"Created",
"content": b"OH HAI",
"headers": [[b"X-Test", b"Boom!"]],
}
)
# Make sure that comes back right on the protocol
self.assertEqual(self.tr.value(), b"HTTP/1.1 201 Created\r\nTransfer-Encoding: chunked\r\nX-Test: Boom!\r\n\r\n6\r\nOH HAI\r\n0\r\n\r\n")
def test_root_path_header(self):
"""
Tests root path header handling
"""
# Send a simple request to the protocol
self.proto.dataReceived(
b"GET /te%20st-%C3%A0/?foo=bar HTTP/1.1\r\n" +
b"Host: somewhere.com\r\n" +
b"Daphne-Root-Path: /foobar%20/bar\r\n" +
b"\r\n"
)
# Get the resulting message off of the channel layer, check root_path
_, message = self.channel_layer.receive_many(["http.request"])
self.assertEqual(message['root_path'], "/foobar /bar")

View File

@ -1,24 +0,0 @@
import socket
from twisted.internet import endpoints
from twisted.internet.interfaces import IStreamServerEndpointStringParser
from twisted.plugin import IPlugin
from zope.interface import implementer
@implementer(IPlugin, IStreamServerEndpointStringParser)
class _FDParser:
prefix = "fd"
def _parseServer(self, reactor, fileno, domain=socket.AF_INET):
fileno = int(fileno)
return endpoints.AdoptedStreamServerEndpoint(reactor, fileno, domain)
def parseStreamServer(self, reactor, *args, **kwargs):
# Delegate to another function with a sane signature. This function has
# an insane signature to trick zope.interface into believing the
# interface is correctly implemented.
return self._parseServer(reactor, *args, **kwargs)
parser = _FDParser()

View File

@ -1,89 +0,0 @@
import importlib
import re
from twisted.web.http_headers import Headers
# Header name regex as per h11.
# https://github.com/python-hyper/h11/blob/a2c68948accadc3876dffcf979d98002e4a4ed27/h11/_abnf.py#L10-L21
HEADER_NAME_RE = re.compile(rb"[-!#$%&'*+.^_`|~0-9a-zA-Z]+")
def import_by_path(path):
"""
Given a dotted/colon path, like project.module:ClassName.callable,
returns the object at the end of the path.
"""
module_path, object_path = path.split(":", 1)
target = importlib.import_module(module_path)
for bit in object_path.split("."):
target = getattr(target, bit)
return target
def header_value(headers, header_name):
value = headers[header_name]
if isinstance(value, list):
value = value[0]
return value.decode("utf-8")
def parse_x_forwarded_for(
headers,
address_header_name="X-Forwarded-For",
port_header_name="X-Forwarded-Port",
proto_header_name="X-Forwarded-Proto",
original_addr=None,
original_scheme=None,
):
"""
Parses an X-Forwarded-For header and returns a host/port pair as a list.
@param headers: The twisted-style object containing a request's headers
@param address_header_name: The name of the expected host header
@param port_header_name: The name of the expected port header
@param proto_header_name: The name of the expected proto header
@param original_addr: A host/port pair that should be returned if the headers are not in the request
@param original_scheme: A scheme that should be returned if the headers are not in the request
@return: A list containing a host (string) as the first entry and a port (int) as the second.
"""
if not address_header_name:
return original_addr, original_scheme
# Convert twisted-style headers into dicts
if isinstance(headers, Headers):
headers = dict(headers.getAllRawHeaders())
# Lowercase all header names in the dict
headers = {name.lower(): values for name, values in headers.items()}
# Make sure header names are bytes (values are checked in header_value)
assert all(isinstance(name, bytes) for name in headers.keys())
address_header_name = address_header_name.lower().encode("utf-8")
result_addr = original_addr
result_scheme = original_scheme
if address_header_name in headers:
address_value = header_value(headers, address_header_name)
if "," in address_value:
address_value = address_value.split(",")[0].strip()
result_addr = [address_value, 0]
if port_header_name:
# We only want to parse the X-Forwarded-Port header if we also parsed the X-Forwarded-For
# header to avoid inconsistent results.
port_header_name = port_header_name.lower().encode("utf-8")
if port_header_name in headers:
port_value = header_value(headers, port_header_name)
try:
result_addr[1] = int(port_value)
except ValueError:
pass
if proto_header_name:
proto_header_name = proto_header_name.lower().encode("utf-8")
if proto_header_name in headers:
result_scheme = header_value(headers, proto_header_name)
return result_addr, result_scheme

View File

@ -1,16 +1,12 @@
from __future__ import unicode_literals
import logging import logging
import six
import time import time
import traceback import traceback
from urllib.parse import unquote from six.moves.urllib_parse import unquote, urlencode
from autobahn.twisted.websocket import ( from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory
ConnectionDeny,
WebSocketServerFactory,
WebSocketServerProtocol,
)
from twisted.internet import defer
from .utils import parse_x_forwarded_for
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -21,253 +17,166 @@ class WebSocketProtocol(WebSocketServerProtocol):
the websocket channels. the websocket channels.
""" """
application_type = "websocket"
# If we should send no more messages (e.g. we error-closed the socket) # If we should send no more messages (e.g. we error-closed the socket)
muted = False muted = False
def set_main_factory(self, main_factory):
self.main_factory = main_factory
self.channel_layer = self.main_factory.channel_layer
def onConnect(self, request): def onConnect(self, request):
self.server = self.factory.server_class
self.server.protocol_connected(self)
self.request = request self.request = request
self.protocol_to_accept = None self.packets_received = 0
self.root_path = self.server.root_path
self.socket_opened = time.time() self.socket_opened = time.time()
self.last_ping = time.time() self.last_data = time.time()
try: try:
# Sanitize and decode headers, potentially extracting root path # Sanitize and decode headers
self.clean_headers = [] self.clean_headers = []
for name, value in request.headers.items(): for name, value in request.headers.items():
name = name.encode("ascii") name = name.encode("ascii")
# Prevent CVE-2015-0219 # Prevent CVE-2015-0219
if b"_" in name: if b"_" in name:
continue continue
if name.lower() == b"daphne-root-path": self.clean_headers.append((name.lower(), value.encode("latin1")))
self.root_path = unquote(value) # Reconstruct query string
else: # TODO: get autobahn to provide it raw
self.clean_headers.append((name.lower(), value.encode("latin1"))) query_string = urlencode(request.params, doseq=True).encode("ascii")
# Make sending channel
self.reply_channel = self.channel_layer.new_channel("websocket.send!")
# Tell main factory about it
self.main_factory.reply_protocols[self.reply_channel] = self
# Get client address if possible # Get client address if possible
peer = self.transport.getPeer() if hasattr(self.transport.getPeer(), "host") and hasattr(self.transport.getPeer(), "port"):
host = self.transport.getHost() self.client_addr = [self.transport.getPeer().host, self.transport.getPeer().port]
if hasattr(peer, "host") and hasattr(peer, "port"): self.server_addr = [self.transport.getHost().host, self.transport.getHost().port]
self.client_addr = [str(peer.host), peer.port]
self.server_addr = [str(host.host), host.port]
else: else:
self.client_addr = None self.client_addr = None
self.server_addr = None self.server_addr = None
# Make initial request info dict from request (we only have it here)
if self.server.proxy_forwarded_address_header:
self.client_addr, self.client_scheme = parse_x_forwarded_for(
dict(self.clean_headers),
self.server.proxy_forwarded_address_header,
self.server.proxy_forwarded_port_header,
self.server.proxy_forwarded_proto_header,
self.client_addr,
)
# Decode websocket subprotocol options
subprotocols = []
for header, value in self.clean_headers:
if header == b"sec-websocket-protocol":
subprotocols = [
x.strip() for x in unquote(value.decode("ascii")).split(",")
]
# Make new application instance with scope
self.path = request.path.encode("ascii") self.path = request.path.encode("ascii")
self.application_deferred = defer.maybeDeferred( self.request_info = {
self.server.create_application, "path": self.unquote(self.path),
self, "headers": self.clean_headers,
{ "query_string": self.unquote(query_string),
"type": "websocket", "client": self.client_addr,
"path": unquote(self.path.decode("ascii")), "server": self.server_addr,
"raw_path": self.path, "reply_channel": self.reply_channel,
"root_path": self.root_path, "order": 0,
"headers": self.clean_headers, }
"query_string": self._raw_query_string, # Passed by HTTP protocol except:
"client": self.client_addr,
"server": self.server_addr,
"subprotocols": subprotocols,
},
)
if self.application_deferred is not None:
self.application_deferred.addCallback(self.applicationCreateWorked)
self.application_deferred.addErrback(self.applicationCreateFailed)
except Exception:
# Exceptions here are not displayed right, just 500. # Exceptions here are not displayed right, just 500.
# Turn them into an ERROR log. # Turn them into an ERROR log.
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
raise raise
# Make a deferred and return it - we'll either call it or err it later on ws_protocol = None
self.handshake_deferred = defer.Deferred() for header, value in self.clean_headers:
return self.handshake_deferred if header == b'sec-websocket-protocol':
protocols = [x.strip() for x in self.unquote(value).split(",")]
for protocol in protocols:
if protocol in self.factory.protocols:
ws_protocol = protocol
break
def applicationCreateWorked(self, application_queue): if ws_protocol and ws_protocol in self.factory.protocols:
""" return ws_protocol
Called when the background thread has successfully made the application
instance.
"""
# Store the application's queue
self.application_queue = application_queue
# Send over the connect message
self.application_queue.put_nowait({"type": "websocket.connect"})
self.server.log_action(
"websocket",
"connecting",
{
"path": self.request.path,
"client": (
"%s:%s" % tuple(self.client_addr) if self.client_addr else None
),
},
)
def applicationCreateFailed(self, failure): @classmethod
def unquote(cls, value):
""" """
Called when application creation fails. Python 2 and 3 compat layer for utf-8 unquoting
""" """
logger.error(failure) if six.PY2:
return failure return unquote(value).decode("utf8")
else:
### Twisted event handling return unquote(value.decode("ascii"))
def onOpen(self): def onOpen(self):
# Send news that this channel is open # Send news that this channel is open
logger.debug("WebSocket %s open and established", self.client_addr) logger.debug("WebSocket open for %s", self.reply_channel)
self.server.log_action( try:
"websocket", self.channel_layer.send("websocket.connect", self.request_info)
"connected", except self.channel_layer.ChannelFull:
{ # You have to consume websocket.connect according to the spec,
# so drop the connection.
self.muted = True
logger.warn("WebSocket force closed for %s due to connect backpressure", self.reply_channel)
# Send code 1013 "try again later" with close.
self.sendCloseFrame(code=1013, isReply=False)
else:
self.factory.log_action("websocket", "connected", {
"path": self.request.path, "path": self.request.path,
"client": ( "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
"%s:%s" % tuple(self.client_addr) if self.client_addr else None })
),
},
)
def onMessage(self, payload, isBinary): def onMessage(self, payload, isBinary):
# If we're muted, do nothing. # If we're muted, do nothing.
if self.muted: if self.muted:
logger.debug("Muting incoming frame on %s", self.client_addr) logger.debug("Muting incoming frame on %s", self.reply_channel)
return return
logger.debug("WebSocket incoming frame on %s", self.client_addr) logger.debug("WebSocket incoming frame on %s", self.reply_channel)
self.last_ping = time.time() self.packets_received += 1
if isBinary: self.last_data = time.time()
self.application_queue.put_nowait( try:
{"type": "websocket.receive", "bytes": payload} if isBinary:
) self.channel_layer.send("websocket.receive", {
else: "reply_channel": self.reply_channel,
self.application_queue.put_nowait( "path": self.unquote(self.path),
{"type": "websocket.receive", "text": payload.decode("utf8")} "order": self.packets_received,
) "bytes": payload,
})
def onClose(self, wasClean, code, reason):
"""
Called when Twisted closes the socket.
"""
self.server.protocol_disconnected(self)
logger.debug("WebSocket closed for %s", self.client_addr)
if not self.muted and hasattr(self, "application_queue"):
self.application_queue.put_nowait(
{"type": "websocket.disconnect", "code": code}
)
self.server.log_action(
"websocket",
"disconnected",
{
"path": self.request.path,
"client": (
"%s:%s" % tuple(self.client_addr) if self.client_addr else None
),
},
)
### Internal event handling
def handle_reply(self, message):
if "type" not in message:
raise ValueError("Message has no type defined")
if message["type"] == "websocket.accept":
self.serverAccept(message.get("subprotocol", None))
elif message["type"] == "websocket.close":
if self.state == self.STATE_CONNECTING:
self.serverReject()
else: else:
self.serverClose(code=message.get("code", None)) self.channel_layer.send("websocket.receive", {
elif message["type"] == "websocket.send": "reply_channel": self.reply_channel,
if self.state == self.STATE_CONNECTING: "path": self.unquote(self.path),
raise ValueError("Socket has not been accepted, so cannot send over it") "order": self.packets_received,
if message.get("bytes", None) and message.get("text", None): "text": payload.decode("utf8"),
raise ValueError( })
"Got invalid WebSocket reply message on %s - contains both bytes and text keys" except self.channel_layer.ChannelFull:
% (message,) # You have to consume websocket.receive according to the spec,
) # so drop the connection.
if message.get("bytes", None): self.muted = True
self.serverSend(message["bytes"], True) logger.warn("WebSocket force closed for %s due to receive backpressure", self.reply_channel)
if message.get("text", None): # Send code 1013 "try again later" with close.
self.serverSend(message["text"], False) self.sendCloseFrame(code=1013, isReply=False)
def handle_exception(self, exception):
"""
Called by the server when our application tracebacks
"""
if hasattr(self, "handshake_deferred"):
# If the handshake is still ongoing, we need to emit a HTTP error
# code rather than a WebSocket one.
self.handshake_deferred.errback(
ConnectionDeny(code=500, reason="Internal server error")
)
else:
self.sendCloseFrame(code=1011)
def serverAccept(self, subprotocol=None):
"""
Called when we get a message saying to accept the connection.
"""
self.handshake_deferred.callback(subprotocol)
del self.handshake_deferred
logger.debug("WebSocket %s accepted by application", self.client_addr)
def serverReject(self):
"""
Called when we get a message saying to reject the connection.
"""
self.handshake_deferred.errback(
ConnectionDeny(code=403, reason="Access denied")
)
del self.handshake_deferred
self.server.protocol_disconnected(self)
logger.debug("WebSocket %s rejected by application", self.client_addr)
self.server.log_action(
"websocket",
"rejected",
{
"path": self.request.path,
"client": (
"%s:%s" % tuple(self.client_addr) if self.client_addr else None
),
},
)
def serverSend(self, content, binary=False): def serverSend(self, content, binary=False):
""" """
Server-side channel message to send a message. Server-side channel message to send a message.
""" """
if self.state == self.STATE_CONNECTING: self.last_data = time.time()
self.serverAccept() logger.debug("Sent WebSocket packet to client for %s", self.reply_channel)
logger.debug("Sent WebSocket packet to client for %s", self.client_addr)
if binary: if binary:
self.sendMessage(content, binary) self.sendMessage(content, binary)
else: else:
self.sendMessage(content.encode("utf8"), binary) self.sendMessage(content.encode("utf8"), binary)
def serverClose(self, code=None): def serverClose(self):
""" """
Server-side channel message to close the socket Server-side channel message to close the socket
""" """
code = 1000 if code is None else code self.sendClose()
self.sendClose(code=code)
### Utils def onClose(self, wasClean, code, reason):
if hasattr(self, "reply_channel"):
logger.debug("WebSocket closed for %s", self.reply_channel)
del self.factory.reply_protocols[self.reply_channel]
try:
if not self.muted:
self.channel_layer.send("websocket.disconnect", {
"reply_channel": self.reply_channel,
"code": code,
"path": self.unquote(self.path),
"order": self.packets_received + 1,
})
except self.channel_layer.ChannelFull:
pass
self.factory.log_action("websocket", "disconnected", {
"path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
})
else:
logger.debug("WebSocket closed before handshake established")
def duration(self): def duration(self):
""" """
@ -275,34 +184,13 @@ class WebSocketProtocol(WebSocketServerProtocol):
""" """
return time.time() - self.socket_opened return time.time() - self.socket_opened
def check_timeouts(self): def check_ping(self):
""" """
Called periodically to see if we should timeout something Checks to see if we should send a keepalive ping.
""" """
# Web timeout checking if (time.time() - self.last_data) > self.main_factory.ping_interval:
if ( self._sendAutoPing()
self.duration() > self.server.websocket_timeout self.last_data = time.time()
and self.server.websocket_timeout >= 0
):
self.serverClose()
# Ping check
# If we're still connecting, deny the connection
if self.state == self.STATE_CONNECTING:
if self.duration() > self.server.websocket_connect_timeout:
self.serverReject()
elif self.state == self.STATE_OPEN:
if (time.time() - self.last_ping) > self.server.ping_interval:
self._sendAutoPing()
self.last_ping = time.time()
def __hash__(self):
return hash(id(self))
def __eq__(self, other):
return id(self) == id(other)
def __repr__(self):
return f"<WebSocketProtocol client={self.client_addr!r} path={self.path!r}>"
class WebSocketFactory(WebSocketServerFactory): class WebSocketFactory(WebSocketServerFactory):
@ -312,20 +200,9 @@ class WebSocketFactory(WebSocketServerFactory):
to get reply ID info. to get reply ID info.
""" """
protocol = WebSocketProtocol def __init__(self, main_factory, *args, **kwargs):
self.main_factory = main_factory
def __init__(self, server_class, *args, **kwargs):
self.server_class = server_class
WebSocketServerFactory.__init__(self, *args, **kwargs) WebSocketServerFactory.__init__(self, *args, **kwargs)
def buildProtocol(self, addr): def log_action(self, *args, **kwargs):
""" self.main_factory.log_action(*args, **kwargs)
Builds protocol instances. We use this to inject the factory object into the protocol.
"""
try:
protocol = super().buildProtocol(addr)
protocol.factory = self
return protocol
except Exception:
logger.error("Cannot build protocol: %s" % traceback.format_exc())
raise

View File

@ -1,81 +0,0 @@
[project]
name = "daphne"
dynamic = ["version"]
description = "Django ASGI (HTTP/WebSocket) server"
requires-python = ">=3.9"
authors = [
{ name = "Django Software Foundation", email = "foundation@djangoproject.com" },
]
license = { text = "BSD" }
classifiers = [
"Development Status :: 4 - Beta",
"Environment :: Web Environment",
"Intended Audience :: Developers",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Internet :: WWW/HTTP",
]
dependencies = ["asgiref>=3.5.2,<4", "autobahn>=22.4.2", "twisted[tls]>=22.4"]
[project.optional-dependencies]
tests = [
"django",
"hypothesis",
"pytest",
"pytest-asyncio",
"pytest-cov",
"black",
"tox",
"flake8",
"flake8-bugbear",
"mypy",
]
[project.urls]
homepage = "https://github.com/django/daphne"
documentation = "https://channels.readthedocs.io"
repository = "https://github.com/django/daphne.git"
changelog = "https://github.com/django/daphne/blob/main/CHANGELOG.txt"
issues = "https://github.com/django/daphne/issues"
[project.scripts]
daphne = "daphne.cli:CommandLineInterface.entrypoint"
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"
[tool.setuptools]
packages = ["daphne"]
[tool.setuptools.dynamic]
version = { attr = "daphne.__version__" }
readme = { file = "README.rst", content-type = "text/x-rst" }
[tool.isort]
profile = "black"
[tool.pytest]
testpaths = ["tests"]
asyncio_mode = "strict"
filterwarnings = ["ignore::pytest.PytestDeprecationWarning"]
[tool.coverage.run]
omit = ["tests/*"]
concurrency = ["multiprocessing"]
[tool.coverage.report]
show_missing = "true"
skip_covered = "true"
[tool.coverage.html]
directory = "reports/coverage_html_report"

2
setup.cfg Normal file
View File

@ -0,0 +1,2 @@
[bdist_wheel]
universal=1

31
setup.py Executable file
View File

@ -0,0 +1,31 @@
import os
import sys
from setuptools import find_packages, setup
from daphne import __version__
# We use the README as the long_description
readme_path = os.path.join(os.path.dirname(__file__), "README.rst")
setup(
name='daphne',
version=__version__,
url='http://www.djangoproject.com/',
author='Django Software Foundation',
author_email='foundation@djangoproject.com',
description='Django ASGI (HTTP/WebSocket) server',
long_description=open(readme_path).read(),
license='BSD',
zip_safe=False,
packages=find_packages(),
include_package_data=True,
install_requires=[
'asgiref>=0.13',
'twisted>=16.0',
'autobahn>=0.12',
],
entry_points={'console_scripts': [
'daphne = daphne.cli:CommandLineInterface.entrypoint',
]},
)

View File

@ -1,284 +0,0 @@
import socket
import struct
import time
import unittest
from http.client import HTTPConnection
from urllib import parse
from daphne.testing import DaphneTestingInstance, TestApplication
class DaphneTestCase(unittest.TestCase):
"""
Base class for Daphne integration test cases.
Boots up a copy of Daphne on a test port and sends it a request, and
retrieves the response. Uses a custom ASGI application and temporary files
to store/retrieve the request/response messages.
"""
### Plain HTTP helpers
def run_daphne_http(
self,
method,
path,
params,
body,
responses,
headers=None,
timeout=1,
xff=False,
request_buffer_size=None,
):
"""
Runs Daphne with the given request callback (given the base URL)
and response messages.
"""
with DaphneTestingInstance(
xff=xff, request_buffer_size=request_buffer_size
) as test_app:
# Add the response messages
test_app.add_send_messages(responses)
# Send it the request. We have to do this the long way to allow
# duplicate headers.
conn = HTTPConnection(test_app.host, test_app.port, timeout=timeout)
if params:
path += "?" + parse.urlencode(params, doseq=True)
conn.putrequest(method, path, skip_accept_encoding=True, skip_host=True)
# Manually send over headers
if headers:
for header_name, header_value in headers:
conn.putheader(header_name, header_value)
# Send body if provided.
if body:
conn.putheader("Content-Length", str(len(body)))
conn.endheaders(message_body=body)
else:
conn.endheaders()
try:
response = conn.getresponse()
except socket.timeout:
# See if they left an exception for us to load
test_app.get_received()
raise RuntimeError(
"Daphne timed out handling request, no exception found."
)
# Return scope, messages, response
return test_app.get_received() + (response,)
def run_daphne_raw(self, data, *, responses=None, timeout=1):
"""
Runs Daphne and sends it the given raw bytestring over a socket.
Accepts list of response messages the application will reply with.
Returns what Daphne sends back.
"""
assert isinstance(data, bytes)
with DaphneTestingInstance() as test_app:
if responses is not None:
test_app.add_send_messages(responses)
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(timeout)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.connect((test_app.host, test_app.port))
s.send(data)
try:
return s.recv(1000000)
except socket.timeout:
raise RuntimeError(
"Daphne timed out handling raw request, no exception found."
)
def run_daphne_request(
self,
method,
path,
params=None,
body=None,
headers=None,
xff=False,
request_buffer_size=None,
):
"""
Convenience method for just testing request handling.
Returns (scope, messages)
"""
scope, messages, _ = self.run_daphne_http(
method=method,
path=path,
params=params,
body=body,
headers=headers,
xff=xff,
request_buffer_size=request_buffer_size,
responses=[
{"type": "http.response.start", "status": 200},
{"type": "http.response.body", "body": b"OK"},
],
)
return scope, messages
def run_daphne_response(self, response_messages):
"""
Convenience method for just testing response handling.
Returns (scope, messages)
"""
_, _, response = self.run_daphne_http(
method="GET", path="/", params={}, body=b"", responses=response_messages
)
return response
### WebSocket helpers
def websocket_handshake(
self,
test_app,
path="/",
params=None,
headers=None,
subprotocols=None,
timeout=1,
):
"""
Runs a WebSocket handshake negotiation and returns the raw socket
object & the selected subprotocol.
You'll need to inject an accept or reject message before this
to let it complete.
"""
# Send it the request. We have to do this the long way to allow
# duplicate headers.
conn = HTTPConnection(test_app.host, test_app.port, timeout=timeout)
if params:
path += "?" + parse.urlencode(params, doseq=True)
conn.putrequest("GET", path, skip_accept_encoding=True, skip_host=True)
# Do WebSocket handshake headers + any other headers
if headers is None:
headers = []
headers.extend(
[
(b"Host", b"example.com"),
(b"Upgrade", b"websocket"),
(b"Connection", b"Upgrade"),
(b"Sec-WebSocket-Key", b"x3JJHMbDL1EzLkh9GBhXDw=="),
(b"Sec-WebSocket-Version", b"13"),
(b"Origin", b"http://example.com"),
]
)
if subprotocols:
headers.append((b"Sec-WebSocket-Protocol", ", ".join(subprotocols)))
if headers:
for header_name, header_value in headers:
conn.putheader(header_name, header_value)
conn.endheaders()
# Read out the response
try:
response = conn.getresponse()
except socket.timeout:
# See if they left an exception for us to load
test_app.get_received()
raise RuntimeError("Daphne timed out handling request, no exception found.")
# Check we got a good response code
if response.status != 101:
raise RuntimeError("WebSocket upgrade did not result in status code 101")
# Prepare headers for subprotocol searching
response_headers = {n.lower(): v for n, v in response.getheaders()}
response.read()
assert not response.closed
# Return the raw socket and any subprotocol
return conn.sock, response_headers.get("sec-websocket-protocol", None)
def websocket_send_frame(self, sock, value):
"""
Sends a WebSocket text or binary frame. Cannot handle long frames.
"""
# Header and text opcode
if isinstance(value, str):
frame = b"\x81"
value = value.encode("utf8")
else:
frame = b"\x82"
# Length plus masking signal bit
frame += struct.pack("!B", len(value) | 0b10000000)
# Mask badly
frame += b"\0\0\0\0"
# Payload
frame += value
sock.sendall(frame)
def receive_from_socket(self, sock, length, timeout=1):
"""
Receives the given amount of bytes from the socket, or times out.
"""
buf = b""
started = time.time()
while len(buf) < length:
buf += sock.recv(length - len(buf))
time.sleep(0.001)
if time.time() - started > timeout:
raise ValueError("Timed out reading from socket")
return buf
def websocket_receive_frame(self, sock):
"""
Receives a WebSocket frame. Cannot handle long frames.
"""
# Read header byte
# TODO: Proper receive buffer handling
opcode = self.receive_from_socket(sock, 1)
if opcode in [b"\x81", b"\x82"]:
# Read length
length = struct.unpack("!B", self.receive_from_socket(sock, 1))[0]
# Read payload
payload = self.receive_from_socket(sock, length)
if opcode == b"\x81":
payload = payload.decode("utf8")
return payload
else:
raise ValueError("Unknown websocket opcode: %r" % opcode)
### Assertions and test management
def tearDown(self):
"""
Ensures any storage files are cleared.
"""
TestApplication.delete_setup()
TestApplication.delete_result()
def assert_is_ip_address(self, address):
"""
Tests whether a given address string is a valid IPv4 or IPv6 address.
"""
try:
socket.inet_aton(address)
except OSError:
self.fail("'%s' is not a valid IP address." % address)
def assert_key_sets(self, required_keys, optional_keys, actual_keys):
"""
Asserts that all required_keys are in actual_keys, and that there
are no keys in actual_keys that aren't required or optional.
"""
present_keys = set(actual_keys)
# Make sure all required keys are present
self.assertTrue(required_keys <= present_keys)
# Assert that no other keys are present
self.assertEqual(set(), present_keys - required_keys - optional_keys)
def assert_valid_path(self, path):
"""
Checks the path is valid and already url-decoded.
"""
self.assertIsInstance(path, str)
# Assert that it's already url decoded
self.assertEqual(path, parse.unquote(path))
def assert_valid_address_and_port(self, host):
"""
Asserts the value is a valid (host, port) tuple.
"""
address, port = host
self.assertIsInstance(address, str)
self.assert_is_ip_address(address)
self.assertIsInstance(port, int)

View File

@ -1,126 +0,0 @@
import string
from urllib import parse
from hypothesis import strategies
HTTP_METHODS = ["OPTIONS", "GET", "HEAD", "POST", "PUT", "DELETE", "TRACE", "CONNECT"]
# Unicode characters of the "Letter" category
letters = strategies.characters(
whitelist_categories=("Lu", "Ll", "Lt", "Lm", "Lo", "Nl")
)
def http_method():
return strategies.sampled_from(HTTP_METHODS)
def _http_path_portion():
alphabet = string.ascii_letters + string.digits + "-._~"
return strategies.text(min_size=1, max_size=128, alphabet=alphabet)
def http_path():
"""
Returns a URL path (not encoded).
"""
return strategies.lists(_http_path_portion(), min_size=0, max_size=10).map(
lambda s: "/" + "/".join(s)
)
def http_body():
"""
Returns random binary body data.
"""
return strategies.binary(min_size=0, max_size=1500)
def valid_bidi(value):
"""
Rejects strings which nonsensical Unicode text direction flags.
Relying on random Unicode characters means that some combinations don't make sense, from a
direction of text point of view. This little helper just rejects those.
"""
try:
value.encode("idna")
except UnicodeError:
return False
else:
return True
def _domain_label():
return strategies.text(alphabet=letters, min_size=1, max_size=63).filter(valid_bidi)
def international_domain_name():
"""
Returns a byte string of a domain name, IDNA-encoded.
"""
return strategies.lists(_domain_label(), min_size=2).map(
lambda s: (".".join(s)).encode("idna")
)
def _query_param():
return strategies.text(alphabet=letters, min_size=1, max_size=255).map(
lambda s: s.encode("utf8")
)
def query_params():
"""
Returns a list of two-tuples byte strings, ready for encoding with urlencode.
We're aiming for a total length of a URL below 2083 characters, so this strategy
ensures that the total urlencoded query string is not longer than 1500 characters.
"""
return strategies.lists(
strategies.tuples(_query_param(), _query_param()), min_size=0
).filter(lambda x: len(parse.urlencode(x)) < 1500)
def header_name():
"""
Strategy returning something that looks like a HTTP header field
https://en.wikipedia.org/wiki/List_of_HTTP_header_fields suggests they are between 4
and 20 characters long
"""
return strategies.text(
alphabet=string.ascii_letters + string.digits + "-", min_size=1, max_size=30
).map(lambda s: s.encode("utf-8"))
def header_value():
"""
Strategy returning something that looks like a HTTP header value
"For example, the Apache 2.3 server by default limits the size of each field to 8190 bytes"
https://en.wikipedia.org/wiki/List_of_HTTP_header_fields
"""
return (
strategies.text(
alphabet=string.ascii_letters
+ string.digits
+ string.punctuation.replace(",", "")
+ " /t",
min_size=1,
max_size=8190,
)
.map(lambda s: s.encode("utf-8"))
.filter(lambda s: len(s) < 8190)
)
def headers():
"""
Strategy returning a list of tuples, containing HTTP header fields and their values.
"[Apache 2.3] there can be at most 100 header fields in a single request."
https://en.wikipedia.org/wiki/List_of_HTTP_header_fields
"""
return strategies.lists(
strategies.tuples(header_name(), header_value()), min_size=0, max_size=100
)

View File

@ -1,21 +0,0 @@
import django
from django.conf import settings
from django.test.utils import override_settings
from daphne.checks import check_daphne_installed
def test_check_daphne_installed():
"""
Test check error is raised if daphne is not listed before staticfiles, and vice versa.
"""
settings.configure(
INSTALLED_APPS=["daphne.apps.DaphneConfig", "django.contrib.staticfiles"]
)
django.setup()
errors = check_daphne_installed(None)
assert len(errors) == 0
with override_settings(INSTALLED_APPS=["django.contrib.staticfiles", "daphne"]):
errors = check_daphne_installed(None)
assert len(errors) == 1
assert errors[0].id == "daphne.E001"

View File

@ -1,267 +0,0 @@
import logging
import os
from argparse import ArgumentError
from unittest import TestCase, skipUnless
from daphne.cli import CommandLineInterface
from daphne.endpoints import build_endpoint_description_strings as build
class TestEndpointDescriptions(TestCase):
"""
Tests that the endpoint parsing/generation works as intended.
"""
def testBasics(self):
self.assertEqual(build(), [], msg="Empty list returned when no kwargs given")
def testTcpPortBindings(self):
self.assertEqual(
build(port=1234, host="example.com"),
["tcp:port=1234:interface=example.com"],
)
self.assertEqual(
build(port=8000, host="127.0.0.1"), ["tcp:port=8000:interface=127.0.0.1"]
)
self.assertEqual(
build(port=8000, host="[200a::1]"), [r"tcp:port=8000:interface=200a\:\:1"]
)
self.assertEqual(
build(port=8000, host="200a::1"), [r"tcp:port=8000:interface=200a\:\:1"]
)
# incomplete port/host kwargs raise errors
self.assertRaises(ValueError, build, port=123)
self.assertRaises(ValueError, build, host="example.com")
def testUnixSocketBinding(self):
self.assertEqual(
build(unix_socket="/tmp/daphne.sock"), ["unix:/tmp/daphne.sock"]
)
def testFileDescriptorBinding(self):
self.assertEqual(build(file_descriptor=5), ["fd:fileno=5"])
def testMultipleEnpoints(self):
self.assertEqual(
sorted(
build(
file_descriptor=123,
unix_socket="/tmp/daphne.sock",
port=8080,
host="10.0.0.1",
)
),
sorted(
[
"tcp:port=8080:interface=10.0.0.1",
"unix:/tmp/daphne.sock",
"fd:fileno=123",
]
),
)
class TestCLIInterface(TestCase):
"""
Tests the overall CLI class.
"""
class TestedCLI(CommandLineInterface):
"""
CommandLineInterface subclass that we used for testing (has a fake
server subclass).
"""
class TestedServer:
"""
Mock server object for testing.
"""
def __init__(self, **kwargs):
self.init_kwargs = kwargs
def run(self):
pass
server_class = TestedServer
def setUp(self):
logging.disable(logging.CRITICAL)
def tearDown(self):
logging.disable(logging.NOTSET)
def assertCLI(self, args, server_kwargs):
"""
Asserts that the CLI class passes the right args to the server class.
Passes in a fake application automatically.
"""
cli = self.TestedCLI()
cli.run(
args + ["daphne:__version__"]
) # We just pass something importable as app
# Check the server got all arguments as intended
for key, value in server_kwargs.items():
# Get the value and sort it if it's a list (for endpoint checking)
actual_value = cli.server.init_kwargs.get(key)
if isinstance(actual_value, list):
actual_value.sort()
# Check values
self.assertEqual(
value,
actual_value,
"Wrong value for server kwarg %s: %r != %r"
% (key, value, actual_value),
)
def testCLIBasics(self):
"""
Tests basic endpoint generation.
"""
self.assertCLI([], {"endpoints": ["tcp:port=8000:interface=127.0.0.1"]})
self.assertCLI(
["-p", "123"], {"endpoints": ["tcp:port=123:interface=127.0.0.1"]}
)
self.assertCLI(
["-b", "10.0.0.1"], {"endpoints": ["tcp:port=8000:interface=10.0.0.1"]}
)
self.assertCLI(
["-b", "200a::1"], {"endpoints": [r"tcp:port=8000:interface=200a\:\:1"]}
)
self.assertCLI(
["-b", "[200a::1]"], {"endpoints": [r"tcp:port=8000:interface=200a\:\:1"]}
)
self.assertCLI(
["-p", "8080", "-b", "example.com"],
{"endpoints": ["tcp:port=8080:interface=example.com"]},
)
def testUnixSockets(self):
self.assertCLI(
["-p", "8080", "-u", "/tmp/daphne.sock"],
{
"endpoints": [
"tcp:port=8080:interface=127.0.0.1",
"unix:/tmp/daphne.sock",
]
},
)
self.assertCLI(
["-b", "example.com", "-u", "/tmp/daphne.sock"],
{
"endpoints": [
"tcp:port=8000:interface=example.com",
"unix:/tmp/daphne.sock",
]
},
)
self.assertCLI(
["-u", "/tmp/daphne.sock", "--fd", "5"],
{"endpoints": ["fd:fileno=5", "unix:/tmp/daphne.sock"]},
)
def testMixedCLIEndpointCreation(self):
"""
Tests mixing the shortcut options with the endpoint string options.
"""
self.assertCLI(
["-p", "8080", "-e", "unix:/tmp/daphne.sock"],
{
"endpoints": [
"tcp:port=8080:interface=127.0.0.1",
"unix:/tmp/daphne.sock",
]
},
)
self.assertCLI(
["-p", "8080", "-e", "tcp:port=8080:interface=127.0.0.1"],
{
"endpoints": [
"tcp:port=8080:interface=127.0.0.1",
"tcp:port=8080:interface=127.0.0.1",
]
},
)
def testCustomEndpoints(self):
"""
Tests entirely custom endpoints
"""
self.assertCLI(["-e", "imap:"], {"endpoints": ["imap:"]})
def test_default_proxyheaders(self):
"""
Passing `--proxy-headers` without a parameter will use the
`X-Forwarded-For` header.
"""
self.assertCLI(
["--proxy-headers"], {"proxy_forwarded_address_header": "X-Forwarded-For"}
)
def test_custom_proxyhost(self):
"""
Passing `--proxy-headers-host` will set the used host header to
the passed one, and `--proxy-headers` is mandatory.
"""
self.assertCLI(
["--proxy-headers", "--proxy-headers-host", "blah"],
{"proxy_forwarded_address_header": "blah"},
)
with self.assertRaises(expected_exception=ArgumentError) as exc:
self.assertCLI(
["--proxy-headers-host", "blah"],
{"proxy_forwarded_address_header": "blah"},
)
self.assertEqual(exc.exception.argument_name, "--proxy-headers-host")
self.assertEqual(
exc.exception.message,
"--proxy-headers has to be passed for this parameter.",
)
def test_custom_proxyport(self):
"""
Passing `--proxy-headers-port` will set the used port header to
the passed one, and `--proxy-headers` is mandatory.
"""
self.assertCLI(
["--proxy-headers", "--proxy-headers-port", "blah2"],
{"proxy_forwarded_port_header": "blah2"},
)
with self.assertRaises(expected_exception=ArgumentError) as exc:
self.assertCLI(
["--proxy-headers-port", "blah2"],
{"proxy_forwarded_address_header": "blah2"},
)
self.assertEqual(exc.exception.argument_name, "--proxy-headers-port")
self.assertEqual(
exc.exception.message,
"--proxy-headers has to be passed for this parameter.",
)
def test_custom_servername(self):
"""
Passing `--server-name` will set the default server header
from 'daphne' to the passed one.
"""
self.assertCLI([], {"server_name": "daphne"})
self.assertCLI(["--server-name", ""], {"server_name": ""})
self.assertCLI(["--server-name", "python"], {"server_name": "python"})
def test_no_servername(self):
"""
Passing `--no-server-name` will set server name to '' (empty string)
"""
self.assertCLI(["--no-server-name"], {"server_name": ""})
@skipUnless(os.getenv("ASGI_THREADS"), "ASGI_THREADS environment variable not set.")
class TestASGIThreads(TestCase):
def test_default_executor(self):
from daphne.server import twisted_loop
executor = twisted_loop._default_executor
self.assertEqual(executor._max_workers, int(os.getenv("ASGI_THREADS")))

View File

@ -1,49 +0,0 @@
import unittest
from daphne.http_protocol import WebRequest
class MockServer:
"""
Mock server object for testing.
"""
def protocol_connected(self, *args, **kwargs):
pass
class MockFactory:
"""
Mock factory object for testing.
"""
def __init__(self):
self.server = MockServer()
class MockChannel:
"""
Mock channel object for testing.
"""
def __init__(self):
self.factory = MockFactory()
self.transport = None
def getPeer(self, *args, **kwargs):
return "peer"
def getHost(self, *args, **kwargs):
return "host"
class TestHTTPProtocol(unittest.TestCase):
"""
Tests the HTTP protocol classes.
"""
def test_web_request_initialisation(self):
channel = MockChannel()
request = WebRequest(channel)
self.assertIsNone(request.client_addr)
self.assertIsNone(request.server_addr)

View File

@ -1,324 +0,0 @@
import collections
from urllib import parse
import http_strategies
from http_base import DaphneTestCase
from hypothesis import assume, given, settings
from hypothesis.strategies import integers
class TestHTTPRequest(DaphneTestCase):
"""
Tests the HTTP request handling.
"""
def assert_valid_http_scope(
self, scope, method, path, params=None, headers=None, scheme=None
):
"""
Checks that the passed scope is a valid ASGI HTTP scope regarding types
and some urlencoding things.
"""
# Check overall keys
self.assert_key_sets(
required_keys={
"asgi",
"type",
"http_version",
"method",
"path",
"raw_path",
"query_string",
"headers",
},
optional_keys={"scheme", "root_path", "client", "server"},
actual_keys=scope.keys(),
)
self.assertEqual(scope["asgi"]["version"], "3.0")
# Check that it is the right type
self.assertEqual(scope["type"], "http")
# Method (uppercased unicode string)
self.assertIsInstance(scope["method"], str)
self.assertEqual(scope["method"], method.upper())
# Path
self.assert_valid_path(scope["path"])
# HTTP version
self.assertIn(scope["http_version"], ["1.0", "1.1", "1.2"])
# Scheme
self.assertIn(scope["scheme"], ["http", "https"])
if scheme:
self.assertEqual(scheme, scope["scheme"])
# Query string (byte string and still url encoded)
query_string = scope["query_string"]
self.assertIsInstance(query_string, bytes)
if params:
self.assertEqual(
query_string, parse.urlencode(params or []).encode("ascii")
)
# Ordering of header names is not important, but the order of values for a header
# name is. To assert whether that order is kept, we transform both the request
# headers and the channel message headers into a dictionary
# {name: [value1, value2, ...]} and check if they're equal.
transformed_scope_headers = collections.defaultdict(list)
for name, value in scope["headers"]:
transformed_scope_headers[name].append(value)
transformed_request_headers = collections.defaultdict(list)
for name, value in headers or []:
expected_name = name.lower().strip()
expected_value = value.strip()
transformed_request_headers[expected_name].append(expected_value)
for name, value in transformed_request_headers.items():
self.assertIn(name, transformed_scope_headers)
self.assertEqual(value, transformed_scope_headers[name])
# Root path
self.assertIsInstance(scope.get("root_path", ""), str)
# Client and server addresses
client = scope.get("client")
if client is not None:
self.assert_valid_address_and_port(client)
server = scope.get("server")
if server is not None:
self.assert_valid_address_and_port(server)
def assert_valid_http_request_message(self, message, body=None):
"""
Asserts that a message is a valid http.request message
"""
# Check overall keys
self.assert_key_sets(
required_keys={"type"},
optional_keys={"body", "more_body"},
actual_keys=message.keys(),
)
# Check that it is the right type
self.assertEqual(message["type"], "http.request")
# If there's a body present, check its type
self.assertIsInstance(message.get("body", b""), bytes)
if body is not None:
self.assertEqual(body, message.get("body", b""))
def test_minimal_request(self):
"""
Smallest viable example. Mostly verifies that our request building works.
"""
scope, messages = self.run_daphne_request("GET", "/")
self.assert_valid_http_scope(scope, "GET", "/")
self.assert_valid_http_request_message(messages[0], body=b"")
@given(
request_path=http_strategies.http_path(),
request_params=http_strategies.query_params(),
)
@settings(max_examples=5, deadline=5000)
def test_get_request(self, request_path, request_params):
"""
Tests a typical HTTP GET request, with a path and query parameters
"""
scope, messages = self.run_daphne_request(
"GET", request_path, params=request_params
)
self.assert_valid_http_scope(scope, "GET", request_path, params=request_params)
self.assert_valid_http_request_message(messages[0], body=b"")
@given(request_path=http_strategies.http_path(), chunk_size=integers(min_value=1))
@settings(max_examples=5, deadline=5000)
def test_request_body_chunking(self, request_path, chunk_size):
"""
Tests request body chunking logic.
"""
body = b"The quick brown fox jumps over the lazy dog"
_, messages = self.run_daphne_request(
"POST",
request_path,
body=body,
request_buffer_size=chunk_size,
)
# Avoid running those asserts when there's a single "http.disconnect"
if len(messages) > 1:
assert messages[0]["body"].decode() == body.decode()[:chunk_size]
assert not messages[-2]["more_body"]
assert messages[-1] == {"type": "http.disconnect"}
@given(
request_path=http_strategies.http_path(),
request_body=http_strategies.http_body(),
)
@settings(max_examples=5, deadline=5000)
def test_post_request(self, request_path, request_body):
"""
Tests a typical HTTP POST request, with a path and body.
"""
scope, messages = self.run_daphne_request(
"POST", request_path, body=request_body
)
self.assert_valid_http_scope(scope, "POST", request_path)
self.assert_valid_http_request_message(messages[0], body=request_body)
def test_raw_path(self):
"""
Tests that /foo%2Fbar produces raw_path and a decoded path
"""
scope, _ = self.run_daphne_request("GET", "/foo%2Fbar")
self.assertEqual(scope["path"], "/foo/bar")
self.assertEqual(scope["raw_path"], b"/foo%2Fbar")
@given(request_headers=http_strategies.headers())
@settings(max_examples=5, deadline=5000)
def test_headers(self, request_headers):
"""
Tests that HTTP header fields are handled as specified
"""
request_path = parse.quote("/te st-à/")
scope, messages = self.run_daphne_request(
"OPTIONS", request_path, headers=request_headers
)
self.assert_valid_http_scope(
scope, "OPTIONS", request_path, headers=request_headers
)
self.assert_valid_http_request_message(messages[0], body=b"")
@given(request_headers=http_strategies.headers())
@settings(max_examples=5, deadline=5000)
def test_duplicate_headers(self, request_headers):
"""
Tests that duplicate header values are preserved
"""
# Make sure there's duplicate headers
assume(len(request_headers) >= 2)
header_name = request_headers[0][0]
duplicated_headers = [(header_name, header[1]) for header in request_headers]
# Run the request
request_path = parse.quote("/te st-à/")
scope, messages = self.run_daphne_request(
"OPTIONS", request_path, headers=duplicated_headers
)
self.assert_valid_http_scope(
scope, "OPTIONS", request_path, headers=duplicated_headers
)
self.assert_valid_http_request_message(messages[0], body=b"")
@given(
request_method=http_strategies.http_method(),
request_path=http_strategies.http_path(),
request_params=http_strategies.query_params(),
request_headers=http_strategies.headers(),
request_body=http_strategies.http_body(),
)
@settings(max_examples=2, deadline=5000)
def test_kitchen_sink(
self,
request_method,
request_path,
request_params,
request_headers,
request_body,
):
"""
Throw everything at Daphne that we dare. The idea is that if a combination
of method/path/headers/body would break the spec, hypothesis will eventually find it.
"""
scope, messages = self.run_daphne_request(
request_method,
request_path,
params=request_params,
headers=request_headers,
body=request_body,
)
self.assert_valid_http_scope(
scope,
request_method,
request_path,
params=request_params,
headers=request_headers,
)
self.assert_valid_http_request_message(messages[0], body=request_body)
def test_headers_are_lowercased_and_stripped(self):
"""
Make sure headers are normalized as the spec says they are.
"""
headers = [(b"MYCUSTOMHEADER", b" foobar ")]
scope, messages = self.run_daphne_request("GET", "/", headers=headers)
self.assert_valid_http_scope(scope, "GET", "/", headers=headers)
self.assert_valid_http_request_message(messages[0], body=b"")
# Note that Daphne returns a list of tuples here, which is fine, because the spec
# asks to treat them interchangeably.
assert [list(x) for x in scope["headers"]] == [[b"mycustomheader", b"foobar"]]
@given(daphne_path=http_strategies.http_path())
@settings(max_examples=5, deadline=5000)
def test_root_path_header(self, daphne_path):
"""
Tests root_path handling.
"""
# Daphne-Root-Path must be URL encoded when submitting as HTTP header field
headers = [("Daphne-Root-Path", parse.quote(daphne_path.encode("utf8")))]
scope, messages = self.run_daphne_request("GET", "/", headers=headers)
# Daphne-Root-Path is not included in the returned 'headers' section. So we expect
# empty headers.
self.assert_valid_http_scope(scope, "GET", "/", headers=[])
self.assert_valid_http_request_message(messages[0], body=b"")
# And what we're looking for, root_path being set.
assert scope["root_path"] == daphne_path
def test_x_forwarded_for_ignored(self):
"""
Make sure that, by default, X-Forwarded-For is ignored.
"""
headers = [[b"X-Forwarded-For", b"10.1.2.3"], [b"X-Forwarded-Port", b"80"]]
scope, messages = self.run_daphne_request("GET", "/", headers=headers)
self.assert_valid_http_scope(scope, "GET", "/", headers=headers)
self.assert_valid_http_request_message(messages[0], body=b"")
# It should NOT appear in the client scope item
self.assertNotEqual(scope["client"], ["10.1.2.3", 80])
def test_x_forwarded_for_parsed(self):
"""
When X-Forwarded-For is enabled, make sure it is respected.
"""
headers = [[b"X-Forwarded-For", b"10.1.2.3"], [b"X-Forwarded-Port", b"80"]]
scope, messages = self.run_daphne_request("GET", "/", headers=headers, xff=True)
self.assert_valid_http_scope(scope, "GET", "/", headers=headers)
self.assert_valid_http_request_message(messages[0], body=b"")
# It should now appear in the client scope item
self.assertEqual(scope["client"], ["10.1.2.3", 80])
def test_x_forwarded_for_no_port(self):
"""
When X-Forwarded-For is enabled but only the host is passed, make sure
that at least makes it through.
"""
headers = [[b"X-Forwarded-For", b"10.1.2.3"]]
scope, messages = self.run_daphne_request("GET", "/", headers=headers, xff=True)
self.assert_valid_http_scope(scope, "GET", "/", headers=headers)
self.assert_valid_http_request_message(messages[0], body=b"")
# It should now appear in the client scope item
self.assertEqual(scope["client"], ["10.1.2.3", 0])
def test_bad_requests(self):
"""
Tests that requests with invalid (non-ASCII) characters fail.
"""
# Bad path
response = self.run_daphne_raw(
b"GET /\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n"
)
self.assertTrue(b"400 Bad Request" in response)
# Bad querystring
response = self.run_daphne_raw(
b"GET /?\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n"
)
self.assertTrue(b"400 Bad Request" in response)
def test_invalid_header_name(self):
"""
Tests that requests with invalid header names fail.
"""
# Test cases follow those used by h11
# https://github.com/python-hyper/h11/blob/a2c68948accadc3876dffcf979d98002e4a4ed27/h11/tests/test_headers.py#L24-L35
for header_name in [b"foo bar", b"foo\x00bar", b"foo\xffbar", b"foo\x01bar"]:
response = self.run_daphne_raw(
f"GET / HTTP/1.0\r\n{header_name}: baz\r\n\r\n".encode("ascii")
)
self.assertTrue(b"400 Bad Request" in response)

View File

@ -1,187 +0,0 @@
import http_strategies
from http_base import DaphneTestCase
from hypothesis import given, settings
class TestHTTPResponse(DaphneTestCase):
"""
Tests HTTP response handling.
"""
def normalize_headers(self, headers):
"""
Lowercases and sorts headers, and strips transfer-encoding ones.
"""
return sorted(
[(b"server", b"daphne")]
+ [
(name.lower(), value.strip())
for name, value in headers
if name.lower() not in (b"server", b"transfer-encoding")
]
)
def encode_headers(self, headers):
def encode(s):
return s if isinstance(s, bytes) else s.encode("utf-8")
return [[encode(k), encode(v)] for k, v in headers]
def test_minimal_response(self):
"""
Smallest viable example. Mostly verifies that our response building works.
"""
response = self.run_daphne_response(
[
{"type": "http.response.start", "status": 200},
{"type": "http.response.body", "body": b"hello world"},
]
)
self.assertEqual(response.status, 200)
self.assertEqual(response.read(), b"hello world")
def test_status_code_required(self):
"""
Asserts that passing in the 'status' key is required.
Previous versions of Daphne did not enforce this, so this test is here
to make sure it stays required.
"""
with self.assertRaises(ValueError):
self.run_daphne_response(
[
{"type": "http.response.start"},
{"type": "http.response.body", "body": b"hello world"},
]
)
def test_custom_status_code(self):
"""
Tries a non-default status code.
"""
response = self.run_daphne_response(
[
{"type": "http.response.start", "status": 201},
{"type": "http.response.body", "body": b"i made a thing!"},
]
)
self.assertEqual(response.status, 201)
self.assertEqual(response.read(), b"i made a thing!")
def test_chunked_response(self):
"""
Tries sending a response in multiple parts.
"""
response = self.run_daphne_response(
[
{"type": "http.response.start", "status": 201},
{"type": "http.response.body", "body": b"chunk 1 ", "more_body": True},
{"type": "http.response.body", "body": b"chunk 2"},
]
)
self.assertEqual(response.status, 201)
self.assertEqual(response.read(), b"chunk 1 chunk 2")
def test_chunked_response_empty(self):
"""
Tries sending a response in multiple parts and an empty end.
"""
response = self.run_daphne_response(
[
{"type": "http.response.start", "status": 201},
{"type": "http.response.body", "body": b"chunk 1 ", "more_body": True},
{"type": "http.response.body", "body": b"chunk 2", "more_body": True},
{"type": "http.response.body"},
]
)
self.assertEqual(response.status, 201)
self.assertEqual(response.read(), b"chunk 1 chunk 2")
@given(body=http_strategies.http_body())
@settings(max_examples=5, deadline=5000)
def test_body(self, body):
"""
Tries body variants.
"""
response = self.run_daphne_response(
[
{"type": "http.response.start", "status": 200},
{"type": "http.response.body", "body": body},
]
)
self.assertEqual(response.status, 200)
self.assertEqual(response.read(), body)
@given(headers=http_strategies.headers())
@settings(max_examples=5, deadline=5000)
def test_headers(self, headers):
# The ASGI spec requires us to lowercase our header names
response = self.run_daphne_response(
[
{
"type": "http.response.start",
"status": 200,
"headers": self.normalize_headers(headers),
},
{"type": "http.response.body"},
]
)
# Check headers in a sensible way. Ignore transfer-encoding.
self.assertEqual(
self.normalize_headers(self.encode_headers(response.getheaders())),
self.normalize_headers(headers),
)
def test_headers_type(self):
"""
Headers should be `bytes`
"""
with self.assertRaises(ValueError) as context:
self.run_daphne_response(
[
{
"type": "http.response.start",
"status": 200,
"headers": [["foo", b"bar"]],
},
{"type": "http.response.body", "body": b""},
]
)
self.assertEqual(
str(context.exception),
"Header name 'foo' expected to be `bytes`, but got `<class 'str'>`",
)
with self.assertRaises(ValueError) as context:
self.run_daphne_response(
[
{
"type": "http.response.start",
"status": 200,
"headers": [[b"foo", True]],
},
{"type": "http.response.body", "body": b""},
]
)
self.assertEqual(
str(context.exception),
"Header value 'True' expected to be `bytes`, but got `<class 'bool'>`",
)
def test_headers_type_raw(self):
"""
Daphne returns a 500 error response if the application sends invalid
headers.
"""
response = self.run_daphne_raw(
b"GET / HTTP/1.0\r\n\r\n",
responses=[
{
"type": "http.response.start",
"status": 200,
"headers": [["foo", b"bar"]],
},
{"type": "http.response.body", "body": b""},
],
)
self.assertTrue(response.startswith(b"HTTP/1.0 500 Internal Server Error"))

View File

@ -1,84 +0,0 @@
from unittest import TestCase
from twisted.web.http_headers import Headers
from daphne.utils import parse_x_forwarded_for
class TestXForwardedForHttpParsing(TestCase):
"""
Tests that the parse_x_forwarded_for util correctly parses twisted Header.
"""
def test_basic(self):
headers = Headers(
{
b"X-Forwarded-For": [b"10.1.2.3"],
b"X-Forwarded-Port": [b"1234"],
b"X-Forwarded-Proto": [b"https"],
}
)
result = parse_x_forwarded_for(headers)
self.assertEqual(result, (["10.1.2.3", 1234], "https"))
self.assertIsInstance(result[0][0], str)
self.assertIsInstance(result[1], str)
def test_address_only(self):
headers = Headers({b"X-Forwarded-For": [b"10.1.2.3"]})
self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None))
def test_v6_address(self):
headers = Headers({b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"]})
self.assertEqual(parse_x_forwarded_for(headers), (["1043::a321:0001", 0], None))
def test_multiple_proxys(self):
headers = Headers({b"X-Forwarded-For": [b"10.1.2.3, 10.1.2.4"]})
self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None))
def test_original(self):
headers = Headers({})
self.assertEqual(
parse_x_forwarded_for(headers, original_addr=["127.0.0.1", 80]),
(["127.0.0.1", 80], None),
)
def test_no_original(self):
headers = Headers({})
self.assertEqual(parse_x_forwarded_for(headers), (None, None))
class TestXForwardedForWsParsing(TestCase):
"""
Tests that the parse_x_forwarded_for util correctly parses dict headers.
"""
def test_basic(self):
headers = {
b"X-Forwarded-For": b"10.1.2.3",
b"X-Forwarded-Port": b"1234",
b"X-Forwarded-Proto": b"https",
}
self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 1234], "https"))
def test_address_only(self):
headers = {b"X-Forwarded-For": b"10.1.2.3"}
self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None))
def test_v6_address(self):
headers = {b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"]}
self.assertEqual(parse_x_forwarded_for(headers), (["1043::a321:0001", 0], None))
def test_multiple_proxies(self):
headers = {b"X-Forwarded-For": b"10.1.2.3, 10.1.2.4"}
self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None))
def test_original(self):
headers = {}
self.assertEqual(
parse_x_forwarded_for(headers, original_addr=["127.0.0.1", 80]),
(["127.0.0.1", 80], None),
)
def test_no_original(self):
headers = {}
self.assertEqual(parse_x_forwarded_for(headers), (None, None))

View File

@ -1,338 +0,0 @@
import collections
import time
from urllib import parse
import http_strategies
from http_base import DaphneTestCase, DaphneTestingInstance
from hypothesis import given, settings
from daphne.testing import BaseDaphneTestingInstance
class TestWebsocket(DaphneTestCase):
"""
Tests WebSocket handshake, send and receive.
"""
def assert_valid_websocket_scope(
self, scope, path="/", params=None, headers=None, scheme=None, subprotocols=None
):
"""
Checks that the passed scope is a valid ASGI HTTP scope regarding types
and some urlencoding things.
"""
# Check overall keys
self.assert_key_sets(
required_keys={
"asgi",
"type",
"path",
"raw_path",
"query_string",
"headers",
},
optional_keys={"scheme", "root_path", "client", "server", "subprotocols"},
actual_keys=scope.keys(),
)
self.assertEqual(scope["asgi"]["version"], "3.0")
# Check that it is the right type
self.assertEqual(scope["type"], "websocket")
# Path
self.assert_valid_path(scope["path"])
# Scheme
self.assertIn(scope.get("scheme", "ws"), ["ws", "wss"])
if scheme:
self.assertEqual(scheme, scope["scheme"])
# Query string (byte string and still url encoded)
query_string = scope["query_string"]
self.assertIsInstance(query_string, bytes)
if params:
self.assertEqual(
query_string, parse.urlencode(params or []).encode("ascii")
)
# Ordering of header names is not important, but the order of values for a header
# name is. To assert whether that order is kept, we transform both the request
# headers and the channel message headers into a dictionary
# {name: [value1, value2, ...]} and check if they're equal.
transformed_scope_headers = collections.defaultdict(list)
for name, value in scope["headers"]:
transformed_scope_headers.setdefault(name, [])
# Make sure to split out any headers collapsed with commas
for bit in value.split(b","):
if bit.strip():
transformed_scope_headers[name].append(bit.strip())
transformed_request_headers = collections.defaultdict(list)
for name, value in headers or []:
expected_name = name.lower().strip()
expected_value = value.strip()
# Make sure to split out any headers collapsed with commas
transformed_request_headers.setdefault(expected_name, [])
for bit in expected_value.split(b","):
if bit.strip():
transformed_request_headers[expected_name].append(bit.strip())
for name, value in transformed_request_headers.items():
self.assertIn(name, transformed_scope_headers)
self.assertEqual(value, transformed_scope_headers[name])
# Root path
self.assertIsInstance(scope.get("root_path", ""), str)
# Client and server addresses
client = scope.get("client")
if client is not None:
self.assert_valid_address_and_port(client)
server = scope.get("server")
if server is not None:
self.assert_valid_address_and_port(server)
# Subprotocols
scope_subprotocols = scope.get("subprotocols", [])
if scope_subprotocols:
assert all(isinstance(x, str) for x in scope_subprotocols)
if subprotocols:
assert sorted(scope_subprotocols) == sorted(subprotocols)
def assert_valid_websocket_connect_message(self, message):
"""
Asserts that a message is a valid http.request message
"""
# Check overall keys
self.assert_key_sets(
required_keys={"type"}, optional_keys=set(), actual_keys=message.keys()
)
# Check that it is the right type
self.assertEqual(message["type"], "websocket.connect")
def test_accept(self):
"""
Tests we can open and accept a socket.
"""
with DaphneTestingInstance() as test_app:
test_app.add_send_messages([{"type": "websocket.accept"}])
self.websocket_handshake(test_app)
# Validate the scope and messages we got
scope, messages = test_app.get_received()
self.assert_valid_websocket_scope(scope)
self.assert_valid_websocket_connect_message(messages[0])
def test_reject(self):
"""
Tests we can reject a socket and it won't complete the handshake.
"""
with DaphneTestingInstance() as test_app:
test_app.add_send_messages([{"type": "websocket.close"}])
with self.assertRaises(RuntimeError):
self.websocket_handshake(test_app)
def test_subprotocols(self):
"""
Tests that we can ask for subprotocols and then select one.
"""
subprotocols = ["proto1", "proto2"]
with DaphneTestingInstance() as test_app:
test_app.add_send_messages(
[{"type": "websocket.accept", "subprotocol": "proto2"}]
)
_, subprotocol = self.websocket_handshake(
test_app, subprotocols=subprotocols
)
# Validate the scope and messages we got
assert subprotocol == "proto2"
scope, messages = test_app.get_received()
self.assert_valid_websocket_scope(scope, subprotocols=subprotocols)
self.assert_valid_websocket_connect_message(messages[0])
def test_xff(self):
"""
Tests that X-Forwarded-For headers get parsed right
"""
headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]]
with DaphneTestingInstance(xff=True) as test_app:
test_app.add_send_messages([{"type": "websocket.accept"}])
self.websocket_handshake(test_app, headers=headers)
# Validate the scope and messages we got
scope, messages = test_app.get_received()
self.assert_valid_websocket_scope(scope)
self.assert_valid_websocket_connect_message(messages[0])
assert scope["client"] == ["10.1.2.3", 80]
@given(
request_path=http_strategies.http_path(),
request_params=http_strategies.query_params(),
request_headers=http_strategies.headers(),
)
@settings(max_examples=5, deadline=2000)
def test_http_bits(self, request_path, request_params, request_headers):
"""
Tests that various HTTP-level bits (query string params, path, headers)
carry over into the scope.
"""
with DaphneTestingInstance() as test_app:
test_app.add_send_messages([{"type": "websocket.accept"}])
self.websocket_handshake(
test_app,
path=parse.quote(request_path),
params=request_params,
headers=request_headers,
)
# Validate the scope and messages we got
scope, messages = test_app.get_received()
self.assert_valid_websocket_scope(
scope, path=request_path, params=request_params, headers=request_headers
)
self.assert_valid_websocket_connect_message(messages[0])
def test_raw_path(self):
"""
Tests that /foo%2Fbar produces raw_path and a decoded path
"""
with DaphneTestingInstance() as test_app:
test_app.add_send_messages([{"type": "websocket.accept"}])
self.websocket_handshake(test_app, path="/foo%2Fbar")
# Validate the scope and messages we got
scope, _ = test_app.get_received()
self.assertEqual(scope["path"], "/foo/bar")
self.assertEqual(scope["raw_path"], b"/foo%2Fbar")
@given(daphne_path=http_strategies.http_path())
@settings(max_examples=5, deadline=2000)
def test_root_path(self, *, daphne_path):
"""
Tests root_path handling.
"""
headers = [("Daphne-Root-Path", parse.quote(daphne_path))]
with DaphneTestingInstance() as test_app:
test_app.add_send_messages([{"type": "websocket.accept"}])
self.websocket_handshake(
test_app,
path="/",
headers=headers,
)
# Validate the scope and messages we got
scope, _ = test_app.get_received()
# Daphne-Root-Path is not included in the returned 'headers' section.
self.assertNotIn(
"daphne-root-path", (header[0].lower() for header in scope["headers"])
)
# And what we're looking for, root_path being set.
self.assertEqual(scope["root_path"], daphne_path)
def test_text_frames(self):
"""
Tests we can send and receive text frames.
"""
with DaphneTestingInstance() as test_app:
# Connect
test_app.add_send_messages([{"type": "websocket.accept"}])
sock, _ = self.websocket_handshake(test_app)
_, messages = test_app.get_received()
self.assert_valid_websocket_connect_message(messages[0])
# Prep frame for it to send
test_app.add_send_messages(
[{"type": "websocket.send", "text": "here be dragons 🐉"}]
)
# Send it a frame
self.websocket_send_frame(sock, "what is here? 🌍")
# Receive a frame and make sure it's correct
assert self.websocket_receive_frame(sock) == "here be dragons 🐉"
# Make sure it got our frame
_, messages = test_app.get_received()
assert messages[1] == {
"type": "websocket.receive",
"text": "what is here? 🌍",
}
def test_binary_frames(self):
"""
Tests we can send and receive binary frames with things that are very
much not valid UTF-8.
"""
with DaphneTestingInstance() as test_app:
# Connect
test_app.add_send_messages([{"type": "websocket.accept"}])
sock, _ = self.websocket_handshake(test_app)
_, messages = test_app.get_received()
self.assert_valid_websocket_connect_message(messages[0])
# Prep frame for it to send
test_app.add_send_messages(
[{"type": "websocket.send", "bytes": b"here be \xe2 bytes"}]
)
# Send it a frame
self.websocket_send_frame(sock, b"what is here? \xe2")
# Receive a frame and make sure it's correct
assert self.websocket_receive_frame(sock) == b"here be \xe2 bytes"
# Make sure it got our frame
_, messages = test_app.get_received()
assert messages[1] == {
"type": "websocket.receive",
"bytes": b"what is here? \xe2",
}
def test_http_timeout(self):
"""
Tests that the HTTP timeout doesn't kick in for WebSockets
"""
with DaphneTestingInstance(http_timeout=1) as test_app:
# Connect
test_app.add_send_messages([{"type": "websocket.accept"}])
sock, _ = self.websocket_handshake(test_app)
_, messages = test_app.get_received()
self.assert_valid_websocket_connect_message(messages[0])
# Wait 2 seconds
time.sleep(2)
# Prep frame for it to send
test_app.add_send_messages([{"type": "websocket.send", "text": "cake"}])
# Send it a frame
self.websocket_send_frame(sock, "still alive?")
# Receive a frame and make sure it's correct
assert self.websocket_receive_frame(sock) == "cake"
def test_application_checker_handles_asyncio_cancellederror(self):
with CancellingTestingInstance() as app:
# Connect to the websocket app, it will immediately raise
# asyncio.CancelledError
sock, _ = self.websocket_handshake(app)
# Disconnect from the socket
sock.close()
# Wait for application_checker to clean up the applications for
# disconnected clients, and for the server to be stopped.
time.sleep(3)
# Make sure we received either no error, or a ConnectionsNotEmpty
while not app.process.errors.empty():
err, _tb = app.process.errors.get()
if not isinstance(err, ConnectionsNotEmpty):
raise err
self.fail(
"Server connections were not cleaned up after an asyncio.CancelledError was raised"
)
async def cancelling_application(scope, receive, send):
import asyncio
from twisted.internet import reactor
# Stop the server after a short delay so that the teardown is run.
reactor.callLater(2, reactor.stop)
await send({"type": "websocket.accept"})
raise asyncio.CancelledError()
class ConnectionsNotEmpty(Exception):
pass
class CancellingTestingInstance(BaseDaphneTestingInstance):
def __init__(self):
super().__init__(application=cancelling_application)
def process_teardown(self):
import multiprocessing
# Get a hold of the enclosing DaphneProcess (we're currently running in
# the same process as the application).
proc = multiprocessing.current_process()
# By now the (only) socket should have disconnected, and the
# application_checker should have run. If there are any connections
# still, it means that the application_checker did not clean them up.
if proc.server.connections:
raise ConnectionsNotEmpty()

View File

@ -1,8 +0,0 @@
[tox]
envlist =
py{39,310,311,312,313}
[testenv]
extras = tests
commands =
pytest -v {posargs}