Implement Black code formatting

This commit is contained in:
Andrew Godwin 2018-08-27 12:27:32 +10:00
parent 88792984e7
commit 0ed6294406
18 changed files with 513 additions and 596 deletions

View File

@ -3,22 +3,25 @@ sudo: false
language: python
python:
- '3.5'
- '3.6'
- '3.5'
env:
- TWISTED="twisted==18.7.0"
- TWISTED="twisted"
- TWISTED="twisted==18.7.0"
install:
- pip install $TWISTED isort unify flake8 -e .[tests]
- pip install $TWISTED -e .[tests]
- pip freeze
script:
- pytest
- flake8
- isort --check-only --diff --recursive daphne tests
- unify --check-only --recursive --quote \" daphne tests
stages:
- lint
- test
- name: release
if: branch = master
jobs:
include:
@ -30,6 +33,13 @@ jobs:
env: TWISTED="twisted"
dist: xenial
sudo: required
- stage: lint
install: pip install -U -e .[tests] black pyflakes isort
script:
- pyflakes .
- black --check .
- isort --check-only --diff --recursive channels_redis tests
- stage: release
script: skip
deploy:

View File

@ -49,13 +49,16 @@ class AccessLogGenerator(object):
request="WSDISCONNECT %(path)s" % details,
)
def write_entry(self, host, date, request, status=None, length=None, ident=None, user=None):
def write_entry(
self, host, date, request, status=None, length=None, ident=None, user=None
):
"""
Writes an NCSA-style entry to the log file (some liberty is taken with
what the entries are for non-HTTP)
"""
self.stream.write(
"%s %s %s [%s] \"%s\" %s %s\n" % (
'%s %s %s [%s] "%s" %s %s\n'
% (
host,
ident or "-",
user or "-",

View File

@ -23,15 +23,9 @@ class CommandLineInterface(object):
server_class = Server
def __init__(self):
self.parser = argparse.ArgumentParser(
description=self.description,
)
self.parser = argparse.ArgumentParser(description=self.description)
self.parser.add_argument(
"-p",
"--port",
type=int,
help="Port number to listen on",
default=None,
"-p", "--port", type=int, help="Port number to listen on", default=None
)
self.parser.add_argument(
"-b",
@ -176,7 +170,15 @@ class CommandLineInterface(object):
sys.path.insert(0, ".")
application = import_by_path(args.application)
# Set up port/host bindings
if not any([args.host, args.port is not None, args.unix_socket, args.file_descriptor, args.socket_strings]):
if not any(
[
args.host,
args.port is not None,
args.unix_socket,
args.file_descriptor,
args.socket_strings,
]
):
# no advanced binding options passed, patch in defaults
args.host = DEFAULT_HOST
args.port = DEFAULT_PORT
@ -189,16 +191,11 @@ class CommandLineInterface(object):
host=args.host,
port=args.port,
unix_socket=args.unix_socket,
file_descriptor=args.file_descriptor
)
endpoints = sorted(
args.socket_strings + endpoints
file_descriptor=args.file_descriptor,
)
endpoints = sorted(args.socket_strings + endpoints)
# Start the server
logger.info(
"Starting server at %s" %
(", ".join(endpoints), )
)
logger.info("Starting server at %s" % (", ".join(endpoints),))
self.server = self.server_class(
application=application,
endpoints=endpoints,
@ -208,12 +205,20 @@ class CommandLineInterface(object):
websocket_timeout=args.websocket_timeout,
websocket_connect_timeout=args.websocket_connect_timeout,
application_close_timeout=args.application_close_timeout,
action_logger=AccessLogGenerator(access_log_stream) if access_log_stream else None,
action_logger=AccessLogGenerator(access_log_stream)
if access_log_stream
else None,
ws_protocols=args.ws_protocols,
root_path=args.root_path,
verbosity=args.verbosity,
proxy_forwarded_address_header="X-Forwarded-For" if args.proxy_headers else None,
proxy_forwarded_port_header="X-Forwarded-Port" if args.proxy_headers else None,
proxy_forwarded_proto_header="X-Forwarded-Proto" if args.proxy_headers else None,
proxy_forwarded_address_header="X-Forwarded-For"
if args.proxy_headers
else None,
proxy_forwarded_port_header="X-Forwarded-Port"
if args.proxy_headers
else None,
proxy_forwarded_proto_header="X-Forwarded-Proto"
if args.proxy_headers
else None,
)
self.server.run()

View File

@ -1,10 +1,5 @@
def build_endpoint_description_strings(
host=None,
port=None,
unix_socket=None,
file_descriptor=None
host=None, port=None, unix_socket=None, file_descriptor=None
):
"""
Build a list of twisted endpoint description strings that the server will listen on.

View File

@ -23,7 +23,8 @@ class WebRequest(http.Request):
GET and POST out.
"""
error_template = """
error_template = (
"""
<html>
<head>
<title>%(title)s</title>
@ -40,7 +41,13 @@ class WebRequest(http.Request):
<footer>Daphne</footer>
</body>
</html>
""".replace("\n", "").replace(" ", " ").replace(" ", " ").replace(" ", " ") # Shorten it a bit, bytes wise
""".replace(
"\n", ""
)
.replace(" ", " ")
.replace(" ", " ")
.replace(" ", " ")
) # Shorten it a bit, bytes wise
def __init__(self, *args, **kwargs):
try:
@ -84,7 +91,7 @@ class WebRequest(http.Request):
self.server.proxy_forwarded_port_header,
self.server.proxy_forwarded_proto_header,
self.client_addr,
self.client_scheme
self.client_scheme,
)
# Check for unicodeish path (or it'll crash when trying to parse)
try:
@ -105,7 +112,9 @@ class WebRequest(http.Request):
# Is it WebSocket? IS IT?!
if upgrade_header and upgrade_header.lower() == b"websocket":
# Make WebSocket protocol to hand off to
protocol = self.server.ws_factory.buildProtocol(self.transport.getPeer())
protocol = self.server.ws_factory.buildProtocol(
self.transport.getPeer()
)
if not protocol:
# If protocol creation fails, we signal "internal server error"
self.setResponseCode(500)
@ -151,10 +160,15 @@ class WebRequest(http.Request):
logger.debug("HTTP %s request for %s", self.method, self.client_addr)
self.content.seek(0, 0)
# Work out the application scope and create application
self.application_queue = yield maybeDeferred(self.server.create_application, self, {
self.application_queue = yield maybeDeferred(
self.server.create_application,
self,
{
"type": "http",
# TODO: Correctly say if it's 1.1 or 1.0
"http_version": self.clientproto.split(b"/")[-1].decode("ascii"),
"http_version": self.clientproto.split(b"/")[-1].decode(
"ascii"
),
"method": self.method.decode("ascii"),
"path": unquote(self.path.decode("ascii")),
"root_path": self.root_path,
@ -163,21 +177,21 @@ class WebRequest(http.Request):
"headers": self.clean_headers,
"client": self.client_addr,
"server": self.server_addr,
})
},
)
# Check they didn't close an unfinished request
if self.application_queue is None or self.content.closed:
# Not much we can do, the request is prematurely abandoned.
return
# Run application against request
self.application_queue.put_nowait(
{
"type": "http.request",
"body": self.content.read(),
},
{"type": "http.request", "body": self.content.read()}
)
except Exception:
logger.error(traceback.format_exc())
self.basic_error(500, b"Internal Server Error", "Daphne HTTP processing error")
self.basic_error(
500, b"Internal Server Error", "Daphne HTTP processing error"
)
def connectionLost(self, reason):
"""
@ -217,16 +231,23 @@ class WebRequest(http.Request):
raise ValueError("HTTP response has already been started")
self._response_started = True
if "status" not in message:
raise ValueError("Specifying a status code is required for a Response message.")
raise ValueError(
"Specifying a status code is required for a Response message."
)
# Set HTTP status code
self.setResponseCode(message["status"])
# Write headers
for header, value in message.get("headers", {}):
self.responseHeaders.addRawHeader(header, value)
logger.debug("HTTP %s response started for %s", message["status"], self.client_addr)
logger.debug(
"HTTP %s response started for %s", message["status"], self.client_addr
)
elif message["type"] == "http.response.body":
if not self._response_started:
raise ValueError("HTTP response has not yet been started but got %s" % message["type"])
raise ValueError(
"HTTP response has not yet been started but got %s"
% message["type"]
)
# Write out body
http.Request.write(self, message.get("body", b""))
# End if there's no more content
@ -239,15 +260,21 @@ class WebRequest(http.Request):
# The path is malformed somehow - do our best to log something
uri = repr(self.uri)
try:
self.server.log_action("http", "complete", {
self.server.log_action(
"http",
"complete",
{
"path": uri,
"status": self.code,
"method": self.method.decode("ascii", "replace"),
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
"client": "%s:%s" % tuple(self.client_addr)
if self.client_addr
else None,
"time_taken": self.duration(),
"size": self.sentLength,
})
except Exception as e:
},
)
except Exception:
logger.error(traceback.format_exc())
else:
logger.debug("HTTP response chunk for %s", self.client_addr)
@ -270,7 +297,11 @@ class WebRequest(http.Request):
logger.warning("Application timed out while sending response")
self.finish()
else:
self.basic_error(503, b"Service Unavailable", "Application failed to respond within time limit.")
self.basic_error(
503,
b"Service Unavailable",
"Application failed to respond within time limit.",
)
### Utility functions
@ -281,11 +312,7 @@ class WebRequest(http.Request):
"""
# If we don't yet have a path, then don't send as we never opened.
if self.path:
self.application_queue.put_nowait(
{
"type": "http.disconnect",
},
)
self.application_queue.put_nowait({"type": "http.disconnect"})
def duration(self):
"""
@ -299,20 +326,25 @@ class WebRequest(http.Request):
"""
Responds with a server-level error page (very basic)
"""
self.handle_reply({
self.handle_reply(
{
"type": "http.response.start",
"status": status,
"headers": [
(b"Content-Type", b"text/html; charset=utf-8"),
],
})
self.handle_reply({
"headers": [(b"Content-Type", b"text/html; charset=utf-8")],
}
)
self.handle_reply(
{
"type": "http.response.body",
"body": (self.error_template % {
"body": (
self.error_template
% {
"title": str(status) + " " + status_text.decode("ascii"),
"body": body,
}).encode("utf8"),
})
}
).encode("utf8"),
}
)
def __hash__(self):
return hash(id(self))
@ -343,7 +375,7 @@ class HTTPFactory(http.HTTPFactory):
protocol = http.HTTPFactory.buildProtocol(self, addr)
protocol.requestFactory = WebRequest
return protocol
except Exception as e:
except Exception:
logger.error("Cannot build protocol: %s" % traceback.format_exc())
raise

View File

@ -2,13 +2,14 @@
import sys # isort:skip
import warnings # isort:skip
from twisted.internet import asyncioreactor # isort:skip
current_reactor = sys.modules.get("twisted.internet.reactor", None)
if current_reactor is not None:
if not isinstance(current_reactor, asyncioreactor.AsyncioSelectorReactor):
warnings.warn(
"Something has already installed a non-asyncio Twisted reactor. Attempting to uninstall it; " +
"you can fix this warning by importing daphne.server early in your codebase or " +
"finding the package that imports Twisted and importing it later on.",
"Something has already installed a non-asyncio Twisted reactor. Attempting to uninstall it; "
+ "you can fix this warning by importing daphne.server early in your codebase or "
+ "finding the package that imports Twisted and importing it later on.",
UserWarning,
)
del sys.modules["twisted.internet.reactor"]
@ -34,7 +35,6 @@ logger = logging.getLogger(__name__)
class Server(object):
def __init__(
self,
application,
@ -91,11 +91,13 @@ class Server(object):
self.ws_factory.setProtocolOptions(
autoPingTimeout=self.ping_timeout,
allowNullOrigin=True,
openHandshakeTimeout=self.websocket_handshake_timeout
openHandshakeTimeout=self.websocket_handshake_timeout,
)
if self.verbosity <= 1:
# Redirect the Twisted log to nowhere
globalLogBeginner.beginLoggingTo([lambda _: None], redirectStandardIO=False, discardBuffer=True)
globalLogBeginner.beginLoggingTo(
[lambda _: None], redirectStandardIO=False, discardBuffer=True
)
else:
globalLogBeginner.beginLoggingTo([STDLibLogObserver(__name__)])
@ -103,7 +105,9 @@ class Server(object):
if http.H2_ENABLED:
logger.info("HTTP/2 support enabled")
else:
logger.info("HTTP/2 support not enabled (install the http2 and tls Twisted extras)")
logger.info(
"HTTP/2 support not enabled (install the http2 and tls Twisted extras)"
)
# Kick off the timeout loop
reactor.callLater(1, self.application_checker)
@ -141,7 +145,11 @@ class Server(object):
host = port.getHost()
if hasattr(host, "host") and hasattr(host, "port"):
self.listening_addresses.append((host.host, host.port))
logger.info("Listening on TCP address %s:%s", port.getHost().host, port.getHost().port)
logger.info(
"Listening on TCP address %s:%s",
port.getHost().host,
port.getHost().port,
)
def listen_error(self, failure):
logger.critical("Listen failure: %s", failure.getErrorMessage())
@ -187,10 +195,13 @@ class Server(object):
# Run it, and stash the future for later checking
if protocol not in self.connections:
return None
self.connections[protocol]["application_instance"] = asyncio.ensure_future(application_instance(
self.connections[protocol]["application_instance"] = asyncio.ensure_future(
application_instance(
receive=input_queue.get,
send=lambda message: self.handle_reply(protocol, message),
), loop=asyncio.get_event_loop())
),
loop=asyncio.get_event_loop(),
)
return input_queue
async def handle_reply(self, protocol, message):
@ -215,7 +226,10 @@ class Server(object):
application_instance = details.get("application_instance", None)
# First, see if the protocol disconnected and the app has taken
# too long to close up
if disconnected and time.time() - disconnected > self.application_close_timeout:
if (
disconnected
and time.time() - disconnected > self.application_close_timeout
):
if application_instance and not application_instance.done():
logger.warning(
"Application instance %r for connection %s took too long to shut down and was killed.",
@ -238,14 +252,11 @@ class Server(object):
else:
exception_output = "{}\n{}{}".format(
exception,
"".join(traceback.format_tb(
exception.__traceback__,
)),
"".join(traceback.format_tb(exception.__traceback__)),
" {}".format(exception),
)
logger.error(
"Exception inside application: %s",
exception_output,
"Exception inside application: %s", exception_output
)
if not disconnected:
protocol.handle_exception(exception)

View File

@ -100,9 +100,7 @@ class DaphneTestingInstance:
Adds messages for the application to send back.
The next time it receives an incoming message, it will reply with these.
"""
TestApplication.save_setup(
response_messages=messages,
)
TestApplication.save_setup(response_messages=messages)
class DaphneProcess(multiprocessing.Process):
@ -193,12 +191,7 @@ class TestApplication:
Stores setup information.
"""
with open(cls.setup_storage, "wb") as fh:
pickle.dump(
{
"response_messages": response_messages,
},
fh,
)
pickle.dump({"response_messages": response_messages}, fh)
@classmethod
def load_setup(cls):
@ -218,13 +211,7 @@ class TestApplication:
We could use pickle here, but that seems wrong, still, somehow.
"""
with open(cls.result_storage, "wb") as fh:
pickle.dump(
{
"scope": scope,
"messages": messages,
},
fh,
)
pickle.dump({"scope": scope, "messages": messages}, fh)
@classmethod
def save_exception(cls, exception):
@ -233,12 +220,7 @@ class TestApplication:
We could use pickle here, but that seems wrong, still, somehow.
"""
with open(cls.result_storage, "wb") as fh:
pickle.dump(
{
"exception": exception,
},
fh,
)
pickle.dump({"exception": exception}, fh)
@classmethod
def load_result(cls):

View File

@ -22,12 +22,14 @@ def header_value(headers, header_name):
return value.decode("utf-8")
def parse_x_forwarded_for(headers,
def parse_x_forwarded_for(
headers,
address_header_name="X-Forwarded-For",
port_header_name="X-Forwarded-Port",
proto_header_name="X-Forwarded-Proto",
original_addr=None,
original_scheme=None):
original_scheme=None,
):
"""
Parses an X-Forwarded-For header and returns a host/port pair as a list.

View File

@ -3,7 +3,11 @@ import time
import traceback
from urllib.parse import unquote
from autobahn.twisted.websocket import ConnectionDeny, WebSocketServerFactory, WebSocketServerProtocol
from autobahn.twisted.websocket import (
ConnectionDeny,
WebSocketServerFactory,
WebSocketServerProtocol,
)
from twisted.internet import defer
from .utils import parse_x_forwarded_for
@ -54,20 +58,21 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.server.proxy_forwarded_address_header,
self.server.proxy_forwarded_port_header,
self.server.proxy_forwarded_proto_header,
self.client_addr
self.client_addr,
)
# Decode websocket subprotocol options
subprotocols = []
for header, value in self.clean_headers:
if header == b"sec-websocket-protocol":
subprotocols = [
x.strip()
for x in
unquote(value.decode("ascii")).split(",")
x.strip() for x in unquote(value.decode("ascii")).split(",")
]
# Make new application instance with scope
self.path = request.path.encode("ascii")
self.application_deferred = defer.maybeDeferred(self.server.create_application, self, {
self.application_deferred = defer.maybeDeferred(
self.server.create_application,
self,
{
"type": "websocket",
"path": unquote(self.path.decode("ascii")),
"headers": self.clean_headers,
@ -75,11 +80,12 @@ class WebSocketProtocol(WebSocketServerProtocol):
"client": self.client_addr,
"server": self.server_addr,
"subprotocols": subprotocols,
})
},
)
if self.application_deferred is not None:
self.application_deferred.addCallback(self.applicationCreateWorked)
self.application_deferred.addErrback(self.applicationCreateFailed)
except Exception as e:
except Exception:
# Exceptions here are not displayed right, just 500.
# Turn them into an ERROR log.
logger.error(traceback.format_exc())
@ -98,10 +104,16 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.application_queue = application_queue
# Send over the connect message
self.application_queue.put_nowait({"type": "websocket.connect"})
self.server.log_action("websocket", "connecting", {
self.server.log_action(
"websocket",
"connecting",
{
"path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
})
"client": "%s:%s" % tuple(self.client_addr)
if self.client_addr
else None,
},
)
def applicationCreateFailed(self, failure):
"""
@ -115,10 +127,16 @@ class WebSocketProtocol(WebSocketServerProtocol):
def onOpen(self):
# Send news that this channel is open
logger.debug("WebSocket %s open and established", self.client_addr)
self.server.log_action("websocket", "connected", {
self.server.log_action(
"websocket",
"connected",
{
"path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
})
"client": "%s:%s" % tuple(self.client_addr)
if self.client_addr
else None,
},
)
def onMessage(self, payload, isBinary):
# If we're muted, do nothing.
@ -128,15 +146,13 @@ class WebSocketProtocol(WebSocketServerProtocol):
logger.debug("WebSocket incoming frame on %s", self.client_addr)
self.last_ping = time.time()
if isBinary:
self.application_queue.put_nowait({
"type": "websocket.receive",
"bytes": payload,
})
self.application_queue.put_nowait(
{"type": "websocket.receive", "bytes": payload}
)
else:
self.application_queue.put_nowait({
"type": "websocket.receive",
"text": payload.decode("utf8"),
})
self.application_queue.put_nowait(
{"type": "websocket.receive", "text": payload.decode("utf8")}
)
def onClose(self, wasClean, code, reason):
"""
@ -145,14 +161,19 @@ class WebSocketProtocol(WebSocketServerProtocol):
self.server.protocol_disconnected(self)
logger.debug("WebSocket closed for %s", self.client_addr)
if not self.muted and hasattr(self, "application_queue"):
self.application_queue.put_nowait({
"type": "websocket.disconnect",
"code": code,
})
self.server.log_action("websocket", "disconnected", {
self.application_queue.put_nowait(
{"type": "websocket.disconnect", "code": code}
)
self.server.log_action(
"websocket",
"disconnected",
{
"path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
})
"client": "%s:%s" % tuple(self.client_addr)
if self.client_addr
else None,
},
)
### Internal event handling
@ -171,9 +192,8 @@ class WebSocketProtocol(WebSocketServerProtocol):
raise ValueError("Socket has not been accepted, so cannot send over it")
if message.get("bytes", None) and message.get("text", None):
raise ValueError(
"Got invalid WebSocket reply message on %s - contains both bytes and text keys" % (
message,
)
"Got invalid WebSocket reply message on %s - contains both bytes and text keys"
% (message,)
)
if message.get("bytes", None):
self.serverSend(message["bytes"], True)
@ -187,7 +207,9 @@ class WebSocketProtocol(WebSocketServerProtocol):
if hasattr(self, "handshake_deferred"):
# If the handshake is still ongoing, we need to emit a HTTP error
# code rather than a WebSocket one.
self.handshake_deferred.errback(ConnectionDeny(code=500, reason="Internal server error"))
self.handshake_deferred.errback(
ConnectionDeny(code=500, reason="Internal server error")
)
else:
self.sendCloseFrame(code=1011)
@ -203,14 +225,22 @@ class WebSocketProtocol(WebSocketServerProtocol):
"""
Called when we get a message saying to reject the connection.
"""
self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied"))
self.handshake_deferred.errback(
ConnectionDeny(code=403, reason="Access denied")
)
del self.handshake_deferred
self.server.protocol_disconnected(self)
logger.debug("WebSocket %s rejected by application", self.client_addr)
self.server.log_action("websocket", "rejected", {
self.server.log_action(
"websocket",
"rejected",
{
"path": self.request.path,
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
})
"client": "%s:%s" % tuple(self.client_addr)
if self.client_addr
else None,
},
)
def serverSend(self, content, binary=False):
"""
@ -244,7 +274,10 @@ class WebSocketProtocol(WebSocketServerProtocol):
Called periodically to see if we should timeout something
"""
# Web timeout checking
if self.duration() > self.server.websocket_timeout and self.server.websocket_timeout >= 0:
if (
self.duration() > self.server.websocket_timeout
and self.server.websocket_timeout >= 0
):
self.serverClose()
# Ping check
# If we're still connecting, deny the connection
@ -287,6 +320,6 @@ class WebSocketFactory(WebSocketServerFactory):
protocol = super(WebSocketFactory, self).buildProtocol(addr)
protocol.factory = self
return protocol
except Exception as e:
except Exception:
logger.error("Cannot build protocol: %s" % traceback.format_exc())
raise

