diff --git a/.travis.yml b/.travis.yml
index aa69b0b..6a617c4 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -3,22 +3,25 @@ sudo: false
language: python
python:
-- '3.5'
- '3.6'
+- '3.5'
env:
-- TWISTED="twisted==18.7.0"
- TWISTED="twisted"
+- TWISTED="twisted==18.7.0"
install:
-- pip install $TWISTED isort unify flake8 -e .[tests]
+- pip install $TWISTED -e .[tests]
- pip freeze
script:
- pytest
-- flake8
-- isort --check-only --diff --recursive daphne tests
-- unify --check-only --recursive --quote \" daphne tests
+
+stages:
+ - lint
+ - test
+ - name: release
+ if: branch = master
jobs:
include:
@@ -30,6 +33,13 @@ jobs:
env: TWISTED="twisted"
dist: xenial
sudo: required
+ - stage: lint
+ install: pip install -U -e .[tests] black pyflakes isort
+ script:
+ - pyflakes .
+ - black --check .
+ - isort --check-only --diff --recursive channels_redis tests
+
- stage: release
script: skip
deploy:
diff --git a/daphne/access.py b/daphne/access.py
index ce80f49..2b3b1cd 100644
--- a/daphne/access.py
+++ b/daphne/access.py
@@ -49,13 +49,16 @@ class AccessLogGenerator(object):
request="WSDISCONNECT %(path)s" % details,
)
- def write_entry(self, host, date, request, status=None, length=None, ident=None, user=None):
+ def write_entry(
+ self, host, date, request, status=None, length=None, ident=None, user=None
+ ):
"""
Writes an NCSA-style entry to the log file (some liberty is taken with
what the entries are for non-HTTP)
"""
self.stream.write(
- "%s %s %s [%s] \"%s\" %s %s\n" % (
+ '%s %s %s [%s] "%s" %s %s\n'
+ % (
host,
ident or "-",
user or "-",
diff --git a/daphne/cli.py b/daphne/cli.py
index 7f42084..28cf1b9 100755
--- a/daphne/cli.py
+++ b/daphne/cli.py
@@ -23,15 +23,9 @@ class CommandLineInterface(object):
server_class = Server
def __init__(self):
- self.parser = argparse.ArgumentParser(
- description=self.description,
- )
+ self.parser = argparse.ArgumentParser(description=self.description)
self.parser.add_argument(
- "-p",
- "--port",
- type=int,
- help="Port number to listen on",
- default=None,
+ "-p", "--port", type=int, help="Port number to listen on", default=None
)
self.parser.add_argument(
"-b",
@@ -128,7 +122,7 @@ class CommandLineInterface(object):
"--proxy-headers",
dest="proxy_headers",
help="Enable parsing and using of X-Forwarded-For and X-Forwarded-Port headers and using that as the "
- "client address",
+ "client address",
default=False,
action="store_true",
)
@@ -176,7 +170,15 @@ class CommandLineInterface(object):
sys.path.insert(0, ".")
application = import_by_path(args.application)
# Set up port/host bindings
- if not any([args.host, args.port is not None, args.unix_socket, args.file_descriptor, args.socket_strings]):
+ if not any(
+ [
+ args.host,
+ args.port is not None,
+ args.unix_socket,
+ args.file_descriptor,
+ args.socket_strings,
+ ]
+ ):
# no advanced binding options passed, patch in defaults
args.host = DEFAULT_HOST
args.port = DEFAULT_PORT
@@ -189,16 +191,11 @@ class CommandLineInterface(object):
host=args.host,
port=args.port,
unix_socket=args.unix_socket,
- file_descriptor=args.file_descriptor
- )
- endpoints = sorted(
- args.socket_strings + endpoints
+ file_descriptor=args.file_descriptor,
)
+ endpoints = sorted(args.socket_strings + endpoints)
# Start the server
- logger.info(
- "Starting server at %s" %
- (", ".join(endpoints), )
- )
+ logger.info("Starting server at %s" % (", ".join(endpoints),))
self.server = self.server_class(
application=application,
endpoints=endpoints,
@@ -208,12 +205,20 @@ class CommandLineInterface(object):
websocket_timeout=args.websocket_timeout,
websocket_connect_timeout=args.websocket_connect_timeout,
application_close_timeout=args.application_close_timeout,
- action_logger=AccessLogGenerator(access_log_stream) if access_log_stream else None,
+ action_logger=AccessLogGenerator(access_log_stream)
+ if access_log_stream
+ else None,
ws_protocols=args.ws_protocols,
root_path=args.root_path,
verbosity=args.verbosity,
- proxy_forwarded_address_header="X-Forwarded-For" if args.proxy_headers else None,
- proxy_forwarded_port_header="X-Forwarded-Port" if args.proxy_headers else None,
- proxy_forwarded_proto_header="X-Forwarded-Proto" if args.proxy_headers else None,
+ proxy_forwarded_address_header="X-Forwarded-For"
+ if args.proxy_headers
+ else None,
+ proxy_forwarded_port_header="X-Forwarded-Port"
+ if args.proxy_headers
+ else None,
+ proxy_forwarded_proto_header="X-Forwarded-Proto"
+ if args.proxy_headers
+ else None,
)
self.server.run()
diff --git a/daphne/endpoints.py b/daphne/endpoints.py
index 6188357..83e472a 100644
--- a/daphne/endpoints.py
+++ b/daphne/endpoints.py
@@ -1,10 +1,5 @@
-
-
def build_endpoint_description_strings(
- host=None,
- port=None,
- unix_socket=None,
- file_descriptor=None
+ host=None, port=None, unix_socket=None, file_descriptor=None
):
"""
Build a list of twisted endpoint description strings that the server will listen on.
diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py
index 915e475..2c8d840 100755
--- a/daphne/http_protocol.py
+++ b/daphne/http_protocol.py
@@ -23,7 +23,8 @@ class WebRequest(http.Request):
GET and POST out.
"""
- error_template = """
+ error_template = (
+ """
%(title)s
@@ -40,7 +41,13 @@ class WebRequest(http.Request):
- """.replace("\n", "").replace(" ", " ").replace(" ", " ").replace(" ", " ") # Shorten it a bit, bytes wise
+ """.replace(
+ "\n", ""
+ )
+ .replace(" ", " ")
+ .replace(" ", " ")
+ .replace(" ", " ")
+ ) # Shorten it a bit, bytes wise
def __init__(self, *args, **kwargs):
try:
@@ -84,7 +91,7 @@ class WebRequest(http.Request):
self.server.proxy_forwarded_port_header,
self.server.proxy_forwarded_proto_header,
self.client_addr,
- self.client_scheme
+ self.client_scheme,
)
# Check for unicodeish path (or it'll crash when trying to parse)
try:
@@ -105,7 +112,9 @@ class WebRequest(http.Request):
# Is it WebSocket? IS IT?!
if upgrade_header and upgrade_header.lower() == b"websocket":
# Make WebSocket protocol to hand off to
- protocol = self.server.ws_factory.buildProtocol(self.transport.getPeer())
+ protocol = self.server.ws_factory.buildProtocol(
+ self.transport.getPeer()
+ )
if not protocol:
# If protocol creation fails, we signal "internal server error"
self.setResponseCode(500)
@@ -151,33 +160,38 @@ class WebRequest(http.Request):
logger.debug("HTTP %s request for %s", self.method, self.client_addr)
self.content.seek(0, 0)
# Work out the application scope and create application
- self.application_queue = yield maybeDeferred(self.server.create_application, self, {
- "type": "http",
- # TODO: Correctly say if it's 1.1 or 1.0
- "http_version": self.clientproto.split(b"/")[-1].decode("ascii"),
- "method": self.method.decode("ascii"),
- "path": unquote(self.path.decode("ascii")),
- "root_path": self.root_path,
- "scheme": self.client_scheme,
- "query_string": self.query_string,
- "headers": self.clean_headers,
- "client": self.client_addr,
- "server": self.server_addr,
- })
+ self.application_queue = yield maybeDeferred(
+ self.server.create_application,
+ self,
+ {
+ "type": "http",
+ # TODO: Correctly say if it's 1.1 or 1.0
+ "http_version": self.clientproto.split(b"/")[-1].decode(
+ "ascii"
+ ),
+ "method": self.method.decode("ascii"),
+ "path": unquote(self.path.decode("ascii")),
+ "root_path": self.root_path,
+ "scheme": self.client_scheme,
+ "query_string": self.query_string,
+ "headers": self.clean_headers,
+ "client": self.client_addr,
+ "server": self.server_addr,
+ },
+ )
# Check they didn't close an unfinished request
if self.application_queue is None or self.content.closed:
# Not much we can do, the request is prematurely abandoned.
return
# Run application against request
self.application_queue.put_nowait(
- {
- "type": "http.request",
- "body": self.content.read(),
- },
+ {"type": "http.request", "body": self.content.read()}
)
except Exception:
logger.error(traceback.format_exc())
- self.basic_error(500, b"Internal Server Error", "Daphne HTTP processing error")
+ self.basic_error(
+ 500, b"Internal Server Error", "Daphne HTTP processing error"
+ )
def connectionLost(self, reason):
"""
@@ -217,16 +231,23 @@ class WebRequest(http.Request):
raise ValueError("HTTP response has already been started")
self._response_started = True
if "status" not in message:
- raise ValueError("Specifying a status code is required for a Response message.")
+ raise ValueError(
+ "Specifying a status code is required for a Response message."
+ )
# Set HTTP status code
self.setResponseCode(message["status"])
# Write headers
for header, value in message.get("headers", {}):
self.responseHeaders.addRawHeader(header, value)
- logger.debug("HTTP %s response started for %s", message["status"], self.client_addr)
+ logger.debug(
+ "HTTP %s response started for %s", message["status"], self.client_addr
+ )
elif message["type"] == "http.response.body":
if not self._response_started:
- raise ValueError("HTTP response has not yet been started but got %s" % message["type"])
+ raise ValueError(
+ "HTTP response has not yet been started but got %s"
+ % message["type"]
+ )
# Write out body
http.Request.write(self, message.get("body", b""))
# End if there's no more content
@@ -239,15 +260,21 @@ class WebRequest(http.Request):
# The path is malformed somehow - do our best to log something
uri = repr(self.uri)
try:
- self.server.log_action("http", "complete", {
- "path": uri,
- "status": self.code,
- "method": self.method.decode("ascii", "replace"),
- "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
- "time_taken": self.duration(),
- "size": self.sentLength,
- })
- except Exception as e:
+ self.server.log_action(
+ "http",
+ "complete",
+ {
+ "path": uri,
+ "status": self.code,
+ "method": self.method.decode("ascii", "replace"),
+ "client": "%s:%s" % tuple(self.client_addr)
+ if self.client_addr
+ else None,
+ "time_taken": self.duration(),
+ "size": self.sentLength,
+ },
+ )
+ except Exception:
logger.error(traceback.format_exc())
else:
logger.debug("HTTP response chunk for %s", self.client_addr)
@@ -270,7 +297,11 @@ class WebRequest(http.Request):
logger.warning("Application timed out while sending response")
self.finish()
else:
- self.basic_error(503, b"Service Unavailable", "Application failed to respond within time limit.")
+ self.basic_error(
+ 503,
+ b"Service Unavailable",
+ "Application failed to respond within time limit.",
+ )
### Utility functions
@@ -281,11 +312,7 @@ class WebRequest(http.Request):
"""
# If we don't yet have a path, then don't send as we never opened.
if self.path:
- self.application_queue.put_nowait(
- {
- "type": "http.disconnect",
- },
- )
+ self.application_queue.put_nowait({"type": "http.disconnect"})
def duration(self):
"""
@@ -299,20 +326,25 @@ class WebRequest(http.Request):
"""
Responds with a server-level error page (very basic)
"""
- self.handle_reply({
- "type": "http.response.start",
- "status": status,
- "headers": [
- (b"Content-Type", b"text/html; charset=utf-8"),
- ],
- })
- self.handle_reply({
- "type": "http.response.body",
- "body": (self.error_template % {
- "title": str(status) + " " + status_text.decode("ascii"),
- "body": body,
- }).encode("utf8"),
- })
+ self.handle_reply(
+ {
+ "type": "http.response.start",
+ "status": status,
+ "headers": [(b"Content-Type", b"text/html; charset=utf-8")],
+ }
+ )
+ self.handle_reply(
+ {
+ "type": "http.response.body",
+ "body": (
+ self.error_template
+ % {
+ "title": str(status) + " " + status_text.decode("ascii"),
+ "body": body,
+ }
+ ).encode("utf8"),
+ }
+ )
def __hash__(self):
return hash(id(self))
@@ -343,7 +375,7 @@ class HTTPFactory(http.HTTPFactory):
protocol = http.HTTPFactory.buildProtocol(self, addr)
protocol.requestFactory = WebRequest
return protocol
- except Exception as e:
+ except Exception:
logger.error("Cannot build protocol: %s" % traceback.format_exc())
raise
diff --git a/daphne/server.py b/daphne/server.py
index dfae544..133762d 100755
--- a/daphne/server.py
+++ b/daphne/server.py
@@ -2,13 +2,14 @@
import sys # isort:skip
import warnings # isort:skip
from twisted.internet import asyncioreactor # isort:skip
+
current_reactor = sys.modules.get("twisted.internet.reactor", None)
if current_reactor is not None:
if not isinstance(current_reactor, asyncioreactor.AsyncioSelectorReactor):
warnings.warn(
- "Something has already installed a non-asyncio Twisted reactor. Attempting to uninstall it; " +
- "you can fix this warning by importing daphne.server early in your codebase or " +
- "finding the package that imports Twisted and importing it later on.",
+ "Something has already installed a non-asyncio Twisted reactor. Attempting to uninstall it; "
+ + "you can fix this warning by importing daphne.server early in your codebase or "
+ + "finding the package that imports Twisted and importing it later on.",
UserWarning,
)
del sys.modules["twisted.internet.reactor"]
@@ -34,7 +35,6 @@ logger = logging.getLogger(__name__)
class Server(object):
-
def __init__(
self,
application,
@@ -91,11 +91,13 @@ class Server(object):
self.ws_factory.setProtocolOptions(
autoPingTimeout=self.ping_timeout,
allowNullOrigin=True,
- openHandshakeTimeout=self.websocket_handshake_timeout
+ openHandshakeTimeout=self.websocket_handshake_timeout,
)
if self.verbosity <= 1:
# Redirect the Twisted log to nowhere
- globalLogBeginner.beginLoggingTo([lambda _: None], redirectStandardIO=False, discardBuffer=True)
+ globalLogBeginner.beginLoggingTo(
+ [lambda _: None], redirectStandardIO=False, discardBuffer=True
+ )
else:
globalLogBeginner.beginLoggingTo([STDLibLogObserver(__name__)])
@@ -103,7 +105,9 @@ class Server(object):
if http.H2_ENABLED:
logger.info("HTTP/2 support enabled")
else:
- logger.info("HTTP/2 support not enabled (install the http2 and tls Twisted extras)")
+ logger.info(
+ "HTTP/2 support not enabled (install the http2 and tls Twisted extras)"
+ )
# Kick off the timeout loop
reactor.callLater(1, self.application_checker)
@@ -141,7 +145,11 @@ class Server(object):
host = port.getHost()
if hasattr(host, "host") and hasattr(host, "port"):
self.listening_addresses.append((host.host, host.port))
- logger.info("Listening on TCP address %s:%s", port.getHost().host, port.getHost().port)
+ logger.info(
+ "Listening on TCP address %s:%s",
+ port.getHost().host,
+ port.getHost().port,
+ )
def listen_error(self, failure):
logger.critical("Listen failure: %s", failure.getErrorMessage())
@@ -187,10 +195,13 @@ class Server(object):
# Run it, and stash the future for later checking
if protocol not in self.connections:
return None
- self.connections[protocol]["application_instance"] = asyncio.ensure_future(application_instance(
- receive=input_queue.get,
- send=lambda message: self.handle_reply(protocol, message),
- ), loop=asyncio.get_event_loop())
+ self.connections[protocol]["application_instance"] = asyncio.ensure_future(
+ application_instance(
+ receive=input_queue.get,
+ send=lambda message: self.handle_reply(protocol, message),
+ ),
+ loop=asyncio.get_event_loop(),
+ )
return input_queue
async def handle_reply(self, protocol, message):
@@ -215,7 +226,10 @@ class Server(object):
application_instance = details.get("application_instance", None)
# First, see if the protocol disconnected and the app has taken
# too long to close up
- if disconnected and time.time() - disconnected > self.application_close_timeout:
+ if (
+ disconnected
+ and time.time() - disconnected > self.application_close_timeout
+ ):
if application_instance and not application_instance.done():
logger.warning(
"Application instance %r for connection %s took too long to shut down and was killed.",
@@ -238,14 +252,11 @@ class Server(object):
else:
exception_output = "{}\n{}{}".format(
exception,
- "".join(traceback.format_tb(
- exception.__traceback__,
- )),
+ "".join(traceback.format_tb(exception.__traceback__)),
" {}".format(exception),
)
logger.error(
- "Exception inside application: %s",
- exception_output,
+ "Exception inside application: %s", exception_output
)
if not disconnected:
protocol.handle_exception(exception)
diff --git a/daphne/testing.py b/daphne/testing.py
index e606952..f5f3724 100644
--- a/daphne/testing.py
+++ b/daphne/testing.py
@@ -100,9 +100,7 @@ class DaphneTestingInstance:
Adds messages for the application to send back.
The next time it receives an incoming message, it will reply with these.
"""
- TestApplication.save_setup(
- response_messages=messages,
- )
+ TestApplication.save_setup(response_messages=messages)
class DaphneProcess(multiprocessing.Process):
@@ -193,12 +191,7 @@ class TestApplication:
Stores setup information.
"""
with open(cls.setup_storage, "wb") as fh:
- pickle.dump(
- {
- "response_messages": response_messages,
- },
- fh,
- )
+ pickle.dump({"response_messages": response_messages}, fh)
@classmethod
def load_setup(cls):
@@ -218,13 +211,7 @@ class TestApplication:
We could use pickle here, but that seems wrong, still, somehow.
"""
with open(cls.result_storage, "wb") as fh:
- pickle.dump(
- {
- "scope": scope,
- "messages": messages,
- },
- fh,
- )
+ pickle.dump({"scope": scope, "messages": messages}, fh)
@classmethod
def save_exception(cls, exception):
@@ -233,12 +220,7 @@ class TestApplication:
We could use pickle here, but that seems wrong, still, somehow.
"""
with open(cls.result_storage, "wb") as fh:
- pickle.dump(
- {
- "exception": exception,
- },
- fh,
- )
+ pickle.dump({"exception": exception}, fh)
@classmethod
def load_result(cls):
diff --git a/daphne/utils.py b/daphne/utils.py
index ad64439..81f1f9d 100644
--- a/daphne/utils.py
+++ b/daphne/utils.py
@@ -22,12 +22,14 @@ def header_value(headers, header_name):
return value.decode("utf-8")
-def parse_x_forwarded_for(headers,
- address_header_name="X-Forwarded-For",
- port_header_name="X-Forwarded-Port",
- proto_header_name="X-Forwarded-Proto",
- original_addr=None,
- original_scheme=None):
+def parse_x_forwarded_for(
+ headers,
+ address_header_name="X-Forwarded-For",
+ port_header_name="X-Forwarded-Port",
+ proto_header_name="X-Forwarded-Proto",
+ original_addr=None,
+ original_scheme=None,
+):
"""
Parses an X-Forwarded-For header and returns a host/port pair as a list.
diff --git a/daphne/ws_protocol.py b/daphne/ws_protocol.py
index f0b7bda..edf1254 100755
--- a/daphne/ws_protocol.py
+++ b/daphne/ws_protocol.py
@@ -3,7 +3,11 @@ import time
import traceback
from urllib.parse import unquote
-from autobahn.twisted.websocket import ConnectionDeny, WebSocketServerFactory, WebSocketServerProtocol
+from autobahn.twisted.websocket import (
+ ConnectionDeny,
+ WebSocketServerFactory,
+ WebSocketServerProtocol,
+)
from twisted.internet import defer
from .utils import parse_x_forwarded_for
@@ -54,32 +58,34 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.server.proxy_forwarded_address_header,
self.server.proxy_forwarded_port_header,
self.server.proxy_forwarded_proto_header,
- self.client_addr
+ self.client_addr,
)
# Decode websocket subprotocol options
subprotocols = []
for header, value in self.clean_headers:
if header == b"sec-websocket-protocol":
subprotocols = [
- x.strip()
- for x in
- unquote(value.decode("ascii")).split(",")
+ x.strip() for x in unquote(value.decode("ascii")).split(",")
]
# Make new application instance with scope
self.path = request.path.encode("ascii")
- self.application_deferred = defer.maybeDeferred(self.server.create_application, self, {
- "type": "websocket",
- "path": unquote(self.path.decode("ascii")),
- "headers": self.clean_headers,
- "query_string": self._raw_query_string, # Passed by HTTP protocol
- "client": self.client_addr,
- "server": self.server_addr,
- "subprotocols": subprotocols,
- })
+ self.application_deferred = defer.maybeDeferred(
+ self.server.create_application,
+ self,
+ {
+ "type": "websocket",
+ "path": unquote(self.path.decode("ascii")),
+ "headers": self.clean_headers,
+ "query_string": self._raw_query_string, # Passed by HTTP protocol
+ "client": self.client_addr,
+ "server": self.server_addr,
+ "subprotocols": subprotocols,
+ },
+ )
if self.application_deferred is not None:
self.application_deferred.addCallback(self.applicationCreateWorked)
self.application_deferred.addErrback(self.applicationCreateFailed)
- except Exception as e:
+ except Exception:
# Exceptions here are not displayed right, just 500.
# Turn them into an ERROR log.
logger.error(traceback.format_exc())
@@ -98,10 +104,16 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.application_queue = application_queue
# Send over the connect message
self.application_queue.put_nowait({"type": "websocket.connect"})
- self.server.log_action("websocket", "connecting", {
- "path": self.request.path,
- "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
- })
+ self.server.log_action(
+ "websocket",
+ "connecting",
+ {
+ "path": self.request.path,
+ "client": "%s:%s" % tuple(self.client_addr)
+ if self.client_addr
+ else None,
+ },
+ )
def applicationCreateFailed(self, failure):
"""
@@ -115,10 +127,16 @@ class WebSocketProtocol(WebSocketServerProtocol):
def onOpen(self):
# Send news that this channel is open
logger.debug("WebSocket %s open and established", self.client_addr)
- self.server.log_action("websocket", "connected", {
- "path": self.request.path,
- "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
- })
+ self.server.log_action(
+ "websocket",
+ "connected",
+ {
+ "path": self.request.path,
+ "client": "%s:%s" % tuple(self.client_addr)
+ if self.client_addr
+ else None,
+ },
+ )
def onMessage(self, payload, isBinary):
# If we're muted, do nothing.
@@ -128,15 +146,13 @@ class WebSocketProtocol(WebSocketServerProtocol):
logger.debug("WebSocket incoming frame on %s", self.client_addr)
self.last_ping = time.time()
if isBinary:
- self.application_queue.put_nowait({
- "type": "websocket.receive",
- "bytes": payload,
- })
+ self.application_queue.put_nowait(
+ {"type": "websocket.receive", "bytes": payload}
+ )
else:
- self.application_queue.put_nowait({
- "type": "websocket.receive",
- "text": payload.decode("utf8"),
- })
+ self.application_queue.put_nowait(
+ {"type": "websocket.receive", "text": payload.decode("utf8")}
+ )
def onClose(self, wasClean, code, reason):
"""
@@ -145,14 +161,19 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.server.protocol_disconnected(self)
logger.debug("WebSocket closed for %s", self.client_addr)
if not self.muted and hasattr(self, "application_queue"):
- self.application_queue.put_nowait({
- "type": "websocket.disconnect",
- "code": code,
- })
- self.server.log_action("websocket", "disconnected", {
- "path": self.request.path,
- "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
- })
+ self.application_queue.put_nowait(
+ {"type": "websocket.disconnect", "code": code}
+ )
+ self.server.log_action(
+ "websocket",
+ "disconnected",
+ {
+ "path": self.request.path,
+ "client": "%s:%s" % tuple(self.client_addr)
+ if self.client_addr
+ else None,
+ },
+ )
### Internal event handling
@@ -171,9 +192,8 @@ class WebSocketProtocol(WebSocketServerProtocol):
raise ValueError("Socket has not been accepted, so cannot send over it")
if message.get("bytes", None) and message.get("text", None):
raise ValueError(
- "Got invalid WebSocket reply message on %s - contains both bytes and text keys" % (
- message,
- )
+ "Got invalid WebSocket reply message on %s - contains both bytes and text keys"
+ % (message,)
)
if message.get("bytes", None):
self.serverSend(message["bytes"], True)
@@ -187,7 +207,9 @@ class WebSocketProtocol(WebSocketServerProtocol):
if hasattr(self, "handshake_deferred"):
# If the handshake is still ongoing, we need to emit a HTTP error
# code rather than a WebSocket one.
- self.handshake_deferred.errback(ConnectionDeny(code=500, reason="Internal server error"))
+ self.handshake_deferred.errback(
+ ConnectionDeny(code=500, reason="Internal server error")
+ )
else:
self.sendCloseFrame(code=1011)
@@ -203,14 +225,22 @@ class WebSocketProtocol(WebSocketServerProtocol):
"""
Called when we get a message saying to reject the connection.
"""
- self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied"))
+ self.handshake_deferred.errback(
+ ConnectionDeny(code=403, reason="Access denied")
+ )
del self.handshake_deferred
self.server.protocol_disconnected(self)
logger.debug("WebSocket %s rejected by application", self.client_addr)
- self.server.log_action("websocket", "rejected", {
- "path": self.request.path,
- "client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
- })
+ self.server.log_action(
+ "websocket",
+ "rejected",
+ {
+ "path": self.request.path,
+ "client": "%s:%s" % tuple(self.client_addr)
+ if self.client_addr
+ else None,
+ },
+ )
def serverSend(self, content, binary=False):
"""
@@ -244,7 +274,10 @@ class WebSocketProtocol(WebSocketServerProtocol):
Called periodically to see if we should timeout something
"""
# Web timeout checking
- if self.duration() > self.server.websocket_timeout and self.server.websocket_timeout >= 0:
+ if (
+ self.duration() > self.server.websocket_timeout
+ and self.server.websocket_timeout >= 0
+ ):
self.serverClose()
# Ping check
# If we're still connecting, deny the connection
@@ -287,6 +320,6 @@ class WebSocketFactory(WebSocketServerFactory):
protocol = super(WebSocketFactory, self).buildProtocol(addr)
protocol.factory = self
return protocol
- except Exception as e:
+ except Exception:
logger.error("Cannot build protocol: %s" % traceback.format_exc())
raise
diff --git a/setup.cfg b/setup.cfg
index 8c9a5f5..46dd04f 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -5,9 +5,9 @@ universal=1
addopts = tests/
[isort]
-line_length = 120
+include_trailing_comma = True
multi_line_output = 3
-known_first_party = channels,daphne,asgiref
+known_first_party = channels,daphne,asgiref,channels_redis
[flake8]
exclude = venv/*,tox/*,docs/*,testproject/*,js_client/*,.eggs/*
diff --git a/setup.py b/setup.py
index 692ab6e..ff98e27 100755
--- a/setup.py
+++ b/setup.py
@@ -22,23 +22,12 @@ setup(
package_dir={"twisted": "daphne/twisted"},
packages=find_packages() + ["twisted.plugins"],
include_package_data=True,
- install_requires=[
- "twisted>=18.7",
- "autobahn>=0.18",
- ],
- setup_requires=[
- "pytest-runner",
- ],
- extras_require={
- "tests": [
- "hypothesis",
- "pytest",
- "pytest-asyncio~=0.8",
- ],
+ install_requires=["twisted>=18.7", "autobahn>=0.18"],
+ setup_requires=["pytest-runner"],
+ extras_require={"tests": ["hypothesis", "pytest", "pytest-asyncio~=0.8"]},
+ entry_points={
+ "console_scripts": ["daphne = daphne.cli:CommandLineInterface.entrypoint"]
},
- entry_points={"console_scripts": [
- "daphne = daphne.cli:CommandLineInterface.entrypoint",
- ]},
classifiers=[
"Development Status :: 4 - Beta",
"Environment :: Web Environment",
diff --git a/tests/http_base.py b/tests/http_base.py
index c7cec48..866a066 100644
--- a/tests/http_base.py
+++ b/tests/http_base.py
@@ -19,7 +19,9 @@ class DaphneTestCase(unittest.TestCase):
### Plain HTTP helpers
- def run_daphne_http(self, method, path, params, body, responses, headers=None, timeout=1, xff=False):
+ def run_daphne_http(
+ self, method, path, params, body, responses, headers=None, timeout=1, xff=False
+ ):
"""
Runs Daphne with the given request callback (given the base URL)
and response messages.
@@ -38,7 +40,9 @@ class DaphneTestCase(unittest.TestCase):
# Manually send over headers (encoding any non-safe values as best we can)
if headers:
for header_name, header_value in headers:
- conn.putheader(header_name.encode("utf8"), header_value.encode("utf8"))
+ conn.putheader(
+ header_name.encode("utf8"), header_value.encode("utf8")
+ )
# Send body if provided.
if body:
conn.putheader("Content-Length", str(len(body)))
@@ -50,9 +54,11 @@ class DaphneTestCase(unittest.TestCase):
except socket.timeout:
# See if they left an exception for us to load
test_app.get_received()
- raise RuntimeError("Daphne timed out handling request, no exception found.")
+ raise RuntimeError(
+ "Daphne timed out handling request, no exception found."
+ )
# Return scope, messages, response
- return test_app.get_received() + (response, )
+ return test_app.get_received() + (response,)
def run_daphne_raw(self, data, timeout=1):
"""
@@ -68,9 +74,13 @@ class DaphneTestCase(unittest.TestCase):
try:
return s.recv(1000000)
except socket.timeout:
- raise RuntimeError("Daphne timed out handling raw request, no exception found.")
+ raise RuntimeError(
+ "Daphne timed out handling raw request, no exception found."
+ )
- def run_daphne_request(self, method, path, params=None, body=None, headers=None, xff=False):
+ def run_daphne_request(
+ self, method, path, params=None, body=None, headers=None, xff=False
+ ):
"""
Convenience method for just testing request handling.
Returns (scope, messages)
@@ -95,17 +105,21 @@ class DaphneTestCase(unittest.TestCase):
Returns (scope, messages)
"""
_, _, response = self.run_daphne_http(
- method="GET",
- path="/",
- params={},
- body=b"",
- responses=response_messages,
+ method="GET", path="/", params={}, body=b"", responses=response_messages
)
return response
### WebSocket helpers
- def websocket_handshake(self, test_app, path="/", params=None, headers=None, subprotocols=None, timeout=1):
+ def websocket_handshake(
+ self,
+ test_app,
+ path="/",
+ params=None,
+ headers=None,
+ subprotocols=None,
+ timeout=1,
+ ):
"""
Runs a WebSocket handshake negotiation and returns the raw socket
object & the selected subprotocol.
@@ -124,14 +138,16 @@ class DaphneTestCase(unittest.TestCase):
# Do WebSocket handshake headers + any other headers
if headers is None:
headers = []
- headers.extend([
- ("Host", "example.com"),
- ("Upgrade", "websocket"),
- ("Connection", "Upgrade"),
- ("Sec-WebSocket-Key", "x3JJHMbDL1EzLkh9GBhXDw=="),
- ("Sec-WebSocket-Version", "13"),
- ("Origin", "http://example.com")
- ])
+ headers.extend(
+ [
+ ("Host", "example.com"),
+ ("Upgrade", "websocket"),
+ ("Connection", "Upgrade"),
+ ("Sec-WebSocket-Key", "x3JJHMbDL1EzLkh9GBhXDw=="),
+ ("Sec-WebSocket-Version", "13"),
+ ("Origin", "http://example.com"),
+ ]
+ )
if subprotocols:
headers.append(("Sec-WebSocket-Protocol", ", ".join(subprotocols)))
if headers:
@@ -149,10 +165,7 @@ class DaphneTestCase(unittest.TestCase):
if response.status != 101:
raise RuntimeError("WebSocket upgrade did not result in status code 101")
# Prepare headers for subprotocol searching
- response_headers = dict(
- (n.lower(), v)
- for n, v in response.getheaders()
- )
+ response_headers = dict((n.lower(), v) for n, v in response.getheaders())
response.read()
assert not response.closed
# Return the raw socket and any subprotocol
@@ -234,10 +247,7 @@ class DaphneTestCase(unittest.TestCase):
# Make sure all required keys are present
self.assertTrue(required_keys <= present_keys)
# Assert that no other keys are present
- self.assertEqual(
- set(),
- present_keys - required_keys - optional_keys,
- )
+ self.assertEqual(set(), present_keys - required_keys - optional_keys)
def assert_valid_path(self, path, request_path):
"""
diff --git a/tests/http_strategies.py b/tests/http_strategies.py
index 4335afd..d78ac10 100644
--- a/tests/http_strategies.py
+++ b/tests/http_strategies.py
@@ -6,7 +6,9 @@ from hypothesis import strategies
HTTP_METHODS = ["OPTIONS", "GET", "HEAD", "POST", "PUT", "DELETE", "TRACE", "CONNECT"]
# Unicode characters of the "Letter" category
-letters = strategies.characters(whitelist_categories=("Lu", "Ll", "Lt", "Lm", "Lo", "Nl"))
+letters = strategies.characters(
+ whitelist_categories=("Lu", "Ll", "Lt", "Lm", "Lo", "Nl")
+)
def http_method():
@@ -22,11 +24,9 @@ def http_path():
"""
Returns a URL path (not encoded).
"""
- return strategies.lists(
- _http_path_portion(),
- min_size=0,
- max_size=10,
- ).map(lambda s: "/" + "/".join(s))
+ return strategies.lists(_http_path_portion(), min_size=0, max_size=10).map(
+ lambda s: "/" + "/".join(s)
+ )
def http_body():
@@ -53,10 +53,7 @@ def valid_bidi(value):
def _domain_label():
return strategies.text(
- alphabet=letters,
- min_size=1,
- average_size=6,
- max_size=63,
+ alphabet=letters, min_size=1, average_size=6, max_size=63
).filter(valid_bidi)
@@ -64,19 +61,14 @@ def international_domain_name():
"""
Returns a byte string of a domain name, IDNA-encoded.
"""
- return strategies.lists(
- _domain_label(),
- min_size=2,
- average_size=2,
- ).map(lambda s: (".".join(s)).encode("idna"))
+ return strategies.lists(_domain_label(), min_size=2, average_size=2).map(
+ lambda s: (".".join(s)).encode("idna")
+ )
def _query_param():
return strategies.text(
- alphabet=letters,
- min_size=1,
- average_size=10,
- max_size=255,
+ alphabet=letters, min_size=1, average_size=10, max_size=255
).map(lambda s: s.encode("utf8"))
@@ -87,9 +79,7 @@ def query_params():
ensures that the total urlencoded query string is not longer than 1500 characters.
"""
return strategies.lists(
- strategies.tuples(_query_param(), _query_param()),
- min_size=0,
- average_size=5,
+ strategies.tuples(_query_param(), _query_param()), min_size=0, average_size=5
).filter(lambda x: len(parse.urlencode(x)) < 1500)
@@ -101,9 +91,7 @@ def header_name():
and 20 characters long
"""
return strategies.text(
- alphabet=string.ascii_letters + string.digits + "-",
- min_size=1,
- max_size=30,
+ alphabet=string.ascii_letters + string.digits + "-", min_size=1, max_size=30
)
@@ -115,7 +103,10 @@ def header_value():
https://en.wikipedia.org/wiki/List_of_HTTP_header_fields
"""
return strategies.text(
- alphabet=string.ascii_letters + string.digits + string.punctuation.replace(",", "") + " /t",
+ alphabet=string.ascii_letters
+ + string.digits
+ + string.punctuation.replace(",", "")
+ + " /t",
min_size=1,
average_size=40,
max_size=8190,
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 9e90ab0..7bb45dc 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -18,45 +18,32 @@ class TestEndpointDescriptions(TestCase):
def testTcpPortBindings(self):
self.assertEqual(
build(port=1234, host="example.com"),
- ["tcp:port=1234:interface=example.com"]
+ ["tcp:port=1234:interface=example.com"],
)
self.assertEqual(
- build(port=8000, host="127.0.0.1"),
- ["tcp:port=8000:interface=127.0.0.1"]
+ build(port=8000, host="127.0.0.1"), ["tcp:port=8000:interface=127.0.0.1"]
)
self.assertEqual(
- build(port=8000, host="[200a::1]"),
- [r'tcp:port=8000:interface=200a\:\:1']
+ build(port=8000, host="[200a::1]"), [r"tcp:port=8000:interface=200a\:\:1"]
)
self.assertEqual(
- build(port=8000, host="200a::1"),
- [r'tcp:port=8000:interface=200a\:\:1']
+ build(port=8000, host="200a::1"), [r"tcp:port=8000:interface=200a\:\:1"]
)
# incomplete port/host kwargs raise errors
- self.assertRaises(
- ValueError,
- build, port=123
- )
- self.assertRaises(
- ValueError,
- build, host="example.com"
- )
+ self.assertRaises(ValueError, build, port=123)
+ self.assertRaises(ValueError, build, host="example.com")
def testUnixSocketBinding(self):
self.assertEqual(
- build(unix_socket="/tmp/daphne.sock"),
- ["unix:/tmp/daphne.sock"]
+ build(unix_socket="/tmp/daphne.sock"), ["unix:/tmp/daphne.sock"]
)
def testFileDescriptorBinding(self):
- self.assertEqual(
- build(file_descriptor=5),
- ["fd:fileno=5"]
- )
+ self.assertEqual(build(file_descriptor=5), ["fd:fileno=5"])
def testMultipleEnpoints(self):
self.assertEqual(
@@ -65,14 +52,16 @@ class TestEndpointDescriptions(TestCase):
file_descriptor=123,
unix_socket="/tmp/daphne.sock",
port=8080,
- host="10.0.0.1"
+ host="10.0.0.1",
)
),
- sorted([
- "tcp:port=8080:interface=10.0.0.1",
- "unix:/tmp/daphne.sock",
- "fd:fileno=123"
- ])
+ sorted(
+ [
+ "tcp:port=8080:interface=10.0.0.1",
+ "unix:/tmp/daphne.sock",
+ "fd:fileno=123",
+ ]
+ ),
)
@@ -112,7 +101,9 @@ class TestCLIInterface(TestCase):
Passes in a fake application automatically.
"""
cli = self.TestedCLI()
- cli.run(args + ["daphne:__version__"]) # We just pass something importable as app
+ cli.run(
+ args + ["daphne:__version__"]
+ ) # We just pass something importable as app
# Check the server got all arguments as intended
for key, value in server_kwargs.items():
# Get the value and sort it if it's a list (for endpoint checking)
@@ -123,52 +114,30 @@ class TestCLIInterface(TestCase):
self.assertEqual(
value,
actual_value,
- "Wrong value for server kwarg %s: %r != %r" % (
- key,
- value,
- actual_value,
- ),
+ "Wrong value for server kwarg %s: %r != %r"
+ % (key, value, actual_value),
)
def testCLIBasics(self):
"""
Tests basic endpoint generation.
"""
+ self.assertCLI([], {"endpoints": ["tcp:port=8000:interface=127.0.0.1"]})
self.assertCLI(
- [],
- {
- "endpoints": ["tcp:port=8000:interface=127.0.0.1"],
- },
+ ["-p", "123"], {"endpoints": ["tcp:port=123:interface=127.0.0.1"]}
)
self.assertCLI(
- ["-p", "123"],
- {
- "endpoints": ["tcp:port=123:interface=127.0.0.1"],
- },
+ ["-b", "10.0.0.1"], {"endpoints": ["tcp:port=8000:interface=10.0.0.1"]}
)
self.assertCLI(
- ["-b", "10.0.0.1"],
- {
- "endpoints": ["tcp:port=8000:interface=10.0.0.1"],
- },
+ ["-b", "200a::1"], {"endpoints": [r"tcp:port=8000:interface=200a\:\:1"]}
)
self.assertCLI(
- ["-b", "200a::1"],
- {
- "endpoints": [r'tcp:port=8000:interface=200a\:\:1'],
- },
- )
- self.assertCLI(
- ["-b", "[200a::1]"],
- {
- "endpoints": [r'tcp:port=8000:interface=200a\:\:1'],
- },
+ ["-b", "[200a::1]"], {"endpoints": [r"tcp:port=8000:interface=200a\:\:1"]}
)
self.assertCLI(
["-p", "8080", "-b", "example.com"],
- {
- "endpoints": ["tcp:port=8080:interface=example.com"],
- },
+ {"endpoints": ["tcp:port=8080:interface=example.com"]},
)
def testUnixSockets(self):
@@ -178,7 +147,7 @@ class TestCLIInterface(TestCase):
"endpoints": [
"tcp:port=8080:interface=127.0.0.1",
"unix:/tmp/daphne.sock",
- ],
+ ]
},
)
self.assertCLI(
@@ -187,17 +156,12 @@ class TestCLIInterface(TestCase):
"endpoints": [
"tcp:port=8000:interface=example.com",
"unix:/tmp/daphne.sock",
- ],
+ ]
},
)
self.assertCLI(
["-u", "/tmp/daphne.sock", "--fd", "5"],
- {
- "endpoints": [
- "fd:fileno=5",
- "unix:/tmp/daphne.sock"
- ],
- },
+ {"endpoints": ["fd:fileno=5", "unix:/tmp/daphne.sock"]},
)
def testMixedCLIEndpointCreation(self):
@@ -209,8 +173,8 @@ class TestCLIInterface(TestCase):
{
"endpoints": [
"tcp:port=8080:interface=127.0.0.1",
- "unix:/tmp/daphne.sock"
- ],
+ "unix:/tmp/daphne.sock",
+ ]
},
)
self.assertCLI(
@@ -219,7 +183,7 @@ class TestCLIInterface(TestCase):
"endpoints": [
"tcp:port=8080:interface=127.0.0.1",
"tcp:port=8080:interface=127.0.0.1",
- ],
+ ]
},
)
@@ -227,11 +191,4 @@ class TestCLIInterface(TestCase):
"""
Tests entirely custom endpoints
"""
- self.assertCLI(
- ["-e", "imap:"],
- {
- "endpoints": [
- "imap:",
- ],
- },
- )
+ self.assertCLI(["-e", "imap:"], {"endpoints": ["imap:"]})
diff --git a/tests/test_http_request.py b/tests/test_http_request.py
index 79274ed..e02b8b6 100644
--- a/tests/test_http_request.py
+++ b/tests/test_http_request.py
@@ -15,13 +15,7 @@ class TestHTTPRequest(DaphneTestCase):
"""
def assert_valid_http_scope(
- self,
- scope,
- method,
- path,
- params=None,
- headers=None,
- scheme=None,
+ self, scope, method, path, params=None, headers=None, scheme=None
):
"""
Checks that the passed scope is a valid ASGI HTTP scope regarding types
@@ -29,7 +23,14 @@ class TestHTTPRequest(DaphneTestCase):
"""
# Check overall keys
self.assert_key_sets(
- required_keys={"type", "http_version", "method", "path", "query_string", "headers"},
+ required_keys={
+ "type",
+ "http_version",
+ "method",
+ "path",
+ "query_string",
+ "headers",
+ },
optional_keys={"scheme", "root_path", "client", "server"},
actual_keys=scope.keys(),
)
@@ -50,7 +51,9 @@ class TestHTTPRequest(DaphneTestCase):
query_string = scope["query_string"]
self.assertIsInstance(query_string, bytes)
if params:
- self.assertEqual(query_string, parse.urlencode(params or []).encode("ascii"))
+ self.assertEqual(
+ query_string, parse.urlencode(params or []).encode("ascii")
+ )
# Ordering of header names is not important, but the order of values for a header
# name is. To assert whether that order is kept, we transform both the request
# headers and the channel message headers into a dictionary
@@ -59,7 +62,7 @@ class TestHTTPRequest(DaphneTestCase):
for name, value in scope["headers"]:
transformed_scope_headers[name].append(value)
transformed_request_headers = collections.defaultdict(list)
- for name, value in (headers or []):
+ for name, value in headers or []:
expected_name = name.lower().strip().encode("ascii")
expected_value = value.strip().encode("ascii")
transformed_request_headers[expected_name].append(expected_value)
@@ -103,27 +106,31 @@ class TestHTTPRequest(DaphneTestCase):
@given(
request_path=http_strategies.http_path(),
- request_params=http_strategies.query_params()
+ request_params=http_strategies.query_params(),
)
@settings(max_examples=5, deadline=5000)
def test_get_request(self, request_path, request_params):
"""
Tests a typical HTTP GET request, with a path and query parameters
"""
- scope, messages = self.run_daphne_request("GET", request_path, params=request_params)
+ scope, messages = self.run_daphne_request(
+ "GET", request_path, params=request_params
+ )
self.assert_valid_http_scope(scope, "GET", request_path, params=request_params)
self.assert_valid_http_request_message(messages[0], body=b"")
@given(
request_path=http_strategies.http_path(),
- request_body=http_strategies.http_body()
+ request_body=http_strategies.http_body(),
)
@settings(max_examples=5, deadline=5000)
def test_post_request(self, request_path, request_body):
"""
Tests a typical HTTP POST request, with a path and body.
"""
- scope, messages = self.run_daphne_request("POST", request_path, body=request_body)
+ scope, messages = self.run_daphne_request(
+ "POST", request_path, body=request_body
+ )
self.assert_valid_http_scope(scope, "POST", request_path)
self.assert_valid_http_request_message(messages[0], body=request_body)
@@ -134,8 +141,12 @@ class TestHTTPRequest(DaphneTestCase):
Tests that HTTP header fields are handled as specified
"""
request_path = "/te st-à/"
- scope, messages = self.run_daphne_request("OPTIONS", request_path, headers=request_headers)
- self.assert_valid_http_scope(scope, "OPTIONS", request_path, headers=request_headers)
+ scope, messages = self.run_daphne_request(
+ "OPTIONS", request_path, headers=request_headers
+ )
+ self.assert_valid_http_scope(
+ scope, "OPTIONS", request_path, headers=request_headers
+ )
self.assert_valid_http_request_message(messages[0], body=b"")
@given(request_headers=http_strategies.headers())
@@ -150,8 +161,12 @@ class TestHTTPRequest(DaphneTestCase):
duplicated_headers = [(header_name, header[1]) for header in request_headers]
# Run the request
request_path = "/te st-à/"
- scope, messages = self.run_daphne_request("OPTIONS", request_path, headers=duplicated_headers)
- self.assert_valid_http_scope(scope, "OPTIONS", request_path, headers=duplicated_headers)
+ scope, messages = self.run_daphne_request(
+ "OPTIONS", request_path, headers=duplicated_headers
+ )
+ self.assert_valid_http_scope(
+ scope, "OPTIONS", request_path, headers=duplicated_headers
+ )
self.assert_valid_http_request_message(messages[0], body=b"")
@given(
@@ -222,10 +237,7 @@ class TestHTTPRequest(DaphneTestCase):
"""
Make sure that, by default, X-Forwarded-For is ignored.
"""
- headers = [
- ["X-Forwarded-For", "10.1.2.3"],
- ["X-Forwarded-Port", "80"],
- ]
+ headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]]
scope, messages = self.run_daphne_request("GET", "/", headers=headers)
self.assert_valid_http_scope(scope, "GET", "/", headers=headers)
self.assert_valid_http_request_message(messages[0], body=b"")
@@ -236,10 +248,7 @@ class TestHTTPRequest(DaphneTestCase):
"""
When X-Forwarded-For is enabled, make sure it is respected.
"""
- headers = [
- ["X-Forwarded-For", "10.1.2.3"],
- ["X-Forwarded-Port", "80"],
- ]
+ headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]]
scope, messages = self.run_daphne_request("GET", "/", headers=headers, xff=True)
self.assert_valid_http_scope(scope, "GET", "/", headers=headers)
self.assert_valid_http_request_message(messages[0], body=b"")
@@ -251,9 +260,7 @@ class TestHTTPRequest(DaphneTestCase):
When X-Forwarded-For is enabled but only the host is passed, make sure
that at least makes it through.
"""
- headers = [
- ["X-Forwarded-For", "10.1.2.3"],
- ]
+ headers = [["X-Forwarded-For", "10.1.2.3"]]
scope, messages = self.run_daphne_request("GET", "/", headers=headers, xff=True)
self.assert_valid_http_scope(scope, "GET", "/", headers=headers)
self.assert_valid_http_request_message(messages[0], body=b"")
@@ -265,8 +272,12 @@ class TestHTTPRequest(DaphneTestCase):
Tests that requests with invalid (non-ASCII) characters fail.
"""
# Bad path
- response = self.run_daphne_raw(b"GET /\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n")
+ response = self.run_daphne_raw(
+ b"GET /\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n"
+ )
self.assertTrue(response.startswith(b"HTTP/1.0 400 Bad Request"))
# Bad querystring
- response = self.run_daphne_raw(b"GET /?\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n")
+ response = self.run_daphne_raw(
+ b"GET /?\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n"
+ )
self.assertTrue(response.startswith(b"HTTP/1.0 400 Bad Request"))
diff --git a/tests/test_http_response.py b/tests/test_http_response.py
index 2efc4ec..3576697 100644
--- a/tests/test_http_response.py
+++ b/tests/test_http_response.py
@@ -15,26 +15,24 @@ class TestHTTPResponse(DaphneTestCase):
"""
Lowercases and sorts headers, and strips transfer-encoding ones.
"""
- return sorted([
- (name.lower(), value.strip())
- for name, value in headers
- if name.lower() != "transfer-encoding"
- ])
+ return sorted(
+ [
+ (name.lower(), value.strip())
+ for name, value in headers
+ if name.lower() != "transfer-encoding"
+ ]
+ )
def test_minimal_response(self):
"""
Smallest viable example. Mostly verifies that our response building works.
"""
- response = self.run_daphne_response([
- {
- "type": "http.response.start",
- "status": 200,
- },
- {
- "type": "http.response.body",
- "body": b"hello world",
- },
- ])
+ response = self.run_daphne_response(
+ [
+ {"type": "http.response.start", "status": 200},
+ {"type": "http.response.body", "body": b"hello world"},
+ ]
+ )
self.assertEqual(response.status, 200)
self.assertEqual(response.read(), b"hello world")
@@ -46,30 +44,23 @@ class TestHTTPResponse(DaphneTestCase):
to make sure it stays required.
"""
with self.assertRaises(ValueError):
- self.run_daphne_response([
- {
- "type": "http.response.start",
- },
- {
- "type": "http.response.body",
- "body": b"hello world",
- },
- ])
+ self.run_daphne_response(
+ [
+ {"type": "http.response.start"},
+ {"type": "http.response.body", "body": b"hello world"},
+ ]
+ )
def test_custom_status_code(self):
"""
Tries a non-default status code.
"""
- response = self.run_daphne_response([
- {
- "type": "http.response.start",
- "status": 201,
- },
- {
- "type": "http.response.body",
- "body": b"i made a thing!",
- },
- ])
+ response = self.run_daphne_response(
+ [
+ {"type": "http.response.start", "status": 201},
+ {"type": "http.response.body", "body": b"i made a thing!"},
+ ]
+ )
self.assertEqual(response.status, 201)
self.assertEqual(response.read(), b"i made a thing!")
@@ -77,21 +68,13 @@ class TestHTTPResponse(DaphneTestCase):
"""
Tries sending a response in multiple parts.
"""
- response = self.run_daphne_response([
- {
- "type": "http.response.start",
- "status": 201,
- },
- {
- "type": "http.response.body",
- "body": b"chunk 1 ",
- "more_body": True,
- },
- {
- "type": "http.response.body",
- "body": b"chunk 2",
- },
- ])
+ response = self.run_daphne_response(
+ [
+ {"type": "http.response.start", "status": 201},
+ {"type": "http.response.body", "body": b"chunk 1 ", "more_body": True},
+ {"type": "http.response.body", "body": b"chunk 2"},
+ ]
+ )
self.assertEqual(response.status, 201)
self.assertEqual(response.read(), b"chunk 1 chunk 2")
@@ -99,25 +82,14 @@ class TestHTTPResponse(DaphneTestCase):
"""
Tries sending a response in multiple parts and an empty end.
"""
- response = self.run_daphne_response([
- {
- "type": "http.response.start",
- "status": 201,
- },
- {
- "type": "http.response.body",
- "body": b"chunk 1 ",
- "more_body": True,
- },
- {
- "type": "http.response.body",
- "body": b"chunk 2",
- "more_body": True,
- },
- {
- "type": "http.response.body",
- },
- ])
+ response = self.run_daphne_response(
+ [
+ {"type": "http.response.start", "status": 201},
+ {"type": "http.response.body", "body": b"chunk 1 ", "more_body": True},
+ {"type": "http.response.body", "body": b"chunk 2", "more_body": True},
+ {"type": "http.response.body"},
+ ]
+ )
self.assertEqual(response.status, 201)
self.assertEqual(response.read(), b"chunk 1 chunk 2")
@@ -127,16 +99,12 @@ class TestHTTPResponse(DaphneTestCase):
"""
Tries body variants.
"""
- response = self.run_daphne_response([
- {
- "type": "http.response.start",
- "status": 200,
- },
- {
- "type": "http.response.body",
- "body": body,
- },
- ])
+ response = self.run_daphne_response(
+ [
+ {"type": "http.response.start", "status": 200},
+ {"type": "http.response.body", "body": body},
+ ]
+ )
self.assertEqual(response.status, 200)
self.assertEqual(response.read(), body)
@@ -144,16 +112,16 @@ class TestHTTPResponse(DaphneTestCase):
@settings(max_examples=5, deadline=5000)
def test_headers(self, headers):
# The ASGI spec requires us to lowercase our header names
- response = self.run_daphne_response([
- {
- "type": "http.response.start",
- "status": 200,
- "headers": self.normalize_headers(headers),
- },
- {
- "type": "http.response.body",
- },
- ])
+ response = self.run_daphne_response(
+ [
+ {
+ "type": "http.response.start",
+ "status": 200,
+ "headers": self.normalize_headers(headers),
+ },
+ {"type": "http.response.body"},
+ ]
+ )
# Check headers in a sensible way. Ignore transfer-encoding.
self.assertEqual(
self.normalize_headers(response.getheaders()),
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 5dada0f..6b04939 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -13,48 +13,35 @@ class TestXForwardedForHttpParsing(TestCase):
"""
def test_basic(self):
- headers = Headers({
- b"X-Forwarded-For": [b"10.1.2.3"],
- b"X-Forwarded-Port": [b"1234"],
- b"X-Forwarded-Proto": [b"https"]
- })
+ headers = Headers(
+ {
+ b"X-Forwarded-For": [b"10.1.2.3"],
+ b"X-Forwarded-Port": [b"1234"],
+ b"X-Forwarded-Proto": [b"https"],
+ }
+ )
result = parse_x_forwarded_for(headers)
self.assertEqual(result, (["10.1.2.3", 1234], "https"))
self.assertIsInstance(result[0][0], str)
self.assertIsInstance(result[1], str)
def test_address_only(self):
- headers = Headers({
- b"X-Forwarded-For": [b"10.1.2.3"],
- })
- self.assertEqual(
- parse_x_forwarded_for(headers),
- (["10.1.2.3", 0], None)
- )
+ headers = Headers({b"X-Forwarded-For": [b"10.1.2.3"]})
+ self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None))
def test_v6_address(self):
- headers = Headers({
- b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"],
- })
- self.assertEqual(
- parse_x_forwarded_for(headers),
- (["1043::a321:0001", 0], None)
- )
+ headers = Headers({b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"]})
+ self.assertEqual(parse_x_forwarded_for(headers), (["1043::a321:0001", 0], None))
def test_multiple_proxys(self):
- headers = Headers({
- b"X-Forwarded-For": [b"10.1.2.3, 10.1.2.4"],
- })
- self.assertEqual(
- parse_x_forwarded_for(headers),
- (["10.1.2.3", 0], None)
- )
+ headers = Headers({b"X-Forwarded-For": [b"10.1.2.3, 10.1.2.4"]})
+ self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None))
def test_original(self):
headers = Headers({})
self.assertEqual(
parse_x_forwarded_for(headers, original_addr=["127.0.0.1", 80]),
- (["127.0.0.1", 80], None)
+ (["127.0.0.1", 80], None),
)
def test_no_original(self):
@@ -73,43 +60,25 @@ class TestXForwardedForWsParsing(TestCase):
b"X-Forwarded-Port": b"1234",
b"X-Forwarded-Proto": b"https",
}
- self.assertEqual(
- parse_x_forwarded_for(headers),
- (["10.1.2.3", 1234], "https")
- )
+ self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 1234], "https"))
def test_address_only(self):
- headers = {
- b"X-Forwarded-For": b"10.1.2.3",
- }
- self.assertEqual(
- parse_x_forwarded_for(headers),
- (["10.1.2.3", 0], None)
- )
+ headers = {b"X-Forwarded-For": b"10.1.2.3"}
+ self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None))
def test_v6_address(self):
- headers = {
- b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"],
- }
- self.assertEqual(
- parse_x_forwarded_for(headers),
- (["1043::a321:0001", 0], None)
- )
+ headers = {b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"]}
+ self.assertEqual(parse_x_forwarded_for(headers), (["1043::a321:0001", 0], None))
def test_multiple_proxies(self):
- headers = {
- b"X-Forwarded-For": b"10.1.2.3, 10.1.2.4",
- }
- self.assertEqual(
- parse_x_forwarded_for(headers),
- (["10.1.2.3", 0], None)
- )
+ headers = {b"X-Forwarded-For": b"10.1.2.3, 10.1.2.4"}
+ self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None))
def test_original(self):
headers = {}
self.assertEqual(
parse_x_forwarded_for(headers, original_addr=["127.0.0.1", 80]),
- (["127.0.0.1", 80], None)
+ (["127.0.0.1", 80], None),
)
def test_no_original(self):
diff --git a/tests/test_websocket.py b/tests/test_websocket.py
index 0ae1a21..80ec21d 100644
--- a/tests/test_websocket.py
+++ b/tests/test_websocket.py
@@ -16,13 +16,7 @@ class TestWebsocket(DaphneTestCase):
"""
def assert_valid_websocket_scope(
- self,
- scope,
- path="/",
- params=None,
- headers=None,
- scheme=None,
- subprotocols=None,
+ self, scope, path="/", params=None, headers=None, scheme=None, subprotocols=None
):
"""
Checks that the passed scope is a valid ASGI HTTP scope regarding types
@@ -46,7 +40,9 @@ class TestWebsocket(DaphneTestCase):
query_string = scope["query_string"]
self.assertIsInstance(query_string, bytes)
if params:
- self.assertEqual(query_string, parse.urlencode(params or []).encode("ascii"))
+ self.assertEqual(
+ query_string, parse.urlencode(params or []).encode("ascii")
+ )
# Ordering of header names is not important, but the order of values for a header
# name is. To assert whether that order is kept, we transform both the request
# headers and the channel message headers into a dictionary
@@ -59,7 +55,7 @@ class TestWebsocket(DaphneTestCase):
if bit.strip():
transformed_scope_headers[name].append(bit.strip())
transformed_request_headers = collections.defaultdict(list)
- for name, value in (headers or []):
+ for name, value in headers or []:
expected_name = name.lower().strip().encode("ascii")
expected_value = value.strip().encode("ascii")
# Make sure to split out any headers collapsed with commas
@@ -92,9 +88,7 @@ class TestWebsocket(DaphneTestCase):
"""
# Check overall keys
self.assert_key_sets(
- required_keys={"type"},
- optional_keys=set(),
- actual_keys=message.keys(),
+ required_keys={"type"}, optional_keys=set(), actual_keys=message.keys()
)
# Check that it is the right type
self.assertEqual(message["type"], "websocket.connect")
@@ -104,11 +98,7 @@ class TestWebsocket(DaphneTestCase):
Tests we can open and accept a socket.
"""
with DaphneTestingInstance() as test_app:
- test_app.add_send_messages([
- {
- "type": "websocket.accept",
- }
- ])
+ test_app.add_send_messages([{"type": "websocket.accept"}])
self.websocket_handshake(test_app)
# Validate the scope and messages we got
scope, messages = test_app.get_received()
@@ -120,11 +110,7 @@ class TestWebsocket(DaphneTestCase):
Tests we can reject a socket and it won't complete the handshake.
"""
with DaphneTestingInstance() as test_app:
- test_app.add_send_messages([
- {
- "type": "websocket.close",
- }
- ])
+ test_app.add_send_messages([{"type": "websocket.close"}])
with self.assertRaises(RuntimeError):
self.websocket_handshake(test_app)
@@ -134,13 +120,12 @@ class TestWebsocket(DaphneTestCase):
"""
subprotocols = ["proto1", "proto2"]
with DaphneTestingInstance() as test_app:
- test_app.add_send_messages([
- {
- "type": "websocket.accept",
- "subprotocol": "proto2",
- }
- ])
- _, subprotocol = self.websocket_handshake(test_app, subprotocols=subprotocols)
+ test_app.add_send_messages(
+ [{"type": "websocket.accept", "subprotocol": "proto2"}]
+ )
+ _, subprotocol = self.websocket_handshake(
+ test_app, subprotocols=subprotocols
+ )
# Validate the scope and messages we got
assert subprotocol == "proto2"
scope, messages = test_app.get_received()
@@ -151,16 +136,9 @@ class TestWebsocket(DaphneTestCase):
"""
Tests that X-Forwarded-For headers get parsed right
"""
- headers = [
- ["X-Forwarded-For", "10.1.2.3"],
- ["X-Forwarded-Port", "80"],
- ]
+ headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]]
with DaphneTestingInstance(xff=True) as test_app:
- test_app.add_send_messages([
- {
- "type": "websocket.accept",
- }
- ])
+ test_app.add_send_messages([{"type": "websocket.accept"}])
self.websocket_handshake(test_app, headers=headers)
# Validate the scope and messages we got
scope, messages = test_app.get_received()
@@ -174,22 +152,13 @@ class TestWebsocket(DaphneTestCase):
request_headers=http_strategies.headers(),
)
@settings(max_examples=5, deadline=2000)
- def test_http_bits(
- self,
- request_path,
- request_params,
- request_headers,
- ):
+ def test_http_bits(self, request_path, request_params, request_headers):
"""
Tests that various HTTP-level bits (query string params, path, headers)
carry over into the scope.
"""
with DaphneTestingInstance() as test_app:
- test_app.add_send_messages([
- {
- "type": "websocket.accept",
- }
- ])
+ test_app.add_send_messages([{"type": "websocket.accept"}])
self.websocket_handshake(
test_app,
path=request_path,
@@ -199,10 +168,7 @@ class TestWebsocket(DaphneTestCase):
# Validate the scope and messages we got
scope, messages = test_app.get_received()
self.assert_valid_websocket_scope(
- scope,
- path=request_path,
- params=request_params,
- headers=request_headers,
+ scope, path=request_path, params=request_params, headers=request_headers
)
self.assert_valid_websocket_connect_message(messages[0])
@@ -212,28 +178,24 @@ class TestWebsocket(DaphneTestCase):
"""
with DaphneTestingInstance() as test_app:
# Connect
- test_app.add_send_messages([
- {
- "type": "websocket.accept",
- }
- ])
+ test_app.add_send_messages([{"type": "websocket.accept"}])
sock, _ = self.websocket_handshake(test_app)
_, messages = test_app.get_received()
self.assert_valid_websocket_connect_message(messages[0])
# Prep frame for it to send
- test_app.add_send_messages([
- {
- "type": "websocket.send",
- "text": "here be dragons 🐉",
- }
- ])
+ test_app.add_send_messages(
+ [{"type": "websocket.send", "text": "here be dragons 🐉"}]
+ )
# Send it a frame
self.websocket_send_frame(sock, "what is here? 🌍")
# Receive a frame and make sure it's correct
assert self.websocket_receive_frame(sock) == "here be dragons 🐉"
# Make sure it got our frame
_, messages = test_app.get_received()
- assert messages[1] == {"type": "websocket.receive", "text": "what is here? 🌍"}
+ assert messages[1] == {
+ "type": "websocket.receive",
+ "text": "what is here? 🌍",
+ }
def test_binary_frames(self):
"""
@@ -242,28 +204,24 @@ class TestWebsocket(DaphneTestCase):
"""
with DaphneTestingInstance() as test_app:
# Connect
- test_app.add_send_messages([
- {
- "type": "websocket.accept",
- }
- ])
+ test_app.add_send_messages([{"type": "websocket.accept"}])
sock, _ = self.websocket_handshake(test_app)
_, messages = test_app.get_received()
self.assert_valid_websocket_connect_message(messages[0])
# Prep frame for it to send
- test_app.add_send_messages([
- {
- "type": "websocket.send",
- "bytes": b"here be \xe2 bytes",
- }
- ])
+ test_app.add_send_messages(
+ [{"type": "websocket.send", "bytes": b"here be \xe2 bytes"}]
+ )
# Send it a frame
self.websocket_send_frame(sock, b"what is here? \xe2")
# Receive a frame and make sure it's correct
assert self.websocket_receive_frame(sock) == b"here be \xe2 bytes"
# Make sure it got our frame
_, messages = test_app.get_received()
- assert messages[1] == {"type": "websocket.receive", "bytes": b"what is here? \xe2"}
+ assert messages[1] == {
+ "type": "websocket.receive",
+ "bytes": b"what is here? \xe2",
+ }
def test_http_timeout(self):
"""
@@ -271,23 +229,14 @@ class TestWebsocket(DaphneTestCase):
"""
with DaphneTestingInstance(http_timeout=1) as test_app:
# Connect
- test_app.add_send_messages([
- {
- "type": "websocket.accept",
- }
- ])
+ test_app.add_send_messages([{"type": "websocket.accept"}])
sock, _ = self.websocket_handshake(test_app)
_, messages = test_app.get_received()
self.assert_valid_websocket_connect_message(messages[0])
# Wait 2 seconds
time.sleep(2)
# Prep frame for it to send
- test_app.add_send_messages([
- {
- "type": "websocket.send",
- "text": "cake",
- }
- ])
+ test_app.add_send_messages([{"type": "websocket.send", "text": "cake"}])
# Send it a frame
self.websocket_send_frame(sock, "still alive?")
# Receive a frame and make sure it's correct