From 0ed6294406781f3895e1d11b4c6256f9cfccffca Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Mon, 27 Aug 2018 12:27:32 +1000 Subject: [PATCH] Implement Black code formatting --- .travis.yml | 22 ++++-- daphne/access.py | 7 +- daphne/cli.py | 49 ++++++------ daphne/endpoints.py | 7 +- daphne/http_protocol.py | 142 +++++++++++++++++++++-------------- daphne/server.py | 47 +++++++----- daphne/testing.py | 26 +------ daphne/utils.py | 14 ++-- daphne/ws_protocol.py | 133 ++++++++++++++++++++------------- setup.cfg | 4 +- setup.py | 21 ++---- tests/http_base.py | 66 ++++++++++------- tests/http_strategies.py | 43 +++++------ tests/test_cli.py | 111 +++++++++------------------ tests/test_http_request.py | 73 ++++++++++-------- tests/test_http_response.py | 144 ++++++++++++++---------------------- tests/test_utils.py | 75 ++++++------------- tests/test_websocket.py | 125 +++++++++---------------------- 18 files changed, 513 insertions(+), 596 deletions(-) diff --git a/.travis.yml b/.travis.yml index aa69b0b..6a617c4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,22 +3,25 @@ sudo: false language: python python: -- '3.5' - '3.6' +- '3.5' env: -- TWISTED="twisted==18.7.0" - TWISTED="twisted" +- TWISTED="twisted==18.7.0" install: -- pip install $TWISTED isort unify flake8 -e .[tests] +- pip install $TWISTED -e .[tests] - pip freeze script: - pytest -- flake8 -- isort --check-only --diff --recursive daphne tests -- unify --check-only --recursive --quote \" daphne tests + +stages: + - lint + - test + - name: release + if: branch = master jobs: include: @@ -30,6 +33,13 @@ jobs: env: TWISTED="twisted" dist: xenial sudo: required + - stage: lint + install: pip install -U -e .[tests] black pyflakes isort + script: + - pyflakes . + - black --check . + - isort --check-only --diff --recursive channels_redis tests + - stage: release script: skip deploy: diff --git a/daphne/access.py b/daphne/access.py index ce80f49..2b3b1cd 100644 --- a/daphne/access.py +++ b/daphne/access.py @@ -49,13 +49,16 @@ class AccessLogGenerator(object): request="WSDISCONNECT %(path)s" % details, ) - def write_entry(self, host, date, request, status=None, length=None, ident=None, user=None): + def write_entry( + self, host, date, request, status=None, length=None, ident=None, user=None + ): """ Writes an NCSA-style entry to the log file (some liberty is taken with what the entries are for non-HTTP) """ self.stream.write( - "%s %s %s [%s] \"%s\" %s %s\n" % ( + '%s %s %s [%s] "%s" %s %s\n' + % ( host, ident or "-", user or "-", diff --git a/daphne/cli.py b/daphne/cli.py index 7f42084..28cf1b9 100755 --- a/daphne/cli.py +++ b/daphne/cli.py @@ -23,15 +23,9 @@ class CommandLineInterface(object): server_class = Server def __init__(self): - self.parser = argparse.ArgumentParser( - description=self.description, - ) + self.parser = argparse.ArgumentParser(description=self.description) self.parser.add_argument( - "-p", - "--port", - type=int, - help="Port number to listen on", - default=None, + "-p", "--port", type=int, help="Port number to listen on", default=None ) self.parser.add_argument( "-b", @@ -128,7 +122,7 @@ class CommandLineInterface(object): "--proxy-headers", dest="proxy_headers", help="Enable parsing and using of X-Forwarded-For and X-Forwarded-Port headers and using that as the " - "client address", + "client address", default=False, action="store_true", ) @@ -176,7 +170,15 @@ class CommandLineInterface(object): sys.path.insert(0, ".") application = import_by_path(args.application) # Set up port/host bindings - if not any([args.host, args.port is not None, args.unix_socket, args.file_descriptor, args.socket_strings]): + if not any( + [ + args.host, + args.port is not None, + args.unix_socket, + args.file_descriptor, + args.socket_strings, + ] + ): # no advanced binding options passed, patch in defaults args.host = DEFAULT_HOST args.port = DEFAULT_PORT @@ -189,16 +191,11 @@ class CommandLineInterface(object): host=args.host, port=args.port, unix_socket=args.unix_socket, - file_descriptor=args.file_descriptor - ) - endpoints = sorted( - args.socket_strings + endpoints + file_descriptor=args.file_descriptor, ) + endpoints = sorted(args.socket_strings + endpoints) # Start the server - logger.info( - "Starting server at %s" % - (", ".join(endpoints), ) - ) + logger.info("Starting server at %s" % (", ".join(endpoints),)) self.server = self.server_class( application=application, endpoints=endpoints, @@ -208,12 +205,20 @@ class CommandLineInterface(object): websocket_timeout=args.websocket_timeout, websocket_connect_timeout=args.websocket_connect_timeout, application_close_timeout=args.application_close_timeout, - action_logger=AccessLogGenerator(access_log_stream) if access_log_stream else None, + action_logger=AccessLogGenerator(access_log_stream) + if access_log_stream + else None, ws_protocols=args.ws_protocols, root_path=args.root_path, verbosity=args.verbosity, - proxy_forwarded_address_header="X-Forwarded-For" if args.proxy_headers else None, - proxy_forwarded_port_header="X-Forwarded-Port" if args.proxy_headers else None, - proxy_forwarded_proto_header="X-Forwarded-Proto" if args.proxy_headers else None, + proxy_forwarded_address_header="X-Forwarded-For" + if args.proxy_headers + else None, + proxy_forwarded_port_header="X-Forwarded-Port" + if args.proxy_headers + else None, + proxy_forwarded_proto_header="X-Forwarded-Proto" + if args.proxy_headers + else None, ) self.server.run() diff --git a/daphne/endpoints.py b/daphne/endpoints.py index 6188357..83e472a 100644 --- a/daphne/endpoints.py +++ b/daphne/endpoints.py @@ -1,10 +1,5 @@ - - def build_endpoint_description_strings( - host=None, - port=None, - unix_socket=None, - file_descriptor=None + host=None, port=None, unix_socket=None, file_descriptor=None ): """ Build a list of twisted endpoint description strings that the server will listen on. diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py index 915e475..2c8d840 100755 --- a/daphne/http_protocol.py +++ b/daphne/http_protocol.py @@ -23,7 +23,8 @@ class WebRequest(http.Request): GET and POST out. """ - error_template = """ + error_template = ( + """ %(title)s @@ -40,7 +41,13 @@ class WebRequest(http.Request): - """.replace("\n", "").replace(" ", " ").replace(" ", " ").replace(" ", " ") # Shorten it a bit, bytes wise + """.replace( + "\n", "" + ) + .replace(" ", " ") + .replace(" ", " ") + .replace(" ", " ") + ) # Shorten it a bit, bytes wise def __init__(self, *args, **kwargs): try: @@ -84,7 +91,7 @@ class WebRequest(http.Request): self.server.proxy_forwarded_port_header, self.server.proxy_forwarded_proto_header, self.client_addr, - self.client_scheme + self.client_scheme, ) # Check for unicodeish path (or it'll crash when trying to parse) try: @@ -105,7 +112,9 @@ class WebRequest(http.Request): # Is it WebSocket? IS IT?! if upgrade_header and upgrade_header.lower() == b"websocket": # Make WebSocket protocol to hand off to - protocol = self.server.ws_factory.buildProtocol(self.transport.getPeer()) + protocol = self.server.ws_factory.buildProtocol( + self.transport.getPeer() + ) if not protocol: # If protocol creation fails, we signal "internal server error" self.setResponseCode(500) @@ -151,33 +160,38 @@ class WebRequest(http.Request): logger.debug("HTTP %s request for %s", self.method, self.client_addr) self.content.seek(0, 0) # Work out the application scope and create application - self.application_queue = yield maybeDeferred(self.server.create_application, self, { - "type": "http", - # TODO: Correctly say if it's 1.1 or 1.0 - "http_version": self.clientproto.split(b"/")[-1].decode("ascii"), - "method": self.method.decode("ascii"), - "path": unquote(self.path.decode("ascii")), - "root_path": self.root_path, - "scheme": self.client_scheme, - "query_string": self.query_string, - "headers": self.clean_headers, - "client": self.client_addr, - "server": self.server_addr, - }) + self.application_queue = yield maybeDeferred( + self.server.create_application, + self, + { + "type": "http", + # TODO: Correctly say if it's 1.1 or 1.0 + "http_version": self.clientproto.split(b"/")[-1].decode( + "ascii" + ), + "method": self.method.decode("ascii"), + "path": unquote(self.path.decode("ascii")), + "root_path": self.root_path, + "scheme": self.client_scheme, + "query_string": self.query_string, + "headers": self.clean_headers, + "client": self.client_addr, + "server": self.server_addr, + }, + ) # Check they didn't close an unfinished request if self.application_queue is None or self.content.closed: # Not much we can do, the request is prematurely abandoned. return # Run application against request self.application_queue.put_nowait( - { - "type": "http.request", - "body": self.content.read(), - }, + {"type": "http.request", "body": self.content.read()} ) except Exception: logger.error(traceback.format_exc()) - self.basic_error(500, b"Internal Server Error", "Daphne HTTP processing error") + self.basic_error( + 500, b"Internal Server Error", "Daphne HTTP processing error" + ) def connectionLost(self, reason): """ @@ -217,16 +231,23 @@ class WebRequest(http.Request): raise ValueError("HTTP response has already been started") self._response_started = True if "status" not in message: - raise ValueError("Specifying a status code is required for a Response message.") + raise ValueError( + "Specifying a status code is required for a Response message." + ) # Set HTTP status code self.setResponseCode(message["status"]) # Write headers for header, value in message.get("headers", {}): self.responseHeaders.addRawHeader(header, value) - logger.debug("HTTP %s response started for %s", message["status"], self.client_addr) + logger.debug( + "HTTP %s response started for %s", message["status"], self.client_addr + ) elif message["type"] == "http.response.body": if not self._response_started: - raise ValueError("HTTP response has not yet been started but got %s" % message["type"]) + raise ValueError( + "HTTP response has not yet been started but got %s" + % message["type"] + ) # Write out body http.Request.write(self, message.get("body", b"")) # End if there's no more content @@ -239,15 +260,21 @@ class WebRequest(http.Request): # The path is malformed somehow - do our best to log something uri = repr(self.uri) try: - self.server.log_action("http", "complete", { - "path": uri, - "status": self.code, - "method": self.method.decode("ascii", "replace"), - "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, - "time_taken": self.duration(), - "size": self.sentLength, - }) - except Exception as e: + self.server.log_action( + "http", + "complete", + { + "path": uri, + "status": self.code, + "method": self.method.decode("ascii", "replace"), + "client": "%s:%s" % tuple(self.client_addr) + if self.client_addr + else None, + "time_taken": self.duration(), + "size": self.sentLength, + }, + ) + except Exception: logger.error(traceback.format_exc()) else: logger.debug("HTTP response chunk for %s", self.client_addr) @@ -270,7 +297,11 @@ class WebRequest(http.Request): logger.warning("Application timed out while sending response") self.finish() else: - self.basic_error(503, b"Service Unavailable", "Application failed to respond within time limit.") + self.basic_error( + 503, + b"Service Unavailable", + "Application failed to respond within time limit.", + ) ### Utility functions @@ -281,11 +312,7 @@ class WebRequest(http.Request): """ # If we don't yet have a path, then don't send as we never opened. if self.path: - self.application_queue.put_nowait( - { - "type": "http.disconnect", - }, - ) + self.application_queue.put_nowait({"type": "http.disconnect"}) def duration(self): """ @@ -299,20 +326,25 @@ class WebRequest(http.Request): """ Responds with a server-level error page (very basic) """ - self.handle_reply({ - "type": "http.response.start", - "status": status, - "headers": [ - (b"Content-Type", b"text/html; charset=utf-8"), - ], - }) - self.handle_reply({ - "type": "http.response.body", - "body": (self.error_template % { - "title": str(status) + " " + status_text.decode("ascii"), - "body": body, - }).encode("utf8"), - }) + self.handle_reply( + { + "type": "http.response.start", + "status": status, + "headers": [(b"Content-Type", b"text/html; charset=utf-8")], + } + ) + self.handle_reply( + { + "type": "http.response.body", + "body": ( + self.error_template + % { + "title": str(status) + " " + status_text.decode("ascii"), + "body": body, + } + ).encode("utf8"), + } + ) def __hash__(self): return hash(id(self)) @@ -343,7 +375,7 @@ class HTTPFactory(http.HTTPFactory): protocol = http.HTTPFactory.buildProtocol(self, addr) protocol.requestFactory = WebRequest return protocol - except Exception as e: + except Exception: logger.error("Cannot build protocol: %s" % traceback.format_exc()) raise diff --git a/daphne/server.py b/daphne/server.py index dfae544..133762d 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -2,13 +2,14 @@ import sys # isort:skip import warnings # isort:skip from twisted.internet import asyncioreactor # isort:skip + current_reactor = sys.modules.get("twisted.internet.reactor", None) if current_reactor is not None: if not isinstance(current_reactor, asyncioreactor.AsyncioSelectorReactor): warnings.warn( - "Something has already installed a non-asyncio Twisted reactor. Attempting to uninstall it; " + - "you can fix this warning by importing daphne.server early in your codebase or " + - "finding the package that imports Twisted and importing it later on.", + "Something has already installed a non-asyncio Twisted reactor. Attempting to uninstall it; " + + "you can fix this warning by importing daphne.server early in your codebase or " + + "finding the package that imports Twisted and importing it later on.", UserWarning, ) del sys.modules["twisted.internet.reactor"] @@ -34,7 +35,6 @@ logger = logging.getLogger(__name__) class Server(object): - def __init__( self, application, @@ -91,11 +91,13 @@ class Server(object): self.ws_factory.setProtocolOptions( autoPingTimeout=self.ping_timeout, allowNullOrigin=True, - openHandshakeTimeout=self.websocket_handshake_timeout + openHandshakeTimeout=self.websocket_handshake_timeout, ) if self.verbosity <= 1: # Redirect the Twisted log to nowhere - globalLogBeginner.beginLoggingTo([lambda _: None], redirectStandardIO=False, discardBuffer=True) + globalLogBeginner.beginLoggingTo( + [lambda _: None], redirectStandardIO=False, discardBuffer=True + ) else: globalLogBeginner.beginLoggingTo([STDLibLogObserver(__name__)]) @@ -103,7 +105,9 @@ class Server(object): if http.H2_ENABLED: logger.info("HTTP/2 support enabled") else: - logger.info("HTTP/2 support not enabled (install the http2 and tls Twisted extras)") + logger.info( + "HTTP/2 support not enabled (install the http2 and tls Twisted extras)" + ) # Kick off the timeout loop reactor.callLater(1, self.application_checker) @@ -141,7 +145,11 @@ class Server(object): host = port.getHost() if hasattr(host, "host") and hasattr(host, "port"): self.listening_addresses.append((host.host, host.port)) - logger.info("Listening on TCP address %s:%s", port.getHost().host, port.getHost().port) + logger.info( + "Listening on TCP address %s:%s", + port.getHost().host, + port.getHost().port, + ) def listen_error(self, failure): logger.critical("Listen failure: %s", failure.getErrorMessage()) @@ -187,10 +195,13 @@ class Server(object): # Run it, and stash the future for later checking if protocol not in self.connections: return None - self.connections[protocol]["application_instance"] = asyncio.ensure_future(application_instance( - receive=input_queue.get, - send=lambda message: self.handle_reply(protocol, message), - ), loop=asyncio.get_event_loop()) + self.connections[protocol]["application_instance"] = asyncio.ensure_future( + application_instance( + receive=input_queue.get, + send=lambda message: self.handle_reply(protocol, message), + ), + loop=asyncio.get_event_loop(), + ) return input_queue async def handle_reply(self, protocol, message): @@ -215,7 +226,10 @@ class Server(object): application_instance = details.get("application_instance", None) # First, see if the protocol disconnected and the app has taken # too long to close up - if disconnected and time.time() - disconnected > self.application_close_timeout: + if ( + disconnected + and time.time() - disconnected > self.application_close_timeout + ): if application_instance and not application_instance.done(): logger.warning( "Application instance %r for connection %s took too long to shut down and was killed.", @@ -238,14 +252,11 @@ class Server(object): else: exception_output = "{}\n{}{}".format( exception, - "".join(traceback.format_tb( - exception.__traceback__, - )), + "".join(traceback.format_tb(exception.__traceback__)), " {}".format(exception), ) logger.error( - "Exception inside application: %s", - exception_output, + "Exception inside application: %s", exception_output ) if not disconnected: protocol.handle_exception(exception) diff --git a/daphne/testing.py b/daphne/testing.py index e606952..f5f3724 100644 --- a/daphne/testing.py +++ b/daphne/testing.py @@ -100,9 +100,7 @@ class DaphneTestingInstance: Adds messages for the application to send back. The next time it receives an incoming message, it will reply with these. """ - TestApplication.save_setup( - response_messages=messages, - ) + TestApplication.save_setup(response_messages=messages) class DaphneProcess(multiprocessing.Process): @@ -193,12 +191,7 @@ class TestApplication: Stores setup information. """ with open(cls.setup_storage, "wb") as fh: - pickle.dump( - { - "response_messages": response_messages, - }, - fh, - ) + pickle.dump({"response_messages": response_messages}, fh) @classmethod def load_setup(cls): @@ -218,13 +211,7 @@ class TestApplication: We could use pickle here, but that seems wrong, still, somehow. """ with open(cls.result_storage, "wb") as fh: - pickle.dump( - { - "scope": scope, - "messages": messages, - }, - fh, - ) + pickle.dump({"scope": scope, "messages": messages}, fh) @classmethod def save_exception(cls, exception): @@ -233,12 +220,7 @@ class TestApplication: We could use pickle here, but that seems wrong, still, somehow. """ with open(cls.result_storage, "wb") as fh: - pickle.dump( - { - "exception": exception, - }, - fh, - ) + pickle.dump({"exception": exception}, fh) @classmethod def load_result(cls): diff --git a/daphne/utils.py b/daphne/utils.py index ad64439..81f1f9d 100644 --- a/daphne/utils.py +++ b/daphne/utils.py @@ -22,12 +22,14 @@ def header_value(headers, header_name): return value.decode("utf-8") -def parse_x_forwarded_for(headers, - address_header_name="X-Forwarded-For", - port_header_name="X-Forwarded-Port", - proto_header_name="X-Forwarded-Proto", - original_addr=None, - original_scheme=None): +def parse_x_forwarded_for( + headers, + address_header_name="X-Forwarded-For", + port_header_name="X-Forwarded-Port", + proto_header_name="X-Forwarded-Proto", + original_addr=None, + original_scheme=None, +): """ Parses an X-Forwarded-For header and returns a host/port pair as a list. diff --git a/daphne/ws_protocol.py b/daphne/ws_protocol.py index f0b7bda..edf1254 100755 --- a/daphne/ws_protocol.py +++ b/daphne/ws_protocol.py @@ -3,7 +3,11 @@ import time import traceback from urllib.parse import unquote -from autobahn.twisted.websocket import ConnectionDeny, WebSocketServerFactory, WebSocketServerProtocol +from autobahn.twisted.websocket import ( + ConnectionDeny, + WebSocketServerFactory, + WebSocketServerProtocol, +) from twisted.internet import defer from .utils import parse_x_forwarded_for @@ -54,32 +58,34 @@ class WebSocketProtocol(WebSocketServerProtocol): self.server.proxy_forwarded_address_header, self.server.proxy_forwarded_port_header, self.server.proxy_forwarded_proto_header, - self.client_addr + self.client_addr, ) # Decode websocket subprotocol options subprotocols = [] for header, value in self.clean_headers: if header == b"sec-websocket-protocol": subprotocols = [ - x.strip() - for x in - unquote(value.decode("ascii")).split(",") + x.strip() for x in unquote(value.decode("ascii")).split(",") ] # Make new application instance with scope self.path = request.path.encode("ascii") - self.application_deferred = defer.maybeDeferred(self.server.create_application, self, { - "type": "websocket", - "path": unquote(self.path.decode("ascii")), - "headers": self.clean_headers, - "query_string": self._raw_query_string, # Passed by HTTP protocol - "client": self.client_addr, - "server": self.server_addr, - "subprotocols": subprotocols, - }) + self.application_deferred = defer.maybeDeferred( + self.server.create_application, + self, + { + "type": "websocket", + "path": unquote(self.path.decode("ascii")), + "headers": self.clean_headers, + "query_string": self._raw_query_string, # Passed by HTTP protocol + "client": self.client_addr, + "server": self.server_addr, + "subprotocols": subprotocols, + }, + ) if self.application_deferred is not None: self.application_deferred.addCallback(self.applicationCreateWorked) self.application_deferred.addErrback(self.applicationCreateFailed) - except Exception as e: + except Exception: # Exceptions here are not displayed right, just 500. # Turn them into an ERROR log. logger.error(traceback.format_exc()) @@ -98,10 +104,16 @@ class WebSocketProtocol(WebSocketServerProtocol): self.application_queue = application_queue # Send over the connect message self.application_queue.put_nowait({"type": "websocket.connect"}) - self.server.log_action("websocket", "connecting", { - "path": self.request.path, - "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, - }) + self.server.log_action( + "websocket", + "connecting", + { + "path": self.request.path, + "client": "%s:%s" % tuple(self.client_addr) + if self.client_addr + else None, + }, + ) def applicationCreateFailed(self, failure): """ @@ -115,10 +127,16 @@ class WebSocketProtocol(WebSocketServerProtocol): def onOpen(self): # Send news that this channel is open logger.debug("WebSocket %s open and established", self.client_addr) - self.server.log_action("websocket", "connected", { - "path": self.request.path, - "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, - }) + self.server.log_action( + "websocket", + "connected", + { + "path": self.request.path, + "client": "%s:%s" % tuple(self.client_addr) + if self.client_addr + else None, + }, + ) def onMessage(self, payload, isBinary): # If we're muted, do nothing. @@ -128,15 +146,13 @@ class WebSocketProtocol(WebSocketServerProtocol): logger.debug("WebSocket incoming frame on %s", self.client_addr) self.last_ping = time.time() if isBinary: - self.application_queue.put_nowait({ - "type": "websocket.receive", - "bytes": payload, - }) + self.application_queue.put_nowait( + {"type": "websocket.receive", "bytes": payload} + ) else: - self.application_queue.put_nowait({ - "type": "websocket.receive", - "text": payload.decode("utf8"), - }) + self.application_queue.put_nowait( + {"type": "websocket.receive", "text": payload.decode("utf8")} + ) def onClose(self, wasClean, code, reason): """ @@ -145,14 +161,19 @@ class WebSocketProtocol(WebSocketServerProtocol): self.server.protocol_disconnected(self) logger.debug("WebSocket closed for %s", self.client_addr) if not self.muted and hasattr(self, "application_queue"): - self.application_queue.put_nowait({ - "type": "websocket.disconnect", - "code": code, - }) - self.server.log_action("websocket", "disconnected", { - "path": self.request.path, - "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, - }) + self.application_queue.put_nowait( + {"type": "websocket.disconnect", "code": code} + ) + self.server.log_action( + "websocket", + "disconnected", + { + "path": self.request.path, + "client": "%s:%s" % tuple(self.client_addr) + if self.client_addr + else None, + }, + ) ### Internal event handling @@ -171,9 +192,8 @@ class WebSocketProtocol(WebSocketServerProtocol): raise ValueError("Socket has not been accepted, so cannot send over it") if message.get("bytes", None) and message.get("text", None): raise ValueError( - "Got invalid WebSocket reply message on %s - contains both bytes and text keys" % ( - message, - ) + "Got invalid WebSocket reply message on %s - contains both bytes and text keys" + % (message,) ) if message.get("bytes", None): self.serverSend(message["bytes"], True) @@ -187,7 +207,9 @@ class WebSocketProtocol(WebSocketServerProtocol): if hasattr(self, "handshake_deferred"): # If the handshake is still ongoing, we need to emit a HTTP error # code rather than a WebSocket one. - self.handshake_deferred.errback(ConnectionDeny(code=500, reason="Internal server error")) + self.handshake_deferred.errback( + ConnectionDeny(code=500, reason="Internal server error") + ) else: self.sendCloseFrame(code=1011) @@ -203,14 +225,22 @@ class WebSocketProtocol(WebSocketServerProtocol): """ Called when we get a message saying to reject the connection. """ - self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied")) + self.handshake_deferred.errback( + ConnectionDeny(code=403, reason="Access denied") + ) del self.handshake_deferred self.server.protocol_disconnected(self) logger.debug("WebSocket %s rejected by application", self.client_addr) - self.server.log_action("websocket", "rejected", { - "path": self.request.path, - "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, - }) + self.server.log_action( + "websocket", + "rejected", + { + "path": self.request.path, + "client": "%s:%s" % tuple(self.client_addr) + if self.client_addr + else None, + }, + ) def serverSend(self, content, binary=False): """ @@ -244,7 +274,10 @@ class WebSocketProtocol(WebSocketServerProtocol): Called periodically to see if we should timeout something """ # Web timeout checking - if self.duration() > self.server.websocket_timeout and self.server.websocket_timeout >= 0: + if ( + self.duration() > self.server.websocket_timeout + and self.server.websocket_timeout >= 0 + ): self.serverClose() # Ping check # If we're still connecting, deny the connection @@ -287,6 +320,6 @@ class WebSocketFactory(WebSocketServerFactory): protocol = super(WebSocketFactory, self).buildProtocol(addr) protocol.factory = self return protocol - except Exception as e: + except Exception: logger.error("Cannot build protocol: %s" % traceback.format_exc()) raise diff --git a/setup.cfg b/setup.cfg index 8c9a5f5..46dd04f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,9 +5,9 @@ universal=1 addopts = tests/ [isort] -line_length = 120 +include_trailing_comma = True multi_line_output = 3 -known_first_party = channels,daphne,asgiref +known_first_party = channels,daphne,asgiref,channels_redis [flake8] exclude = venv/*,tox/*,docs/*,testproject/*,js_client/*,.eggs/* diff --git a/setup.py b/setup.py index 692ab6e..ff98e27 100755 --- a/setup.py +++ b/setup.py @@ -22,23 +22,12 @@ setup( package_dir={"twisted": "daphne/twisted"}, packages=find_packages() + ["twisted.plugins"], include_package_data=True, - install_requires=[ - "twisted>=18.7", - "autobahn>=0.18", - ], - setup_requires=[ - "pytest-runner", - ], - extras_require={ - "tests": [ - "hypothesis", - "pytest", - "pytest-asyncio~=0.8", - ], + install_requires=["twisted>=18.7", "autobahn>=0.18"], + setup_requires=["pytest-runner"], + extras_require={"tests": ["hypothesis", "pytest", "pytest-asyncio~=0.8"]}, + entry_points={ + "console_scripts": ["daphne = daphne.cli:CommandLineInterface.entrypoint"] }, - entry_points={"console_scripts": [ - "daphne = daphne.cli:CommandLineInterface.entrypoint", - ]}, classifiers=[ "Development Status :: 4 - Beta", "Environment :: Web Environment", diff --git a/tests/http_base.py b/tests/http_base.py index c7cec48..866a066 100644 --- a/tests/http_base.py +++ b/tests/http_base.py @@ -19,7 +19,9 @@ class DaphneTestCase(unittest.TestCase): ### Plain HTTP helpers - def run_daphne_http(self, method, path, params, body, responses, headers=None, timeout=1, xff=False): + def run_daphne_http( + self, method, path, params, body, responses, headers=None, timeout=1, xff=False + ): """ Runs Daphne with the given request callback (given the base URL) and response messages. @@ -38,7 +40,9 @@ class DaphneTestCase(unittest.TestCase): # Manually send over headers (encoding any non-safe values as best we can) if headers: for header_name, header_value in headers: - conn.putheader(header_name.encode("utf8"), header_value.encode("utf8")) + conn.putheader( + header_name.encode("utf8"), header_value.encode("utf8") + ) # Send body if provided. if body: conn.putheader("Content-Length", str(len(body))) @@ -50,9 +54,11 @@ class DaphneTestCase(unittest.TestCase): except socket.timeout: # See if they left an exception for us to load test_app.get_received() - raise RuntimeError("Daphne timed out handling request, no exception found.") + raise RuntimeError( + "Daphne timed out handling request, no exception found." + ) # Return scope, messages, response - return test_app.get_received() + (response, ) + return test_app.get_received() + (response,) def run_daphne_raw(self, data, timeout=1): """ @@ -68,9 +74,13 @@ class DaphneTestCase(unittest.TestCase): try: return s.recv(1000000) except socket.timeout: - raise RuntimeError("Daphne timed out handling raw request, no exception found.") + raise RuntimeError( + "Daphne timed out handling raw request, no exception found." + ) - def run_daphne_request(self, method, path, params=None, body=None, headers=None, xff=False): + def run_daphne_request( + self, method, path, params=None, body=None, headers=None, xff=False + ): """ Convenience method for just testing request handling. Returns (scope, messages) @@ -95,17 +105,21 @@ class DaphneTestCase(unittest.TestCase): Returns (scope, messages) """ _, _, response = self.run_daphne_http( - method="GET", - path="/", - params={}, - body=b"", - responses=response_messages, + method="GET", path="/", params={}, body=b"", responses=response_messages ) return response ### WebSocket helpers - def websocket_handshake(self, test_app, path="/", params=None, headers=None, subprotocols=None, timeout=1): + def websocket_handshake( + self, + test_app, + path="/", + params=None, + headers=None, + subprotocols=None, + timeout=1, + ): """ Runs a WebSocket handshake negotiation and returns the raw socket object & the selected subprotocol. @@ -124,14 +138,16 @@ class DaphneTestCase(unittest.TestCase): # Do WebSocket handshake headers + any other headers if headers is None: headers = [] - headers.extend([ - ("Host", "example.com"), - ("Upgrade", "websocket"), - ("Connection", "Upgrade"), - ("Sec-WebSocket-Key", "x3JJHMbDL1EzLkh9GBhXDw=="), - ("Sec-WebSocket-Version", "13"), - ("Origin", "http://example.com") - ]) + headers.extend( + [ + ("Host", "example.com"), + ("Upgrade", "websocket"), + ("Connection", "Upgrade"), + ("Sec-WebSocket-Key", "x3JJHMbDL1EzLkh9GBhXDw=="), + ("Sec-WebSocket-Version", "13"), + ("Origin", "http://example.com"), + ] + ) if subprotocols: headers.append(("Sec-WebSocket-Protocol", ", ".join(subprotocols))) if headers: @@ -149,10 +165,7 @@ class DaphneTestCase(unittest.TestCase): if response.status != 101: raise RuntimeError("WebSocket upgrade did not result in status code 101") # Prepare headers for subprotocol searching - response_headers = dict( - (n.lower(), v) - for n, v in response.getheaders() - ) + response_headers = dict((n.lower(), v) for n, v in response.getheaders()) response.read() assert not response.closed # Return the raw socket and any subprotocol @@ -234,10 +247,7 @@ class DaphneTestCase(unittest.TestCase): # Make sure all required keys are present self.assertTrue(required_keys <= present_keys) # Assert that no other keys are present - self.assertEqual( - set(), - present_keys - required_keys - optional_keys, - ) + self.assertEqual(set(), present_keys - required_keys - optional_keys) def assert_valid_path(self, path, request_path): """ diff --git a/tests/http_strategies.py b/tests/http_strategies.py index 4335afd..d78ac10 100644 --- a/tests/http_strategies.py +++ b/tests/http_strategies.py @@ -6,7 +6,9 @@ from hypothesis import strategies HTTP_METHODS = ["OPTIONS", "GET", "HEAD", "POST", "PUT", "DELETE", "TRACE", "CONNECT"] # Unicode characters of the "Letter" category -letters = strategies.characters(whitelist_categories=("Lu", "Ll", "Lt", "Lm", "Lo", "Nl")) +letters = strategies.characters( + whitelist_categories=("Lu", "Ll", "Lt", "Lm", "Lo", "Nl") +) def http_method(): @@ -22,11 +24,9 @@ def http_path(): """ Returns a URL path (not encoded). """ - return strategies.lists( - _http_path_portion(), - min_size=0, - max_size=10, - ).map(lambda s: "/" + "/".join(s)) + return strategies.lists(_http_path_portion(), min_size=0, max_size=10).map( + lambda s: "/" + "/".join(s) + ) def http_body(): @@ -53,10 +53,7 @@ def valid_bidi(value): def _domain_label(): return strategies.text( - alphabet=letters, - min_size=1, - average_size=6, - max_size=63, + alphabet=letters, min_size=1, average_size=6, max_size=63 ).filter(valid_bidi) @@ -64,19 +61,14 @@ def international_domain_name(): """ Returns a byte string of a domain name, IDNA-encoded. """ - return strategies.lists( - _domain_label(), - min_size=2, - average_size=2, - ).map(lambda s: (".".join(s)).encode("idna")) + return strategies.lists(_domain_label(), min_size=2, average_size=2).map( + lambda s: (".".join(s)).encode("idna") + ) def _query_param(): return strategies.text( - alphabet=letters, - min_size=1, - average_size=10, - max_size=255, + alphabet=letters, min_size=1, average_size=10, max_size=255 ).map(lambda s: s.encode("utf8")) @@ -87,9 +79,7 @@ def query_params(): ensures that the total urlencoded query string is not longer than 1500 characters. """ return strategies.lists( - strategies.tuples(_query_param(), _query_param()), - min_size=0, - average_size=5, + strategies.tuples(_query_param(), _query_param()), min_size=0, average_size=5 ).filter(lambda x: len(parse.urlencode(x)) < 1500) @@ -101,9 +91,7 @@ def header_name(): and 20 characters long """ return strategies.text( - alphabet=string.ascii_letters + string.digits + "-", - min_size=1, - max_size=30, + alphabet=string.ascii_letters + string.digits + "-", min_size=1, max_size=30 ) @@ -115,7 +103,10 @@ def header_value(): https://en.wikipedia.org/wiki/List_of_HTTP_header_fields """ return strategies.text( - alphabet=string.ascii_letters + string.digits + string.punctuation.replace(",", "") + " /t", + alphabet=string.ascii_letters + + string.digits + + string.punctuation.replace(",", "") + + " /t", min_size=1, average_size=40, max_size=8190, diff --git a/tests/test_cli.py b/tests/test_cli.py index 9e90ab0..7bb45dc 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -18,45 +18,32 @@ class TestEndpointDescriptions(TestCase): def testTcpPortBindings(self): self.assertEqual( build(port=1234, host="example.com"), - ["tcp:port=1234:interface=example.com"] + ["tcp:port=1234:interface=example.com"], ) self.assertEqual( - build(port=8000, host="127.0.0.1"), - ["tcp:port=8000:interface=127.0.0.1"] + build(port=8000, host="127.0.0.1"), ["tcp:port=8000:interface=127.0.0.1"] ) self.assertEqual( - build(port=8000, host="[200a::1]"), - [r'tcp:port=8000:interface=200a\:\:1'] + build(port=8000, host="[200a::1]"), [r"tcp:port=8000:interface=200a\:\:1"] ) self.assertEqual( - build(port=8000, host="200a::1"), - [r'tcp:port=8000:interface=200a\:\:1'] + build(port=8000, host="200a::1"), [r"tcp:port=8000:interface=200a\:\:1"] ) # incomplete port/host kwargs raise errors - self.assertRaises( - ValueError, - build, port=123 - ) - self.assertRaises( - ValueError, - build, host="example.com" - ) + self.assertRaises(ValueError, build, port=123) + self.assertRaises(ValueError, build, host="example.com") def testUnixSocketBinding(self): self.assertEqual( - build(unix_socket="/tmp/daphne.sock"), - ["unix:/tmp/daphne.sock"] + build(unix_socket="/tmp/daphne.sock"), ["unix:/tmp/daphne.sock"] ) def testFileDescriptorBinding(self): - self.assertEqual( - build(file_descriptor=5), - ["fd:fileno=5"] - ) + self.assertEqual(build(file_descriptor=5), ["fd:fileno=5"]) def testMultipleEnpoints(self): self.assertEqual( @@ -65,14 +52,16 @@ class TestEndpointDescriptions(TestCase): file_descriptor=123, unix_socket="/tmp/daphne.sock", port=8080, - host="10.0.0.1" + host="10.0.0.1", ) ), - sorted([ - "tcp:port=8080:interface=10.0.0.1", - "unix:/tmp/daphne.sock", - "fd:fileno=123" - ]) + sorted( + [ + "tcp:port=8080:interface=10.0.0.1", + "unix:/tmp/daphne.sock", + "fd:fileno=123", + ] + ), ) @@ -112,7 +101,9 @@ class TestCLIInterface(TestCase): Passes in a fake application automatically. """ cli = self.TestedCLI() - cli.run(args + ["daphne:__version__"]) # We just pass something importable as app + cli.run( + args + ["daphne:__version__"] + ) # We just pass something importable as app # Check the server got all arguments as intended for key, value in server_kwargs.items(): # Get the value and sort it if it's a list (for endpoint checking) @@ -123,52 +114,30 @@ class TestCLIInterface(TestCase): self.assertEqual( value, actual_value, - "Wrong value for server kwarg %s: %r != %r" % ( - key, - value, - actual_value, - ), + "Wrong value for server kwarg %s: %r != %r" + % (key, value, actual_value), ) def testCLIBasics(self): """ Tests basic endpoint generation. """ + self.assertCLI([], {"endpoints": ["tcp:port=8000:interface=127.0.0.1"]}) self.assertCLI( - [], - { - "endpoints": ["tcp:port=8000:interface=127.0.0.1"], - }, + ["-p", "123"], {"endpoints": ["tcp:port=123:interface=127.0.0.1"]} ) self.assertCLI( - ["-p", "123"], - { - "endpoints": ["tcp:port=123:interface=127.0.0.1"], - }, + ["-b", "10.0.0.1"], {"endpoints": ["tcp:port=8000:interface=10.0.0.1"]} ) self.assertCLI( - ["-b", "10.0.0.1"], - { - "endpoints": ["tcp:port=8000:interface=10.0.0.1"], - }, + ["-b", "200a::1"], {"endpoints": [r"tcp:port=8000:interface=200a\:\:1"]} ) self.assertCLI( - ["-b", "200a::1"], - { - "endpoints": [r'tcp:port=8000:interface=200a\:\:1'], - }, - ) - self.assertCLI( - ["-b", "[200a::1]"], - { - "endpoints": [r'tcp:port=8000:interface=200a\:\:1'], - }, + ["-b", "[200a::1]"], {"endpoints": [r"tcp:port=8000:interface=200a\:\:1"]} ) self.assertCLI( ["-p", "8080", "-b", "example.com"], - { - "endpoints": ["tcp:port=8080:interface=example.com"], - }, + {"endpoints": ["tcp:port=8080:interface=example.com"]}, ) def testUnixSockets(self): @@ -178,7 +147,7 @@ class TestCLIInterface(TestCase): "endpoints": [ "tcp:port=8080:interface=127.0.0.1", "unix:/tmp/daphne.sock", - ], + ] }, ) self.assertCLI( @@ -187,17 +156,12 @@ class TestCLIInterface(TestCase): "endpoints": [ "tcp:port=8000:interface=example.com", "unix:/tmp/daphne.sock", - ], + ] }, ) self.assertCLI( ["-u", "/tmp/daphne.sock", "--fd", "5"], - { - "endpoints": [ - "fd:fileno=5", - "unix:/tmp/daphne.sock" - ], - }, + {"endpoints": ["fd:fileno=5", "unix:/tmp/daphne.sock"]}, ) def testMixedCLIEndpointCreation(self): @@ -209,8 +173,8 @@ class TestCLIInterface(TestCase): { "endpoints": [ "tcp:port=8080:interface=127.0.0.1", - "unix:/tmp/daphne.sock" - ], + "unix:/tmp/daphne.sock", + ] }, ) self.assertCLI( @@ -219,7 +183,7 @@ class TestCLIInterface(TestCase): "endpoints": [ "tcp:port=8080:interface=127.0.0.1", "tcp:port=8080:interface=127.0.0.1", - ], + ] }, ) @@ -227,11 +191,4 @@ class TestCLIInterface(TestCase): """ Tests entirely custom endpoints """ - self.assertCLI( - ["-e", "imap:"], - { - "endpoints": [ - "imap:", - ], - }, - ) + self.assertCLI(["-e", "imap:"], {"endpoints": ["imap:"]}) diff --git a/tests/test_http_request.py b/tests/test_http_request.py index 79274ed..e02b8b6 100644 --- a/tests/test_http_request.py +++ b/tests/test_http_request.py @@ -15,13 +15,7 @@ class TestHTTPRequest(DaphneTestCase): """ def assert_valid_http_scope( - self, - scope, - method, - path, - params=None, - headers=None, - scheme=None, + self, scope, method, path, params=None, headers=None, scheme=None ): """ Checks that the passed scope is a valid ASGI HTTP scope regarding types @@ -29,7 +23,14 @@ class TestHTTPRequest(DaphneTestCase): """ # Check overall keys self.assert_key_sets( - required_keys={"type", "http_version", "method", "path", "query_string", "headers"}, + required_keys={ + "type", + "http_version", + "method", + "path", + "query_string", + "headers", + }, optional_keys={"scheme", "root_path", "client", "server"}, actual_keys=scope.keys(), ) @@ -50,7 +51,9 @@ class TestHTTPRequest(DaphneTestCase): query_string = scope["query_string"] self.assertIsInstance(query_string, bytes) if params: - self.assertEqual(query_string, parse.urlencode(params or []).encode("ascii")) + self.assertEqual( + query_string, parse.urlencode(params or []).encode("ascii") + ) # Ordering of header names is not important, but the order of values for a header # name is. To assert whether that order is kept, we transform both the request # headers and the channel message headers into a dictionary @@ -59,7 +62,7 @@ class TestHTTPRequest(DaphneTestCase): for name, value in scope["headers"]: transformed_scope_headers[name].append(value) transformed_request_headers = collections.defaultdict(list) - for name, value in (headers or []): + for name, value in headers or []: expected_name = name.lower().strip().encode("ascii") expected_value = value.strip().encode("ascii") transformed_request_headers[expected_name].append(expected_value) @@ -103,27 +106,31 @@ class TestHTTPRequest(DaphneTestCase): @given( request_path=http_strategies.http_path(), - request_params=http_strategies.query_params() + request_params=http_strategies.query_params(), ) @settings(max_examples=5, deadline=5000) def test_get_request(self, request_path, request_params): """ Tests a typical HTTP GET request, with a path and query parameters """ - scope, messages = self.run_daphne_request("GET", request_path, params=request_params) + scope, messages = self.run_daphne_request( + "GET", request_path, params=request_params + ) self.assert_valid_http_scope(scope, "GET", request_path, params=request_params) self.assert_valid_http_request_message(messages[0], body=b"") @given( request_path=http_strategies.http_path(), - request_body=http_strategies.http_body() + request_body=http_strategies.http_body(), ) @settings(max_examples=5, deadline=5000) def test_post_request(self, request_path, request_body): """ Tests a typical HTTP POST request, with a path and body. """ - scope, messages = self.run_daphne_request("POST", request_path, body=request_body) + scope, messages = self.run_daphne_request( + "POST", request_path, body=request_body + ) self.assert_valid_http_scope(scope, "POST", request_path) self.assert_valid_http_request_message(messages[0], body=request_body) @@ -134,8 +141,12 @@ class TestHTTPRequest(DaphneTestCase): Tests that HTTP header fields are handled as specified """ request_path = "/te st-à/" - scope, messages = self.run_daphne_request("OPTIONS", request_path, headers=request_headers) - self.assert_valid_http_scope(scope, "OPTIONS", request_path, headers=request_headers) + scope, messages = self.run_daphne_request( + "OPTIONS", request_path, headers=request_headers + ) + self.assert_valid_http_scope( + scope, "OPTIONS", request_path, headers=request_headers + ) self.assert_valid_http_request_message(messages[0], body=b"") @given(request_headers=http_strategies.headers()) @@ -150,8 +161,12 @@ class TestHTTPRequest(DaphneTestCase): duplicated_headers = [(header_name, header[1]) for header in request_headers] # Run the request request_path = "/te st-à/" - scope, messages = self.run_daphne_request("OPTIONS", request_path, headers=duplicated_headers) - self.assert_valid_http_scope(scope, "OPTIONS", request_path, headers=duplicated_headers) + scope, messages = self.run_daphne_request( + "OPTIONS", request_path, headers=duplicated_headers + ) + self.assert_valid_http_scope( + scope, "OPTIONS", request_path, headers=duplicated_headers + ) self.assert_valid_http_request_message(messages[0], body=b"") @given( @@ -222,10 +237,7 @@ class TestHTTPRequest(DaphneTestCase): """ Make sure that, by default, X-Forwarded-For is ignored. """ - headers = [ - ["X-Forwarded-For", "10.1.2.3"], - ["X-Forwarded-Port", "80"], - ] + headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]] scope, messages = self.run_daphne_request("GET", "/", headers=headers) self.assert_valid_http_scope(scope, "GET", "/", headers=headers) self.assert_valid_http_request_message(messages[0], body=b"") @@ -236,10 +248,7 @@ class TestHTTPRequest(DaphneTestCase): """ When X-Forwarded-For is enabled, make sure it is respected. """ - headers = [ - ["X-Forwarded-For", "10.1.2.3"], - ["X-Forwarded-Port", "80"], - ] + headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]] scope, messages = self.run_daphne_request("GET", "/", headers=headers, xff=True) self.assert_valid_http_scope(scope, "GET", "/", headers=headers) self.assert_valid_http_request_message(messages[0], body=b"") @@ -251,9 +260,7 @@ class TestHTTPRequest(DaphneTestCase): When X-Forwarded-For is enabled but only the host is passed, make sure that at least makes it through. """ - headers = [ - ["X-Forwarded-For", "10.1.2.3"], - ] + headers = [["X-Forwarded-For", "10.1.2.3"]] scope, messages = self.run_daphne_request("GET", "/", headers=headers, xff=True) self.assert_valid_http_scope(scope, "GET", "/", headers=headers) self.assert_valid_http_request_message(messages[0], body=b"") @@ -265,8 +272,12 @@ class TestHTTPRequest(DaphneTestCase): Tests that requests with invalid (non-ASCII) characters fail. """ # Bad path - response = self.run_daphne_raw(b"GET /\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n") + response = self.run_daphne_raw( + b"GET /\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n" + ) self.assertTrue(response.startswith(b"HTTP/1.0 400 Bad Request")) # Bad querystring - response = self.run_daphne_raw(b"GET /?\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n") + response = self.run_daphne_raw( + b"GET /?\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n" + ) self.assertTrue(response.startswith(b"HTTP/1.0 400 Bad Request")) diff --git a/tests/test_http_response.py b/tests/test_http_response.py index 2efc4ec..3576697 100644 --- a/tests/test_http_response.py +++ b/tests/test_http_response.py @@ -15,26 +15,24 @@ class TestHTTPResponse(DaphneTestCase): """ Lowercases and sorts headers, and strips transfer-encoding ones. """ - return sorted([ - (name.lower(), value.strip()) - for name, value in headers - if name.lower() != "transfer-encoding" - ]) + return sorted( + [ + (name.lower(), value.strip()) + for name, value in headers + if name.lower() != "transfer-encoding" + ] + ) def test_minimal_response(self): """ Smallest viable example. Mostly verifies that our response building works. """ - response = self.run_daphne_response([ - { - "type": "http.response.start", - "status": 200, - }, - { - "type": "http.response.body", - "body": b"hello world", - }, - ]) + response = self.run_daphne_response( + [ + {"type": "http.response.start", "status": 200}, + {"type": "http.response.body", "body": b"hello world"}, + ] + ) self.assertEqual(response.status, 200) self.assertEqual(response.read(), b"hello world") @@ -46,30 +44,23 @@ class TestHTTPResponse(DaphneTestCase): to make sure it stays required. """ with self.assertRaises(ValueError): - self.run_daphne_response([ - { - "type": "http.response.start", - }, - { - "type": "http.response.body", - "body": b"hello world", - }, - ]) + self.run_daphne_response( + [ + {"type": "http.response.start"}, + {"type": "http.response.body", "body": b"hello world"}, + ] + ) def test_custom_status_code(self): """ Tries a non-default status code. """ - response = self.run_daphne_response([ - { - "type": "http.response.start", - "status": 201, - }, - { - "type": "http.response.body", - "body": b"i made a thing!", - }, - ]) + response = self.run_daphne_response( + [ + {"type": "http.response.start", "status": 201}, + {"type": "http.response.body", "body": b"i made a thing!"}, + ] + ) self.assertEqual(response.status, 201) self.assertEqual(response.read(), b"i made a thing!") @@ -77,21 +68,13 @@ class TestHTTPResponse(DaphneTestCase): """ Tries sending a response in multiple parts. """ - response = self.run_daphne_response([ - { - "type": "http.response.start", - "status": 201, - }, - { - "type": "http.response.body", - "body": b"chunk 1 ", - "more_body": True, - }, - { - "type": "http.response.body", - "body": b"chunk 2", - }, - ]) + response = self.run_daphne_response( + [ + {"type": "http.response.start", "status": 201}, + {"type": "http.response.body", "body": b"chunk 1 ", "more_body": True}, + {"type": "http.response.body", "body": b"chunk 2"}, + ] + ) self.assertEqual(response.status, 201) self.assertEqual(response.read(), b"chunk 1 chunk 2") @@ -99,25 +82,14 @@ class TestHTTPResponse(DaphneTestCase): """ Tries sending a response in multiple parts and an empty end. """ - response = self.run_daphne_response([ - { - "type": "http.response.start", - "status": 201, - }, - { - "type": "http.response.body", - "body": b"chunk 1 ", - "more_body": True, - }, - { - "type": "http.response.body", - "body": b"chunk 2", - "more_body": True, - }, - { - "type": "http.response.body", - }, - ]) + response = self.run_daphne_response( + [ + {"type": "http.response.start", "status": 201}, + {"type": "http.response.body", "body": b"chunk 1 ", "more_body": True}, + {"type": "http.response.body", "body": b"chunk 2", "more_body": True}, + {"type": "http.response.body"}, + ] + ) self.assertEqual(response.status, 201) self.assertEqual(response.read(), b"chunk 1 chunk 2") @@ -127,16 +99,12 @@ class TestHTTPResponse(DaphneTestCase): """ Tries body variants. """ - response = self.run_daphne_response([ - { - "type": "http.response.start", - "status": 200, - }, - { - "type": "http.response.body", - "body": body, - }, - ]) + response = self.run_daphne_response( + [ + {"type": "http.response.start", "status": 200}, + {"type": "http.response.body", "body": body}, + ] + ) self.assertEqual(response.status, 200) self.assertEqual(response.read(), body) @@ -144,16 +112,16 @@ class TestHTTPResponse(DaphneTestCase): @settings(max_examples=5, deadline=5000) def test_headers(self, headers): # The ASGI spec requires us to lowercase our header names - response = self.run_daphne_response([ - { - "type": "http.response.start", - "status": 200, - "headers": self.normalize_headers(headers), - }, - { - "type": "http.response.body", - }, - ]) + response = self.run_daphne_response( + [ + { + "type": "http.response.start", + "status": 200, + "headers": self.normalize_headers(headers), + }, + {"type": "http.response.body"}, + ] + ) # Check headers in a sensible way. Ignore transfer-encoding. self.assertEqual( self.normalize_headers(response.getheaders()), diff --git a/tests/test_utils.py b/tests/test_utils.py index 5dada0f..6b04939 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -13,48 +13,35 @@ class TestXForwardedForHttpParsing(TestCase): """ def test_basic(self): - headers = Headers({ - b"X-Forwarded-For": [b"10.1.2.3"], - b"X-Forwarded-Port": [b"1234"], - b"X-Forwarded-Proto": [b"https"] - }) + headers = Headers( + { + b"X-Forwarded-For": [b"10.1.2.3"], + b"X-Forwarded-Port": [b"1234"], + b"X-Forwarded-Proto": [b"https"], + } + ) result = parse_x_forwarded_for(headers) self.assertEqual(result, (["10.1.2.3", 1234], "https")) self.assertIsInstance(result[0][0], str) self.assertIsInstance(result[1], str) def test_address_only(self): - headers = Headers({ - b"X-Forwarded-For": [b"10.1.2.3"], - }) - self.assertEqual( - parse_x_forwarded_for(headers), - (["10.1.2.3", 0], None) - ) + headers = Headers({b"X-Forwarded-For": [b"10.1.2.3"]}) + self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None)) def test_v6_address(self): - headers = Headers({ - b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"], - }) - self.assertEqual( - parse_x_forwarded_for(headers), - (["1043::a321:0001", 0], None) - ) + headers = Headers({b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"]}) + self.assertEqual(parse_x_forwarded_for(headers), (["1043::a321:0001", 0], None)) def test_multiple_proxys(self): - headers = Headers({ - b"X-Forwarded-For": [b"10.1.2.3, 10.1.2.4"], - }) - self.assertEqual( - parse_x_forwarded_for(headers), - (["10.1.2.3", 0], None) - ) + headers = Headers({b"X-Forwarded-For": [b"10.1.2.3, 10.1.2.4"]}) + self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None)) def test_original(self): headers = Headers({}) self.assertEqual( parse_x_forwarded_for(headers, original_addr=["127.0.0.1", 80]), - (["127.0.0.1", 80], None) + (["127.0.0.1", 80], None), ) def test_no_original(self): @@ -73,43 +60,25 @@ class TestXForwardedForWsParsing(TestCase): b"X-Forwarded-Port": b"1234", b"X-Forwarded-Proto": b"https", } - self.assertEqual( - parse_x_forwarded_for(headers), - (["10.1.2.3", 1234], "https") - ) + self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 1234], "https")) def test_address_only(self): - headers = { - b"X-Forwarded-For": b"10.1.2.3", - } - self.assertEqual( - parse_x_forwarded_for(headers), - (["10.1.2.3", 0], None) - ) + headers = {b"X-Forwarded-For": b"10.1.2.3"} + self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None)) def test_v6_address(self): - headers = { - b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"], - } - self.assertEqual( - parse_x_forwarded_for(headers), - (["1043::a321:0001", 0], None) - ) + headers = {b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"]} + self.assertEqual(parse_x_forwarded_for(headers), (["1043::a321:0001", 0], None)) def test_multiple_proxies(self): - headers = { - b"X-Forwarded-For": b"10.1.2.3, 10.1.2.4", - } - self.assertEqual( - parse_x_forwarded_for(headers), - (["10.1.2.3", 0], None) - ) + headers = {b"X-Forwarded-For": b"10.1.2.3, 10.1.2.4"} + self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None)) def test_original(self): headers = {} self.assertEqual( parse_x_forwarded_for(headers, original_addr=["127.0.0.1", 80]), - (["127.0.0.1", 80], None) + (["127.0.0.1", 80], None), ) def test_no_original(self): diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 0ae1a21..80ec21d 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -16,13 +16,7 @@ class TestWebsocket(DaphneTestCase): """ def assert_valid_websocket_scope( - self, - scope, - path="/", - params=None, - headers=None, - scheme=None, - subprotocols=None, + self, scope, path="/", params=None, headers=None, scheme=None, subprotocols=None ): """ Checks that the passed scope is a valid ASGI HTTP scope regarding types @@ -46,7 +40,9 @@ class TestWebsocket(DaphneTestCase): query_string = scope["query_string"] self.assertIsInstance(query_string, bytes) if params: - self.assertEqual(query_string, parse.urlencode(params or []).encode("ascii")) + self.assertEqual( + query_string, parse.urlencode(params or []).encode("ascii") + ) # Ordering of header names is not important, but the order of values for a header # name is. To assert whether that order is kept, we transform both the request # headers and the channel message headers into a dictionary @@ -59,7 +55,7 @@ class TestWebsocket(DaphneTestCase): if bit.strip(): transformed_scope_headers[name].append(bit.strip()) transformed_request_headers = collections.defaultdict(list) - for name, value in (headers or []): + for name, value in headers or []: expected_name = name.lower().strip().encode("ascii") expected_value = value.strip().encode("ascii") # Make sure to split out any headers collapsed with commas @@ -92,9 +88,7 @@ class TestWebsocket(DaphneTestCase): """ # Check overall keys self.assert_key_sets( - required_keys={"type"}, - optional_keys=set(), - actual_keys=message.keys(), + required_keys={"type"}, optional_keys=set(), actual_keys=message.keys() ) # Check that it is the right type self.assertEqual(message["type"], "websocket.connect") @@ -104,11 +98,7 @@ class TestWebsocket(DaphneTestCase): Tests we can open and accept a socket. """ with DaphneTestingInstance() as test_app: - test_app.add_send_messages([ - { - "type": "websocket.accept", - } - ]) + test_app.add_send_messages([{"type": "websocket.accept"}]) self.websocket_handshake(test_app) # Validate the scope and messages we got scope, messages = test_app.get_received() @@ -120,11 +110,7 @@ class TestWebsocket(DaphneTestCase): Tests we can reject a socket and it won't complete the handshake. """ with DaphneTestingInstance() as test_app: - test_app.add_send_messages([ - { - "type": "websocket.close", - } - ]) + test_app.add_send_messages([{"type": "websocket.close"}]) with self.assertRaises(RuntimeError): self.websocket_handshake(test_app) @@ -134,13 +120,12 @@ class TestWebsocket(DaphneTestCase): """ subprotocols = ["proto1", "proto2"] with DaphneTestingInstance() as test_app: - test_app.add_send_messages([ - { - "type": "websocket.accept", - "subprotocol": "proto2", - } - ]) - _, subprotocol = self.websocket_handshake(test_app, subprotocols=subprotocols) + test_app.add_send_messages( + [{"type": "websocket.accept", "subprotocol": "proto2"}] + ) + _, subprotocol = self.websocket_handshake( + test_app, subprotocols=subprotocols + ) # Validate the scope and messages we got assert subprotocol == "proto2" scope, messages = test_app.get_received() @@ -151,16 +136,9 @@ class TestWebsocket(DaphneTestCase): """ Tests that X-Forwarded-For headers get parsed right """ - headers = [ - ["X-Forwarded-For", "10.1.2.3"], - ["X-Forwarded-Port", "80"], - ] + headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]] with DaphneTestingInstance(xff=True) as test_app: - test_app.add_send_messages([ - { - "type": "websocket.accept", - } - ]) + test_app.add_send_messages([{"type": "websocket.accept"}]) self.websocket_handshake(test_app, headers=headers) # Validate the scope and messages we got scope, messages = test_app.get_received() @@ -174,22 +152,13 @@ class TestWebsocket(DaphneTestCase): request_headers=http_strategies.headers(), ) @settings(max_examples=5, deadline=2000) - def test_http_bits( - self, - request_path, - request_params, - request_headers, - ): + def test_http_bits(self, request_path, request_params, request_headers): """ Tests that various HTTP-level bits (query string params, path, headers) carry over into the scope. """ with DaphneTestingInstance() as test_app: - test_app.add_send_messages([ - { - "type": "websocket.accept", - } - ]) + test_app.add_send_messages([{"type": "websocket.accept"}]) self.websocket_handshake( test_app, path=request_path, @@ -199,10 +168,7 @@ class TestWebsocket(DaphneTestCase): # Validate the scope and messages we got scope, messages = test_app.get_received() self.assert_valid_websocket_scope( - scope, - path=request_path, - params=request_params, - headers=request_headers, + scope, path=request_path, params=request_params, headers=request_headers ) self.assert_valid_websocket_connect_message(messages[0]) @@ -212,28 +178,24 @@ class TestWebsocket(DaphneTestCase): """ with DaphneTestingInstance() as test_app: # Connect - test_app.add_send_messages([ - { - "type": "websocket.accept", - } - ]) + test_app.add_send_messages([{"type": "websocket.accept"}]) sock, _ = self.websocket_handshake(test_app) _, messages = test_app.get_received() self.assert_valid_websocket_connect_message(messages[0]) # Prep frame for it to send - test_app.add_send_messages([ - { - "type": "websocket.send", - "text": "here be dragons 🐉", - } - ]) + test_app.add_send_messages( + [{"type": "websocket.send", "text": "here be dragons 🐉"}] + ) # Send it a frame self.websocket_send_frame(sock, "what is here? 🌍") # Receive a frame and make sure it's correct assert self.websocket_receive_frame(sock) == "here be dragons 🐉" # Make sure it got our frame _, messages = test_app.get_received() - assert messages[1] == {"type": "websocket.receive", "text": "what is here? 🌍"} + assert messages[1] == { + "type": "websocket.receive", + "text": "what is here? 🌍", + } def test_binary_frames(self): """ @@ -242,28 +204,24 @@ class TestWebsocket(DaphneTestCase): """ with DaphneTestingInstance() as test_app: # Connect - test_app.add_send_messages([ - { - "type": "websocket.accept", - } - ]) + test_app.add_send_messages([{"type": "websocket.accept"}]) sock, _ = self.websocket_handshake(test_app) _, messages = test_app.get_received() self.assert_valid_websocket_connect_message(messages[0]) # Prep frame for it to send - test_app.add_send_messages([ - { - "type": "websocket.send", - "bytes": b"here be \xe2 bytes", - } - ]) + test_app.add_send_messages( + [{"type": "websocket.send", "bytes": b"here be \xe2 bytes"}] + ) # Send it a frame self.websocket_send_frame(sock, b"what is here? \xe2") # Receive a frame and make sure it's correct assert self.websocket_receive_frame(sock) == b"here be \xe2 bytes" # Make sure it got our frame _, messages = test_app.get_received() - assert messages[1] == {"type": "websocket.receive", "bytes": b"what is here? \xe2"} + assert messages[1] == { + "type": "websocket.receive", + "bytes": b"what is here? \xe2", + } def test_http_timeout(self): """ @@ -271,23 +229,14 @@ class TestWebsocket(DaphneTestCase): """ with DaphneTestingInstance(http_timeout=1) as test_app: # Connect - test_app.add_send_messages([ - { - "type": "websocket.accept", - } - ]) + test_app.add_send_messages([{"type": "websocket.accept"}]) sock, _ = self.websocket_handshake(test_app) _, messages = test_app.get_received() self.assert_valid_websocket_connect_message(messages[0]) # Wait 2 seconds time.sleep(2) # Prep frame for it to send - test_app.add_send_messages([ - { - "type": "websocket.send", - "text": "cake", - } - ]) + test_app.add_send_messages([{"type": "websocket.send", "text": "cake"}]) # Send it a frame self.websocket_send_frame(sock, "still alive?") # Receive a frame and make sure it's correct