View File

@ -5,9 +5,9 @@ universal=1
addopts = tests/
[isort]
line_length = 120
include_trailing_comma = True
multi_line_output = 3
known_first_party = channels,daphne,asgiref
known_first_party = channels,daphne,asgiref,channels_redis
[flake8]
exclude = venv/*,tox/*,docs/*,testproject/*,js_client/*,.eggs/*

View File

@ -22,23 +22,12 @@ setup(
package_dir={"twisted": "daphne/twisted"},
packages=find_packages() + ["twisted.plugins"],
include_package_data=True,
install_requires=[
"twisted>=18.7",
"autobahn>=0.18",
],
setup_requires=[
"pytest-runner",
],
extras_require={
"tests": [
"hypothesis",
"pytest",
"pytest-asyncio~=0.8",
],
install_requires=["twisted>=18.7", "autobahn>=0.18"],
setup_requires=["pytest-runner"],
extras_require={"tests": ["hypothesis", "pytest", "pytest-asyncio~=0.8"]},
entry_points={
"console_scripts": ["daphne = daphne.cli:CommandLineInterface.entrypoint"]
},
entry_points={"console_scripts": [
"daphne = daphne.cli:CommandLineInterface.entrypoint",
]},
classifiers=[
"Development Status :: 4 - Beta",
"Environment :: Web Environment",

View File

@ -19,7 +19,9 @@ class DaphneTestCase(unittest.TestCase):
### Plain HTTP helpers
def run_daphne_http(self, method, path, params, body, responses, headers=None, timeout=1, xff=False):
def run_daphne_http(
self, method, path, params, body, responses, headers=None, timeout=1, xff=False
):
"""
Runs Daphne with the given request callback (given the base URL)
and response messages.
@ -38,7 +40,9 @@ class DaphneTestCase(unittest.TestCase):
# Manually send over headers (encoding any non-safe values as best we can)
if headers:
for header_name, header_value in headers:
conn.putheader(header_name.encode("utf8"), header_value.encode("utf8"))
conn.putheader(
header_name.encode("utf8"), header_value.encode("utf8")
)
# Send body if provided.
if body:
conn.putheader("Content-Length", str(len(body)))
@ -50,7 +54,9 @@ class DaphneTestCase(unittest.TestCase):
except socket.timeout:
# See if they left an exception for us to load
test_app.get_received()
raise RuntimeError("Daphne timed out handling request, no exception found.")
raise RuntimeError(
"Daphne timed out handling request, no exception found."
)
# Return scope, messages, response
return test_app.get_received() + (response,)
@ -68,9 +74,13 @@ class DaphneTestCase(unittest.TestCase):
try:
return s.recv(1000000)
except socket.timeout:
raise RuntimeError("Daphne timed out handling raw request, no exception found.")
raise RuntimeError(
"Daphne timed out handling raw request, no exception found."
)
def run_daphne_request(self, method, path, params=None, body=None, headers=None, xff=False):
def run_daphne_request(
self, method, path, params=None, body=None, headers=None, xff=False
):
"""
Convenience method for just testing request handling.
Returns (scope, messages)
@ -95,17 +105,21 @@ class DaphneTestCase(unittest.TestCase):
Returns (scope, messages)
"""
_, _, response = self.run_daphne_http(
method="GET",
path="/",
params={},
body=b"",
responses=response_messages,
method="GET", path="/", params={}, body=b"", responses=response_messages
)
return response
### WebSocket helpers
def websocket_handshake(self, test_app, path="/", params=None, headers=None, subprotocols=None, timeout=1):
def websocket_handshake(
self,
test_app,
path="/",
params=None,
headers=None,
subprotocols=None,
timeout=1,
):
"""
Runs a WebSocket handshake negotiation and returns the raw socket
object & the selected subprotocol.
@ -124,14 +138,16 @@ class DaphneTestCase(unittest.TestCase):
# Do WebSocket handshake headers + any other headers
if headers is None:
headers = []
headers.extend([
headers.extend(
[
("Host", "example.com"),
("Upgrade", "websocket"),
("Connection", "Upgrade"),
("Sec-WebSocket-Key", "x3JJHMbDL1EzLkh9GBhXDw=="),
("Sec-WebSocket-Version", "13"),
("Origin", "http://example.com")
])
("Origin", "http://example.com"),
]
)
if subprotocols:
headers.append(("Sec-WebSocket-Protocol", ", ".join(subprotocols)))
if headers:
@ -149,10 +165,7 @@ class DaphneTestCase(unittest.TestCase):
if response.status != 101:
raise RuntimeError("WebSocket upgrade did not result in status code 101")
# Prepare headers for subprotocol searching
response_headers = dict(
(n.lower(), v)
for n, v in response.getheaders()
)
response_headers = dict((n.lower(), v) for n, v in response.getheaders())
response.read()
assert not response.closed
# Return the raw socket and any subprotocol
@ -234,10 +247,7 @@ class DaphneTestCase(unittest.TestCase):
# Make sure all required keys are present
self.assertTrue(required_keys <= present_keys)
# Assert that no other keys are present
self.assertEqual(
set(),
present_keys - required_keys - optional_keys,
)
self.assertEqual(set(), present_keys - required_keys - optional_keys)
def assert_valid_path(self, path, request_path):
"""

View File

@ -6,7 +6,9 @@ from hypothesis import strategies
HTTP_METHODS = ["OPTIONS", "GET", "HEAD", "POST", "PUT", "DELETE", "TRACE", "CONNECT"]
# Unicode characters of the "Letter" category
letters = strategies.characters(whitelist_categories=("Lu", "Ll", "Lt", "Lm", "Lo", "Nl"))
letters = strategies.characters(
whitelist_categories=("Lu", "Ll", "Lt", "Lm", "Lo", "Nl")
)
def http_method():
@ -22,11 +24,9 @@ def http_path():
"""
Returns a URL path (not encoded).
"""
return strategies.lists(
_http_path_portion(),
min_size=0,
max_size=10,
).map(lambda s: "/" + "/".join(s))
return strategies.lists(_http_path_portion(), min_size=0, max_size=10).map(
lambda s: "/" + "/".join(s)
)
def http_body():
@ -53,10 +53,7 @@ def valid_bidi(value):
def _domain_label():
return strategies.text(
alphabet=letters,
min_size=1,
average_size=6,
max_size=63,
alphabet=letters, min_size=1, average_size=6, max_size=63
).filter(valid_bidi)
@ -64,19 +61,14 @@ def international_domain_name():
"""
Returns a byte string of a domain name, IDNA-encoded.
"""
return strategies.lists(
_domain_label(),
min_size=2,
average_size=2,
).map(lambda s: (".".join(s)).encode("idna"))
return strategies.lists(_domain_label(), min_size=2, average_size=2).map(
lambda s: (".".join(s)).encode("idna")
)
def _query_param():
return strategies.text(
alphabet=letters,
min_size=1,
average_size=10,
max_size=255,
alphabet=letters, min_size=1, average_size=10, max_size=255
).map(lambda s: s.encode("utf8"))
@ -87,9 +79,7 @@ def query_params():
ensures that the total urlencoded query string is not longer than 1500 characters.
"""
return strategies.lists(
strategies.tuples(_query_param(), _query_param()),
min_size=0,
average_size=5,
strategies.tuples(_query_param(), _query_param()), min_size=0, average_size=5
).filter(lambda x: len(parse.urlencode(x)) < 1500)
@ -101,9 +91,7 @@ def header_name():
and 20 characters long
"""
return strategies.text(
alphabet=string.ascii_letters + string.digits + "-",
min_size=1,
max_size=30,
alphabet=string.ascii_letters + string.digits + "-", min_size=1, max_size=30
)
@ -115,7 +103,10 @@ def header_value():
https://en.wikipedia.org/wiki/List_of_HTTP_header_fields
"""
return strategies.text(
alphabet=string.ascii_letters + string.digits + string.punctuation.replace(",", "") + " /t",
alphabet=string.ascii_letters
+ string.digits
+ string.punctuation.replace(",", "")
+ " /t",
min_size=1,
average_size=40,
max_size=8190,

View File

@ -18,45 +18,32 @@ class TestEndpointDescriptions(TestCase):
def testTcpPortBindings(self):
self.assertEqual(
build(port=1234, host="example.com"),
["tcp:port=1234:interface=example.com"]
["tcp:port=1234:interface=example.com"],
)
self.assertEqual(
build(port=8000, host="127.0.0.1"),
["tcp:port=8000:interface=127.0.0.1"]
build(port=8000, host="127.0.0.1"), ["tcp:port=8000:interface=127.0.0.1"]
)
self.assertEqual(
build(port=8000, host="[200a::1]"),
[r'tcp:port=8000:interface=200a\:\:1']
build(port=8000, host="[200a::1]"), [r"tcp:port=8000:interface=200a\:\:1"]
)
self.assertEqual(
build(port=8000, host="200a::1"),
[r'tcp:port=8000:interface=200a\:\:1']
build(port=8000, host="200a::1"), [r"tcp:port=8000:interface=200a\:\:1"]
)
# incomplete port/host kwargs raise errors
self.assertRaises(
ValueError,
build, port=123
)
self.assertRaises(
ValueError,
build, host="example.com"
)
self.assertRaises(ValueError, build, port=123)
self.assertRaises(ValueError, build, host="example.com")
def testUnixSocketBinding(self):
self.assertEqual(
build(unix_socket="/tmp/daphne.sock"),
["unix:/tmp/daphne.sock"]
build(unix_socket="/tmp/daphne.sock"), ["unix:/tmp/daphne.sock"]
)
def testFileDescriptorBinding(self):
self.assertEqual(
build(file_descriptor=5),
["fd:fileno=5"]
)
self.assertEqual(build(file_descriptor=5), ["fd:fileno=5"])
def testMultipleEnpoints(self):
self.assertEqual(
@ -65,14 +52,16 @@ class TestEndpointDescriptions(TestCase):
file_descriptor=123,
unix_socket="/tmp/daphne.sock",
port=8080,
host="10.0.0.1"
host="10.0.0.1",
)
),
sorted([
sorted(
[
"tcp:port=8080:interface=10.0.0.1",
"unix:/tmp/daphne.sock",
"fd:fileno=123"
])
"fd:fileno=123",
]
),
)
@ -112,7 +101,9 @@ class TestCLIInterface(TestCase):
Passes in a fake application automatically.
"""
cli = self.TestedCLI()
cli.run(args + ["daphne:__version__"]) # We just pass something importable as app
cli.run(
args + ["daphne:__version__"]
) # We just pass something importable as app
# Check the server got all arguments as intended
for key, value in server_kwargs.items():
# Get the value and sort it if it's a list (for endpoint checking)
@ -123,52 +114,30 @@ class TestCLIInterface(TestCase):
self.assertEqual(
value,
actual_value,
"Wrong value for server kwarg %s: %r != %r" % (
key,
value,
actual_value,
),
"Wrong value for server kwarg %s: %r != %r"
% (key, value, actual_value),
)
def testCLIBasics(self):
"""
Tests basic endpoint generation.
"""
self.assertCLI([], {"endpoints": ["tcp:port=8000:interface=127.0.0.1"]})
self.assertCLI(
[],
{
"endpoints": ["tcp:port=8000:interface=127.0.0.1"],
},
["-p", "123"], {"endpoints": ["tcp:port=123:interface=127.0.0.1"]}
)
self.assertCLI(
["-p", "123"],
{
"endpoints": ["tcp:port=123:interface=127.0.0.1"],
},
["-b", "10.0.0.1"], {"endpoints": ["tcp:port=8000:interface=10.0.0.1"]}
)
self.assertCLI(
["-b", "10.0.0.1"],
{
"endpoints": ["tcp:port=8000:interface=10.0.0.1"],
},
["-b", "200a::1"], {"endpoints": [r"tcp:port=8000:interface=200a\:\:1"]}
)
self.assertCLI(
["-b", "200a::1"],
{
"endpoints": [r'tcp:port=8000:interface=200a\:\:1'],
},
)
self.assertCLI(
["-b", "[200a::1]"],
{
"endpoints": [r'tcp:port=8000:interface=200a\:\:1'],
},
["-b", "[200a::1]"], {"endpoints": [r"tcp:port=8000:interface=200a\:\:1"]}
)
self.assertCLI(
["-p", "8080", "-b", "example.com"],
{
"endpoints": ["tcp:port=8080:interface=example.com"],
},
{"endpoints": ["tcp:port=8080:interface=example.com"]},
)
def testUnixSockets(self):
@ -178,7 +147,7 @@ class TestCLIInterface(TestCase):
"endpoints": [
"tcp:port=8080:interface=127.0.0.1",
"unix:/tmp/daphne.sock",
],
]
},
)
self.assertCLI(
@ -187,17 +156,12 @@ class TestCLIInterface(TestCase):
"endpoints": [
"tcp:port=8000:interface=example.com",
"unix:/tmp/daphne.sock",
],
]
},
)
self.assertCLI(
["-u", "/tmp/daphne.sock", "--fd", "5"],
{
"endpoints": [
"fd:fileno=5",
"unix:/tmp/daphne.sock"
],
},
{"endpoints": ["fd:fileno=5", "unix:/tmp/daphne.sock"]},
)
def testMixedCLIEndpointCreation(self):
@ -209,8 +173,8 @@ class TestCLIInterface(TestCase):
{
"endpoints": [
"tcp:port=8080:interface=127.0.0.1",
"unix:/tmp/daphne.sock"
],
"unix:/tmp/daphne.sock",
]
},
)
self.assertCLI(
@ -219,7 +183,7 @@ class TestCLIInterface(TestCase):
"endpoints": [
"tcp:port=8080:interface=127.0.0.1",
"tcp:port=8080:interface=127.0.0.1",
],
]
},
)
@ -227,11 +191,4 @@ class TestCLIInterface(TestCase):
"""
Tests entirely custom endpoints
"""
self.assertCLI(
["-e", "imap:"],
{
"endpoints": [
"imap:",
],
},
)
self.assertCLI(["-e", "imap:"], {"endpoints": ["imap:"]})

View File

@ -15,13 +15,7 @@ class TestHTTPRequest(DaphneTestCase):
"""
def assert_valid_http_scope(
self,
scope,
method,
path,
params=None,
headers=None,
scheme=None,
self, scope, method, path, params=None, headers=None, scheme=None
):
"""
Checks that the passed scope is a valid ASGI HTTP scope regarding types
@ -29,7 +23,14 @@ class TestHTTPRequest(DaphneTestCase):
"""
# Check overall keys
self.assert_key_sets(
required_keys={"type", "http_version", "method", "path", "query_string", "headers"},
required_keys={
"type",
"http_version",
"method",
"path",
"query_string",
"headers",
},
optional_keys={"scheme", "root_path", "client", "server"},
actual_keys=scope.keys(),
)
@ -50,7 +51,9 @@ class TestHTTPRequest(DaphneTestCase):
query_string = scope["query_string"]
self.assertIsInstance(query_string, bytes)
if params:
self.assertEqual(query_string, parse.urlencode(params or []).encode("ascii"))
self.assertEqual(
query_string, parse.urlencode(params or []).encode("ascii")
)
# Ordering of header names is not important, but the order of values for a header
# name is. To assert whether that order is kept, we transform both the request
# headers and the channel message headers into a dictionary
@ -59,7 +62,7 @@ class TestHTTPRequest(DaphneTestCase):
for name, value in scope["headers"]:
transformed_scope_headers[name].append(value)
transformed_request_headers = collections.defaultdict(list)
for name, value in (headers or []):
for name, value in headers or []:
expected_name = name.lower().strip().encode("ascii")
expected_value = value.strip().encode("ascii")
transformed_request_headers[expected_name].append(expected_value)
@ -103,27 +106,31 @@ class TestHTTPRequest(DaphneTestCase):
@given(
request_path=http_strategies.http_path(),
request_params=http_strategies.query_params()
request_params=http_strategies.query_params(),
)
@settings(max_examples=5, deadline=5000)
def test_get_request(self, request_path, request_params):
"""
Tests a typical HTTP GET request, with a path and query parameters
"""
scope, messages = self.run_daphne_request("GET", request_path, params=request_params)
scope, messages = self.run_daphne_request(
"GET", request_path, params=request_params
)
self.assert_valid_http_scope(scope, "GET", request_path, params=request_params)
self.assert_valid_http_request_message(messages[0], body=b"")
@given(
request_path=http_strategies.http_path(),
request_body=http_strategies.http_body()
request_body=http_strategies.http_body(),
)
@settings(max_examples=5, deadline=5000)
def test_post_request(self, request_path, request_body):
"""
Tests a typical HTTP POST request, with a path and body.
"""
scope, messages = self.run_daphne_request("POST", request_path, body=request_body)
scope, messages = self.run_daphne_request(
"POST", request_path, body=request_body
)
self.assert_valid_http_scope(scope, "POST", request_path)
self.assert_valid_http_request_message(messages[0], body=request_body)
@ -134,8 +141,12 @@ class TestHTTPRequest(DaphneTestCase):
Tests that HTTP header fields are handled as specified
"""
request_path = "/te st-à/"
scope, messages = self.run_daphne_request("OPTIONS", request_path, headers=request_headers)
self.assert_valid_http_scope(scope, "OPTIONS", request_path, headers=request_headers)
scope, messages = self.run_daphne_request(
"OPTIONS", request_path, headers=request_headers
)
self.assert_valid_http_scope(
scope, "OPTIONS", request_path, headers=request_headers
)
self.assert_valid_http_request_message(messages[0], body=b"")
@given(request_headers=http_strategies.headers())
@ -150,8 +161,12 @@ class TestHTTPRequest(DaphneTestCase):
duplicated_headers = [(header_name, header[1]) for header in request_headers]
# Run the request
request_path = "/te st-à/"
scope, messages = self.run_daphne_request("OPTIONS", request_path, headers=duplicated_headers)
self.assert_valid_http_scope(scope, "OPTIONS", request_path, headers=duplicated_headers)
scope, messages = self.run_daphne_request(
"OPTIONS", request_path, headers=duplicated_headers
)
self.assert_valid_http_scope(
scope, "OPTIONS", request_path, headers=duplicated_headers
)
self.assert_valid_http_request_message(messages[0], body=b"")
@given(
@ -222,10 +237,7 @@ class TestHTTPRequest(DaphneTestCase):
"""
Make sure that, by default, X-Forwarded-For is ignored.
"""
headers = [
["X-Forwarded-For", "10.1.2.3"],
["X-Forwarded-Port", "80"],
]
headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]]
scope, messages = self.run_daphne_request("GET", "/", headers=headers)
self.assert_valid_http_scope(scope, "GET", "/", headers=headers)
self.assert_valid_http_request_message(messages[0], body=b"")
@ -236,10 +248,7 @@ class TestHTTPRequest(DaphneTestCase):
"""
When X-Forwarded-For is enabled, make sure it is respected.
"""
headers = [
["X-Forwarded-For", "10.1.2.3"],
["X-Forwarded-Port", "80"],
]
headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]]
scope, messages = self.run_daphne_request("GET", "/", headers=headers, xff=True)
self.assert_valid_http_scope(scope, "GET", "/", headers=headers)
self.assert_valid_http_request_message(messages[0], body=b"")
@ -251,9 +260,7 @@ class TestHTTPRequest(DaphneTestCase):
When X-Forwarded-For is enabled but only the host is passed, make sure
that at least makes it through.
"""
headers = [
["X-Forwarded-For", "10.1.2.3"],
]
headers = [["X-Forwarded-For", "10.1.2.3"]]
scope, messages = self.run_daphne_request("GET", "/", headers=headers, xff=True)
self.assert_valid_http_scope(scope, "GET", "/", headers=headers)
self.assert_valid_http_request_message(messages[0], body=b"")
@ -265,8 +272,12 @@ class TestHTTPRequest(DaphneTestCase):
Tests that requests with invalid (non-ASCII) characters fail.
"""
# Bad path
response = self.run_daphne_raw(b"GET /\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n")
response = self.run_daphne_raw(
b"GET /\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n"
)
self.assertTrue(response.startswith(b"HTTP/1.0 400 Bad Request"))
# Bad querystring
response = self.run_daphne_raw(b"GET /?\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n")
response = self.run_daphne_raw(
b"GET /?\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n"
)
self.assertTrue(response.startswith(b"HTTP/1.0 400 Bad Request"))

View File

@ -15,26 +15,24 @@ class TestHTTPResponse(DaphneTestCase):
"""
Lowercases and sorts headers, and strips transfer-encoding ones.
"""
return sorted([
return sorted(
[
(name.lower(), value.strip())
for name, value in headers
if name.lower() != "transfer-encoding"
])
]
)
def test_minimal_response(self):
"""
Smallest viable example. Mostly verifies that our response building works.
"""
response = self.run_daphne_response([
{
"type": "http.response.start",
"status": 200,
},
{
"type": "http.response.body",
"body": b"hello world",
},
])
response = self.run_daphne_response(
[
{"type": "http.response.start", "status": 200},
{"type": "http.response.body", "body": b"hello world"},
]
)
self.assertEqual(response.status, 200)
self.assertEqual(response.read(), b"hello world")
@ -46,30 +44,23 @@ class TestHTTPResponse(DaphneTestCase):
to make sure it stays required.
"""
with self.assertRaises(ValueError):
self.run_daphne_response([
{
"type": "http.response.start",
},
{
"type": "http.response.body",
"body": b"hello world",
},
])
self.run_daphne_response(
[
{"type": "http.response.start"},
{"type": "http.response.body", "body": b"hello world"},
]
)
def test_custom_status_code(self):
"""
Tries a non-default status code.
"""
response = self.run_daphne_response([
{
"type": "http.response.start",
"status": 201,
},
{
"type": "http.response.body",
"body": b"i made a thing!",
},
])
response = self.run_daphne_response(
[
{"type": "http.response.start", "status": 201},
{"type": "http.response.body", "body": b"i made a thing!"},
]
)
self.assertEqual(response.status, 201)
self.assertEqual(response.read(), b"i made a thing!")
@ -77,21 +68,13 @@ class TestHTTPResponse(DaphneTestCase):
"""
Tries sending a response in multiple parts.
"""
response = self.run_daphne_response([
{
"type": "http.response.start",
"status": 201,
},
{
"type": "http.response.body",
"body": b"chunk 1 ",
"more_body": True,
},
{
"type": "http.response.body",
"body": b"chunk 2",
},
])
response = self.run_daphne_response(
[
{"type": "http.response.start", "status": 201},
{"type": "http.response.body", "body": b"chunk 1 ", "more_body": True},
{"type": "http.response.body", "body": b"chunk 2"},
]
)
self.assertEqual(response.status, 201)
self.assertEqual(response.read(), b"chunk 1 chunk 2")
@ -99,25 +82,14 @@ class TestHTTPResponse(DaphneTestCase):
"""
Tries sending a response in multiple parts and an empty end.
"""
response = self.run_daphne_response([
{
"type": "http.response.start",
"status": 201,
},
{
"type": "http.response.body",
"body": b"chunk 1 ",
"more_body": True,
},
{
"type": "http.response.body",
"body": b"chunk 2",
"more_body": True,
},
{
"type": "http.response.body",
},
])
response = self.run_daphne_response(
[
{"type": "http.response.start", "status": 201},
{"type": "http.response.body", "body": b"chunk 1 ", "more_body": True},
{"type": "http.response.body", "body": b"chunk 2", "more_body": True},
{"type": "http.response.body"},
]
)
self.assertEqual(response.status, 201)
self.assertEqual(response.read(), b"chunk 1 chunk 2")
@ -127,16 +99,12 @@ class TestHTTPResponse(DaphneTestCase):
"""
Tries body variants.
"""
response = self.run_daphne_response([
{
"type": "http.response.start",
"status": 200,
},
{
"type": "http.response.body",
"body": body,
},
])
response = self.run_daphne_response(
[
{"type": "http.response.start", "status": 200},
{"type": "http.response.body", "body": body},
]
)
self.assertEqual(response.status, 200)
self.assertEqual(response.read(), body)
@ -144,16 +112,16 @@ class TestHTTPResponse(DaphneTestCase):
@settings(max_examples=5, deadline=5000)
def test_headers(self, headers):
# The ASGI spec requires us to lowercase our header names
response = self.run_daphne_response([
response = self.run_daphne_response(
[
{
"type": "http.response.start",
"status": 200,
"headers": self.normalize_headers(headers),
},
{
"type": "http.response.body",
},
])
{"type": "http.response.body"},
]
)
# Check headers in a sensible way. Ignore transfer-encoding.
self.assertEqual(
self.normalize_headers(response.getheaders()),

View File

@ -13,48 +13,35 @@ class TestXForwardedForHttpParsing(TestCase):
"""
def test_basic(self):
headers = Headers({
headers = Headers(
{
b"X-Forwarded-For": [b"10.1.2.3"],
b"X-Forwarded-Port": [b"1234"],
b"X-Forwarded-Proto": [b"https"]
})
b"X-Forwarded-Proto": [b"https"],
}
)
result = parse_x_forwarded_for(headers)
self.assertEqual(result, (["10.1.2.3", 1234], "https"))
self.assertIsInstance(result[0][0], str)
self.assertIsInstance(result[1], str)
def test_address_only(self):
headers = Headers({
b"X-Forwarded-For": [b"10.1.2.3"],
})
self.assertEqual(
parse_x_forwarded_for(headers),
(["10.1.2.3", 0], None)
)
headers = Headers({b"X-Forwarded-For": [b"10.1.2.3"]})
self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None))
def test_v6_address(self):
headers = Headers({
b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"],
})
self.assertEqual(
parse_x_forwarded_for(headers),
(["1043::a321:0001", 0], None)
)
headers = Headers({b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"]})
self.assertEqual(parse_x_forwarded_for(headers), (["1043::a321:0001", 0], None))
def test_multiple_proxys(self):
headers = Headers({
b"X-Forwarded-For": [b"10.1.2.3, 10.1.2.4"],
})
self.assertEqual(
parse_x_forwarded_for(headers),
(["10.1.2.3", 0], None)
)
headers = Headers({b"X-Forwarded-For": [b"10.1.2.3, 10.1.2.4"]})
self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None))
def test_original(self):
headers = Headers({})
self.assertEqual(
parse_x_forwarded_for(headers, original_addr=["127.0.0.1", 80]),
(["127.0.0.1", 80], None)
(["127.0.0.1", 80], None),
)
def test_no_original(self):
@ -73,43 +60,25 @@ class TestXForwardedForWsParsing(TestCase):
b"X-Forwarded-Port": b"1234",
b"X-Forwarded-Proto": b"https",
}
self.assertEqual(
parse_x_forwarded_for(headers),
(["10.1.2.3", 1234], "https")
)
self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 1234], "https"))
def test_address_only(self):
headers = {
b"X-Forwarded-For": b"10.1.2.3",
}
self.assertEqual(
parse_x_forwarded_for(headers),
(["10.1.2.3", 0], None)
)
headers = {b"X-Forwarded-For": b"10.1.2.3"}
self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None))
def test_v6_address(self):
headers = {
b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"],
}
self.assertEqual(
parse_x_forwarded_for(headers),
(["1043::a321:0001", 0], None)
)
headers = {b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"]}
self.assertEqual(parse_x_forwarded_for(headers), (["1043::a321:0001", 0], None))
def test_multiple_proxies(self):
headers = {
b"X-Forwarded-For": b"10.1.2.3, 10.1.2.4",
}
self.assertEqual(
parse_x_forwarded_for(headers),
(["10.1.2.3", 0], None)
)
headers = {b"X-Forwarded-For": b"10.1.2.3, 10.1.2.4"}
self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None))
def test_original(self):
headers = {}
self.assertEqual(
parse_x_forwarded_for(headers, original_addr=["127.0.0.1", 80]),
(["127.0.0.1", 80], None)
(["127.0.0.1", 80], None),
)
def test_no_original(self):

View File

@ -16,13 +16,7 @@ class TestWebsocket(DaphneTestCase):
"""
def assert_valid_websocket_scope(
self,
scope,
path="/",
params=None,
headers=None,
scheme=None,
subprotocols=None,
self, scope, path="/", params=None, headers=None, scheme=None, subprotocols=None
):
"""
Checks that the passed scope is a valid ASGI HTTP scope regarding types
@ -46,7 +40,9 @@ class TestWebsocket(DaphneTestCase):
query_string = scope["query_string"]
self.assertIsInstance(query_string, bytes)
if params:
self.assertEqual(query_string, parse.urlencode(params or []).encode("ascii"))
self.assertEqual(
query_string, parse.urlencode(params or []).encode("ascii")
)
# Ordering of header names is not important, but the order of values for a header
# name is. To assert whether that order is kept, we transform both the request
# headers and the channel message headers into a dictionary
@ -59,7 +55,7 @@ class TestWebsocket(DaphneTestCase):
if bit.strip():
transformed_scope_headers[name].append(bit.strip())
transformed_request_headers = collections.defaultdict(list)
for name, value in (headers or []):
for name, value in headers or []:
expected_name = name.lower().strip().encode("ascii")
expected_value = value.strip().encode("ascii")
# Make sure to split out any headers collapsed with commas
@ -92,9 +88,7 @@ class TestWebsocket(DaphneTestCase):
"""
# Check overall keys
self.assert_key_sets(
required_keys={"type"},
optional_keys=set(),
actual_keys=message.keys(),
required_keys={"type"}, optional_keys=set(), actual_keys=message.keys()
)
# Check that it is the right type
self.assertEqual(message["type"], "websocket.connect")
@ -104,11 +98,7 @@ class TestWebsocket(DaphneTestCase):
Tests we can open and accept a socket.
"""
with DaphneTestingInstance() as test_app:
test_app.add_send_messages([
{
"type": "websocket.accept",
}
])
test_app.add_send_messages([{"type": "websocket.accept"}])
self.websocket_handshake(test_app)
# Validate the scope and messages we got
scope, messages = test_app.get_received()
@ -120,11 +110,7 @@ class TestWebsocket(DaphneTestCase):
Tests we can reject a socket and it won't complete the handshake.
"""
with DaphneTestingInstance() as test_app:
test_app.add_send_messages([
{
"type": "websocket.close",
}
])
test_app.add_send_messages([{"type": "websocket.close"}])
with self.assertRaises(RuntimeError):
self.websocket_handshake(test_app)
@ -134,13 +120,12 @@ class TestWebsocket(DaphneTestCase):
"""
subprotocols = ["proto1", "proto2"]
with DaphneTestingInstance() as test_app:
test_app.add_send_messages([
{
"type": "websocket.accept",
"subprotocol": "proto2",
}
])
_, subprotocol = self.websocket_handshake(test_app, subprotocols=subprotocols)
test_app.add_send_messages(
[{"type": "websocket.accept", "subprotocol": "proto2"}]
)
_, subprotocol = self.websocket_handshake(
test_app, subprotocols=subprotocols
)
# Validate the scope and messages we got
assert subprotocol == "proto2"
scope, messages = test_app.get_received()
@ -151,16 +136,9 @@ class TestWebsocket(DaphneTestCase):
"""
Tests that X-Forwarded-For headers get parsed right
"""
headers = [
["X-Forwarded-For", "10.1.2.3"],
["X-Forwarded-Port", "80"],
]
headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]]
with DaphneTestingInstance(xff=True) as test_app:
test_app.add_send_messages([
{
"type": "websocket.accept",
}
])
test_app.add_send_messages([{"type": "websocket.accept"}])
self.websocket_handshake(test_app, headers=headers)
# Validate the scope and messages we got
scope, messages = test_app.get_received()
@ -174,22 +152,13 @@ class TestWebsocket(DaphneTestCase):
request_headers=http_strategies.headers(),
)
@settings(max_examples=5, deadline=2000)
def test_http_bits(
self,
request_path,
request_params,
request_headers,
):
def test_http_bits(self, request_path, request_params, request_headers):
"""
Tests that various HTTP-level bits (query string params, path, headers)
carry over into the scope.
"""
with DaphneTestingInstance() as test_app:
test_app.add_send_messages([
{
"type": "websocket.accept",
}
])
test_app.add_send_messages([{"type": "websocket.accept"}])
self.websocket_handshake(
test_app,
path=request_path,
@ -199,10 +168,7 @@ class TestWebsocket(DaphneTestCase):
# Validate the scope and messages we got
scope, messages = test_app.get_received()
self.assert_valid_websocket_scope(
scope,
path=request_path,
params=request_params,
headers=request_headers,
scope, path=request_path, params=request_params, headers=request_headers
)
self.assert_valid_websocket_connect_message(messages[0])
@ -212,28 +178,24 @@ class TestWebsocket(DaphneTestCase):
"""
with DaphneTestingInstance() as test_app:
# Connect
test_app.add_send_messages([
{
"type": "websocket.accept",
}
])
test_app.add_send_messages([{"type": "websocket.accept"}])
sock, _ = self.websocket_handshake(test_app)
_, messages = test_app.get_received()
self.assert_valid_websocket_connect_message(messages[0])
# Prep frame for it to send
test_app.add_send_messages([
{
"type": "websocket.send",
"text": "here be dragons 🐉",
}
])
test_app.add_send_messages(
[{"type": "websocket.send", "text": "here be dragons 🐉"}]
)
# Send it a frame
self.websocket_send_frame(sock, "what is here? 🌍")
# Receive a frame and make sure it's correct
assert self.websocket_receive_frame(sock) == "here be dragons 🐉"
# Make sure it got our frame
_, messages = test_app.get_received()
assert messages[1] == {"type": "websocket.receive", "text": "what is here? 🌍"}
assert messages[1] == {
"type": "websocket.receive",
"text": "what is here? 🌍",
}
def test_binary_frames(self):
"""
@ -242,28 +204,24 @@ class TestWebsocket(DaphneTestCase):
"""
with DaphneTestingInstance() as test_app:
# Connect
test_app.add_send_messages([
{
"type": "websocket.accept",
}
])
test_app.add_send_messages([{"type": "websocket.accept"}])
sock, _ = self.websocket_handshake(test_app)
_, messages = test_app.get_received()
self.assert_valid_websocket_connect_message(messages[0])
# Prep frame for it to send
test_app.add_send_messages([
{
"type": "websocket.send",
"bytes": b"here be \xe2 bytes",
}
])
test_app.add_send_messages(
[{"type": "websocket.send", "bytes": b"here be \xe2 bytes"}]
)
# Send it a frame
self.websocket_send_frame(sock, b"what is here? \xe2")
# Receive a frame and make sure it's correct
assert self.websocket_receive_frame(sock) == b"here be \xe2 bytes"
# Make sure it got our frame
_, messages = test_app.get_received()
assert messages[1] == {"type": "websocket.receive", "bytes": b"what is here? \xe2"}
assert messages[1] == {
"type": "websocket.receive",
"bytes": b"what is here? \xe2",
}
def test_http_timeout(self):
"""
@ -271,23 +229,14 @@ class TestWebsocket(DaphneTestCase):
"""
with DaphneTestingInstance(http_timeout=1) as test_app:
# Connect
test_app.add_send_messages([
{
"type": "websocket.accept",
}
])
test_app.add_send_messages([{"type": "websocket.accept"}])
sock, _ = self.websocket_handshake(test_app)
_, messages = test_app.get_received()
self.assert_valid_websocket_connect_message(messages[0])
# Wait 2 seconds
time.sleep(2)
# Prep frame for it to send
test_app.add_send_messages([
{
"type": "websocket.send",
"text": "cake",
}
])
test_app.add_send_messages([{"type": "websocket.send", "text": "cake"}])
# Send it a frame
self.websocket_send_frame(sock, "still alive?")
# Receive a frame and make sure it's correct