Implement Black code formatting

This commit is contained in:
Andrew Godwin 2018-08-27 12:27:32 +10:00
parent 88792984e7
commit 0ed6294406
18 changed files with 513 additions and 596 deletions

View File

@ -3,22 +3,25 @@ sudo: false
language: python language: python
python: python:
- '3.5'
- '3.6' - '3.6'
- '3.5'
env: env:
- TWISTED="twisted==18.7.0"
- TWISTED="twisted" - TWISTED="twisted"
- TWISTED="twisted==18.7.0"
install: install:
- pip install $TWISTED isort unify flake8 -e .[tests] - pip install $TWISTED -e .[tests]
- pip freeze - pip freeze
script: script:
- pytest - pytest
- flake8
- isort --check-only --diff --recursive daphne tests stages:
- unify --check-only --recursive --quote \" daphne tests - lint
- test
- name: release
if: branch = master
jobs: jobs:
include: include:
@ -30,6 +33,13 @@ jobs:
env: TWISTED="twisted" env: TWISTED="twisted"
dist: xenial dist: xenial
sudo: required sudo: required
- stage: lint
install: pip install -U -e .[tests] black pyflakes isort
script:
- pyflakes .
- black --check .
- isort --check-only --diff --recursive channels_redis tests
- stage: release - stage: release
script: skip script: skip
deploy: deploy:

View File

@ -49,13 +49,16 @@ class AccessLogGenerator(object):
request="WSDISCONNECT %(path)s" % details, request="WSDISCONNECT %(path)s" % details,
) )
def write_entry(self, host, date, request, status=None, length=None, ident=None, user=None): def write_entry(
self, host, date, request, status=None, length=None, ident=None, user=None
):
""" """
Writes an NCSA-style entry to the log file (some liberty is taken with Writes an NCSA-style entry to the log file (some liberty is taken with
what the entries are for non-HTTP) what the entries are for non-HTTP)
""" """
self.stream.write( self.stream.write(
"%s %s %s [%s] \"%s\" %s %s\n" % ( '%s %s %s [%s] "%s" %s %s\n'
% (
host, host,
ident or "-", ident or "-",
user or "-", user or "-",

View File

@ -23,15 +23,9 @@ class CommandLineInterface(object):
server_class = Server server_class = Server
def __init__(self): def __init__(self):
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(description=self.description)
description=self.description,
)
self.parser.add_argument( self.parser.add_argument(
"-p", "-p", "--port", type=int, help="Port number to listen on", default=None
"--port",
type=int,
help="Port number to listen on",
default=None,
) )
self.parser.add_argument( self.parser.add_argument(
"-b", "-b",
@ -128,7 +122,7 @@ class CommandLineInterface(object):
"--proxy-headers", "--proxy-headers",
dest="proxy_headers", dest="proxy_headers",
help="Enable parsing and using of X-Forwarded-For and X-Forwarded-Port headers and using that as the " help="Enable parsing and using of X-Forwarded-For and X-Forwarded-Port headers and using that as the "
"client address", "client address",
default=False, default=False,
action="store_true", action="store_true",
) )
@ -176,7 +170,15 @@ class CommandLineInterface(object):
sys.path.insert(0, ".") sys.path.insert(0, ".")
application = import_by_path(args.application) application = import_by_path(args.application)
# Set up port/host bindings # Set up port/host bindings
if not any([args.host, args.port is not None, args.unix_socket, args.file_descriptor, args.socket_strings]): if not any(
[
args.host,
args.port is not None,
args.unix_socket,
args.file_descriptor,
args.socket_strings,
]
):
# no advanced binding options passed, patch in defaults # no advanced binding options passed, patch in defaults
args.host = DEFAULT_HOST args.host = DEFAULT_HOST
args.port = DEFAULT_PORT args.port = DEFAULT_PORT
@ -189,16 +191,11 @@ class CommandLineInterface(object):
host=args.host, host=args.host,
port=args.port, port=args.port,
unix_socket=args.unix_socket, unix_socket=args.unix_socket,
file_descriptor=args.file_descriptor file_descriptor=args.file_descriptor,
)
endpoints = sorted(
args.socket_strings + endpoints
) )
endpoints = sorted(args.socket_strings + endpoints)
# Start the server # Start the server
logger.info( logger.info("Starting server at %s" % (", ".join(endpoints),))
"Starting server at %s" %
(", ".join(endpoints), )
)
self.server = self.server_class( self.server = self.server_class(
application=application, application=application,
endpoints=endpoints, endpoints=endpoints,
@ -208,12 +205,20 @@ class CommandLineInterface(object):
websocket_timeout=args.websocket_timeout, websocket_timeout=args.websocket_timeout,
websocket_connect_timeout=args.websocket_connect_timeout, websocket_connect_timeout=args.websocket_connect_timeout,
application_close_timeout=args.application_close_timeout, application_close_timeout=args.application_close_timeout,
action_logger=AccessLogGenerator(access_log_stream) if access_log_stream else None, action_logger=AccessLogGenerator(access_log_stream)
if access_log_stream
else None,
ws_protocols=args.ws_protocols, ws_protocols=args.ws_protocols,
root_path=args.root_path, root_path=args.root_path,
verbosity=args.verbosity, verbosity=args.verbosity,
proxy_forwarded_address_header="X-Forwarded-For" if args.proxy_headers else None, proxy_forwarded_address_header="X-Forwarded-For"
proxy_forwarded_port_header="X-Forwarded-Port" if args.proxy_headers else None, if args.proxy_headers
proxy_forwarded_proto_header="X-Forwarded-Proto" if args.proxy_headers else None, else None,
proxy_forwarded_port_header="X-Forwarded-Port"
if args.proxy_headers
else None,
proxy_forwarded_proto_header="X-Forwarded-Proto"
if args.proxy_headers
else None,
) )
self.server.run() self.server.run()

View File

@ -1,10 +1,5 @@
def build_endpoint_description_strings( def build_endpoint_description_strings(
host=None, host=None, port=None, unix_socket=None, file_descriptor=None
port=None,
unix_socket=None,
file_descriptor=None
): ):
""" """
Build a list of twisted endpoint description strings that the server will listen on. Build a list of twisted endpoint description strings that the server will listen on.

View File

@ -23,7 +23,8 @@ class WebRequest(http.Request):
GET and POST out. GET and POST out.
""" """
error_template = """ error_template = (
"""
<html> <html>
<head> <head>
<title>%(title)s</title> <title>%(title)s</title>
@ -40,7 +41,13 @@ class WebRequest(http.Request):
<footer>Daphne</footer> <footer>Daphne</footer>
</body> </body>
</html> </html>
""".replace("\n", "").replace(" ", " ").replace(" ", " ").replace(" ", " ") # Shorten it a bit, bytes wise """.replace(
"\n", ""
)
.replace(" ", " ")
.replace(" ", " ")
.replace(" ", " ")
) # Shorten it a bit, bytes wise
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
try: try:
@ -84,7 +91,7 @@ class WebRequest(http.Request):
self.server.proxy_forwarded_port_header, self.server.proxy_forwarded_port_header,
self.server.proxy_forwarded_proto_header, self.server.proxy_forwarded_proto_header,
self.client_addr, self.client_addr,
self.client_scheme self.client_scheme,
) )
# Check for unicodeish path (or it'll crash when trying to parse) # Check for unicodeish path (or it'll crash when trying to parse)
try: try:
@ -105,7 +112,9 @@ class WebRequest(http.Request):
# Is it WebSocket? IS IT?! # Is it WebSocket? IS IT?!
if upgrade_header and upgrade_header.lower() == b"websocket": if upgrade_header and upgrade_header.lower() == b"websocket":
# Make WebSocket protocol to hand off to # Make WebSocket protocol to hand off to
protocol = self.server.ws_factory.buildProtocol(self.transport.getPeer()) protocol = self.server.ws_factory.buildProtocol(
self.transport.getPeer()
)
if not protocol: if not protocol:
# If protocol creation fails, we signal "internal server error" # If protocol creation fails, we signal "internal server error"
self.setResponseCode(500) self.setResponseCode(500)
@ -151,33 +160,38 @@ class WebRequest(http.Request):
logger.debug("HTTP %s request for %s", self.method, self.client_addr) logger.debug("HTTP %s request for %s", self.method, self.client_addr)
self.content.seek(0, 0) self.content.seek(0, 0)
# Work out the application scope and create application # Work out the application scope and create application
self.application_queue = yield maybeDeferred(self.server.create_application, self, { self.application_queue = yield maybeDeferred(
"type": "http", self.server.create_application,
# TODO: Correctly say if it's 1.1 or 1.0 self,
"http_version": self.clientproto.split(b"/")[-1].decode("ascii"), {
"method": self.method.decode("ascii"), "type": "http",
"path": unquote(self.path.decode("ascii")), # TODO: Correctly say if it's 1.1 or 1.0
"root_path": self.root_path, "http_version": self.clientproto.split(b"/")[-1].decode(
"scheme": self.client_scheme, "ascii"
"query_string": self.query_string, ),
"headers": self.clean_headers, "method": self.method.decode("ascii"),
"client": self.client_addr, "path": unquote(self.path.decode("ascii")),
"server": self.server_addr, "root_path": self.root_path,
}) "scheme": self.client_scheme,
"query_string": self.query_string,
"headers": self.clean_headers,
"client": self.client_addr,
"server": self.server_addr,
},
)
# Check they didn't close an unfinished request # Check they didn't close an unfinished request
if self.application_queue is None or self.content.closed: if self.application_queue is None or self.content.closed:
# Not much we can do, the request is prematurely abandoned. # Not much we can do, the request is prematurely abandoned.
return return
# Run application against request # Run application against request
self.application_queue.put_nowait( self.application_queue.put_nowait(
{ {"type": "http.request", "body": self.content.read()}
"type": "http.request",
"body": self.content.read(),
},
) )
except Exception: except Exception:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
self.basic_error(500, b"Internal Server Error", "Daphne HTTP processing error") self.basic_error(
500, b"Internal Server Error", "Daphne HTTP processing error"
)
def connectionLost(self, reason): def connectionLost(self, reason):
""" """
@ -217,16 +231,23 @@ class WebRequest(http.Request):
raise ValueError("HTTP response has already been started") raise ValueError("HTTP response has already been started")
self._response_started = True self._response_started = True
if "status" not in message: if "status" not in message:
raise ValueError("Specifying a status code is required for a Response message.") raise ValueError(
"Specifying a status code is required for a Response message."
)
# Set HTTP status code # Set HTTP status code
self.setResponseCode(message["status"]) self.setResponseCode(message["status"])
# Write headers # Write headers
for header, value in message.get("headers", {}): for header, value in message.get("headers", {}):
self.responseHeaders.addRawHeader(header, value) self.responseHeaders.addRawHeader(header, value)
logger.debug("HTTP %s response started for %s", message["status"], self.client_addr) logger.debug(
"HTTP %s response started for %s", message["status"], self.client_addr
)
elif message["type"] == "http.response.body": elif message["type"] == "http.response.body":
if not self._response_started: if not self._response_started:
raise ValueError("HTTP response has not yet been started but got %s" % message["type"]) raise ValueError(
"HTTP response has not yet been started but got %s"
% message["type"]
)
# Write out body # Write out body
http.Request.write(self, message.get("body", b"")) http.Request.write(self, message.get("body", b""))
# End if there's no more content # End if there's no more content
@ -239,15 +260,21 @@ class WebRequest(http.Request):
# The path is malformed somehow - do our best to log something # The path is malformed somehow - do our best to log something
uri = repr(self.uri) uri = repr(self.uri)
try: try:
self.server.log_action("http", "complete", { self.server.log_action(
"path": uri, "http",
"status": self.code, "complete",
"method": self.method.decode("ascii", "replace"), {
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, "path": uri,
"time_taken": self.duration(), "status": self.code,
"size": self.sentLength, "method": self.method.decode("ascii", "replace"),
}) "client": "%s:%s" % tuple(self.client_addr)
except Exception as e: if self.client_addr
else None,
"time_taken": self.duration(),
"size": self.sentLength,
},
)
except Exception:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
else: else:
logger.debug("HTTP response chunk for %s", self.client_addr) logger.debug("HTTP response chunk for %s", self.client_addr)
@ -270,7 +297,11 @@ class WebRequest(http.Request):
logger.warning("Application timed out while sending response") logger.warning("Application timed out while sending response")
self.finish() self.finish()
else: else:
self.basic_error(503, b"Service Unavailable", "Application failed to respond within time limit.") self.basic_error(
503,
b"Service Unavailable",
"Application failed to respond within time limit.",
)
### Utility functions ### Utility functions
@ -281,11 +312,7 @@ class WebRequest(http.Request):
""" """
# If we don't yet have a path, then don't send as we never opened. # If we don't yet have a path, then don't send as we never opened.
if self.path: if self.path:
self.application_queue.put_nowait( self.application_queue.put_nowait({"type": "http.disconnect"})
{
"type": "http.disconnect",
},
)
def duration(self): def duration(self):
""" """
@ -299,20 +326,25 @@ class WebRequest(http.Request):
""" """
Responds with a server-level error page (very basic) Responds with a server-level error page (very basic)
""" """
self.handle_reply({ self.handle_reply(
"type": "http.response.start", {
"status": status, "type": "http.response.start",
"headers": [ "status": status,
(b"Content-Type", b"text/html; charset=utf-8"), "headers": [(b"Content-Type", b"text/html; charset=utf-8")],
], }
}) )
self.handle_reply({ self.handle_reply(
"type": "http.response.body", {
"body": (self.error_template % { "type": "http.response.body",
"title": str(status) + " " + status_text.decode("ascii"), "body": (
"body": body, self.error_template
}).encode("utf8"), % {
}) "title": str(status) + " " + status_text.decode("ascii"),
"body": body,
}
).encode("utf8"),
}
)
def __hash__(self): def __hash__(self):
return hash(id(self)) return hash(id(self))
@ -343,7 +375,7 @@ class HTTPFactory(http.HTTPFactory):
protocol = http.HTTPFactory.buildProtocol(self, addr) protocol = http.HTTPFactory.buildProtocol(self, addr)
protocol.requestFactory = WebRequest protocol.requestFactory = WebRequest
return protocol return protocol
except Exception as e: except Exception:
logger.error("Cannot build protocol: %s" % traceback.format_exc()) logger.error("Cannot build protocol: %s" % traceback.format_exc())
raise raise

