From 8c031239adc994b1c75f5a99ae319a3838a6a721 Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Wed, 30 May 2018 09:52:47 -0700 Subject: [PATCH 01/17] Remove HTTP timeout by default, and mid-response error for it --- daphne/cli.py | 2 +- daphne/http_protocol.py | 8 ++++++-- daphne/server.py | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/daphne/cli.py b/daphne/cli.py index 7361914..d9d9fb4 100755 --- a/daphne/cli.py +++ b/daphne/cli.py @@ -86,7 +86,7 @@ class CommandLineInterface(object): "--http-timeout", type=int, help="How long to wait for worker before timing out HTTP connections", - default=120, + default=None, ) self.parser.add_argument( "--access-log", diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py index 060aab5..2c946e1 100755 --- a/daphne/http_protocol.py +++ b/daphne/http_protocol.py @@ -256,8 +256,12 @@ class WebRequest(http.Request): Called periodically to see if we should timeout something """ # Web timeout checking - if self.duration() > self.server.http_timeout: - self.basic_error(503, b"Service Unavailable", "Application failed to respond within time limit.") + if self.server.http_timeout and self.duration() > self.server.http_timeout: + if self._response_started: + logger.warning("Application timed out while sending response") + self.finish() + else: + self.basic_error(503, b"Service Unavailable", "Application failed to respond within time limit.") ### Utility functions diff --git a/daphne/server.py b/daphne/server.py index c9ab4a3..6a65ebb 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -42,7 +42,7 @@ class Server(object): endpoints=None, signal_handlers=True, action_logger=None, - http_timeout=120, + http_timeout=None, websocket_timeout=86400, websocket_connect_timeout=20, ping_interval=20, From ece52b8e7958b41c86fd89830bc7814368e42bde Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Sat, 2 Jun 2018 06:45:02 +0100 Subject: [PATCH 02/17] Don't try and read requests that are closed already (#205) --- daphne/http_protocol.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py index 2c946e1..1155135 100755 --- a/daphne/http_protocol.py +++ b/daphne/http_protocol.py @@ -159,6 +159,10 @@ class WebRequest(http.Request): "client": self.client_addr, "server": self.server_addr, }) + # Check they didn't close an unfinished request + if self.content.closed: + # Not much we can do, the request is prematurely abandoned. + return # Run application against request self.application_queue.put_nowait( { From bb54f41736f545ebf8b7d647421ae0575bce5eca Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Wed, 13 Jun 2018 11:55:20 -0700 Subject: [PATCH 03/17] Releasing 2.2.0 --- CHANGELOG.txt | 14 ++++++++++++++ daphne/__init__.py | 2 +- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.txt b/CHANGELOG.txt index 5e40440..d4d5bbc 100644 --- a/CHANGELOG.txt +++ b/CHANGELOG.txt @@ -1,3 +1,17 @@ +2.2.0 (2018-06-13) +------------------ + +* HTTP timeouts have been removed by default, as they were only needed + with ASGI/Channels 1. You can re-enable them with the --http-timeout + argument to Daphne. + +* Occasional errors on application timeout for non-fully-opened sockets + and for trying to read closed requests under high load are fixed. + +* X-Forwarded-For headers are now correctly decoded in all environments + and no longer have unicode matching issues. + + 2.1.2 (2018-05-24) ------------------ diff --git a/daphne/__init__.py b/daphne/__init__.py index 4eabd0b..8a124bf 100755 --- a/daphne/__init__.py +++ b/daphne/__init__.py @@ -1 +1 @@ -__version__ = "2.1.2" +__version__ = "2.2.0" From 6dcc0d52b3c4ca97df8cc47bd93672076a87bd7e Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Sun, 24 Jun 2018 16:33:54 -0700 Subject: [PATCH 04/17] send() should not block once connection is closed --- daphne/server.py | 4 ++++ setup.cfg | 1 + 2 files changed, 5 insertions(+) diff --git a/daphne/server.py b/daphne/server.py index 6a65ebb..d8d43f2 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -195,6 +195,10 @@ class Server(object): """ Coroutine that jumps the reply message from asyncio to Twisted """ + # Don't do anything if the connection is closed + if self.connections[protocol].get("disconnected", None): + return + # Let the protocol handle it protocol.handle_reply(message) ### Utility diff --git a/setup.cfg b/setup.cfg index 4223ddf..8c9a5f5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,6 +7,7 @@ addopts = tests/ [isort] line_length = 120 multi_line_output = 3 +known_first_party = channels,daphne,asgiref [flake8] exclude = venv/*,tox/*,docs/*,testproject/*,js_client/*,.eggs/* From d5611bccb64e2aa1e2c3e77d2c401f6146c8d9af Mon Sep 17 00:00:00 2001 From: Brian May Date: Fri, 6 Jul 2018 11:26:34 +1000 Subject: [PATCH 05/17] Don't crash if connection closed before application started (#213) Fixes #205. --- daphne/http_protocol.py | 2 +- daphne/server.py | 2 ++ daphne/ws_protocol.py | 5 +++-- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py index 1155135..3ff10be 100755 --- a/daphne/http_protocol.py +++ b/daphne/http_protocol.py @@ -160,7 +160,7 @@ class WebRequest(http.Request): "server": self.server_addr, }) # Check they didn't close an unfinished request - if self.content.closed: + if self.application_queue is None or self.content.closed: # Not much we can do, the request is prematurely abandoned. return # Run application against request diff --git a/daphne/server.py b/daphne/server.py index d8d43f2..63d34db 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -185,6 +185,8 @@ class Server(object): input_queue = asyncio.Queue() application_instance = yield deferToThread(self.application, scope=scope) # Run it, and stash the future for later checking + if protocol not in self.connections: + return None self.connections[protocol]["application_instance"] = asyncio.ensure_future(application_instance( receive=input_queue.get, send=lambda message: self.handle_reply(protocol, message), diff --git a/daphne/ws_protocol.py b/daphne/ws_protocol.py index 9edfce6..a7331b5 100755 --- a/daphne/ws_protocol.py +++ b/daphne/ws_protocol.py @@ -75,8 +75,9 @@ class WebSocketProtocol(WebSocketServerProtocol): "server": self.server_addr, "subprotocols": subprotocols, }) - self.application_deferred.addCallback(self.applicationCreateWorked) - self.application_deferred.addErrback(self.applicationCreateFailed) + if self.application_deferred is not None: + self.application_deferred.addCallback(self.applicationCreateWorked) + self.application_deferred.addErrback(self.applicationCreateFailed) except Exception as e: # Exceptions here are not displayed right, just 500. # Turn them into an ERROR log. From e16b58bcb542521590d4d5b9700f7096636a46f2 Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Sun, 22 Jul 2018 09:45:59 -0700 Subject: [PATCH 06/17] Releasing 2.2.1 --- CHANGELOG.txt | 13 +++++++++++++ daphne/__init__.py | 2 +- setup.py | 3 ++- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.txt b/CHANGELOG.txt index d4d5bbc..2740193 100644 --- a/CHANGELOG.txt +++ b/CHANGELOG.txt @@ -1,3 +1,16 @@ +2.2.1 (2018-07-22) +------------------ + +* Python 3.7 compatability is flagged and ensured by using Twisted 18.7 and + above as a dependency. + +* The send() awaitable in applications no longer blocks if the connection is + closed. + +* Fixed a race condition where applications would be cleaned up before they + had even started. + + 2.2.0 (2018-06-13) ------------------ diff --git a/daphne/__init__.py b/daphne/__init__.py index 8a124bf..b19ee4b 100755 --- a/daphne/__init__.py +++ b/daphne/__init__.py @@ -1 +1 @@ -__version__ = "2.2.0" +__version__ = "2.2.1" diff --git a/setup.py b/setup.py index 939c0cb..692ab6e 100755 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ setup( packages=find_packages() + ["twisted.plugins"], include_package_data=True, install_requires=[ - "twisted>=17.5", + "twisted>=18.7", "autobahn>=0.18", ], setup_requires=[ @@ -49,6 +49,7 @@ setup( "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.5", "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", "Topic :: Internet :: WWW/HTTP", ], ) From adb622d4f5af4ab30f2b9e688aedb3c53bd65513 Mon Sep 17 00:00:00 2001 From: Anders Jensen Date: Sun, 22 Jul 2018 18:54:42 +0200 Subject: [PATCH 07/17] Removed deferToThread for ASGI instance constructor (#218) The previous behaviour was from an older spec. --- daphne/http_protocol.py | 4 ++-- daphne/server.py | 4 +--- daphne/ws_protocol.py | 2 +- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py index 3ff10be..c4e17fe 100755 --- a/daphne/http_protocol.py +++ b/daphne/http_protocol.py @@ -3,7 +3,7 @@ import time import traceback from urllib.parse import unquote -from twisted.internet.defer import inlineCallbacks +from twisted.internet.defer import inlineCallbacks, maybeDeferred from twisted.internet.interfaces import IProtocolNegotiationFactory from twisted.protocols.policies import ProtocolWrapper from twisted.web import http @@ -146,7 +146,7 @@ class WebRequest(http.Request): logger.debug("HTTP %s request for %s", self.method, self.client_addr) self.content.seek(0, 0) # Work out the application scope and create application - self.application_queue = yield self.server.create_application(self, { + self.application_queue = yield maybeDeferred(self.server.create_application, self, { "type": "http", # TODO: Correctly say if it's 1.1 or 1.0 "http_version": self.clientproto.split(b"/")[-1].decode("ascii"), diff --git a/daphne/server.py b/daphne/server.py index 63d34db..cff80f2 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -24,7 +24,6 @@ from concurrent.futures import CancelledError from twisted.internet import defer, reactor from twisted.internet.endpoints import serverFromString -from twisted.internet.threads import deferToThread from twisted.logger import STDLibLogObserver, globalLogBeginner from twisted.web import http @@ -171,7 +170,6 @@ class Server(object): ### Internal event/message handling - @defer.inlineCallbacks def create_application(self, protocol, scope): """ Creates a new application instance that fronts a Protocol instance @@ -183,7 +181,7 @@ class Server(object): assert "application_instance" not in self.connections[protocol] # Make an instance of the application input_queue = asyncio.Queue() - application_instance = yield deferToThread(self.application, scope=scope) + application_instance = self.application(scope=scope) # Run it, and stash the future for later checking if protocol not in self.connections: return None diff --git a/daphne/ws_protocol.py b/daphne/ws_protocol.py index a7331b5..8ae8749 100755 --- a/daphne/ws_protocol.py +++ b/daphne/ws_protocol.py @@ -66,7 +66,7 @@ class WebSocketProtocol(WebSocketServerProtocol): ] # Make new application instance with scope self.path = request.path.encode("ascii") - self.application_deferred = self.server.create_application(self, { + self.application_deferred = defer.maybeDeferred(self.server.create_application, self, { "type": "websocket", "path": unquote(self.path.decode("ascii")), "headers": self.clean_headers, From 2f94210321358b5be1f204428999934ae467583d Mon Sep 17 00:00:00 2001 From: Nick Sellen Date: Tue, 24 Jul 2018 22:25:03 +0200 Subject: [PATCH 08/17] Add x-forwarded-proto support (#219) --- daphne/cli.py | 1 + daphne/http_protocol.py | 11 ++++++++--- daphne/server.py | 2 ++ daphne/testing.py | 1 + daphne/utils.py | 24 +++++++++++++++++------- daphne/ws_protocol.py | 3 ++- tests/test_utils.py | 35 +++++++++++++++++++---------------- 7 files changed, 50 insertions(+), 27 deletions(-) diff --git a/daphne/cli.py b/daphne/cli.py index d9d9fb4..7f42084 100755 --- a/daphne/cli.py +++ b/daphne/cli.py @@ -214,5 +214,6 @@ class CommandLineInterface(object): verbosity=args.verbosity, proxy_forwarded_address_header="X-Forwarded-For" if args.proxy_headers else None, proxy_forwarded_port_header="X-Forwarded-Port" if args.proxy_headers else None, + proxy_forwarded_proto_header="X-Forwarded-Proto" if args.proxy_headers else None, ) self.server.run() diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py index c4e17fe..915e475 100755 --- a/daphne/http_protocol.py +++ b/daphne/http_protocol.py @@ -73,13 +73,18 @@ class WebRequest(http.Request): else: self.client_addr = None self.server_addr = None + + self.client_scheme = "https" if self.isSecure() else "http" + # See if we need to get the address from a proxy header instead if self.server.proxy_forwarded_address_header: - self.client_addr = parse_x_forwarded_for( + self.client_addr, self.client_scheme = parse_x_forwarded_for( self.requestHeaders, self.server.proxy_forwarded_address_header, self.server.proxy_forwarded_port_header, - self.client_addr + self.server.proxy_forwarded_proto_header, + self.client_addr, + self.client_scheme ) # Check for unicodeish path (or it'll crash when trying to parse) try: @@ -153,7 +158,7 @@ class WebRequest(http.Request): "method": self.method.decode("ascii"), "path": unquote(self.path.decode("ascii")), "root_path": self.root_path, - "scheme": "https" if self.isSecure() else "http", + "scheme": self.client_scheme, "query_string": self.query_string, "headers": self.clean_headers, "client": self.client_addr, diff --git a/daphne/server.py b/daphne/server.py index cff80f2..dfae544 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -49,6 +49,7 @@ class Server(object): root_path="", proxy_forwarded_address_header=None, proxy_forwarded_port_header=None, + proxy_forwarded_proto_header=None, verbosity=1, websocket_handshake_timeout=5, application_close_timeout=10, @@ -67,6 +68,7 @@ class Server(object): self.ping_timeout = ping_timeout self.proxy_forwarded_address_header = proxy_forwarded_address_header self.proxy_forwarded_port_header = proxy_forwarded_port_header + self.proxy_forwarded_proto_header = proxy_forwarded_proto_header self.websocket_timeout = websocket_timeout self.websocket_connect_timeout = websocket_connect_timeout self.websocket_handshake_timeout = websocket_handshake_timeout diff --git a/daphne/testing.py b/daphne/testing.py index 4c87c51..e606952 100644 --- a/daphne/testing.py +++ b/daphne/testing.py @@ -37,6 +37,7 @@ class DaphneTestingInstance: if self.xff: kwargs["proxy_forwarded_address_header"] = "X-Forwarded-For" kwargs["proxy_forwarded_port_header"] = "X-Forwarded-Port" + kwargs["proxy_forwarded_proto_header"] = "X-Forwarded-Proto" if self.http_timeout: kwargs["http_timeout"] = self.http_timeout # Start up process diff --git a/daphne/utils.py b/daphne/utils.py index cd9e86e..ad64439 100644 --- a/daphne/utils.py +++ b/daphne/utils.py @@ -25,18 +25,22 @@ def header_value(headers, header_name): def parse_x_forwarded_for(headers, address_header_name="X-Forwarded-For", port_header_name="X-Forwarded-Port", - original=None): + proto_header_name="X-Forwarded-Proto", + original_addr=None, + original_scheme=None): """ Parses an X-Forwarded-For header and returns a host/port pair as a list. @param headers: The twisted-style object containing a request's headers @param address_header_name: The name of the expected host header @param port_header_name: The name of the expected port header - @param original: A host/port pair that should be returned if the headers are not in the request + @param proto_header_name: The name of the expected proto header + @param original_addr: A host/port pair that should be returned if the headers are not in the request + @param original_scheme: A scheme that should be returned if the headers are not in the request @return: A list containing a host (string) as the first entry and a port (int) as the second. """ if not address_header_name: - return original + return original_addr, original_scheme # Convert twisted-style headers into dicts if isinstance(headers, Headers): @@ -49,14 +53,15 @@ def parse_x_forwarded_for(headers, assert all(isinstance(name, bytes) for name in headers.keys()) address_header_name = address_header_name.lower().encode("utf-8") - result = original + result_addr = original_addr + result_scheme = original_scheme if address_header_name in headers: address_value = header_value(headers, address_header_name) if "," in address_value: address_value = address_value.split(",")[0].strip() - result = [address_value, 0] + result_addr = [address_value, 0] if port_header_name: # We only want to parse the X-Forwarded-Port header if we also parsed the X-Forwarded-For @@ -65,8 +70,13 @@ def parse_x_forwarded_for(headers, if port_header_name in headers: port_value = header_value(headers, port_header_name) try: - result[1] = int(port_value) + result_addr[1] = int(port_value) except ValueError: pass - return result + if proto_header_name: + proto_header_name = proto_header_name.lower().encode("utf-8") + if proto_header_name in headers: + result_scheme = header_value(headers, proto_header_name) + + return result_addr, result_scheme diff --git a/daphne/ws_protocol.py b/daphne/ws_protocol.py index 8ae8749..f0b7bda 100755 --- a/daphne/ws_protocol.py +++ b/daphne/ws_protocol.py @@ -49,10 +49,11 @@ class WebSocketProtocol(WebSocketServerProtocol): self.server_addr = None if self.server.proxy_forwarded_address_header: - self.client_addr = parse_x_forwarded_for( + self.client_addr, self.client_scheme = parse_x_forwarded_for( dict(self.clean_headers), self.server.proxy_forwarded_address_header, self.server.proxy_forwarded_port_header, + self.server.proxy_forwarded_proto_header, self.client_addr ) # Decode websocket subprotocol options diff --git a/tests/test_utils.py b/tests/test_utils.py index 786b8c9..5dada0f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -15,11 +15,13 @@ class TestXForwardedForHttpParsing(TestCase): def test_basic(self): headers = Headers({ b"X-Forwarded-For": [b"10.1.2.3"], - b"X-Forwarded-Port": [b"1234"] + b"X-Forwarded-Port": [b"1234"], + b"X-Forwarded-Proto": [b"https"] }) result = parse_x_forwarded_for(headers) - self.assertEqual(result, ["10.1.2.3", 1234]) - self.assertIsInstance(result[0], str) + self.assertEqual(result, (["10.1.2.3", 1234], "https")) + self.assertIsInstance(result[0][0], str) + self.assertIsInstance(result[1], str) def test_address_only(self): headers = Headers({ @@ -27,7 +29,7 @@ class TestXForwardedForHttpParsing(TestCase): }) self.assertEqual( parse_x_forwarded_for(headers), - ["10.1.2.3", 0] + (["10.1.2.3", 0], None) ) def test_v6_address(self): @@ -36,7 +38,7 @@ class TestXForwardedForHttpParsing(TestCase): }) self.assertEqual( parse_x_forwarded_for(headers), - ["1043::a321:0001", 0] + (["1043::a321:0001", 0], None) ) def test_multiple_proxys(self): @@ -45,19 +47,19 @@ class TestXForwardedForHttpParsing(TestCase): }) self.assertEqual( parse_x_forwarded_for(headers), - ["10.1.2.3", 0] + (["10.1.2.3", 0], None) ) def test_original(self): headers = Headers({}) self.assertEqual( - parse_x_forwarded_for(headers, original=["127.0.0.1", 80]), - ["127.0.0.1", 80] + parse_x_forwarded_for(headers, original_addr=["127.0.0.1", 80]), + (["127.0.0.1", 80], None) ) def test_no_original(self): headers = Headers({}) - self.assertIsNone(parse_x_forwarded_for(headers)) + self.assertEqual(parse_x_forwarded_for(headers), (None, None)) class TestXForwardedForWsParsing(TestCase): @@ -69,10 +71,11 @@ class TestXForwardedForWsParsing(TestCase): headers = { b"X-Forwarded-For": b"10.1.2.3", b"X-Forwarded-Port": b"1234", + b"X-Forwarded-Proto": b"https", } self.assertEqual( parse_x_forwarded_for(headers), - ["10.1.2.3", 1234] + (["10.1.2.3", 1234], "https") ) def test_address_only(self): @@ -81,7 +84,7 @@ class TestXForwardedForWsParsing(TestCase): } self.assertEqual( parse_x_forwarded_for(headers), - ["10.1.2.3", 0] + (["10.1.2.3", 0], None) ) def test_v6_address(self): @@ -90,7 +93,7 @@ class TestXForwardedForWsParsing(TestCase): } self.assertEqual( parse_x_forwarded_for(headers), - ["1043::a321:0001", 0] + (["1043::a321:0001", 0], None) ) def test_multiple_proxies(self): @@ -99,16 +102,16 @@ class TestXForwardedForWsParsing(TestCase): } self.assertEqual( parse_x_forwarded_for(headers), - ["10.1.2.3", 0] + (["10.1.2.3", 0], None) ) def test_original(self): headers = {} self.assertEqual( - parse_x_forwarded_for(headers, original=["127.0.0.1", 80]), - ["127.0.0.1", 80] + parse_x_forwarded_for(headers, original_addr=["127.0.0.1", 80]), + (["127.0.0.1", 80], None) ) def test_no_original(self): headers = {} - self.assertIsNone(parse_x_forwarded_for(headers)) + self.assertEqual(parse_x_forwarded_for(headers), (None, None)) From 5fe47cbbedb7d6b57e15698ded4f31a5b46a2fc2 Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Thu, 9 Aug 2018 11:36:22 -0700 Subject: [PATCH 09/17] Add an issue template --- .github/ISSUE_TEMPLATE.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE.md diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md new file mode 100644 index 0000000..8c7aade --- /dev/null +++ b/.github/ISSUE_TEMPLATE.md @@ -0,0 +1,14 @@ +Issues are for **concrete, actionable bugs and feature requests** only - if you're just asking for debugging help or technical support we have to direct you elsewhere. If you just have questions or support requests please use: + +- Stack Overflow +- The Django Users mailing list django-users@googlegroups.com (https://groups.google.com/forum/#!forum/django-users) + +We have to limit this because of limited volunteer time to respond to issues! + +Please also try and include, if you can: + +- Your OS and runtime environment, and browser if applicable +- A `pip freeze` output showing your package versions +- What you expected to happen vs. what actually happened +- How you're running Channels (runserver? daphne/runworker? Nginx/Apache in front?) +- Console logs and full tracebacks of any errors From 47358c7c79db858527bee7df6cfab8c6454c68ae Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Thu, 16 Aug 2018 21:34:50 -0700 Subject: [PATCH 10/17] Releasing 2.2.2 --- CHANGELOG.txt | 10 ++++++++++ daphne/__init__.py | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.txt b/CHANGELOG.txt index 2740193..12a7ff2 100644 --- a/CHANGELOG.txt +++ b/CHANGELOG.txt @@ -1,3 +1,13 @@ +2.2.2 (2018-08-16) +------------------ + +* X-Forwarded-Proto support is now present and enabled if you turn on the + --proxy-headers flag + +* ASGI applications are no longer instantiated in a thread (the ASGI spec + was finalised to say all constructors must be non-blocking on the main thread) + + 2.2.1 (2018-07-22) ------------------ diff --git a/daphne/__init__.py b/daphne/__init__.py index b19ee4b..ba51ced 100755 --- a/daphne/__init__.py +++ b/daphne/__init__.py @@ -1 +1 @@ -__version__ = "2.2.1" +__version__ = "2.2.2" From 88792984e710a45f2c0ee7ef28b5ae5843c00573 Mon Sep 17 00:00:00 2001 From: Michael Date: Fri, 24 Aug 2018 23:46:04 +0000 Subject: [PATCH 11/17] Run tests against Python 3.7 (#224) --- .travis.yml | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index ad50622..aa69b0b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,7 +7,7 @@ python: - '3.6' env: -- TWISTED="twisted==17.5.0" +- TWISTED="twisted==18.7.0" - TWISTED="twisted" install: @@ -22,6 +22,14 @@ script: jobs: include: + - python: '3.7' + env: TWISTED="twisted==18.7.0" + dist: xenial + sudo: required + - python: '3.7' + env: TWISTED="twisted" + dist: xenial + sudo: required - stage: release script: skip deploy: From 0ed6294406781f3895e1d11b4c6256f9cfccffca Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Mon, 27 Aug 2018 12:27:32 +1000 Subject: [PATCH 12/17] Implement Black code formatting --- .travis.yml | 22 ++++-- daphne/access.py | 7 +- daphne/cli.py | 49 ++++++------ daphne/endpoints.py | 7 +- daphne/http_protocol.py | 142 +++++++++++++++++++++-------------- daphne/server.py | 47 +++++++----- daphne/testing.py | 26 +------ daphne/utils.py | 14 ++-- daphne/ws_protocol.py | 133 ++++++++++++++++++++------------- setup.cfg | 4 +- setup.py | 21 ++---- tests/http_base.py | 66 ++++++++++------- tests/http_strategies.py | 43 +++++------ tests/test_cli.py | 111 +++++++++------------------ tests/test_http_request.py | 73 ++++++++++-------- tests/test_http_response.py | 144 ++++++++++++++---------------------- tests/test_utils.py | 75 ++++++------------- tests/test_websocket.py | 125 +++++++++---------------------- 18 files changed, 513 insertions(+), 596 deletions(-) diff --git a/.travis.yml b/.travis.yml index aa69b0b..6a617c4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,22 +3,25 @@ sudo: false language: python python: -- '3.5' - '3.6' +- '3.5' env: -- TWISTED="twisted==18.7.0" - TWISTED="twisted" +- TWISTED="twisted==18.7.0" install: -- pip install $TWISTED isort unify flake8 -e .[tests] +- pip install $TWISTED -e .[tests] - pip freeze script: - pytest -- flake8 -- isort --check-only --diff --recursive daphne tests -- unify --check-only --recursive --quote \" daphne tests + +stages: + - lint + - test + - name: release + if: branch = master jobs: include: @@ -30,6 +33,13 @@ jobs: env: TWISTED="twisted" dist: xenial sudo: required + - stage: lint + install: pip install -U -e .[tests] black pyflakes isort + script: + - pyflakes . + - black --check . + - isort --check-only --diff --recursive channels_redis tests + - stage: release script: skip deploy: diff --git a/daphne/access.py b/daphne/access.py index ce80f49..2b3b1cd 100644 --- a/daphne/access.py +++ b/daphne/access.py @@ -49,13 +49,16 @@ class AccessLogGenerator(object): request="WSDISCONNECT %(path)s" % details, ) - def write_entry(self, host, date, request, status=None, length=None, ident=None, user=None): + def write_entry( + self, host, date, request, status=None, length=None, ident=None, user=None + ): """ Writes an NCSA-style entry to the log file (some liberty is taken with what the entries are for non-HTTP) """ self.stream.write( - "%s %s %s [%s] \"%s\" %s %s\n" % ( + '%s %s %s [%s] "%s" %s %s\n' + % ( host, ident or "-", user or "-", diff --git a/daphne/cli.py b/daphne/cli.py index 7f42084..28cf1b9 100755 --- a/daphne/cli.py +++ b/daphne/cli.py @@ -23,15 +23,9 @@ class CommandLineInterface(object): server_class = Server def __init__(self): - self.parser = argparse.ArgumentParser( - description=self.description, - ) + self.parser = argparse.ArgumentParser(description=self.description) self.parser.add_argument( - "-p", - "--port", - type=int, - help="Port number to listen on", - default=None, + "-p", "--port", type=int, help="Port number to listen on", default=None ) self.parser.add_argument( "-b", @@ -128,7 +122,7 @@ class CommandLineInterface(object): "--proxy-headers", dest="proxy_headers", help="Enable parsing and using of X-Forwarded-For and X-Forwarded-Port headers and using that as the " - "client address", + "client address", default=False, action="store_true", ) @@ -176,7 +170,15 @@ class CommandLineInterface(object): sys.path.insert(0, ".") application = import_by_path(args.application) # Set up port/host bindings - if not any([args.host, args.port is not None, args.unix_socket, args.file_descriptor, args.socket_strings]): + if not any( + [ + args.host, + args.port is not None, + args.unix_socket, + args.file_descriptor, + args.socket_strings, + ] + ): # no advanced binding options passed, patch in defaults args.host = DEFAULT_HOST args.port = DEFAULT_PORT @@ -189,16 +191,11 @@ class CommandLineInterface(object): host=args.host, port=args.port, unix_socket=args.unix_socket, - file_descriptor=args.file_descriptor - ) - endpoints = sorted( - args.socket_strings + endpoints + file_descriptor=args.file_descriptor, ) + endpoints = sorted(args.socket_strings + endpoints) # Start the server - logger.info( - "Starting server at %s" % - (", ".join(endpoints), ) - ) + logger.info("Starting server at %s" % (", ".join(endpoints),)) self.server = self.server_class( application=application, endpoints=endpoints, @@ -208,12 +205,20 @@ class CommandLineInterface(object): websocket_timeout=args.websocket_timeout, websocket_connect_timeout=args.websocket_connect_timeout, application_close_timeout=args.application_close_timeout, - action_logger=AccessLogGenerator(access_log_stream) if access_log_stream else None, + action_logger=AccessLogGenerator(access_log_stream) + if access_log_stream + else None, ws_protocols=args.ws_protocols, root_path=args.root_path, verbosity=args.verbosity, - proxy_forwarded_address_header="X-Forwarded-For" if args.proxy_headers else None, - proxy_forwarded_port_header="X-Forwarded-Port" if args.proxy_headers else None, - proxy_forwarded_proto_header="X-Forwarded-Proto" if args.proxy_headers else None, + proxy_forwarded_address_header="X-Forwarded-For" + if args.proxy_headers + else None, + proxy_forwarded_port_header="X-Forwarded-Port" + if args.proxy_headers + else None, + proxy_forwarded_proto_header="X-Forwarded-Proto" + if args.proxy_headers + else None, ) self.server.run() diff --git a/daphne/endpoints.py b/daphne/endpoints.py index 6188357..83e472a 100644 --- a/daphne/endpoints.py +++ b/daphne/endpoints.py @@ -1,10 +1,5 @@ - - def build_endpoint_description_strings( - host=None, - port=None, - unix_socket=None, - file_descriptor=None + host=None, port=None, unix_socket=None, file_descriptor=None ): """ Build a list of twisted endpoint description strings that the server will listen on. diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py index 915e475..2c8d840 100755 --- a/daphne/http_protocol.py +++ b/daphne/http_protocol.py @@ -23,7 +23,8 @@ class WebRequest(http.Request): GET and POST out. """ - error_template = """ + error_template = ( + """ %(title)s @@ -40,7 +41,13 @@ class WebRequest(http.Request):
Daphne
- """.replace("\n", "").replace(" ", " ").replace(" ", " ").replace(" ", " ") # Shorten it a bit, bytes wise + """.replace( + "\n", "" + ) + .replace(" ", " ") + .replace(" ", " ") + .replace(" ", " ") + ) # Shorten it a bit, bytes wise def __init__(self, *args, **kwargs): try: @@ -84,7 +91,7 @@ class WebRequest(http.Request): self.server.proxy_forwarded_port_header, self.server.proxy_forwarded_proto_header, self.client_addr, - self.client_scheme + self.client_scheme, ) # Check for unicodeish path (or it'll crash when trying to parse) try: @@ -105,7 +112,9 @@ class WebRequest(http.Request): # Is it WebSocket? IS IT?! if upgrade_header and upgrade_header.lower() == b"websocket": # Make WebSocket protocol to hand off to - protocol = self.server.ws_factory.buildProtocol(self.transport.getPeer()) + protocol = self.server.ws_factory.buildProtocol( + self.transport.getPeer() + ) if not protocol: # If protocol creation fails, we signal "internal server error" self.setResponseCode(500) @@ -151,33 +160,38 @@ class WebRequest(http.Request): logger.debug("HTTP %s request for %s", self.method, self.client_addr) self.content.seek(0, 0) # Work out the application scope and create application - self.application_queue = yield maybeDeferred(self.server.create_application, self, { - "type": "http", - # TODO: Correctly say if it's 1.1 or 1.0 - "http_version": self.clientproto.split(b"/")[-1].decode("ascii"), - "method": self.method.decode("ascii"), - "path": unquote(self.path.decode("ascii")), - "root_path": self.root_path, - "scheme": self.client_scheme, - "query_string": self.query_string, - "headers": self.clean_headers, - "client": self.client_addr, - "server": self.server_addr, - }) + self.application_queue = yield maybeDeferred( + self.server.create_application, + self, + { + "type": "http", + # TODO: Correctly say if it's 1.1 or 1.0 + "http_version": self.clientproto.split(b"/")[-1].decode( + "ascii" + ), + "method": self.method.decode("ascii"), + "path": unquote(self.path.decode("ascii")), + "root_path": self.root_path, + "scheme": self.client_scheme, + "query_string": self.query_string, + "headers": self.clean_headers, + "client": self.client_addr, + "server": self.server_addr, + }, + ) # Check they didn't close an unfinished request if self.application_queue is None or self.content.closed: # Not much we can do, the request is prematurely abandoned. return # Run application against request self.application_queue.put_nowait( - { - "type": "http.request", - "body": self.content.read(), - }, + {"type": "http.request", "body": self.content.read()} ) except Exception: logger.error(traceback.format_exc()) - self.basic_error(500, b"Internal Server Error", "Daphne HTTP processing error") + self.basic_error( + 500, b"Internal Server Error", "Daphne HTTP processing error" + ) def connectionLost(self, reason): """ @@ -217,16 +231,23 @@ class WebRequest(http.Request): raise ValueError("HTTP response has already been started") self._response_started = True if "status" not in message: - raise ValueError("Specifying a status code is required for a Response message.") + raise ValueError( + "Specifying a status code is required for a Response message." + ) # Set HTTP status code self.setResponseCode(message["status"]) # Write headers for header, value in message.get("headers", {}): self.responseHeaders.addRawHeader(header, value) - logger.debug("HTTP %s response started for %s", message["status"], self.client_addr) + logger.debug( + "HTTP %s response started for %s", message["status"], self.client_addr + ) elif message["type"] == "http.response.body": if not self._response_started: - raise ValueError("HTTP response has not yet been started but got %s" % message["type"]) + raise ValueError( + "HTTP response has not yet been started but got %s" + % message["type"] + ) # Write out body http.Request.write(self, message.get("body", b"")) # End if there's no more content @@ -239,15 +260,21 @@ class WebRequest(http.Request): # The path is malformed somehow - do our best to log something uri = repr(self.uri) try: - self.server.log_action("http", "complete", { - "path": uri, - "status": self.code, - "method": self.method.decode("ascii", "replace"), - "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, - "time_taken": self.duration(), - "size": self.sentLength, - }) - except Exception as e: + self.server.log_action( + "http", + "complete", + { + "path": uri, + "status": self.code, + "method": self.method.decode("ascii", "replace"), + "client": "%s:%s" % tuple(self.client_addr) + if self.client_addr + else None, + "time_taken": self.duration(), + "size": self.sentLength, + }, + ) + except Exception: logger.error(traceback.format_exc()) else: logger.debug("HTTP response chunk for %s", self.client_addr) @@ -270,7 +297,11 @@ class WebRequest(http.Request): logger.warning("Application timed out while sending response") self.finish() else: - self.basic_error(503, b"Service Unavailable", "Application failed to respond within time limit.") + self.basic_error( + 503, + b"Service Unavailable", + "Application failed to respond within time limit.", + ) ### Utility functions @@ -281,11 +312,7 @@ class WebRequest(http.Request): """ # If we don't yet have a path, then don't send as we never opened. if self.path: - self.application_queue.put_nowait( - { - "type": "http.disconnect", - }, - ) + self.application_queue.put_nowait({"type": "http.disconnect"}) def duration(self): """ @@ -299,20 +326,25 @@ class WebRequest(http.Request): """ Responds with a server-level error page (very basic) """ - self.handle_reply({ - "type": "http.response.start", - "status": status, - "headers": [ - (b"Content-Type", b"text/html; charset=utf-8"), - ], - }) - self.handle_reply({ - "type": "http.response.body", - "body": (self.error_template % { - "title": str(status) + " " + status_text.decode("ascii"), - "body": body, - }).encode("utf8"), - }) + self.handle_reply( + { + "type": "http.response.start", + "status": status, + "headers": [(b"Content-Type", b"text/html; charset=utf-8")], + } + ) + self.handle_reply( + { + "type": "http.response.body", + "body": ( + self.error_template + % { + "title": str(status) + " " + status_text.decode("ascii"), + "body": body, + } + ).encode("utf8"), + } + ) def __hash__(self): return hash(id(self)) @@ -343,7 +375,7 @@ class HTTPFactory(http.HTTPFactory): protocol = http.HTTPFactory.buildProtocol(self, addr) protocol.requestFactory = WebRequest return protocol - except Exception as e: + except Exception: logger.error("Cannot build protocol: %s" % traceback.format_exc()) raise diff --git a/daphne/server.py b/daphne/server.py index dfae544..133762d 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -2,13 +2,14 @@ import sys # isort:skip import warnings # isort:skip from twisted.internet import asyncioreactor # isort:skip + current_reactor = sys.modules.get("twisted.internet.reactor", None) if current_reactor is not None: if not isinstance(current_reactor, asyncioreactor.AsyncioSelectorReactor): warnings.warn( - "Something has already installed a non-asyncio Twisted reactor. Attempting to uninstall it; " + - "you can fix this warning by importing daphne.server early in your codebase or " + - "finding the package that imports Twisted and importing it later on.", + "Something has already installed a non-asyncio Twisted reactor. Attempting to uninstall it; " + + "you can fix this warning by importing daphne.server early in your codebase or " + + "finding the package that imports Twisted and importing it later on.", UserWarning, ) del sys.modules["twisted.internet.reactor"] @@ -34,7 +35,6 @@ logger = logging.getLogger(__name__) class Server(object): - def __init__( self, application, @@ -91,11 +91,13 @@ class Server(object): self.ws_factory.setProtocolOptions( autoPingTimeout=self.ping_timeout, allowNullOrigin=True, - openHandshakeTimeout=self.websocket_handshake_timeout + openHandshakeTimeout=self.websocket_handshake_timeout, ) if self.verbosity <= 1: # Redirect the Twisted log to nowhere - globalLogBeginner.beginLoggingTo([lambda _: None], redirectStandardIO=False, discardBuffer=True) + globalLogBeginner.beginLoggingTo( + [lambda _: None], redirectStandardIO=False, discardBuffer=True + ) else: globalLogBeginner.beginLoggingTo([STDLibLogObserver(__name__)]) @@ -103,7 +105,9 @@ class Server(object): if http.H2_ENABLED: logger.info("HTTP/2 support enabled") else: - logger.info("HTTP/2 support not enabled (install the http2 and tls Twisted extras)") + logger.info( + "HTTP/2 support not enabled (install the http2 and tls Twisted extras)" + ) # Kick off the timeout loop reactor.callLater(1, self.application_checker) @@ -141,7 +145,11 @@ class Server(object): host = port.getHost() if hasattr(host, "host") and hasattr(host, "port"): self.listening_addresses.append((host.host, host.port)) - logger.info("Listening on TCP address %s:%s", port.getHost().host, port.getHost().port) + logger.info( + "Listening on TCP address %s:%s", + port.getHost().host, + port.getHost().port, + ) def listen_error(self, failure): logger.critical("Listen failure: %s", failure.getErrorMessage()) @@ -187,10 +195,13 @@ class Server(object): # Run it, and stash the future for later checking if protocol not in self.connections: return None - self.connections[protocol]["application_instance"] = asyncio.ensure_future(application_instance( - receive=input_queue.get, - send=lambda message: self.handle_reply(protocol, message), - ), loop=asyncio.get_event_loop()) + self.connections[protocol]["application_instance"] = asyncio.ensure_future( + application_instance( + receive=input_queue.get, + send=lambda message: self.handle_reply(protocol, message), + ), + loop=asyncio.get_event_loop(), + ) return input_queue async def handle_reply(self, protocol, message): @@ -215,7 +226,10 @@ class Server(object): application_instance = details.get("application_instance", None) # First, see if the protocol disconnected and the app has taken # too long to close up - if disconnected and time.time() - disconnected > self.application_close_timeout: + if ( + disconnected + and time.time() - disconnected > self.application_close_timeout + ): if application_instance and not application_instance.done(): logger.warning( "Application instance %r for connection %s took too long to shut down and was killed.", @@ -238,14 +252,11 @@ class Server(object): else: exception_output = "{}\n{}{}".format( exception, - "".join(traceback.format_tb( - exception.__traceback__, - )), + "".join(traceback.format_tb(exception.__traceback__)), " {}".format(exception), ) logger.error( - "Exception inside application: %s", - exception_output, + "Exception inside application: %s", exception_output ) if not disconnected: protocol.handle_exception(exception) diff --git a/daphne/testing.py b/daphne/testing.py index e606952..f5f3724 100644 --- a/daphne/testing.py +++ b/daphne/testing.py @@ -100,9 +100,7 @@ class DaphneTestingInstance: Adds messages for the application to send back. The next time it receives an incoming message, it will reply with these. """ - TestApplication.save_setup( - response_messages=messages, - ) + TestApplication.save_setup(response_messages=messages) class DaphneProcess(multiprocessing.Process): @@ -193,12 +191,7 @@ class TestApplication: Stores setup information. """ with open(cls.setup_storage, "wb") as fh: - pickle.dump( - { - "response_messages": response_messages, - }, - fh, - ) + pickle.dump({"response_messages": response_messages}, fh) @classmethod def load_setup(cls): @@ -218,13 +211,7 @@ class TestApplication: We could use pickle here, but that seems wrong, still, somehow. """ with open(cls.result_storage, "wb") as fh: - pickle.dump( - { - "scope": scope, - "messages": messages, - }, - fh, - ) + pickle.dump({"scope": scope, "messages": messages}, fh) @classmethod def save_exception(cls, exception): @@ -233,12 +220,7 @@ class TestApplication: We could use pickle here, but that seems wrong, still, somehow. """ with open(cls.result_storage, "wb") as fh: - pickle.dump( - { - "exception": exception, - }, - fh, - ) + pickle.dump({"exception": exception}, fh) @classmethod def load_result(cls): diff --git a/daphne/utils.py b/daphne/utils.py index ad64439..81f1f9d 100644 --- a/daphne/utils.py +++ b/daphne/utils.py @@ -22,12 +22,14 @@ def header_value(headers, header_name): return value.decode("utf-8") -def parse_x_forwarded_for(headers, - address_header_name="X-Forwarded-For", - port_header_name="X-Forwarded-Port", - proto_header_name="X-Forwarded-Proto", - original_addr=None, - original_scheme=None): +def parse_x_forwarded_for( + headers, + address_header_name="X-Forwarded-For", + port_header_name="X-Forwarded-Port", + proto_header_name="X-Forwarded-Proto", + original_addr=None, + original_scheme=None, +): """ Parses an X-Forwarded-For header and returns a host/port pair as a list. diff --git a/daphne/ws_protocol.py b/daphne/ws_protocol.py index f0b7bda..edf1254 100755 --- a/daphne/ws_protocol.py +++ b/daphne/ws_protocol.py @@ -3,7 +3,11 @@ import time import traceback from urllib.parse import unquote -from autobahn.twisted.websocket import ConnectionDeny, WebSocketServerFactory, WebSocketServerProtocol +from autobahn.twisted.websocket import ( + ConnectionDeny, + WebSocketServerFactory, + WebSocketServerProtocol, +) from twisted.internet import defer from .utils import parse_x_forwarded_for @@ -54,32 +58,34 @@ class WebSocketProtocol(WebSocketServerProtocol): self.server.proxy_forwarded_address_header, self.server.proxy_forwarded_port_header, self.server.proxy_forwarded_proto_header, - self.client_addr + self.client_addr, ) # Decode websocket subprotocol options subprotocols = [] for header, value in self.clean_headers: if header == b"sec-websocket-protocol": subprotocols = [ - x.strip() - for x in - unquote(value.decode("ascii")).split(",") + x.strip() for x in unquote(value.decode("ascii")).split(",") ] # Make new application instance with scope self.path = request.path.encode("ascii") - self.application_deferred = defer.maybeDeferred(self.server.create_application, self, { - "type": "websocket", - "path": unquote(self.path.decode("ascii")), - "headers": self.clean_headers, - "query_string": self._raw_query_string, # Passed by HTTP protocol - "client": self.client_addr, - "server": self.server_addr, - "subprotocols": subprotocols, - }) + self.application_deferred = defer.maybeDeferred( + self.server.create_application, + self, + { + "type": "websocket", + "path": unquote(self.path.decode("ascii")), + "headers": self.clean_headers, + "query_string": self._raw_query_string, # Passed by HTTP protocol + "client": self.client_addr, + "server": self.server_addr, + "subprotocols": subprotocols, + }, + ) if self.application_deferred is not None: self.application_deferred.addCallback(self.applicationCreateWorked) self.application_deferred.addErrback(self.applicationCreateFailed) - except Exception as e: + except Exception: # Exceptions here are not displayed right, just 500. # Turn them into an ERROR log. logger.error(traceback.format_exc()) @@ -98,10 +104,16 @@ class WebSocketProtocol(WebSocketServerProtocol): self.application_queue = application_queue # Send over the connect message self.application_queue.put_nowait({"type": "websocket.connect"}) - self.server.log_action("websocket", "connecting", { - "path": self.request.path, - "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, - }) + self.server.log_action( + "websocket", + "connecting", + { + "path": self.request.path, + "client": "%s:%s" % tuple(self.client_addr) + if self.client_addr + else None, + }, + ) def applicationCreateFailed(self, failure): """ @@ -115,10 +127,16 @@ class WebSocketProtocol(WebSocketServerProtocol): def onOpen(self): # Send news that this channel is open logger.debug("WebSocket %s open and established", self.client_addr) - self.server.log_action("websocket", "connected", { - "path": self.request.path, - "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, - }) + self.server.log_action( + "websocket", + "connected", + { + "path": self.request.path, + "client": "%s:%s" % tuple(self.client_addr) + if self.client_addr + else None, + }, + ) def onMessage(self, payload, isBinary): # If we're muted, do nothing. @@ -128,15 +146,13 @@ class WebSocketProtocol(WebSocketServerProtocol): logger.debug("WebSocket incoming frame on %s", self.client_addr) self.last_ping = time.time() if isBinary: - self.application_queue.put_nowait({ - "type": "websocket.receive", - "bytes": payload, - }) + self.application_queue.put_nowait( + {"type": "websocket.receive", "bytes": payload} + ) else: - self.application_queue.put_nowait({ - "type": "websocket.receive", - "text": payload.decode("utf8"), - }) + self.application_queue.put_nowait( + {"type": "websocket.receive", "text": payload.decode("utf8")} + ) def onClose(self, wasClean, code, reason): """ @@ -145,14 +161,19 @@ class WebSocketProtocol(WebSocketServerProtocol): self.server.protocol_disconnected(self) logger.debug("WebSocket closed for %s", self.client_addr) if not self.muted and hasattr(self, "application_queue"): - self.application_queue.put_nowait({ - "type": "websocket.disconnect", - "code": code, - }) - self.server.log_action("websocket", "disconnected", { - "path": self.request.path, - "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, - }) + self.application_queue.put_nowait( + {"type": "websocket.disconnect", "code": code} + ) + self.server.log_action( + "websocket", + "disconnected", + { + "path": self.request.path, + "client": "%s:%s" % tuple(self.client_addr) + if self.client_addr + else None, + }, + ) ### Internal event handling @@ -171,9 +192,8 @@ class WebSocketProtocol(WebSocketServerProtocol): raise ValueError("Socket has not been accepted, so cannot send over it") if message.get("bytes", None) and message.get("text", None): raise ValueError( - "Got invalid WebSocket reply message on %s - contains both bytes and text keys" % ( - message, - ) + "Got invalid WebSocket reply message on %s - contains both bytes and text keys" + % (message,) ) if message.get("bytes", None): self.serverSend(message["bytes"], True) @@ -187,7 +207,9 @@ class WebSocketProtocol(WebSocketServerProtocol): if hasattr(self, "handshake_deferred"): # If the handshake is still ongoing, we need to emit a HTTP error # code rather than a WebSocket one. - self.handshake_deferred.errback(ConnectionDeny(code=500, reason="Internal server error")) + self.handshake_deferred.errback( + ConnectionDeny(code=500, reason="Internal server error") + ) else: self.sendCloseFrame(code=1011) @@ -203,14 +225,22 @@ class WebSocketProtocol(WebSocketServerProtocol): """ Called when we get a message saying to reject the connection. """ - self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied")) + self.handshake_deferred.errback( + ConnectionDeny(code=403, reason="Access denied") + ) del self.handshake_deferred self.server.protocol_disconnected(self) logger.debug("WebSocket %s rejected by application", self.client_addr) - self.server.log_action("websocket", "rejected", { - "path": self.request.path, - "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, - }) + self.server.log_action( + "websocket", + "rejected", + { + "path": self.request.path, + "client": "%s:%s" % tuple(self.client_addr) + if self.client_addr + else None, + }, + ) def serverSend(self, content, binary=False): """ @@ -244,7 +274,10 @@ class WebSocketProtocol(WebSocketServerProtocol): Called periodically to see if we should timeout something """ # Web timeout checking - if self.duration() > self.server.websocket_timeout and self.server.websocket_timeout >= 0: + if ( + self.duration() > self.server.websocket_timeout + and self.server.websocket_timeout >= 0 + ): self.serverClose() # Ping check # If we're still connecting, deny the connection @@ -287,6 +320,6 @@ class WebSocketFactory(WebSocketServerFactory): protocol = super(WebSocketFactory, self).buildProtocol(addr) protocol.factory = self return protocol - except Exception as e: + except Exception: logger.error("Cannot build protocol: %s" % traceback.format_exc()) raise diff --git a/setup.cfg b/setup.cfg index 8c9a5f5..46dd04f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,9 +5,9 @@ universal=1 addopts = tests/ [isort] -line_length = 120 +include_trailing_comma = True multi_line_output = 3 -known_first_party = channels,daphne,asgiref +known_first_party = channels,daphne,asgiref,channels_redis [flake8] exclude = venv/*,tox/*,docs/*,testproject/*,js_client/*,.eggs/* diff --git a/setup.py b/setup.py index 692ab6e..ff98e27 100755 --- a/setup.py +++ b/setup.py @@ -22,23 +22,12 @@ setup( package_dir={"twisted": "daphne/twisted"}, packages=find_packages() + ["twisted.plugins"], include_package_data=True, - install_requires=[ - "twisted>=18.7", - "autobahn>=0.18", - ], - setup_requires=[ - "pytest-runner", - ], - extras_require={ - "tests": [ - "hypothesis", - "pytest", - "pytest-asyncio~=0.8", - ], + install_requires=["twisted>=18.7", "autobahn>=0.18"], + setup_requires=["pytest-runner"], + extras_require={"tests": ["hypothesis", "pytest", "pytest-asyncio~=0.8"]}, + entry_points={ + "console_scripts": ["daphne = daphne.cli:CommandLineInterface.entrypoint"] }, - entry_points={"console_scripts": [ - "daphne = daphne.cli:CommandLineInterface.entrypoint", - ]}, classifiers=[ "Development Status :: 4 - Beta", "Environment :: Web Environment", diff --git a/tests/http_base.py b/tests/http_base.py index c7cec48..866a066 100644 --- a/tests/http_base.py +++ b/tests/http_base.py @@ -19,7 +19,9 @@ class DaphneTestCase(unittest.TestCase): ### Plain HTTP helpers - def run_daphne_http(self, method, path, params, body, responses, headers=None, timeout=1, xff=False): + def run_daphne_http( + self, method, path, params, body, responses, headers=None, timeout=1, xff=False + ): """ Runs Daphne with the given request callback (given the base URL) and response messages. @@ -38,7 +40,9 @@ class DaphneTestCase(unittest.TestCase): # Manually send over headers (encoding any non-safe values as best we can) if headers: for header_name, header_value in headers: - conn.putheader(header_name.encode("utf8"), header_value.encode("utf8")) + conn.putheader( + header_name.encode("utf8"), header_value.encode("utf8") + ) # Send body if provided. if body: conn.putheader("Content-Length", str(len(body))) @@ -50,9 +54,11 @@ class DaphneTestCase(unittest.TestCase): except socket.timeout: # See if they left an exception for us to load test_app.get_received() - raise RuntimeError("Daphne timed out handling request, no exception found.") + raise RuntimeError( + "Daphne timed out handling request, no exception found." + ) # Return scope, messages, response - return test_app.get_received() + (response, ) + return test_app.get_received() + (response,) def run_daphne_raw(self, data, timeout=1): """ @@ -68,9 +74,13 @@ class DaphneTestCase(unittest.TestCase): try: return s.recv(1000000) except socket.timeout: - raise RuntimeError("Daphne timed out handling raw request, no exception found.") + raise RuntimeError( + "Daphne timed out handling raw request, no exception found." + ) - def run_daphne_request(self, method, path, params=None, body=None, headers=None, xff=False): + def run_daphne_request( + self, method, path, params=None, body=None, headers=None, xff=False + ): """ Convenience method for just testing request handling. Returns (scope, messages) @@ -95,17 +105,21 @@ class DaphneTestCase(unittest.TestCase): Returns (scope, messages) """ _, _, response = self.run_daphne_http( - method="GET", - path="/", - params={}, - body=b"", - responses=response_messages, + method="GET", path="/", params={}, body=b"", responses=response_messages ) return response ### WebSocket helpers - def websocket_handshake(self, test_app, path="/", params=None, headers=None, subprotocols=None, timeout=1): + def websocket_handshake( + self, + test_app, + path="/", + params=None, + headers=None, + subprotocols=None, + timeout=1, + ): """ Runs a WebSocket handshake negotiation and returns the raw socket object & the selected subprotocol. @@ -124,14 +138,16 @@ class DaphneTestCase(unittest.TestCase): # Do WebSocket handshake headers + any other headers if headers is None: headers = [] - headers.extend([ - ("Host", "example.com"), - ("Upgrade", "websocket"), - ("Connection", "Upgrade"), - ("Sec-WebSocket-Key", "x3JJHMbDL1EzLkh9GBhXDw=="), - ("Sec-WebSocket-Version", "13"), - ("Origin", "http://example.com") - ]) + headers.extend( + [ + ("Host", "example.com"), + ("Upgrade", "websocket"), + ("Connection", "Upgrade"), + ("Sec-WebSocket-Key", "x3JJHMbDL1EzLkh9GBhXDw=="), + ("Sec-WebSocket-Version", "13"), + ("Origin", "http://example.com"), + ] + ) if subprotocols: headers.append(("Sec-WebSocket-Protocol", ", ".join(subprotocols))) if headers: @@ -149,10 +165,7 @@ class DaphneTestCase(unittest.TestCase): if response.status != 101: raise RuntimeError("WebSocket upgrade did not result in status code 101") # Prepare headers for subprotocol searching - response_headers = dict( - (n.lower(), v) - for n, v in response.getheaders() - ) + response_headers = dict((n.lower(), v) for n, v in response.getheaders()) response.read() assert not response.closed # Return the raw socket and any subprotocol @@ -234,10 +247,7 @@ class DaphneTestCase(unittest.TestCase): # Make sure all required keys are present self.assertTrue(required_keys <= present_keys) # Assert that no other keys are present - self.assertEqual( - set(), - present_keys - required_keys - optional_keys, - ) + self.assertEqual(set(), present_keys - required_keys - optional_keys) def assert_valid_path(self, path, request_path): """ diff --git a/tests/http_strategies.py b/tests/http_strategies.py index 4335afd..d78ac10 100644 --- a/tests/http_strategies.py +++ b/tests/http_strategies.py @@ -6,7 +6,9 @@ from hypothesis import strategies HTTP_METHODS = ["OPTIONS", "GET", "HEAD", "POST", "PUT", "DELETE", "TRACE", "CONNECT"] # Unicode characters of the "Letter" category -letters = strategies.characters(whitelist_categories=("Lu", "Ll", "Lt", "Lm", "Lo", "Nl")) +letters = strategies.characters( + whitelist_categories=("Lu", "Ll", "Lt", "Lm", "Lo", "Nl") +) def http_method(): @@ -22,11 +24,9 @@ def http_path(): """ Returns a URL path (not encoded). """ - return strategies.lists( - _http_path_portion(), - min_size=0, - max_size=10, - ).map(lambda s: "/" + "/".join(s)) + return strategies.lists(_http_path_portion(), min_size=0, max_size=10).map( + lambda s: "/" + "/".join(s) + ) def http_body(): @@ -53,10 +53,7 @@ def valid_bidi(value): def _domain_label(): return strategies.text( - alphabet=letters, - min_size=1, - average_size=6, - max_size=63, + alphabet=letters, min_size=1, average_size=6, max_size=63 ).filter(valid_bidi) @@ -64,19 +61,14 @@ def international_domain_name(): """ Returns a byte string of a domain name, IDNA-encoded. """ - return strategies.lists( - _domain_label(), - min_size=2, - average_size=2, - ).map(lambda s: (".".join(s)).encode("idna")) + return strategies.lists(_domain_label(), min_size=2, average_size=2).map( + lambda s: (".".join(s)).encode("idna") + ) def _query_param(): return strategies.text( - alphabet=letters, - min_size=1, - average_size=10, - max_size=255, + alphabet=letters, min_size=1, average_size=10, max_size=255 ).map(lambda s: s.encode("utf8")) @@ -87,9 +79,7 @@ def query_params(): ensures that the total urlencoded query string is not longer than 1500 characters. """ return strategies.lists( - strategies.tuples(_query_param(), _query_param()), - min_size=0, - average_size=5, + strategies.tuples(_query_param(), _query_param()), min_size=0, average_size=5 ).filter(lambda x: len(parse.urlencode(x)) < 1500) @@ -101,9 +91,7 @@ def header_name(): and 20 characters long """ return strategies.text( - alphabet=string.ascii_letters + string.digits + "-", - min_size=1, - max_size=30, + alphabet=string.ascii_letters + string.digits + "-", min_size=1, max_size=30 ) @@ -115,7 +103,10 @@ def header_value(): https://en.wikipedia.org/wiki/List_of_HTTP_header_fields """ return strategies.text( - alphabet=string.ascii_letters + string.digits + string.punctuation.replace(",", "") + " /t", + alphabet=string.ascii_letters + + string.digits + + string.punctuation.replace(",", "") + + " /t", min_size=1, average_size=40, max_size=8190, diff --git a/tests/test_cli.py b/tests/test_cli.py index 9e90ab0..7bb45dc 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -18,45 +18,32 @@ class TestEndpointDescriptions(TestCase): def testTcpPortBindings(self): self.assertEqual( build(port=1234, host="example.com"), - ["tcp:port=1234:interface=example.com"] + ["tcp:port=1234:interface=example.com"], ) self.assertEqual( - build(port=8000, host="127.0.0.1"), - ["tcp:port=8000:interface=127.0.0.1"] + build(port=8000, host="127.0.0.1"), ["tcp:port=8000:interface=127.0.0.1"] ) self.assertEqual( - build(port=8000, host="[200a::1]"), - [r'tcp:port=8000:interface=200a\:\:1'] + build(port=8000, host="[200a::1]"), [r"tcp:port=8000:interface=200a\:\:1"] ) self.assertEqual( - build(port=8000, host="200a::1"), - [r'tcp:port=8000:interface=200a\:\:1'] + build(port=8000, host="200a::1"), [r"tcp:port=8000:interface=200a\:\:1"] ) # incomplete port/host kwargs raise errors - self.assertRaises( - ValueError, - build, port=123 - ) - self.assertRaises( - ValueError, - build, host="example.com" - ) + self.assertRaises(ValueError, build, port=123) + self.assertRaises(ValueError, build, host="example.com") def testUnixSocketBinding(self): self.assertEqual( - build(unix_socket="/tmp/daphne.sock"), - ["unix:/tmp/daphne.sock"] + build(unix_socket="/tmp/daphne.sock"), ["unix:/tmp/daphne.sock"] ) def testFileDescriptorBinding(self): - self.assertEqual( - build(file_descriptor=5), - ["fd:fileno=5"] - ) + self.assertEqual(build(file_descriptor=5), ["fd:fileno=5"]) def testMultipleEnpoints(self): self.assertEqual( @@ -65,14 +52,16 @@ class TestEndpointDescriptions(TestCase): file_descriptor=123, unix_socket="/tmp/daphne.sock", port=8080, - host="10.0.0.1" + host="10.0.0.1", ) ), - sorted([ - "tcp:port=8080:interface=10.0.0.1", - "unix:/tmp/daphne.sock", - "fd:fileno=123" - ]) + sorted( + [ + "tcp:port=8080:interface=10.0.0.1", + "unix:/tmp/daphne.sock", + "fd:fileno=123", + ] + ), ) @@ -112,7 +101,9 @@ class TestCLIInterface(TestCase): Passes in a fake application automatically. """ cli = self.TestedCLI() - cli.run(args + ["daphne:__version__"]) # We just pass something importable as app + cli.run( + args + ["daphne:__version__"] + ) # We just pass something importable as app # Check the server got all arguments as intended for key, value in server_kwargs.items(): # Get the value and sort it if it's a list (for endpoint checking) @@ -123,52 +114,30 @@ class TestCLIInterface(TestCase): self.assertEqual( value, actual_value, - "Wrong value for server kwarg %s: %r != %r" % ( - key, - value, - actual_value, - ), + "Wrong value for server kwarg %s: %r != %r" + % (key, value, actual_value), ) def testCLIBasics(self): """ Tests basic endpoint generation. """ + self.assertCLI([], {"endpoints": ["tcp:port=8000:interface=127.0.0.1"]}) self.assertCLI( - [], - { - "endpoints": ["tcp:port=8000:interface=127.0.0.1"], - }, + ["-p", "123"], {"endpoints": ["tcp:port=123:interface=127.0.0.1"]} ) self.assertCLI( - ["-p", "123"], - { - "endpoints": ["tcp:port=123:interface=127.0.0.1"], - }, + ["-b", "10.0.0.1"], {"endpoints": ["tcp:port=8000:interface=10.0.0.1"]} ) self.assertCLI( - ["-b", "10.0.0.1"], - { - "endpoints": ["tcp:port=8000:interface=10.0.0.1"], - }, + ["-b", "200a::1"], {"endpoints": [r"tcp:port=8000:interface=200a\:\:1"]} ) self.assertCLI( - ["-b", "200a::1"], - { - "endpoints": [r'tcp:port=8000:interface=200a\:\:1'], - }, - ) - self.assertCLI( - ["-b", "[200a::1]"], - { - "endpoints": [r'tcp:port=8000:interface=200a\:\:1'], - }, + ["-b", "[200a::1]"], {"endpoints": [r"tcp:port=8000:interface=200a\:\:1"]} ) self.assertCLI( ["-p", "8080", "-b", "example.com"], - { - "endpoints": ["tcp:port=8080:interface=example.com"], - }, + {"endpoints": ["tcp:port=8080:interface=example.com"]}, ) def testUnixSockets(self): @@ -178,7 +147,7 @@ class TestCLIInterface(TestCase): "endpoints": [ "tcp:port=8080:interface=127.0.0.1", "unix:/tmp/daphne.sock", - ], + ] }, ) self.assertCLI( @@ -187,17 +156,12 @@ class TestCLIInterface(TestCase): "endpoints": [ "tcp:port=8000:interface=example.com", "unix:/tmp/daphne.sock", - ], + ] }, ) self.assertCLI( ["-u", "/tmp/daphne.sock", "--fd", "5"], - { - "endpoints": [ - "fd:fileno=5", - "unix:/tmp/daphne.sock" - ], - }, + {"endpoints": ["fd:fileno=5", "unix:/tmp/daphne.sock"]}, ) def testMixedCLIEndpointCreation(self): @@ -209,8 +173,8 @@ class TestCLIInterface(TestCase): { "endpoints": [ "tcp:port=8080:interface=127.0.0.1", - "unix:/tmp/daphne.sock" - ], + "unix:/tmp/daphne.sock", + ] }, ) self.assertCLI( @@ -219,7 +183,7 @@ class TestCLIInterface(TestCase): "endpoints": [ "tcp:port=8080:interface=127.0.0.1", "tcp:port=8080:interface=127.0.0.1", - ], + ] }, ) @@ -227,11 +191,4 @@ class TestCLIInterface(TestCase): """ Tests entirely custom endpoints """ - self.assertCLI( - ["-e", "imap:"], - { - "endpoints": [ - "imap:", - ], - }, - ) + self.assertCLI(["-e", "imap:"], {"endpoints": ["imap:"]}) diff --git a/tests/test_http_request.py b/tests/test_http_request.py index 79274ed..e02b8b6 100644 --- a/tests/test_http_request.py +++ b/tests/test_http_request.py @@ -15,13 +15,7 @@ class TestHTTPRequest(DaphneTestCase): """ def assert_valid_http_scope( - self, - scope, - method, - path, - params=None, - headers=None, - scheme=None, + self, scope, method, path, params=None, headers=None, scheme=None ): """ Checks that the passed scope is a valid ASGI HTTP scope regarding types @@ -29,7 +23,14 @@ class TestHTTPRequest(DaphneTestCase): """ # Check overall keys self.assert_key_sets( - required_keys={"type", "http_version", "method", "path", "query_string", "headers"}, + required_keys={ + "type", + "http_version", + "method", + "path", + "query_string", + "headers", + }, optional_keys={"scheme", "root_path", "client", "server"}, actual_keys=scope.keys(), ) @@ -50,7 +51,9 @@ class TestHTTPRequest(DaphneTestCase): query_string = scope["query_string"] self.assertIsInstance(query_string, bytes) if params: - self.assertEqual(query_string, parse.urlencode(params or []).encode("ascii")) + self.assertEqual( + query_string, parse.urlencode(params or []).encode("ascii") + ) # Ordering of header names is not important, but the order of values for a header # name is. To assert whether that order is kept, we transform both the request # headers and the channel message headers into a dictionary @@ -59,7 +62,7 @@ class TestHTTPRequest(DaphneTestCase): for name, value in scope["headers"]: transformed_scope_headers[name].append(value) transformed_request_headers = collections.defaultdict(list) - for name, value in (headers or []): + for name, value in headers or []: expected_name = name.lower().strip().encode("ascii") expected_value = value.strip().encode("ascii") transformed_request_headers[expected_name].append(expected_value) @@ -103,27 +106,31 @@ class TestHTTPRequest(DaphneTestCase): @given( request_path=http_strategies.http_path(), - request_params=http_strategies.query_params() + request_params=http_strategies.query_params(), ) @settings(max_examples=5, deadline=5000) def test_get_request(self, request_path, request_params): """ Tests a typical HTTP GET request, with a path and query parameters """ - scope, messages = self.run_daphne_request("GET", request_path, params=request_params) + scope, messages = self.run_daphne_request( + "GET", request_path, params=request_params + ) self.assert_valid_http_scope(scope, "GET", request_path, params=request_params) self.assert_valid_http_request_message(messages[0], body=b"") @given( request_path=http_strategies.http_path(), - request_body=http_strategies.http_body() + request_body=http_strategies.http_body(), ) @settings(max_examples=5, deadline=5000) def test_post_request(self, request_path, request_body): """ Tests a typical HTTP POST request, with a path and body. """ - scope, messages = self.run_daphne_request("POST", request_path, body=request_body) + scope, messages = self.run_daphne_request( + "POST", request_path, body=request_body + ) self.assert_valid_http_scope(scope, "POST", request_path) self.assert_valid_http_request_message(messages[0], body=request_body) @@ -134,8 +141,12 @@ class TestHTTPRequest(DaphneTestCase): Tests that HTTP header fields are handled as specified """ request_path = "/te st-à/" - scope, messages = self.run_daphne_request("OPTIONS", request_path, headers=request_headers) - self.assert_valid_http_scope(scope, "OPTIONS", request_path, headers=request_headers) + scope, messages = self.run_daphne_request( + "OPTIONS", request_path, headers=request_headers + ) + self.assert_valid_http_scope( + scope, "OPTIONS", request_path, headers=request_headers + ) self.assert_valid_http_request_message(messages[0], body=b"") @given(request_headers=http_strategies.headers()) @@ -150,8 +161,12 @@ class TestHTTPRequest(DaphneTestCase): duplicated_headers = [(header_name, header[1]) for header in request_headers] # Run the request request_path = "/te st-à/" - scope, messages = self.run_daphne_request("OPTIONS", request_path, headers=duplicated_headers) - self.assert_valid_http_scope(scope, "OPTIONS", request_path, headers=duplicated_headers) + scope, messages = self.run_daphne_request( + "OPTIONS", request_path, headers=duplicated_headers + ) + self.assert_valid_http_scope( + scope, "OPTIONS", request_path, headers=duplicated_headers + ) self.assert_valid_http_request_message(messages[0], body=b"") @given( @@ -222,10 +237,7 @@ class TestHTTPRequest(DaphneTestCase): """ Make sure that, by default, X-Forwarded-For is ignored. """ - headers = [ - ["X-Forwarded-For", "10.1.2.3"], - ["X-Forwarded-Port", "80"], - ] + headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]] scope, messages = self.run_daphne_request("GET", "/", headers=headers) self.assert_valid_http_scope(scope, "GET", "/", headers=headers) self.assert_valid_http_request_message(messages[0], body=b"") @@ -236,10 +248,7 @@ class TestHTTPRequest(DaphneTestCase): """ When X-Forwarded-For is enabled, make sure it is respected. """ - headers = [ - ["X-Forwarded-For", "10.1.2.3"], - ["X-Forwarded-Port", "80"], - ] + headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]] scope, messages = self.run_daphne_request("GET", "/", headers=headers, xff=True) self.assert_valid_http_scope(scope, "GET", "/", headers=headers) self.assert_valid_http_request_message(messages[0], body=b"") @@ -251,9 +260,7 @@ class TestHTTPRequest(DaphneTestCase): When X-Forwarded-For is enabled but only the host is passed, make sure that at least makes it through. """ - headers = [ - ["X-Forwarded-For", "10.1.2.3"], - ] + headers = [["X-Forwarded-For", "10.1.2.3"]] scope, messages = self.run_daphne_request("GET", "/", headers=headers, xff=True) self.assert_valid_http_scope(scope, "GET", "/", headers=headers) self.assert_valid_http_request_message(messages[0], body=b"") @@ -265,8 +272,12 @@ class TestHTTPRequest(DaphneTestCase): Tests that requests with invalid (non-ASCII) characters fail. """ # Bad path - response = self.run_daphne_raw(b"GET /\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n") + response = self.run_daphne_raw( + b"GET /\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n" + ) self.assertTrue(response.startswith(b"HTTP/1.0 400 Bad Request")) # Bad querystring - response = self.run_daphne_raw(b"GET /?\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n") + response = self.run_daphne_raw( + b"GET /?\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n" + ) self.assertTrue(response.startswith(b"HTTP/1.0 400 Bad Request")) diff --git a/tests/test_http_response.py b/tests/test_http_response.py index 2efc4ec..3576697 100644 --- a/tests/test_http_response.py +++ b/tests/test_http_response.py @@ -15,26 +15,24 @@ class TestHTTPResponse(DaphneTestCase): """ Lowercases and sorts headers, and strips transfer-encoding ones. """ - return sorted([ - (name.lower(), value.strip()) - for name, value in headers - if name.lower() != "transfer-encoding" - ]) + return sorted( + [ + (name.lower(), value.strip()) + for name, value in headers + if name.lower() != "transfer-encoding" + ] + ) def test_minimal_response(self): """ Smallest viable example. Mostly verifies that our response building works. """ - response = self.run_daphne_response([ - { - "type": "http.response.start", - "status": 200, - }, - { - "type": "http.response.body", - "body": b"hello world", - }, - ]) + response = self.run_daphne_response( + [ + {"type": "http.response.start", "status": 200}, + {"type": "http.response.body", "body": b"hello world"}, + ] + ) self.assertEqual(response.status, 200) self.assertEqual(response.read(), b"hello world") @@ -46,30 +44,23 @@ class TestHTTPResponse(DaphneTestCase): to make sure it stays required. """ with self.assertRaises(ValueError): - self.run_daphne_response([ - { - "type": "http.response.start", - }, - { - "type": "http.response.body", - "body": b"hello world", - }, - ]) + self.run_daphne_response( + [ + {"type": "http.response.start"}, + {"type": "http.response.body", "body": b"hello world"}, + ] + ) def test_custom_status_code(self): """ Tries a non-default status code. """ - response = self.run_daphne_response([ - { - "type": "http.response.start", - "status": 201, - }, - { - "type": "http.response.body", - "body": b"i made a thing!", - }, - ]) + response = self.run_daphne_response( + [ + {"type": "http.response.start", "status": 201}, + {"type": "http.response.body", "body": b"i made a thing!"}, + ] + ) self.assertEqual(response.status, 201) self.assertEqual(response.read(), b"i made a thing!") @@ -77,21 +68,13 @@ class TestHTTPResponse(DaphneTestCase): """ Tries sending a response in multiple parts. """ - response = self.run_daphne_response([ - { - "type": "http.response.start", - "status": 201, - }, - { - "type": "http.response.body", - "body": b"chunk 1 ", - "more_body": True, - }, - { - "type": "http.response.body", - "body": b"chunk 2", - }, - ]) + response = self.run_daphne_response( + [ + {"type": "http.response.start", "status": 201}, + {"type": "http.response.body", "body": b"chunk 1 ", "more_body": True}, + {"type": "http.response.body", "body": b"chunk 2"}, + ] + ) self.assertEqual(response.status, 201) self.assertEqual(response.read(), b"chunk 1 chunk 2") @@ -99,25 +82,14 @@ class TestHTTPResponse(DaphneTestCase): """ Tries sending a response in multiple parts and an empty end. """ - response = self.run_daphne_response([ - { - "type": "http.response.start", - "status": 201, - }, - { - "type": "http.response.body", - "body": b"chunk 1 ", - "more_body": True, - }, - { - "type": "http.response.body", - "body": b"chunk 2", - "more_body": True, - }, - { - "type": "http.response.body", - }, - ]) + response = self.run_daphne_response( + [ + {"type": "http.response.start", "status": 201}, + {"type": "http.response.body", "body": b"chunk 1 ", "more_body": True}, + {"type": "http.response.body", "body": b"chunk 2", "more_body": True}, + {"type": "http.response.body"}, + ] + ) self.assertEqual(response.status, 201) self.assertEqual(response.read(), b"chunk 1 chunk 2") @@ -127,16 +99,12 @@ class TestHTTPResponse(DaphneTestCase): """ Tries body variants. """ - response = self.run_daphne_response([ - { - "type": "http.response.start", - "status": 200, - }, - { - "type": "http.response.body", - "body": body, - }, - ]) + response = self.run_daphne_response( + [ + {"type": "http.response.start", "status": 200}, + {"type": "http.response.body", "body": body}, + ] + ) self.assertEqual(response.status, 200) self.assertEqual(response.read(), body) @@ -144,16 +112,16 @@ class TestHTTPResponse(DaphneTestCase): @settings(max_examples=5, deadline=5000) def test_headers(self, headers): # The ASGI spec requires us to lowercase our header names - response = self.run_daphne_response([ - { - "type": "http.response.start", - "status": 200, - "headers": self.normalize_headers(headers), - }, - { - "type": "http.response.body", - }, - ]) + response = self.run_daphne_response( + [ + { + "type": "http.response.start", + "status": 200, + "headers": self.normalize_headers(headers), + }, + {"type": "http.response.body"}, + ] + ) # Check headers in a sensible way. Ignore transfer-encoding. self.assertEqual( self.normalize_headers(response.getheaders()), diff --git a/tests/test_utils.py b/tests/test_utils.py index 5dada0f..6b04939 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -13,48 +13,35 @@ class TestXForwardedForHttpParsing(TestCase): """ def test_basic(self): - headers = Headers({ - b"X-Forwarded-For": [b"10.1.2.3"], - b"X-Forwarded-Port": [b"1234"], - b"X-Forwarded-Proto": [b"https"] - }) + headers = Headers( + { + b"X-Forwarded-For": [b"10.1.2.3"], + b"X-Forwarded-Port": [b"1234"], + b"X-Forwarded-Proto": [b"https"], + } + ) result = parse_x_forwarded_for(headers) self.assertEqual(result, (["10.1.2.3", 1234], "https")) self.assertIsInstance(result[0][0], str) self.assertIsInstance(result[1], str) def test_address_only(self): - headers = Headers({ - b"X-Forwarded-For": [b"10.1.2.3"], - }) - self.assertEqual( - parse_x_forwarded_for(headers), - (["10.1.2.3", 0], None) - ) + headers = Headers({b"X-Forwarded-For": [b"10.1.2.3"]}) + self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None)) def test_v6_address(self): - headers = Headers({ - b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"], - }) - self.assertEqual( - parse_x_forwarded_for(headers), - (["1043::a321:0001", 0], None) - ) + headers = Headers({b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"]}) + self.assertEqual(parse_x_forwarded_for(headers), (["1043::a321:0001", 0], None)) def test_multiple_proxys(self): - headers = Headers({ - b"X-Forwarded-For": [b"10.1.2.3, 10.1.2.4"], - }) - self.assertEqual( - parse_x_forwarded_for(headers), - (["10.1.2.3", 0], None) - ) + headers = Headers({b"X-Forwarded-For": [b"10.1.2.3, 10.1.2.4"]}) + self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None)) def test_original(self): headers = Headers({}) self.assertEqual( parse_x_forwarded_for(headers, original_addr=["127.0.0.1", 80]), - (["127.0.0.1", 80], None) + (["127.0.0.1", 80], None), ) def test_no_original(self): @@ -73,43 +60,25 @@ class TestXForwardedForWsParsing(TestCase): b"X-Forwarded-Port": b"1234", b"X-Forwarded-Proto": b"https", } - self.assertEqual( - parse_x_forwarded_for(headers), - (["10.1.2.3", 1234], "https") - ) + self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 1234], "https")) def test_address_only(self): - headers = { - b"X-Forwarded-For": b"10.1.2.3", - } - self.assertEqual( - parse_x_forwarded_for(headers), - (["10.1.2.3", 0], None) - ) + headers = {b"X-Forwarded-For": b"10.1.2.3"} + self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None)) def test_v6_address(self): - headers = { - b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"], - } - self.assertEqual( - parse_x_forwarded_for(headers), - (["1043::a321:0001", 0], None) - ) + headers = {b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"]} + self.assertEqual(parse_x_forwarded_for(headers), (["1043::a321:0001", 0], None)) def test_multiple_proxies(self): - headers = { - b"X-Forwarded-For": b"10.1.2.3, 10.1.2.4", - } - self.assertEqual( - parse_x_forwarded_for(headers), - (["10.1.2.3", 0], None) - ) + headers = {b"X-Forwarded-For": b"10.1.2.3, 10.1.2.4"} + self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None)) def test_original(self): headers = {} self.assertEqual( parse_x_forwarded_for(headers, original_addr=["127.0.0.1", 80]), - (["127.0.0.1", 80], None) + (["127.0.0.1", 80], None), ) def test_no_original(self): diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 0ae1a21..80ec21d 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -16,13 +16,7 @@ class TestWebsocket(DaphneTestCase): """ def assert_valid_websocket_scope( - self, - scope, - path="/", - params=None, - headers=None, - scheme=None, - subprotocols=None, + self, scope, path="/", params=None, headers=None, scheme=None, subprotocols=None ): """ Checks that the passed scope is a valid ASGI HTTP scope regarding types @@ -46,7 +40,9 @@ class TestWebsocket(DaphneTestCase): query_string = scope["query_string"] self.assertIsInstance(query_string, bytes) if params: - self.assertEqual(query_string, parse.urlencode(params or []).encode("ascii")) + self.assertEqual( + query_string, parse.urlencode(params or []).encode("ascii") + ) # Ordering of header names is not important, but the order of values for a header # name is. To assert whether that order is kept, we transform both the request # headers and the channel message headers into a dictionary @@ -59,7 +55,7 @@ class TestWebsocket(DaphneTestCase): if bit.strip(): transformed_scope_headers[name].append(bit.strip()) transformed_request_headers = collections.defaultdict(list) - for name, value in (headers or []): + for name, value in headers or []: expected_name = name.lower().strip().encode("ascii") expected_value = value.strip().encode("ascii") # Make sure to split out any headers collapsed with commas @@ -92,9 +88,7 @@ class TestWebsocket(DaphneTestCase): """ # Check overall keys self.assert_key_sets( - required_keys={"type"}, - optional_keys=set(), - actual_keys=message.keys(), + required_keys={"type"}, optional_keys=set(), actual_keys=message.keys() ) # Check that it is the right type self.assertEqual(message["type"], "websocket.connect") @@ -104,11 +98,7 @@ class TestWebsocket(DaphneTestCase): Tests we can open and accept a socket. """ with DaphneTestingInstance() as test_app: - test_app.add_send_messages([ - { - "type": "websocket.accept", - } - ]) + test_app.add_send_messages([{"type": "websocket.accept"}]) self.websocket_handshake(test_app) # Validate the scope and messages we got scope, messages = test_app.get_received() @@ -120,11 +110,7 @@ class TestWebsocket(DaphneTestCase): Tests we can reject a socket and it won't complete the handshake. """ with DaphneTestingInstance() as test_app: - test_app.add_send_messages([ - { - "type": "websocket.close", - } - ]) + test_app.add_send_messages([{"type": "websocket.close"}]) with self.assertRaises(RuntimeError): self.websocket_handshake(test_app) @@ -134,13 +120,12 @@ class TestWebsocket(DaphneTestCase): """ subprotocols = ["proto1", "proto2"] with DaphneTestingInstance() as test_app: - test_app.add_send_messages([ - { - "type": "websocket.accept", - "subprotocol": "proto2", - } - ]) - _, subprotocol = self.websocket_handshake(test_app, subprotocols=subprotocols) + test_app.add_send_messages( + [{"type": "websocket.accept", "subprotocol": "proto2"}] + ) + _, subprotocol = self.websocket_handshake( + test_app, subprotocols=subprotocols + ) # Validate the scope and messages we got assert subprotocol == "proto2" scope, messages = test_app.get_received() @@ -151,16 +136,9 @@ class TestWebsocket(DaphneTestCase): """ Tests that X-Forwarded-For headers get parsed right """ - headers = [ - ["X-Forwarded-For", "10.1.2.3"], - ["X-Forwarded-Port", "80"], - ] + headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]] with DaphneTestingInstance(xff=True) as test_app: - test_app.add_send_messages([ - { - "type": "websocket.accept", - } - ]) + test_app.add_send_messages([{"type": "websocket.accept"}]) self.websocket_handshake(test_app, headers=headers) # Validate the scope and messages we got scope, messages = test_app.get_received() @@ -174,22 +152,13 @@ class TestWebsocket(DaphneTestCase): request_headers=http_strategies.headers(), ) @settings(max_examples=5, deadline=2000) - def test_http_bits( - self, - request_path, - request_params, - request_headers, - ): + def test_http_bits(self, request_path, request_params, request_headers): """ Tests that various HTTP-level bits (query string params, path, headers) carry over into the scope. """ with DaphneTestingInstance() as test_app: - test_app.add_send_messages([ - { - "type": "websocket.accept", - } - ]) + test_app.add_send_messages([{"type": "websocket.accept"}]) self.websocket_handshake( test_app, path=request_path, @@ -199,10 +168,7 @@ class TestWebsocket(DaphneTestCase): # Validate the scope and messages we got scope, messages = test_app.get_received() self.assert_valid_websocket_scope( - scope, - path=request_path, - params=request_params, - headers=request_headers, + scope, path=request_path, params=request_params, headers=request_headers ) self.assert_valid_websocket_connect_message(messages[0]) @@ -212,28 +178,24 @@ class TestWebsocket(DaphneTestCase): """ with DaphneTestingInstance() as test_app: # Connect - test_app.add_send_messages([ - { - "type": "websocket.accept", - } - ]) + test_app.add_send_messages([{"type": "websocket.accept"}]) sock, _ = self.websocket_handshake(test_app) _, messages = test_app.get_received() self.assert_valid_websocket_connect_message(messages[0]) # Prep frame for it to send - test_app.add_send_messages([ - { - "type": "websocket.send", - "text": "here be dragons 🐉", - } - ]) + test_app.add_send_messages( + [{"type": "websocket.send", "text": "here be dragons 🐉"}] + ) # Send it a frame self.websocket_send_frame(sock, "what is here? 🌍") # Receive a frame and make sure it's correct assert self.websocket_receive_frame(sock) == "here be dragons 🐉" # Make sure it got our frame _, messages = test_app.get_received() - assert messages[1] == {"type": "websocket.receive", "text": "what is here? 🌍"} + assert messages[1] == { + "type": "websocket.receive", + "text": "what is here? 🌍", + } def test_binary_frames(self): """ @@ -242,28 +204,24 @@ class TestWebsocket(DaphneTestCase): """ with DaphneTestingInstance() as test_app: # Connect - test_app.add_send_messages([ - { - "type": "websocket.accept", - } - ]) + test_app.add_send_messages([{"type": "websocket.accept"}]) sock, _ = self.websocket_handshake(test_app) _, messages = test_app.get_received() self.assert_valid_websocket_connect_message(messages[0]) # Prep frame for it to send - test_app.add_send_messages([ - { - "type": "websocket.send", - "bytes": b"here be \xe2 bytes", - } - ]) + test_app.add_send_messages( + [{"type": "websocket.send", "bytes": b"here be \xe2 bytes"}] + ) # Send it a frame self.websocket_send_frame(sock, b"what is here? \xe2") # Receive a frame and make sure it's correct assert self.websocket_receive_frame(sock) == b"here be \xe2 bytes" # Make sure it got our frame _, messages = test_app.get_received() - assert messages[1] == {"type": "websocket.receive", "bytes": b"what is here? \xe2"} + assert messages[1] == { + "type": "websocket.receive", + "bytes": b"what is here? \xe2", + } def test_http_timeout(self): """ @@ -271,23 +229,14 @@ class TestWebsocket(DaphneTestCase): """ with DaphneTestingInstance(http_timeout=1) as test_app: # Connect - test_app.add_send_messages([ - { - "type": "websocket.accept", - } - ]) + test_app.add_send_messages([{"type": "websocket.accept"}]) sock, _ = self.websocket_handshake(test_app) _, messages = test_app.get_received() self.assert_valid_websocket_connect_message(messages[0]) # Wait 2 seconds time.sleep(2) # Prep frame for it to send - test_app.add_send_messages([ - { - "type": "websocket.send", - "text": "cake", - } - ]) + test_app.add_send_messages([{"type": "websocket.send", "text": "cake"}]) # Send it a frame self.websocket_send_frame(sock, "still alive?") # Receive a frame and make sure it's correct From 460bdf64dbef1f6b0c7d1555b44cb76c3cd2834a Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Mon, 27 Aug 2018 12:31:54 +1000 Subject: [PATCH 13/17] Only lint the daphne and tests directories --- .travis.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 6a617c4..b91d2c0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -36,8 +36,8 @@ jobs: - stage: lint install: pip install -U -e .[tests] black pyflakes isort script: - - pyflakes . - - black --check . + - pyflakes daphne tests + - black --check daphne tests - isort --check-only --diff --recursive channels_redis tests - stage: release From 02a299e5a79b8d1c342251e6f38e46f7954f35a2 Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Mon, 27 Aug 2018 12:40:51 +1000 Subject: [PATCH 14/17] Fix isort in travis --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index b91d2c0..1dfeab3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -38,7 +38,7 @@ jobs: script: - pyflakes daphne tests - black --check daphne tests - - isort --check-only --diff --recursive channels_redis tests + - isort --check-only --diff --recursive daphne tests - stage: release script: skip From c5554cb817aff5219de6c3236a38ce147ad2ac4b Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Mon, 27 Aug 2018 14:21:40 +1000 Subject: [PATCH 15/17] Tidying up --- .travis.yml | 1 + setup.cfg | 1 + 2 files changed, 2 insertions(+) diff --git a/.travis.yml b/.travis.yml index 1dfeab3..6e31929 100644 --- a/.travis.yml +++ b/.travis.yml @@ -33,6 +33,7 @@ jobs: env: TWISTED="twisted" dist: xenial sudo: required + - stage: lint install: pip install -U -e .[tests] black pyflakes isort script: diff --git a/setup.cfg b/setup.cfg index 46dd04f..e50af68 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,6 +8,7 @@ addopts = tests/ include_trailing_comma = True multi_line_output = 3 known_first_party = channels,daphne,asgiref,channels_redis +line_length = 88 [flake8] exclude = venv/*,tox/*,docs/*,testproject/*,js_client/*,.eggs/* From 3e4aab95e209cc15477e6dcdf2b8fb71a367da06 Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Wed, 29 Aug 2018 17:57:06 -0700 Subject: [PATCH 16/17] Fix Travis release stage --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 6e31929..d87df46 100644 --- a/.travis.yml +++ b/.travis.yml @@ -21,7 +21,7 @@ stages: - lint - test - name: release - if: branch = master + if: tag IS present jobs: include: From e93643ff5a2797f05e88bc59800d7b4dbf41765d Mon Sep 17 00:00:00 2001 From: Imblc Date: Fri, 28 Sep 2018 23:45:03 +0700 Subject: [PATCH 17/17] Fixed #229: Allow `bytes` headers only Previously Daphne was too lax and would happily accept strings too. --- daphne/server.py | 19 +++++++++++++++ tests/http_base.py | 22 ++++++++---------- tests/http_strategies.py | 24 +++++++++++-------- tests/test_http_request.py | 12 +++++----- tests/test_http_response.py | 46 +++++++++++++++++++++++++++++++++++-- tests/test_websocket.py | 4 ++-- 6 files changed, 95 insertions(+), 32 deletions(-) diff --git a/daphne/server.py b/daphne/server.py index 133762d..3f27bf5 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -211,9 +211,28 @@ class Server(object): # Don't do anything if the connection is closed if self.connections[protocol].get("disconnected", None): return + self.check_headers_type(message) # Let the protocol handle it protocol.handle_reply(message) + @staticmethod + def check_headers_type(message): + if not message["type"] == "http.response.start": + return + for k, v in message.get("headers", []): + if not isinstance(k, bytes): + raise ValueError( + "Header name '{}' expected to be `bytes`, but got `{}`".format( + k, type(k) + ) + ) + if not isinstance(v, bytes): + raise ValueError( + "Header value '{}' expected to be `bytes`, but got `{}`".format( + v, type(v) + ) + ) + ### Utility def application_checker(self): diff --git a/tests/http_base.py b/tests/http_base.py index 866a066..e6fc92b 100644 --- a/tests/http_base.py +++ b/tests/http_base.py @@ -37,12 +37,10 @@ class DaphneTestCase(unittest.TestCase): if params: path += "?" + parse.urlencode(params, doseq=True) conn.putrequest(method, path, skip_accept_encoding=True, skip_host=True) - # Manually send over headers (encoding any non-safe values as best we can) + # Manually send over headers if headers: for header_name, header_value in headers: - conn.putheader( - header_name.encode("utf8"), header_value.encode("utf8") - ) + conn.putheader(header_name, header_value) # Send body if provided. if body: conn.putheader("Content-Length", str(len(body))) @@ -140,19 +138,19 @@ class DaphneTestCase(unittest.TestCase): headers = [] headers.extend( [ - ("Host", "example.com"), - ("Upgrade", "websocket"), - ("Connection", "Upgrade"), - ("Sec-WebSocket-Key", "x3JJHMbDL1EzLkh9GBhXDw=="), - ("Sec-WebSocket-Version", "13"), - ("Origin", "http://example.com"), + (b"Host", b"example.com"), + (b"Upgrade", b"websocket"), + (b"Connection", b"Upgrade"), + (b"Sec-WebSocket-Key", b"x3JJHMbDL1EzLkh9GBhXDw=="), + (b"Sec-WebSocket-Version", b"13"), + (b"Origin", b"http://example.com"), ] ) if subprotocols: - headers.append(("Sec-WebSocket-Protocol", ", ".join(subprotocols))) + headers.append((b"Sec-WebSocket-Protocol", ", ".join(subprotocols))) if headers: for header_name, header_value in headers: - conn.putheader(header_name.encode("utf8"), header_value.encode("utf8")) + conn.putheader(header_name, header_value) conn.endheaders() # Read out the response try: diff --git a/tests/http_strategies.py b/tests/http_strategies.py index d78ac10..e9d8736 100644 --- a/tests/http_strategies.py +++ b/tests/http_strategies.py @@ -92,7 +92,7 @@ def header_name(): """ return strategies.text( alphabet=string.ascii_letters + string.digits + "-", min_size=1, max_size=30 - ) + ).map(lambda s: s.encode("utf-8")) def header_value(): @@ -102,15 +102,19 @@ def header_value(): "For example, the Apache 2.3 server by default limits the size of each field to 8190 bytes" https://en.wikipedia.org/wiki/List_of_HTTP_header_fields """ - return strategies.text( - alphabet=string.ascii_letters - + string.digits - + string.punctuation.replace(",", "") - + " /t", - min_size=1, - average_size=40, - max_size=8190, - ).filter(lambda s: len(s.encode("utf8")) < 8190) + return ( + strategies.text( + alphabet=string.ascii_letters + + string.digits + + string.punctuation.replace(",", "") + + " /t", + min_size=1, + average_size=40, + max_size=8190, + ) + .map(lambda s: s.encode("utf-8")) + .filter(lambda s: len(s) < 8190) + ) def headers(): diff --git a/tests/test_http_request.py b/tests/test_http_request.py index e02b8b6..c1efd64 100644 --- a/tests/test_http_request.py +++ b/tests/test_http_request.py @@ -63,8 +63,8 @@ class TestHTTPRequest(DaphneTestCase): transformed_scope_headers[name].append(value) transformed_request_headers = collections.defaultdict(list) for name, value in headers or []: - expected_name = name.lower().strip().encode("ascii") - expected_value = value.strip().encode("ascii") + expected_name = name.lower().strip() + expected_value = value.strip() transformed_request_headers[expected_name].append(expected_value) for name, value in transformed_request_headers.items(): self.assertIn(name, transformed_scope_headers) @@ -209,7 +209,7 @@ class TestHTTPRequest(DaphneTestCase): """ Make sure headers are normalized as the spec says they are. """ - headers = [("MYCUSTOMHEADER", " foobar ")] + headers = [(b"MYCUSTOMHEADER", b" foobar ")] scope, messages = self.run_daphne_request("GET", "/", headers=headers) self.assert_valid_http_scope(scope, "GET", "/", headers=headers) self.assert_valid_http_request_message(messages[0], body=b"") @@ -237,7 +237,7 @@ class TestHTTPRequest(DaphneTestCase): """ Make sure that, by default, X-Forwarded-For is ignored. """ - headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]] + headers = [[b"X-Forwarded-For", b"10.1.2.3"], [b"X-Forwarded-Port", b"80"]] scope, messages = self.run_daphne_request("GET", "/", headers=headers) self.assert_valid_http_scope(scope, "GET", "/", headers=headers) self.assert_valid_http_request_message(messages[0], body=b"") @@ -248,7 +248,7 @@ class TestHTTPRequest(DaphneTestCase): """ When X-Forwarded-For is enabled, make sure it is respected. """ - headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]] + headers = [[b"X-Forwarded-For", b"10.1.2.3"], [b"X-Forwarded-Port", b"80"]] scope, messages = self.run_daphne_request("GET", "/", headers=headers, xff=True) self.assert_valid_http_scope(scope, "GET", "/", headers=headers) self.assert_valid_http_request_message(messages[0], body=b"") @@ -260,7 +260,7 @@ class TestHTTPRequest(DaphneTestCase): When X-Forwarded-For is enabled but only the host is passed, make sure that at least makes it through. """ - headers = [["X-Forwarded-For", "10.1.2.3"]] + headers = [[b"X-Forwarded-For", b"10.1.2.3"]] scope, messages = self.run_daphne_request("GET", "/", headers=headers, xff=True) self.assert_valid_http_scope(scope, "GET", "/", headers=headers) self.assert_valid_http_request_message(messages[0], body=b"") diff --git a/tests/test_http_response.py b/tests/test_http_response.py index 3576697..afb8e39 100644 --- a/tests/test_http_response.py +++ b/tests/test_http_response.py @@ -19,10 +19,16 @@ class TestHTTPResponse(DaphneTestCase): [ (name.lower(), value.strip()) for name, value in headers - if name.lower() != "transfer-encoding" + if name.lower() != b"transfer-encoding" ] ) + def encode_headers(self, headers): + def encode(s): + return s if isinstance(s, bytes) else s.encode("utf-8") + + return [[encode(k), encode(v)] for k, v in headers] + def test_minimal_response(self): """ Smallest viable example. Mostly verifies that our response building works. @@ -124,6 +130,42 @@ class TestHTTPResponse(DaphneTestCase): ) # Check headers in a sensible way. Ignore transfer-encoding. self.assertEqual( - self.normalize_headers(response.getheaders()), + self.normalize_headers(self.encode_headers(response.getheaders())), self.normalize_headers(headers), ) + + def test_headers_type(self): + """ + Headers should be `bytes` + """ + with self.assertRaises(ValueError) as context: + self.run_daphne_response( + [ + { + "type": "http.response.start", + "status": 200, + "headers": [["foo", b"bar"]], + }, + {"type": "http.response.body", "body": b""}, + ] + ) + self.assertEqual( + str(context.exception), + "Header name 'foo' expected to be `bytes`, but got ``", + ) + + with self.assertRaises(ValueError) as context: + self.run_daphne_response( + [ + { + "type": "http.response.start", + "status": 200, + "headers": [[b"foo", True]], + }, + {"type": "http.response.body", "body": b""}, + ] + ) + self.assertEqual( + str(context.exception), + "Header value 'True' expected to be `bytes`, but got ``", + ) diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 80ec21d..69a54f8 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -56,8 +56,8 @@ class TestWebsocket(DaphneTestCase): transformed_scope_headers[name].append(bit.strip()) transformed_request_headers = collections.defaultdict(list) for name, value in headers or []: - expected_name = name.lower().strip().encode("ascii") - expected_value = value.strip().encode("ascii") + expected_name = name.lower().strip() + expected_value = value.strip() # Make sure to split out any headers collapsed with commas transformed_request_headers.setdefault(expected_name, []) for bit in expected_value.split(b","):