View File

@ -2,13 +2,14 @@
import sys # isort:skip import sys # isort:skip
import warnings # isort:skip import warnings # isort:skip
from twisted.internet import asyncioreactor # isort:skip from twisted.internet import asyncioreactor # isort:skip
current_reactor = sys.modules.get("twisted.internet.reactor", None) current_reactor = sys.modules.get("twisted.internet.reactor", None)
if current_reactor is not None: if current_reactor is not None:
if not isinstance(current_reactor, asyncioreactor.AsyncioSelectorReactor): if not isinstance(current_reactor, asyncioreactor.AsyncioSelectorReactor):
warnings.warn( warnings.warn(
"Something has already installed a non-asyncio Twisted reactor. Attempting to uninstall it; " + "Something has already installed a non-asyncio Twisted reactor. Attempting to uninstall it; "
"you can fix this warning by importing daphne.server early in your codebase or " + + "you can fix this warning by importing daphne.server early in your codebase or "
"finding the package that imports Twisted and importing it later on.", + "finding the package that imports Twisted and importing it later on.",
UserWarning, UserWarning,
) )
del sys.modules["twisted.internet.reactor"] del sys.modules["twisted.internet.reactor"]
@ -34,7 +35,6 @@ logger = logging.getLogger(__name__)
class Server(object): class Server(object):
def __init__( def __init__(
self, self,
application, application,
@ -91,11 +91,13 @@ class Server(object):
self.ws_factory.setProtocolOptions( self.ws_factory.setProtocolOptions(
autoPingTimeout=self.ping_timeout, autoPingTimeout=self.ping_timeout,
allowNullOrigin=True, allowNullOrigin=True,
openHandshakeTimeout=self.websocket_handshake_timeout openHandshakeTimeout=self.websocket_handshake_timeout,
) )
if self.verbosity <= 1: if self.verbosity <= 1:
# Redirect the Twisted log to nowhere # Redirect the Twisted log to nowhere
globalLogBeginner.beginLoggingTo([lambda _: None], redirectStandardIO=False, discardBuffer=True) globalLogBeginner.beginLoggingTo(
[lambda _: None], redirectStandardIO=False, discardBuffer=True
)
else: else:
globalLogBeginner.beginLoggingTo([STDLibLogObserver(__name__)]) globalLogBeginner.beginLoggingTo([STDLibLogObserver(__name__)])
@ -103,7 +105,9 @@ class Server(object):
if http.H2_ENABLED: if http.H2_ENABLED:
logger.info("HTTP/2 support enabled") logger.info("HTTP/2 support enabled")
else: else:
logger.info("HTTP/2 support not enabled (install the http2 and tls Twisted extras)") logger.info(
"HTTP/2 support not enabled (install the http2 and tls Twisted extras)"
)
# Kick off the timeout loop # Kick off the timeout loop
reactor.callLater(1, self.application_checker) reactor.callLater(1, self.application_checker)
@ -141,7 +145,11 @@ class Server(object):
host = port.getHost() host = port.getHost()
if hasattr(host, "host") and hasattr(host, "port"): if hasattr(host, "host") and hasattr(host, "port"):
self.listening_addresses.append((host.host, host.port)) self.listening_addresses.append((host.host, host.port))
logger.info("Listening on TCP address %s:%s", port.getHost().host, port.getHost().port) logger.info(
"Listening on TCP address %s:%s",
port.getHost().host,
port.getHost().port,
)
def listen_error(self, failure): def listen_error(self, failure):
logger.critical("Listen failure: %s", failure.getErrorMessage()) logger.critical("Listen failure: %s", failure.getErrorMessage())
@ -187,10 +195,13 @@ class Server(object):
# Run it, and stash the future for later checking # Run it, and stash the future for later checking
if protocol not in self.connections: if protocol not in self.connections:
return None return None
self.connections[protocol]["application_instance"] = asyncio.ensure_future(application_instance( self.connections[protocol]["application_instance"] = asyncio.ensure_future(
receive=input_queue.get, application_instance(
send=lambda message: self.handle_reply(protocol, message), receive=input_queue.get,
), loop=asyncio.get_event_loop()) send=lambda message: self.handle_reply(protocol, message),
),
loop=asyncio.get_event_loop(),
)
return input_queue return input_queue
async def handle_reply(self, protocol, message): async def handle_reply(self, protocol, message):
@ -215,7 +226,10 @@ class Server(object):
application_instance = details.get("application_instance", None) application_instance = details.get("application_instance", None)
# First, see if the protocol disconnected and the app has taken # First, see if the protocol disconnected and the app has taken
# too long to close up # too long to close up
if disconnected and time.time() - disconnected > self.application_close_timeout: if (
disconnected
and time.time() - disconnected > self.application_close_timeout
):
if application_instance and not application_instance.done(): if application_instance and not application_instance.done():
logger.warning( logger.warning(
"Application instance %r for connection %s took too long to shut down and was killed.", "Application instance %r for connection %s took too long to shut down and was killed.",
@ -238,14 +252,11 @@ class Server(object):
else: else:
exception_output = "{}\n{}{}".format( exception_output = "{}\n{}{}".format(
exception, exception,
"".join(traceback.format_tb( "".join(traceback.format_tb(exception.__traceback__)),
exception.__traceback__,
)),
" {}".format(exception), " {}".format(exception),
) )
logger.error( logger.error(
"Exception inside application: %s", "Exception inside application: %s", exception_output
exception_output,
) )
if not disconnected: if not disconnected:
protocol.handle_exception(exception) protocol.handle_exception(exception)

View File

@ -100,9 +100,7 @@ class DaphneTestingInstance:
Adds messages for the application to send back. Adds messages for the application to send back.
The next time it receives an incoming message, it will reply with these. The next time it receives an incoming message, it will reply with these.
""" """
TestApplication.save_setup( TestApplication.save_setup(response_messages=messages)
response_messages=messages,
)
class DaphneProcess(multiprocessing.Process): class DaphneProcess(multiprocessing.Process):
@ -193,12 +191,7 @@ class TestApplication:
Stores setup information. Stores setup information.
""" """
with open(cls.setup_storage, "wb") as fh: with open(cls.setup_storage, "wb") as fh:
pickle.dump( pickle.dump({"response_messages": response_messages}, fh)
{
"response_messages": response_messages,
},
fh,
)
@classmethod @classmethod
def load_setup(cls): def load_setup(cls):
@ -218,13 +211,7 @@ class TestApplication:
We could use pickle here, but that seems wrong, still, somehow. We could use pickle here, but that seems wrong, still, somehow.
""" """
with open(cls.result_storage, "wb") as fh: with open(cls.result_storage, "wb") as fh:
pickle.dump( pickle.dump({"scope": scope, "messages": messages}, fh)
{
"scope": scope,
"messages": messages,
},
fh,
)
@classmethod @classmethod
def save_exception(cls, exception): def save_exception(cls, exception):
@ -233,12 +220,7 @@ class TestApplication:
We could use pickle here, but that seems wrong, still, somehow. We could use pickle here, but that seems wrong, still, somehow.
""" """
with open(cls.result_storage, "wb") as fh: with open(cls.result_storage, "wb") as fh:
pickle.dump( pickle.dump({"exception": exception}, fh)
{
"exception": exception,
},
fh,
)
@classmethod @classmethod
def load_result(cls): def load_result(cls):

View File

@ -22,12 +22,14 @@ def header_value(headers, header_name):
return value.decode("utf-8") return value.decode("utf-8")
def parse_x_forwarded_for(headers, def parse_x_forwarded_for(
address_header_name="X-Forwarded-For", headers,
port_header_name="X-Forwarded-Port", address_header_name="X-Forwarded-For",
proto_header_name="X-Forwarded-Proto", port_header_name="X-Forwarded-Port",
original_addr=None, proto_header_name="X-Forwarded-Proto",
original_scheme=None): original_addr=None,
original_scheme=None,
):
""" """
Parses an X-Forwarded-For header and returns a host/port pair as a list. Parses an X-Forwarded-For header and returns a host/port pair as a list.

View File

@ -3,7 +3,11 @@ import time
import traceback import traceback
from urllib.parse import unquote from urllib.parse import unquote
from autobahn.twisted.websocket import ConnectionDeny, WebSocketServerFactory, WebSocketServerProtocol from autobahn.twisted.websocket import (
ConnectionDeny,
WebSocketServerFactory,
WebSocketServerProtocol,
)
from twisted.internet import defer from twisted.internet import defer
from .utils import parse_x_forwarded_for from .utils import parse_x_forwarded_for
@ -54,32 +58,34 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.server.proxy_forwarded_address_header, self.server.proxy_forwarded_address_header,
self.server.proxy_forwarded_port_header, self.server.proxy_forwarded_port_header,
self.server.proxy_forwarded_proto_header, self.server.proxy_forwarded_proto_header,
self.client_addr self.client_addr,
) )
# Decode websocket subprotocol options # Decode websocket subprotocol options
subprotocols = [] subprotocols = []
for header, value in self.clean_headers: for header, value in self.clean_headers:
if header == b"sec-websocket-protocol": if header == b"sec-websocket-protocol":
subprotocols = [ subprotocols = [
x.strip() x.strip() for x in unquote(value.decode("ascii")).split(",")
for x in
unquote(value.decode("ascii")).split(",")
] ]
# Make new application instance with scope # Make new application instance with scope
self.path = request.path.encode("ascii") self.path = request.path.encode("ascii")
self.application_deferred = defer.maybeDeferred(self.server.create_application, self, { self.application_deferred = defer.maybeDeferred(
"type": "websocket", self.server.create_application,
"path": unquote(self.path.decode("ascii")), self,
"headers": self.clean_headers, {
"query_string": self._raw_query_string, # Passed by HTTP protocol "type": "websocket",
"client": self.client_addr, "path": unquote(self.path.decode("ascii")),
"server": self.server_addr, "headers": self.clean_headers,
"subprotocols": subprotocols, "query_string": self._raw_query_string, # Passed by HTTP protocol
}) "client": self.client_addr,
"server": self.server_addr,
"subprotocols": subprotocols,
},
)
if self.application_deferred is not None: if self.application_deferred is not None:
self.application_deferred.addCallback(self.applicationCreateWorked) self.application_deferred.addCallback(self.applicationCreateWorked)
self.application_deferred.addErrback(self.applicationCreateFailed) self.application_deferred.addErrback(self.applicationCreateFailed)
except Exception as e: except Exception:
# Exceptions here are not displayed right, just 500. # Exceptions here are not displayed right, just 500.
# Turn them into an ERROR log. # Turn them into an ERROR log.
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
@ -98,10 +104,16 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.application_queue = application_queue self.application_queue = application_queue
# Send over the connect message # Send over the connect message
self.application_queue.put_nowait({"type": "websocket.connect"}) self.application_queue.put_nowait({"type": "websocket.connect"})
self.server.log_action("websocket", "connecting", { self.server.log_action(
"path": self.request.path, "websocket",
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, "connecting",
}) {
"path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr)
if self.client_addr
else None,
},
)
def applicationCreateFailed(self, failure): def applicationCreateFailed(self, failure):
""" """
@ -115,10 +127,16 @@ class WebSocketProtocol(WebSocketServerProtocol):
def onOpen(self): def onOpen(self):
# Send news that this channel is open # Send news that this channel is open
logger.debug("WebSocket %s open and established", self.client_addr) logger.debug("WebSocket %s open and established", self.client_addr)
self.server.log_action("websocket", "connected", { self.server.log_action(
"path": self.request.path, "websocket",
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, "connected",
}) {
"path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr)
if self.client_addr
else None,
},
)
def onMessage(self, payload, isBinary): def onMessage(self, payload, isBinary):
# If we're muted, do nothing. # If we're muted, do nothing.
@ -128,15 +146,13 @@ class WebSocketProtocol(WebSocketServerProtocol):
logger.debug("WebSocket incoming frame on %s", self.client_addr) logger.debug("WebSocket incoming frame on %s", self.client_addr)
self.last_ping = time.time() self.last_ping = time.time()
if isBinary: if isBinary:
self.application_queue.put_nowait({ self.application_queue.put_nowait(
"type": "websocket.receive", {"type": "websocket.receive", "bytes": payload}
"bytes": payload, )
})
else: else:
self.application_queue.put_nowait({ self.application_queue.put_nowait(
"type": "websocket.receive", {"type": "websocket.receive", "text": payload.decode("utf8")}
"text": payload.decode("utf8"), )
})
def onClose(self, wasClean, code, reason): def onClose(self, wasClean, code, reason):
""" """
@ -145,14 +161,19 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.server.protocol_disconnected(self) self.server.protocol_disconnected(self)
logger.debug("WebSocket closed for %s", self.client_addr) logger.debug("WebSocket closed for %s", self.client_addr)
if not self.muted and hasattr(self, "application_queue"): if not self.muted and hasattr(self, "application_queue"):
self.application_queue.put_nowait({ self.application_queue.put_nowait(
"type": "websocket.disconnect", {"type": "websocket.disconnect", "code": code}
"code": code, )
}) self.server.log_action(
self.server.log_action("websocket", "disconnected", { "websocket",
"path": self.request.path, "disconnected",
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, {
}) "path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr)
if self.client_addr
else None,
},
)
### Internal event handling ### Internal event handling
@ -171,9 +192,8 @@ class WebSocketProtocol(WebSocketServerProtocol):
raise ValueError("Socket has not been accepted, so cannot send over it") raise ValueError("Socket has not been accepted, so cannot send over it")
if message.get("bytes", None) and message.get("text", None): if message.get("bytes", None) and message.get("text", None):
raise ValueError( raise ValueError(
"Got invalid WebSocket reply message on %s - contains both bytes and text keys" % ( "Got invalid WebSocket reply message on %s - contains both bytes and text keys"
message, % (message,)
)
) )
if message.get("bytes", None): if message.get("bytes", None):
self.serverSend(message["bytes"], True) self.serverSend(message["bytes"], True)
@ -187,7 +207,9 @@ class WebSocketProtocol(WebSocketServerProtocol):
if hasattr(self, "handshake_deferred"): if hasattr(self, "handshake_deferred"):
# If the handshake is still ongoing, we need to emit a HTTP error # If the handshake is still ongoing, we need to emit a HTTP error
# code rather than a WebSocket one. # code rather than a WebSocket one.
self.handshake_deferred.errback(ConnectionDeny(code=500, reason="Internal server error")) self.handshake_deferred.errback(
ConnectionDeny(code=500, reason="Internal server error")
)
else: else:
self.sendCloseFrame(code=1011) self.sendCloseFrame(code=1011)
@ -203,14 +225,22 @@ class WebSocketProtocol(WebSocketServerProtocol):
""" """
Called when we get a message saying to reject the connection. Called when we get a message saying to reject the connection.
""" """
self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied")) self.handshake_deferred.errback(
ConnectionDeny(code=403, reason="Access denied")
)
del self.handshake_deferred del self.handshake_deferred
self.server.protocol_disconnected(self) self.server.protocol_disconnected(self)
logger.debug("WebSocket %s rejected by application", self.client_addr) logger.debug("WebSocket %s rejected by application", self.client_addr)
self.server.log_action("websocket", "rejected", { self.server.log_action(
"path": self.request.path, "websocket",
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None, "rejected",
}) {
"path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr)
if self.client_addr
else None,
},
)
def serverSend(self, content, binary=False): def serverSend(self, content, binary=False):
""" """
@ -244,7 +274,10 @@ class WebSocketProtocol(WebSocketServerProtocol):
Called periodically to see if we should timeout something Called periodically to see if we should timeout something
""" """
# Web timeout checking # Web timeout checking
if self.duration() > self.server.websocket_timeout and self.server.websocket_timeout >= 0: if (
self.duration() > self.server.websocket_timeout
and self.server.websocket_timeout >= 0
):
self.serverClose() self.serverClose()
# Ping check # Ping check
# If we're still connecting, deny the connection # If we're still connecting, deny the connection
@ -287,6 +320,6 @@ class WebSocketFactory(WebSocketServerFactory):
protocol = super(WebSocketFactory, self).buildProtocol(addr) protocol = super(WebSocketFactory, self).buildProtocol(addr)
protocol.factory = self protocol.factory = self
return protocol return protocol
except Exception as e: except Exception:
logger.error("Cannot build protocol: %s" % traceback.format_exc()) logger.error("Cannot build protocol: %s" % traceback.format_exc())
raise raise

View File

@ -5,9 +5,9 @@ universal=1
addopts = tests/ addopts = tests/
[isort] [isort]
line_length = 120 include_trailing_comma = True
multi_line_output = 3 multi_line_output = 3
known_first_party = channels,daphne,asgiref known_first_party = channels,daphne,asgiref,channels_redis
[flake8] [flake8]
exclude = venv/*,tox/*,docs/*,testproject/*,js_client/*,.eggs/* exclude = venv/*,tox/*,docs/*,testproject/*,js_client/*,.eggs/*

View File

@ -22,23 +22,12 @@ setup(
package_dir={"twisted": "daphne/twisted"}, package_dir={"twisted": "daphne/twisted"},
packages=find_packages() + ["twisted.plugins"], packages=find_packages() + ["twisted.plugins"],
include_package_data=True, include_package_data=True,
install_requires=[ install_requires=["twisted>=18.7", "autobahn>=0.18"],
"twisted>=18.7", setup_requires=["pytest-runner"],
"autobahn>=0.18", extras_require={"tests": ["hypothesis", "pytest", "pytest-asyncio~=0.8"]},
], entry_points={
setup_requires=[ "console_scripts": ["daphne = daphne.cli:CommandLineInterface.entrypoint"]
"pytest-runner",
],
extras_require={
"tests": [
"hypothesis",
"pytest",
"pytest-asyncio~=0.8",
],
}, },
entry_points={"console_scripts": [
"daphne = daphne.cli:CommandLineInterface.entrypoint",
]},
classifiers=[ classifiers=[
"Development Status :: 4 - Beta", "Development Status :: 4 - Beta",
"Environment :: Web Environment", "Environment :: Web Environment",

View File

@ -19,7 +19,9 @@ class DaphneTestCase(unittest.TestCase):
### Plain HTTP helpers ### Plain HTTP helpers
def run_daphne_http(self, method, path, params, body, responses, headers=None, timeout=1, xff=False): def run_daphne_http(
self, method, path, params, body, responses, headers=None, timeout=1, xff=False
):
""" """
Runs Daphne with the given request callback (given the base URL) Runs Daphne with the given request callback (given the base URL)
and response messages. and response messages.
@ -38,7 +40,9 @@ class DaphneTestCase(unittest.TestCase):
# Manually send over headers (encoding any non-safe values as best we can) # Manually send over headers (encoding any non-safe values as best we can)
if headers: if headers:
for header_name, header_value in headers: for header_name, header_value in headers:
conn.putheader(header_name.encode("utf8"), header_value.encode("utf8")) conn.putheader(
header_name.encode("utf8"), header_value.encode("utf8")
)
# Send body if provided. # Send body if provided.
if body: if body:
conn.putheader("Content-Length", str(len(body))) conn.putheader("Content-Length", str(len(body)))
@ -50,9 +54,11 @@ class DaphneTestCase(unittest.TestCase):
except socket.timeout: except socket.timeout:
# See if they left an exception for us to load # See if they left an exception for us to load
test_app.get_received() test_app.get_received()
raise RuntimeError("Daphne timed out handling request, no exception found.") raise RuntimeError(
"Daphne timed out handling request, no exception found."
)
# Return scope, messages, response # Return scope, messages, response
return test_app.get_received() + (response, ) return test_app.get_received() + (response,)
def run_daphne_raw(self, data, timeout=1): def run_daphne_raw(self, data, timeout=1):
""" """
@ -68,9 +74,13 @@ class DaphneTestCase(unittest.TestCase):
try: try:
return s.recv(1000000) return s.recv(1000000)
except socket.timeout: except socket.timeout:
raise RuntimeError("Daphne timed out handling raw request, no exception found.") raise RuntimeError(
"Daphne timed out handling raw request, no exception found."
)
def run_daphne_request(self, method, path, params=None, body=None, headers=None, xff=False): def run_daphne_request(
self, method, path, params=None, body=None, headers=None, xff=False
):
""" """
Convenience method for just testing request handling. Convenience method for just testing request handling.
Returns (scope, messages) Returns (scope, messages)
@ -95,17 +105,21 @@ class DaphneTestCase(unittest.TestCase):
Returns (scope, messages) Returns (scope, messages)
""" """
_, _, response = self.run_daphne_http( _, _, response = self.run_daphne_http(
method="GET", method="GET", path="/", params={}, body=b"", responses=response_messages
path="/",
params={},
body=b"",
responses=response_messages,
) )
return response return response
### WebSocket helpers ### WebSocket helpers
def websocket_handshake(self, test_app, path="/", params=None, headers=None, subprotocols=None, timeout=1): def websocket_handshake(
self,
test_app,
path="/",
params=None,
headers=None,
subprotocols=None,
timeout=1,
):
""" """
Runs a WebSocket handshake negotiation and returns the raw socket Runs a WebSocket handshake negotiation and returns the raw socket
object & the selected subprotocol. object & the selected subprotocol.
@ -124,14 +138,16 @@ class DaphneTestCase(unittest.TestCase):
# Do WebSocket handshake headers + any other headers # Do WebSocket handshake headers + any other headers
if headers is None: if headers is None:
headers = [] headers = []
headers.extend([ headers.extend(
("Host", "example.com"), [
("Upgrade", "websocket"), ("Host", "example.com"),
("Connection", "Upgrade"), ("Upgrade", "websocket"),
("Sec-WebSocket-Key", "x3JJHMbDL1EzLkh9GBhXDw=="), ("Connection", "Upgrade"),
("Sec-WebSocket-Version", "13"), ("Sec-WebSocket-Key", "x3JJHMbDL1EzLkh9GBhXDw=="),
("Origin", "http://example.com") ("Sec-WebSocket-Version", "13"),
]) ("Origin", "http://example.com"),
]
)
if subprotocols: if subprotocols:
headers.append(("Sec-WebSocket-Protocol", ", ".join(subprotocols))) headers.append(("Sec-WebSocket-Protocol", ", ".join(subprotocols)))
if headers: if headers:
@ -149,10 +165,7 @@ class DaphneTestCase(unittest.TestCase):
if response.status != 101: if response.status != 101:
raise RuntimeError("WebSocket upgrade did not result in status code 101") raise RuntimeError("WebSocket upgrade did not result in status code 101")
# Prepare headers for subprotocol searching # Prepare headers for subprotocol searching
response_headers = dict( response_headers = dict((n.lower(), v) for n, v in response.getheaders())
(n.lower(), v)
for n, v in response.getheaders()
)
response.read() response.read()
assert not response.closed assert not response.closed
# Return the raw socket and any subprotocol # Return the raw socket and any subprotocol
@ -234,10 +247,7 @@ class DaphneTestCase(unittest.TestCase):
# Make sure all required keys are present # Make sure all required keys are present
self.assertTrue(required_keys <= present_keys) self.assertTrue(required_keys <= present_keys)
# Assert that no other keys are present # Assert that no other keys are present
self.assertEqual( self.assertEqual(set(), present_keys - required_keys - optional_keys)
set(),
present_keys - required_keys - optional_keys,
)
def assert_valid_path(self, path, request_path): def assert_valid_path(self, path, request_path):
""" """

View File

@ -6,7 +6,9 @@ from hypothesis import strategies
HTTP_METHODS = ["OPTIONS", "GET", "HEAD", "POST", "PUT", "DELETE", "TRACE", "CONNECT"] HTTP_METHODS = ["OPTIONS", "GET", "HEAD", "POST", "PUT", "DELETE", "TRACE", "CONNECT"]
# Unicode characters of the "Letter" category # Unicode characters of the "Letter" category
letters = strategies.characters(whitelist_categories=("Lu", "Ll", "Lt", "Lm", "Lo", "Nl")) letters = strategies.characters(
whitelist_categories=("Lu", "Ll", "Lt", "Lm", "Lo", "Nl")
)
def http_method(): def http_method():
@ -22,11 +24,9 @@ def http_path():
""" """
Returns a URL path (not encoded). Returns a URL path (not encoded).
""" """
return strategies.lists( return strategies.lists(_http_path_portion(), min_size=0, max_size=10).map(
_http_path_portion(), lambda s: "/" + "/".join(s)
min_size=0, )
max_size=10,
).map(lambda s: "/" + "/".join(s))
def http_body(): def http_body():
@ -53,10 +53,7 @@ def valid_bidi(value):
def _domain_label(): def _domain_label():
return strategies.text( return strategies.text(
alphabet=letters, alphabet=letters, min_size=1, average_size=6, max_size=63
min_size=1,
average_size=6,
max_size=63,
).filter(valid_bidi) ).filter(valid_bidi)
@ -64,19 +61,14 @@ def international_domain_name():
""" """
Returns a byte string of a domain name, IDNA-encoded. Returns a byte string of a domain name, IDNA-encoded.
""" """
return strategies.lists( return strategies.lists(_domain_label(), min_size=2, average_size=2).map(
_domain_label(), lambda s: (".".join(s)).encode("idna")
min_size=2, )
average_size=2,
).map(lambda s: (".".join(s)).encode("idna"))
def _query_param(): def _query_param():
return strategies.text( return strategies.text(
alphabet=letters, alphabet=letters, min_size=1, average_size=10, max_size=255
min_size=1,
average_size=10,
max_size=255,
).map(lambda s: s.encode("utf8")) ).map(lambda s: s.encode("utf8"))
@ -87,9 +79,7 @@ def query_params():
ensures that the total urlencoded query string is not longer than 1500 characters. ensures that the total urlencoded query string is not longer than 1500 characters.
""" """
return strategies.lists( return strategies.lists(
strategies.tuples(_query_param(), _query_param()), strategies.tuples(_query_param(), _query_param()), min_size=0, average_size=5
min_size=0,
average_size=5,
).filter(lambda x: len(parse.urlencode(x)) < 1500) ).filter(lambda x: len(parse.urlencode(x)) < 1500)
@ -101,9 +91,7 @@ def header_name():
and 20 characters long and 20 characters long
""" """
return strategies.text( return strategies.text(
alphabet=string.ascii_letters + string.digits + "-", alphabet=string.ascii_letters + string.digits + "-", min_size=1, max_size=30
min_size=1,
max_size=30,
) )
@ -115,7 +103,10 @@ def header_value():
https://en.wikipedia.org/wiki/List_of_HTTP_header_fields https://en.wikipedia.org/wiki/List_of_HTTP_header_fields
""" """
return strategies.text( return strategies.text(
alphabet=string.ascii_letters + string.digits + string.punctuation.replace(",", "") + " /t", alphabet=string.ascii_letters
+ string.digits
+ string.punctuation.replace(",", "")
+ " /t",
min_size=1, min_size=1,
average_size=40, average_size=40,
max_size=8190, max_size=8190,

View File

@ -18,45 +18,32 @@ class TestEndpointDescriptions(TestCase):
def testTcpPortBindings(self): def testTcpPortBindings(self):
self.assertEqual( self.assertEqual(
build(port=1234, host="example.com"), build(port=1234, host="example.com"),
["tcp:port=1234:interface=example.com"] ["tcp:port=1234:interface=example.com"],
) )
self.assertEqual( self.assertEqual(
build(port=8000, host="127.0.0.1"), build(port=8000, host="127.0.0.1"), ["tcp:port=8000:interface=127.0.0.1"]
["tcp:port=8000:interface=127.0.0.1"]
) )
self.assertEqual( self.assertEqual(
build(port=8000, host="[200a::1]"), build(port=8000, host="[200a::1]"), [r"tcp:port=8000:interface=200a\:\:1"]
[r'tcp:port=8000:interface=200a\:\:1']
) )
self.assertEqual( self.assertEqual(
build(port=8000, host="200a::1"), build(port=8000, host="200a::1"), [r"tcp:port=8000:interface=200a\:\:1"]
[r'tcp:port=8000:interface=200a\:\:1']
) )
# incomplete port/host kwargs raise errors # incomplete port/host kwargs raise errors
self.assertRaises( self.assertRaises(ValueError, build, port=123)
ValueError, self.assertRaises(ValueError, build, host="example.com")
build, port=123
)
self.assertRaises(
ValueError,
build, host="example.com"
)
def testUnixSocketBinding(self): def testUnixSocketBinding(self):
self.assertEqual( self.assertEqual(
build(unix_socket="/tmp/daphne.sock"), build(unix_socket="/tmp/daphne.sock"), ["unix:/tmp/daphne.sock"]
["unix:/tmp/daphne.sock"]
) )
def testFileDescriptorBinding(self): def testFileDescriptorBinding(self):
self.assertEqual( self.assertEqual(build(file_descriptor=5), ["fd:fileno=5"])
build(file_descriptor=5),
["fd:fileno=5"]
)
def testMultipleEnpoints(self): def testMultipleEnpoints(self):
self.assertEqual( self.assertEqual(
@ -65,14 +52,16 @@ class TestEndpointDescriptions(TestCase):
file_descriptor=123, file_descriptor=123,
unix_socket="/tmp/daphne.sock", unix_socket="/tmp/daphne.sock",
port=8080, port=8080,
host="10.0.0.1" host="10.0.0.1",
) )
), ),
sorted([ sorted(
"tcp:port=8080:interface=10.0.0.1", [
"unix:/tmp/daphne.sock", "tcp:port=8080:interface=10.0.0.1",
"fd:fileno=123" "unix:/tmp/daphne.sock",
]) "fd:fileno=123",
]
),
) )
@ -112,7 +101,9 @@ class TestCLIInterface(TestCase):
Passes in a fake application automatically. Passes in a fake application automatically.
""" """
cli = self.TestedCLI() cli = self.TestedCLI()
cli.run(args + ["daphne:__version__"]) # We just pass something importable as app cli.run(
args + ["daphne:__version__"]
) # We just pass something importable as app
# Check the server got all arguments as intended # Check the server got all arguments as intended
for key, value in server_kwargs.items(): for key, value in server_kwargs.items():
# Get the value and sort it if it's a list (for endpoint checking) # Get the value and sort it if it's a list (for endpoint checking)
@ -123,52 +114,30 @@ class TestCLIInterface(TestCase):
self.assertEqual( self.assertEqual(
value, value,
actual_value, actual_value,
"Wrong value for server kwarg %s: %r != %r" % ( "Wrong value for server kwarg %s: %r != %r"
key, % (key, value, actual_value),
value,
actual_value,
),
) )
def testCLIBasics(self): def testCLIBasics(self):
""" """
Tests basic endpoint generation. Tests basic endpoint generation.
""" """
self.assertCLI([], {"endpoints": ["tcp:port=8000:interface=127.0.0.1"]})
self.assertCLI( self.assertCLI(
[], ["-p", "123"], {"endpoints": ["tcp:port=123:interface=127.0.0.1"]}
{
"endpoints": ["tcp:port=8000:interface=127.0.0.1"],
},
) )
self.assertCLI( self.assertCLI(
["-p", "123"], ["-b", "10.0.0.1"], {"endpoints": ["tcp:port=8000:interface=10.0.0.1"]}
{
"endpoints": ["tcp:port=123:interface=127.0.0.1"],
},
) )
self.assertCLI( self.assertCLI(
["-b", "10.0.0.1"], ["-b", "200a::1"], {"endpoints": [r"tcp:port=8000:interface=200a\:\:1"]}
{
"endpoints": ["tcp:port=8000:interface=10.0.0.1"],
},
) )
self.assertCLI( self.assertCLI(
["-b", "200a::1"], ["-b", "[200a::1]"], {"endpoints": [r"tcp:port=8000:interface=200a\:\:1"]}
{
"endpoints": [r'tcp:port=8000:interface=200a\:\:1'],
},
)
self.assertCLI(
["-b", "[200a::1]"],
{
"endpoints": [r'tcp:port=8000:interface=200a\:\:1'],
},
) )
self.assertCLI( self.assertCLI(
["-p", "8080", "-b", "example.com"], ["-p", "8080", "-b", "example.com"],
{ {"endpoints": ["tcp:port=8080:interface=example.com"]},
"endpoints": ["tcp:port=8080:interface=example.com"],
},
) )
def testUnixSockets(self): def testUnixSockets(self):
@ -178,7 +147,7 @@ class TestCLIInterface(TestCase):
"endpoints": [ "endpoints": [
"tcp:port=8080:interface=127.0.0.1", "tcp:port=8080:interface=127.0.0.1",
"unix:/tmp/daphne.sock", "unix:/tmp/daphne.sock",
], ]
}, },
) )
self.assertCLI( self.assertCLI(
@ -187,17 +156,12 @@ class TestCLIInterface(TestCase):
"endpoints": [ "endpoints": [
"tcp:port=8000:interface=example.com", "tcp:port=8000:interface=example.com",
"unix:/tmp/daphne.sock", "unix:/tmp/daphne.sock",
], ]
}, },
) )
self.assertCLI( self.assertCLI(
["-u", "/tmp/daphne.sock", "--fd", "5"], ["-u", "/tmp/daphne.sock", "--fd", "5"],
{ {"endpoints": ["fd:fileno=5", "unix:/tmp/daphne.sock"]},
"endpoints": [
"fd:fileno=5",
"unix:/tmp/daphne.sock"
],
},
) )
def testMixedCLIEndpointCreation(self): def testMixedCLIEndpointCreation(self):
@ -209,8 +173,8 @@ class TestCLIInterface(TestCase):
{ {
"endpoints": [ "endpoints": [
"tcp:port=8080:interface=127.0.0.1", "tcp:port=8080:interface=127.0.0.1",
"unix:/tmp/daphne.sock" "unix:/tmp/daphne.sock",
], ]
}, },
) )
self.assertCLI( self.assertCLI(
@ -219,7 +183,7 @@ class TestCLIInterface(TestCase):
"endpoints": [ "endpoints": [
"tcp:port=8080:interface=127.0.0.1", "tcp:port=8080:interface=127.0.0.1",
"tcp:port=8080:interface=127.0.0.1", "tcp:port=8080:interface=127.0.0.1",
], ]
}, },
) )
@ -227,11 +191,4 @@ class TestCLIInterface(TestCase):
""" """
Tests entirely custom endpoints Tests entirely custom endpoints
""" """
self.assertCLI( self.assertCLI(["-e", "imap:"], {"endpoints": ["imap:"]})
["-e", "imap:"],
{
"endpoints": [
"imap:",
],
},
)

View File

@ -15,13 +15,7 @@ class TestHTTPRequest(DaphneTestCase):
""" """
def assert_valid_http_scope( def assert_valid_http_scope(
self, self, scope, method, path, params=None, headers=None, scheme=None
scope,
method,
path,
params=None,
headers=None,
scheme=None,
): ):
""" """
Checks that the passed scope is a valid ASGI HTTP scope regarding types Checks that the passed scope is a valid ASGI HTTP scope regarding types
@ -29,7 +23,14 @@ class TestHTTPRequest(DaphneTestCase):
""" """
# Check overall keys # Check overall keys
self.assert_key_sets( self.assert_key_sets(
required_keys={"type", "http_version", "method", "path", "query_string", "headers"}, required_keys={
"type",
"http_version",
"method",
"path",
"query_string",
"headers",
},
optional_keys={"scheme", "root_path", "client", "server"}, optional_keys={"scheme", "root_path", "client", "server"},
actual_keys=scope.keys(), actual_keys=scope.keys(),
) )
@ -50,7 +51,9 @@ class TestHTTPRequest(DaphneTestCase):
query_string = scope["query_string"] query_string = scope["query_string"]
self.assertIsInstance(query_string, bytes) self.assertIsInstance(query_string, bytes)
if params: if params:
self.assertEqual(query_string, parse.urlencode(params or []).encode("ascii")) self.assertEqual(
query_string, parse.urlencode(params or []).encode("ascii")
)
# Ordering of header names is not important, but the order of values for a header # Ordering of header names is not important, but the order of values for a header
# name is. To assert whether that order is kept, we transform both the request # name is. To assert whether that order is kept, we transform both the request
# headers and the channel message headers into a dictionary # headers and the channel message headers into a dictionary
@ -59,7 +62,7 @@ class TestHTTPRequest(DaphneTestCase):
for name, value in scope["headers"]: for name, value in scope["headers"]:
transformed_scope_headers[name].append(value) transformed_scope_headers[name].append(value)
transformed_request_headers = collections.defaultdict(list) transformed_request_headers = collections.defaultdict(list)
for name, value in (headers or []): for name, value in headers or []:
expected_name = name.lower().strip().encode("ascii") expected_name = name.lower().strip().encode("ascii")
expected_value = value.strip().encode("ascii") expected_value = value.strip().encode("ascii")
transformed_request_headers[expected_name].append(expected_value) transformed_request_headers[expected_name].append(expected_value)
@ -103,27 +106,31 @@ class TestHTTPRequest(DaphneTestCase):
@given( @given(
request_path=http_strategies.http_path(), request_path=http_strategies.http_path(),
request_params=http_strategies.query_params() request_params=http_strategies.query_params(),
) )
@settings(max_examples=5, deadline=5000) @settings(max_examples=5, deadline=5000)
def test_get_request(self, request_path, request_params): def test_get_request(self, request_path, request_params):
""" """
Tests a typical HTTP GET request, with a path and query parameters Tests a typical HTTP GET request, with a path and query parameters
""" """
scope, messages = self.run_daphne_request("GET", request_path, params=request_params) scope, messages = self.run_daphne_request(
"GET", request_path, params=request_params
)
self.assert_valid_http_scope(scope, "GET", request_path, params=request_params) self.assert_valid_http_scope(scope, "GET", request_path, params=request_params)
self.assert_valid_http_request_message(messages[0], body=b"") self.assert_valid_http_request_message(messages[0], body=b"")
@given( @given(
request_path=http_strategies.http_path(), request_path=http_strategies.http_path(),
request_body=http_strategies.http_body() request_body=http_strategies.http_body(),
) )
@settings(max_examples=5, deadline=5000) @settings(max_examples=5, deadline=5000)
def test_post_request(self, request_path, request_body): def test_post_request(self, request_path, request_body):
""" """
Tests a typical HTTP POST request, with a path and body. Tests a typical HTTP POST request, with a path and body.
""" """
scope, messages = self.run_daphne_request("POST", request_path, body=request_body) scope, messages = self.run_daphne_request(
"POST", request_path, body=request_body
)
self.assert_valid_http_scope(scope, "POST", request_path) self.assert_valid_http_scope(scope, "POST", request_path)
self.assert_valid_http_request_message(messages[0], body=request_body) self.assert_valid_http_request_message(messages[0], body=request_body)
@ -134,8 +141,12 @@ class TestHTTPRequest(DaphneTestCase):
Tests that HTTP header fields are handled as specified Tests that HTTP header fields are handled as specified
""" """
request_path = "/te st-à/" request_path = "/te st-à/"
scope, messages = self.run_daphne_request("OPTIONS", request_path, headers=request_headers) scope, messages = self.run_daphne_request(
self.assert_valid_http_scope(scope, "OPTIONS", request_path, headers=request_headers) "OPTIONS", request_path, headers=request_headers
)
self.assert_valid_http_scope(
scope, "OPTIONS", request_path, headers=request_headers
)
self.assert_valid_http_request_message(messages[0], body=b"") self.assert_valid_http_request_message(messages[0], body=b"")
@given(request_headers=http_strategies.headers()) @given(request_headers=http_strategies.headers())
@ -150,8 +161,12 @@ class TestHTTPRequest(DaphneTestCase):
duplicated_headers = [(header_name, header[1]) for header in request_headers] duplicated_headers = [(header_name, header[1]) for header in request_headers]
# Run the request # Run the request
request_path = "/te st-à/" request_path = "/te st-à/"
scope, messages = self.run_daphne_request("OPTIONS", request_path, headers=duplicated_headers) scope, messages = self.run_daphne_request(
self.assert_valid_http_scope(scope, "OPTIONS", request_path, headers=duplicated_headers) "OPTIONS", request_path, headers=duplicated_headers
)
self.assert_valid_http_scope(
scope, "OPTIONS", request_path, headers=duplicated_headers
)
self.assert_valid_http_request_message(messages[0], body=b"") self.assert_valid_http_request_message(messages[0], body=b"")
@given( @given(
@ -222,10 +237,7 @@ class TestHTTPRequest(DaphneTestCase):
""" """
Make sure that, by default, X-Forwarded-For is ignored. Make sure that, by default, X-Forwarded-For is ignored.
""" """
headers = [ headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]]
["X-Forwarded-For", "10.1.2.3"],
["X-Forwarded-Port", "80"],
]
scope, messages = self.run_daphne_request("GET", "/", headers=headers) scope, messages = self.run_daphne_request("GET", "/", headers=headers)
self.assert_valid_http_scope(scope, "GET", "/", headers=headers) self.assert_valid_http_scope(scope, "GET", "/", headers=headers)
self.assert_valid_http_request_message(messages[0], body=b"") self.assert_valid_http_request_message(messages[0], body=b"")
@ -236,10 +248,7 @@ class TestHTTPRequest(DaphneTestCase):
""" """
When X-Forwarded-For is enabled, make sure it is respected. When X-Forwarded-For is enabled, make sure it is respected.
""" """
headers = [ headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]]
["X-Forwarded-For", "10.1.2.3"],
["X-Forwarded-Port", "80"],
]
scope, messages = self.run_daphne_request("GET", "/", headers=headers, xff=True) scope, messages = self.run_daphne_request("GET", "/", headers=headers, xff=True)
self.assert_valid_http_scope(scope, "GET", "/", headers=headers) self.assert_valid_http_scope(scope, "GET", "/", headers=headers)
self.assert_valid_http_request_message(messages[0], body=b"") self.assert_valid_http_request_message(messages[0], body=b"")
@ -251,9 +260,7 @@ class TestHTTPRequest(DaphneTestCase):
When X-Forwarded-For is enabled but only the host is passed, make sure When X-Forwarded-For is enabled but only the host is passed, make sure
that at least makes it through. that at least makes it through.
""" """
headers = [ headers = [["X-Forwarded-For", "10.1.2.3"]]
["X-Forwarded-For", "10.1.2.3"],
]
scope, messages = self.run_daphne_request("GET", "/", headers=headers, xff=True) scope, messages = self.run_daphne_request("GET", "/", headers=headers, xff=True)
self.assert_valid_http_scope(scope, "GET", "/", headers=headers) self.assert_valid_http_scope(scope, "GET", "/", headers=headers)
self.assert_valid_http_request_message(messages[0], body=b"") self.assert_valid_http_request_message(messages[0], body=b"")
@ -265,8 +272,12 @@ class TestHTTPRequest(DaphneTestCase):
Tests that requests with invalid (non-ASCII) characters fail. Tests that requests with invalid (non-ASCII) characters fail.
""" """
# Bad path # Bad path
response = self.run_daphne_raw(b"GET /\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n") response = self.run_daphne_raw(
b"GET /\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n"
)
self.assertTrue(response.startswith(b"HTTP/1.0 400 Bad Request")) self.assertTrue(response.startswith(b"HTTP/1.0 400 Bad Request"))
# Bad querystring # Bad querystring
response = self.run_daphne_raw(b"GET /?\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n") response = self.run_daphne_raw(
b"GET /?\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n"
)
self.assertTrue(response.startswith(b"HTTP/1.0 400 Bad Request")) self.assertTrue(response.startswith(b"HTTP/1.0 400 Bad Request"))

View File

@ -15,26 +15,24 @@ class TestHTTPResponse(DaphneTestCase):
""" """
Lowercases and sorts headers, and strips transfer-encoding ones. Lowercases and sorts headers, and strips transfer-encoding ones.
""" """
return sorted([ return sorted(
(name.lower(), value.strip()) [
for name, value in headers (name.lower(), value.strip())
if name.lower() != "transfer-encoding" for name, value in headers
]) if name.lower() != "transfer-encoding"
]
)
def test_minimal_response(self): def test_minimal_response(self):
""" """
Smallest viable example. Mostly verifies that our response building works. Smallest viable example. Mostly verifies that our response building works.
""" """
response = self.run_daphne_response([ response = self.run_daphne_response(
{ [
"type": "http.response.start", {"type": "http.response.start", "status": 200},
"status": 200, {"type": "http.response.body", "body": b"hello world"},
}, ]
{ )
"type": "http.response.body",
"body": b"hello world",
},
])
self.assertEqual(response.status, 200) self.assertEqual(response.status, 200)
self.assertEqual(response.read(), b"hello world") self.assertEqual(response.read(), b"hello world")
@ -46,30 +44,23 @@ class TestHTTPResponse(DaphneTestCase):
to make sure it stays required. to make sure it stays required.
""" """
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.run_daphne_response([ self.run_daphne_response(
{ [
"type": "http.response.start", {"type": "http.response.start"},
}, {"type": "http.response.body", "body": b"hello world"},
{ ]
"type": "http.response.body", )
"body": b"hello world",
},
])
def test_custom_status_code(self): def test_custom_status_code(self):
""" """
Tries a non-default status code. Tries a non-default status code.
""" """
response = self.run_daphne_response([ response = self.run_daphne_response(
{ [
"type": "http.response.start", {"type": "http.response.start", "status": 201},
"status": 201, {"type": "http.response.body", "body": b"i made a thing!"},
}, ]
{ )
"type": "http.response.body",
"body": b"i made a thing!",
},
])
self.assertEqual(response.status, 201) self.assertEqual(response.status, 201)
self.assertEqual(response.read(), b"i made a thing!") self.assertEqual(response.read(), b"i made a thing!")
@ -77,21 +68,13 @@ class TestHTTPResponse(DaphneTestCase):
""" """
Tries sending a response in multiple parts. Tries sending a response in multiple parts.
""" """
response = self.run_daphne_response([ response = self.run_daphne_response(
{ [
"type": "http.response.start", {"type": "http.response.start", "status": 201},
"status": 201, {"type": "http.response.body", "body": b"chunk 1 ", "more_body": True},
}, {"type": "http.response.body", "body": b"chunk 2"},
{ ]
"type": "http.response.body", )
"body": b"chunk 1 ",
"more_body": True,
},
{
"type": "http.response.body",
"body": b"chunk 2",
},
])
self.assertEqual(response.status, 201) self.assertEqual(response.status, 201)
self.assertEqual(response.read(), b"chunk 1 chunk 2") self.assertEqual(response.read(), b"chunk 1 chunk 2")
@ -99,25 +82,14 @@ class TestHTTPResponse(DaphneTestCase):
""" """
Tries sending a response in multiple parts and an empty end. Tries sending a response in multiple parts and an empty end.
""" """
response = self.run_daphne_response([ response = self.run_daphne_response(
{ [
"type": "http.response.start", {"type": "http.response.start", "status": 201},
"status": 201, {"type": "http.response.body", "body": b"chunk 1 ", "more_body": True},
}, {"type": "http.response.body", "body": b"chunk 2", "more_body": True},
{ {"type": "http.response.body"},
"type": "http.response.body", ]
"body": b"chunk 1 ", )
"more_body": True,
},
{
"type": "http.response.body",
"body": b"chunk 2",
"more_body": True,
},
{
"type": "http.response.body",
},
])
self.assertEqual(response.status, 201) self.assertEqual(response.status, 201)
self.assertEqual(response.read(), b"chunk 1 chunk 2") self.assertEqual(response.read(), b"chunk 1 chunk 2")
@ -127,16 +99,12 @@ class TestHTTPResponse(DaphneTestCase):
""" """
Tries body variants. Tries body variants.
""" """
response = self.run_daphne_response([ response = self.run_daphne_response(
{ [
"type": "http.response.start", {"type": "http.response.start", "status": 200},
"status": 200, {"type": "http.response.body", "body": body},
}, ]
{ )
"type": "http.response.body",
"body": body,
},
])
self.assertEqual(response.status, 200) self.assertEqual(response.status, 200)
self.assertEqual(response.read(), body) self.assertEqual(response.read(), body)
@ -144,16 +112,16 @@ class TestHTTPResponse(DaphneTestCase):
@settings(max_examples=5, deadline=5000) @settings(max_examples=5, deadline=5000)
def test_headers(self, headers): def test_headers(self, headers):
# The ASGI spec requires us to lowercase our header names # The ASGI spec requires us to lowercase our header names
response = self.run_daphne_response([ response = self.run_daphne_response(
{ [
"type": "http.response.start", {
"status": 200, "type": "http.response.start",
"headers": self.normalize_headers(headers), "status": 200,
}, "headers": self.normalize_headers(headers),
{ },
"type": "http.response.body", {"type": "http.response.body"},
}, ]
]) )
# Check headers in a sensible way. Ignore transfer-encoding. # Check headers in a sensible way. Ignore transfer-encoding.
self.assertEqual( self.assertEqual(
self.normalize_headers(response.getheaders()), self.normalize_headers(response.getheaders()),

View File

@ -13,48 +13,35 @@ class TestXForwardedForHttpParsing(TestCase):
""" """
def test_basic(self): def test_basic(self):
headers = Headers({ headers = Headers(
b"X-Forwarded-For": [b"10.1.2.3"], {
b"X-Forwarded-Port": [b"1234"], b"X-Forwarded-For": [b"10.1.2.3"],
b"X-Forwarded-Proto": [b"https"] b"X-Forwarded-Port": [b"1234"],
}) b"X-Forwarded-Proto": [b"https"],
}
)
result = parse_x_forwarded_for(headers) result = parse_x_forwarded_for(headers)
self.assertEqual(result, (["10.1.2.3", 1234], "https")) self.assertEqual(result, (["10.1.2.3", 1234], "https"))
self.assertIsInstance(result[0][0], str) self.assertIsInstance(result[0][0], str)
self.assertIsInstance(result[1], str) self.assertIsInstance(result[1], str)
def test_address_only(self): def test_address_only(self):
headers = Headers({ headers = Headers({b"X-Forwarded-For": [b"10.1.2.3"]})
b"X-Forwarded-For": [b"10.1.2.3"], self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None))
})
self.assertEqual(
parse_x_forwarded_for(headers),
(["10.1.2.3", 0], None)
)
def test_v6_address(self): def test_v6_address(self):
headers = Headers({ headers = Headers({b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"]})
b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"], self.assertEqual(parse_x_forwarded_for(headers), (["1043::a321:0001", 0], None))
})
self.assertEqual(
parse_x_forwarded_for(headers),
(["1043::a321:0001", 0], None)
)
def test_multiple_proxys(self): def test_multiple_proxys(self):
headers = Headers({ headers = Headers({b"X-Forwarded-For": [b"10.1.2.3, 10.1.2.4"]})
b"X-Forwarded-For": [b"10.1.2.3, 10.1.2.4"], self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None))
})
self.assertEqual(
parse_x_forwarded_for(headers),
(["10.1.2.3", 0], None)
)
def test_original(self): def test_original(self):
headers = Headers({}) headers = Headers({})
self.assertEqual( self.assertEqual(
parse_x_forwarded_for(headers, original_addr=["127.0.0.1", 80]), parse_x_forwarded_for(headers, original_addr=["127.0.0.1", 80]),
(["127.0.0.1", 80], None) (["127.0.0.1", 80], None),
) )
def test_no_original(self): def test_no_original(self):
@ -73,43 +60,25 @@ class TestXForwardedForWsParsing(TestCase):
b"X-Forwarded-Port": b"1234", b"X-Forwarded-Port": b"1234",
b"X-Forwarded-Proto": b"https", b"X-Forwarded-Proto": b"https",
} }
self.assertEqual( self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 1234], "https"))
parse_x_forwarded_for(headers),
(["10.1.2.3", 1234], "https")
)
def test_address_only(self): def test_address_only(self):
headers = { headers = {b"X-Forwarded-For": b"10.1.2.3"}
b"X-Forwarded-For": b"10.1.2.3", self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None))
}
self.assertEqual(
parse_x_forwarded_for(headers),
(["10.1.2.3", 0], None)
)
def test_v6_address(self): def test_v6_address(self):
headers = { headers = {b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"]}
b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"], self.assertEqual(parse_x_forwarded_for(headers), (["1043::a321:0001", 0], None))
}
self.assertEqual(
parse_x_forwarded_for(headers),
(["1043::a321:0001", 0], None)
)
def test_multiple_proxies(self): def test_multiple_proxies(self):
headers = { headers = {b"X-Forwarded-For": b"10.1.2.3, 10.1.2.4"}
b"X-Forwarded-For": b"10.1.2.3, 10.1.2.4", self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None))
}
self.assertEqual(
parse_x_forwarded_for(headers),
(["10.1.2.3", 0], None)
)
def test_original(self): def test_original(self):
headers = {} headers = {}
self.assertEqual( self.assertEqual(
parse_x_forwarded_for(headers, original_addr=["127.0.0.1", 80]), parse_x_forwarded_for(headers, original_addr=["127.0.0.1", 80]),
(["127.0.0.1", 80], None) (["127.0.0.1", 80], None),
) )
def test_no_original(self): def test_no_original(self):

View File

@ -16,13 +16,7 @@ class TestWebsocket(DaphneTestCase):
""" """
def assert_valid_websocket_scope( def assert_valid_websocket_scope(
self, self, scope, path="/", params=None, headers=None, scheme=None, subprotocols=None
scope,
path="/",
params=None,
headers=None,
scheme=None,
subprotocols=None,
): ):
""" """
Checks that the passed scope is a valid ASGI HTTP scope regarding types Checks that the passed scope is a valid ASGI HTTP scope regarding types
@ -46,7 +40,9 @@ class TestWebsocket(DaphneTestCase):
query_string = scope["query_string"] query_string = scope["query_string"]
self.assertIsInstance(query_string, bytes) self.assertIsInstance(query_string, bytes)
if params: if params:
self.assertEqual(query_string, parse.urlencode(params or []).encode("ascii")) self.assertEqual(
query_string, parse.urlencode(params or []).encode("ascii")
)
# Ordering of header names is not important, but the order of values for a header # Ordering of header names is not important, but the order of values for a header
# name is. To assert whether that order is kept, we transform both the request # name is. To assert whether that order is kept, we transform both the request
# headers and the channel message headers into a dictionary # headers and the channel message headers into a dictionary
@ -59,7 +55,7 @@ class TestWebsocket(DaphneTestCase):
if bit.strip(): if bit.strip():
transformed_scope_headers[name].append(bit.strip()) transformed_scope_headers[name].append(bit.strip())
transformed_request_headers = collections.defaultdict(list) transformed_request_headers = collections.defaultdict(list)
for name, value in (headers or []): for name, value in headers or []:
expected_name = name.lower().strip().encode("ascii") expected_name = name.lower().strip().encode("ascii")
expected_value = value.strip().encode("ascii") expected_value = value.strip().encode("ascii")
# Make sure to split out any headers collapsed with commas # Make sure to split out any headers collapsed with commas
@ -92,9 +88,7 @@ class TestWebsocket(DaphneTestCase):
""" """
# Check overall keys # Check overall keys
self.assert_key_sets( self.assert_key_sets(
required_keys={"type"}, required_keys={"type"}, optional_keys=set(), actual_keys=message.keys()
optional_keys=set(),
actual_keys=message.keys(),
) )
# Check that it is the right type # Check that it is the right type
self.assertEqual(message["type"], "websocket.connect") self.assertEqual(message["type"], "websocket.connect")
@ -104,11 +98,7 @@ class TestWebsocket(DaphneTestCase):
Tests we can open and accept a socket. Tests we can open and accept a socket.
""" """
with DaphneTestingInstance() as test_app: with DaphneTestingInstance() as test_app:
test_app.add_send_messages([ test_app.add_send_messages([{"type": "websocket.accept"}])
{
"type": "websocket.accept",
}
])
self.websocket_handshake(test_app) self.websocket_handshake(test_app)
# Validate the scope and messages we got # Validate the scope and messages we got
scope, messages = test_app.get_received() scope, messages = test_app.get_received()
@ -120,11 +110,7 @@ class TestWebsocket(DaphneTestCase):
Tests we can reject a socket and it won't complete the handshake. Tests we can reject a socket and it won't complete the handshake.
""" """
with DaphneTestingInstance() as test_app: with DaphneTestingInstance() as test_app:
test_app.add_send_messages([ test_app.add_send_messages([{"type": "websocket.close"}])
{
"type": "websocket.close",
}
])
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
self.websocket_handshake(test_app) self.websocket_handshake(test_app)
@ -134,13 +120,12 @@ class TestWebsocket(DaphneTestCase):
""" """
subprotocols = ["proto1", "proto2"] subprotocols = ["proto1", "proto2"]
with DaphneTestingInstance() as test_app: with DaphneTestingInstance() as test_app:
test_app.add_send_messages([ test_app.add_send_messages(
{ [{"type": "websocket.accept", "subprotocol": "proto2"}]
"type": "websocket.accept", )
"subprotocol": "proto2", _, subprotocol = self.websocket_handshake(
} test_app, subprotocols=subprotocols
]) )
_, subprotocol = self.websocket_handshake(test_app, subprotocols=subprotocols)
# Validate the scope and messages we got # Validate the scope and messages we got
assert subprotocol == "proto2" assert subprotocol == "proto2"
scope, messages = test_app.get_received() scope, messages = test_app.get_received()
@ -151,16 +136,9 @@ class TestWebsocket(DaphneTestCase):
""" """
Tests that X-Forwarded-For headers get parsed right Tests that X-Forwarded-For headers get parsed right
""" """
headers = [ headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]]
["X-Forwarded-For", "10.1.2.3"],
["X-Forwarded-Port", "80"],
]
with DaphneTestingInstance(xff=True) as test_app: with DaphneTestingInstance(xff=True) as test_app:
test_app.add_send_messages([ test_app.add_send_messages([{"type": "websocket.accept"}])
{
"type": "websocket.accept",
}
])
self.websocket_handshake(test_app, headers=headers) self.websocket_handshake(test_app, headers=headers)
# Validate the scope and messages we got # Validate the scope and messages we got
scope, messages = test_app.get_received() scope, messages = test_app.get_received()
@ -174,22 +152,13 @@ class TestWebsocket(DaphneTestCase):
request_headers=http_strategies.headers(), request_headers=http_strategies.headers(),
) )
@settings(max_examples=5, deadline=2000) @settings(max_examples=5, deadline=2000)
def test_http_bits( def test_http_bits(self, request_path, request_params, request_headers):
self,
request_path,
request_params,
request_headers,
):
""" """
Tests that various HTTP-level bits (query string params, path, headers) Tests that various HTTP-level bits (query string params, path, headers)
carry over into the scope. carry over into the scope.
""" """
with DaphneTestingInstance() as test_app: with DaphneTestingInstance() as test_app:
test_app.add_send_messages([ test_app.add_send_messages([{"type": "websocket.accept"}])
{
"type": "websocket.accept",
}
])
self.websocket_handshake( self.websocket_handshake(
test_app, test_app,
path=request_path, path=request_path,
@ -199,10 +168,7 @@ class TestWebsocket(DaphneTestCase):
# Validate the scope and messages we got # Validate the scope and messages we got
scope, messages = test_app.get_received() scope, messages = test_app.get_received()
self.assert_valid_websocket_scope( self.assert_valid_websocket_scope(
scope, scope, path=request_path, params=request_params, headers=request_headers
path=request_path,
params=request_params,
headers=request_headers,
) )
self.assert_valid_websocket_connect_message(messages[0]) self.assert_valid_websocket_connect_message(messages[0])
@ -212,28 +178,24 @@ class TestWebsocket(DaphneTestCase):
""" """
with DaphneTestingInstance() as test_app: with DaphneTestingInstance() as test_app:
# Connect # Connect
test_app.add_send_messages([ test_app.add_send_messages([{"type": "websocket.accept"}])
{
"type": "websocket.accept",
}
])
sock, _ = self.websocket_handshake(test_app) sock, _ = self.websocket_handshake(test_app)
_, messages = test_app.get_received() _, messages = test_app.get_received()
self.assert_valid_websocket_connect_message(messages[0]) self.assert_valid_websocket_connect_message(messages[0])
# Prep frame for it to send # Prep frame for it to send
test_app.add_send_messages([ test_app.add_send_messages(
{ [{"type": "websocket.send", "text": "here be dragons 🐉"}]
"type": "websocket.send", )
"text": "here be dragons 🐉",
}
])
# Send it a frame # Send it a frame
self.websocket_send_frame(sock, "what is here? 🌍") self.websocket_send_frame(sock, "what is here? 🌍")
# Receive a frame and make sure it's correct # Receive a frame and make sure it's correct
assert self.websocket_receive_frame(sock) == "here be dragons 🐉" assert self.websocket_receive_frame(sock) == "here be dragons 🐉"
# Make sure it got our frame # Make sure it got our frame
_, messages = test_app.get_received() _, messages = test_app.get_received()
assert messages[1] == {"type": "websocket.receive", "text": "what is here? 🌍"} assert messages[1] == {
"type": "websocket.receive",
"text": "what is here? 🌍",
}
def test_binary_frames(self): def test_binary_frames(self):
""" """
@ -242,28 +204,24 @@ class TestWebsocket(DaphneTestCase):
""" """
with DaphneTestingInstance() as test_app: with DaphneTestingInstance() as test_app:
# Connect # Connect
test_app.add_send_messages([ test_app.add_send_messages([{"type": "websocket.accept"}])
{
"type": "websocket.accept",
}
])
sock, _ = self.websocket_handshake(test_app) sock, _ = self.websocket_handshake(test_app)
_, messages = test_app.get_received() _, messages = test_app.get_received()
self.assert_valid_websocket_connect_message(messages[0]) self.assert_valid_websocket_connect_message(messages[0])
# Prep frame for it to send # Prep frame for it to send
test_app.add_send_messages([ test_app.add_send_messages(
{ [{"type": "websocket.send", "bytes": b"here be \xe2 bytes"}]
"type": "websocket.send", )
"bytes": b"here be \xe2 bytes",
}
])
# Send it a frame # Send it a frame
self.websocket_send_frame(sock, b"what is here? \xe2") self.websocket_send_frame(sock, b"what is here? \xe2")
# Receive a frame and make sure it's correct # Receive a frame and make sure it's correct
assert self.websocket_receive_frame(sock) == b"here be \xe2 bytes" assert self.websocket_receive_frame(sock) == b"here be \xe2 bytes"
# Make sure it got our frame # Make sure it got our frame
_, messages = test_app.get_received() _, messages = test_app.get_received()
assert messages[1] == {"type": "websocket.receive", "bytes": b"what is here? \xe2"} assert messages[1] == {
"type": "websocket.receive",
"bytes": b"what is here? \xe2",
}
def test_http_timeout(self): def test_http_timeout(self):
""" """
@ -271,23 +229,14 @@ class TestWebsocket(DaphneTestCase):
""" """
with DaphneTestingInstance(http_timeout=1) as test_app: with DaphneTestingInstance(http_timeout=1) as test_app:
# Connect # Connect
test_app.add_send_messages([ test_app.add_send_messages([{"type": "websocket.accept"}])
{
"type": "websocket.accept",
}
])
sock, _ = self.websocket_handshake(test_app) sock, _ = self.websocket_handshake(test_app)
_, messages = test_app.get_received() _, messages = test_app.get_received()
self.assert_valid_websocket_connect_message(messages[0]) self.assert_valid_websocket_connect_message(messages[0])
# Wait 2 seconds # Wait 2 seconds
time.sleep(2) time.sleep(2)
# Prep frame for it to send # Prep frame for it to send
test_app.add_send_messages([ test_app.add_send_messages([{"type": "websocket.send", "text": "cake"}])
{
"type": "websocket.send",
"text": "cake",
}
])
# Send it a frame # Send it a frame
self.websocket_send_frame(sock, "still alive?") self.websocket_send_frame(sock, "still alive?")
# Receive a frame and make sure it's correct # Receive a frame and make sure it's correct