mirror of
https://github.com/django/daphne.git
synced 2024-11-21 15:36:33 +03:00
Implement Black code formatting
This commit is contained in:
parent
88792984e7
commit
0ed6294406
22
.travis.yml
22
.travis.yml
|
@ -3,22 +3,25 @@ sudo: false
|
|||
language: python
|
||||
|
||||
python:
|
||||
- '3.5'
|
||||
- '3.6'
|
||||
- '3.5'
|
||||
|
||||
env:
|
||||
- TWISTED="twisted==18.7.0"
|
||||
- TWISTED="twisted"
|
||||
- TWISTED="twisted==18.7.0"
|
||||
|
||||
install:
|
||||
- pip install $TWISTED isort unify flake8 -e .[tests]
|
||||
- pip install $TWISTED -e .[tests]
|
||||
- pip freeze
|
||||
|
||||
script:
|
||||
- pytest
|
||||
- flake8
|
||||
- isort --check-only --diff --recursive daphne tests
|
||||
- unify --check-only --recursive --quote \" daphne tests
|
||||
|
||||
stages:
|
||||
- lint
|
||||
- test
|
||||
- name: release
|
||||
if: branch = master
|
||||
|
||||
jobs:
|
||||
include:
|
||||
|
@ -30,6 +33,13 @@ jobs:
|
|||
env: TWISTED="twisted"
|
||||
dist: xenial
|
||||
sudo: required
|
||||
- stage: lint
|
||||
install: pip install -U -e .[tests] black pyflakes isort
|
||||
script:
|
||||
- pyflakes .
|
||||
- black --check .
|
||||
- isort --check-only --diff --recursive channels_redis tests
|
||||
|
||||
- stage: release
|
||||
script: skip
|
||||
deploy:
|
||||
|
|
|
@ -49,13 +49,16 @@ class AccessLogGenerator(object):
|
|||
request="WSDISCONNECT %(path)s" % details,
|
||||
)
|
||||
|
||||
def write_entry(self, host, date, request, status=None, length=None, ident=None, user=None):
|
||||
def write_entry(
|
||||
self, host, date, request, status=None, length=None, ident=None, user=None
|
||||
):
|
||||
"""
|
||||
Writes an NCSA-style entry to the log file (some liberty is taken with
|
||||
what the entries are for non-HTTP)
|
||||
"""
|
||||
self.stream.write(
|
||||
"%s %s %s [%s] \"%s\" %s %s\n" % (
|
||||
'%s %s %s [%s] "%s" %s %s\n'
|
||||
% (
|
||||
host,
|
||||
ident or "-",
|
||||
user or "-",
|
||||
|
|
|
@ -23,15 +23,9 @@ class CommandLineInterface(object):
|
|||
server_class = Server
|
||||
|
||||
def __init__(self):
|
||||
self.parser = argparse.ArgumentParser(
|
||||
description=self.description,
|
||||
)
|
||||
self.parser = argparse.ArgumentParser(description=self.description)
|
||||
self.parser.add_argument(
|
||||
"-p",
|
||||
"--port",
|
||||
type=int,
|
||||
help="Port number to listen on",
|
||||
default=None,
|
||||
"-p", "--port", type=int, help="Port number to listen on", default=None
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"-b",
|
||||
|
@ -128,7 +122,7 @@ class CommandLineInterface(object):
|
|||
"--proxy-headers",
|
||||
dest="proxy_headers",
|
||||
help="Enable parsing and using of X-Forwarded-For and X-Forwarded-Port headers and using that as the "
|
||||
"client address",
|
||||
"client address",
|
||||
default=False,
|
||||
action="store_true",
|
||||
)
|
||||
|
@ -176,7 +170,15 @@ class CommandLineInterface(object):
|
|||
sys.path.insert(0, ".")
|
||||
application = import_by_path(args.application)
|
||||
# Set up port/host bindings
|
||||
if not any([args.host, args.port is not None, args.unix_socket, args.file_descriptor, args.socket_strings]):
|
||||
if not any(
|
||||
[
|
||||
args.host,
|
||||
args.port is not None,
|
||||
args.unix_socket,
|
||||
args.file_descriptor,
|
||||
args.socket_strings,
|
||||
]
|
||||
):
|
||||
# no advanced binding options passed, patch in defaults
|
||||
args.host = DEFAULT_HOST
|
||||
args.port = DEFAULT_PORT
|
||||
|
@ -189,16 +191,11 @@ class CommandLineInterface(object):
|
|||
host=args.host,
|
||||
port=args.port,
|
||||
unix_socket=args.unix_socket,
|
||||
file_descriptor=args.file_descriptor
|
||||
)
|
||||
endpoints = sorted(
|
||||
args.socket_strings + endpoints
|
||||
file_descriptor=args.file_descriptor,
|
||||
)
|
||||
endpoints = sorted(args.socket_strings + endpoints)
|
||||
# Start the server
|
||||
logger.info(
|
||||
"Starting server at %s" %
|
||||
(", ".join(endpoints), )
|
||||
)
|
||||
logger.info("Starting server at %s" % (", ".join(endpoints),))
|
||||
self.server = self.server_class(
|
||||
application=application,
|
||||
endpoints=endpoints,
|
||||
|
@ -208,12 +205,20 @@ class CommandLineInterface(object):
|
|||
websocket_timeout=args.websocket_timeout,
|
||||
websocket_connect_timeout=args.websocket_connect_timeout,
|
||||
application_close_timeout=args.application_close_timeout,
|
||||
action_logger=AccessLogGenerator(access_log_stream) if access_log_stream else None,
|
||||
action_logger=AccessLogGenerator(access_log_stream)
|
||||
if access_log_stream
|
||||
else None,
|
||||
ws_protocols=args.ws_protocols,
|
||||
root_path=args.root_path,
|
||||
verbosity=args.verbosity,
|
||||
proxy_forwarded_address_header="X-Forwarded-For" if args.proxy_headers else None,
|
||||
proxy_forwarded_port_header="X-Forwarded-Port" if args.proxy_headers else None,
|
||||
proxy_forwarded_proto_header="X-Forwarded-Proto" if args.proxy_headers else None,
|
||||
proxy_forwarded_address_header="X-Forwarded-For"
|
||||
if args.proxy_headers
|
||||
else None,
|
||||
proxy_forwarded_port_header="X-Forwarded-Port"
|
||||
if args.proxy_headers
|
||||
else None,
|
||||
proxy_forwarded_proto_header="X-Forwarded-Proto"
|
||||
if args.proxy_headers
|
||||
else None,
|
||||
)
|
||||
self.server.run()
|
||||
|
|
|
@ -1,10 +1,5 @@
|
|||
|
||||
|
||||
def build_endpoint_description_strings(
|
||||
host=None,
|
||||
port=None,
|
||||
unix_socket=None,
|
||||
file_descriptor=None
|
||||
host=None, port=None, unix_socket=None, file_descriptor=None
|
||||
):
|
||||
"""
|
||||
Build a list of twisted endpoint description strings that the server will listen on.
|
||||
|
|
|
@ -23,7 +23,8 @@ class WebRequest(http.Request):
|
|||
GET and POST out.
|
||||
"""
|
||||
|
||||
error_template = """
|
||||
error_template = (
|
||||
"""
|
||||
<html>
|
||||
<head>
|
||||
<title>%(title)s</title>
|
||||
|
@ -40,7 +41,13 @@ class WebRequest(http.Request):
|
|||
<footer>Daphne</footer>
|
||||
</body>
|
||||
</html>
|
||||
""".replace("\n", "").replace(" ", " ").replace(" ", " ").replace(" ", " ") # Shorten it a bit, bytes wise
|
||||
""".replace(
|
||||
"\n", ""
|
||||
)
|
||||
.replace(" ", " ")
|
||||
.replace(" ", " ")
|
||||
.replace(" ", " ")
|
||||
) # Shorten it a bit, bytes wise
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
try:
|
||||
|
@ -84,7 +91,7 @@ class WebRequest(http.Request):
|
|||
self.server.proxy_forwarded_port_header,
|
||||
self.server.proxy_forwarded_proto_header,
|
||||
self.client_addr,
|
||||
self.client_scheme
|
||||
self.client_scheme,
|
||||
)
|
||||
# Check for unicodeish path (or it'll crash when trying to parse)
|
||||
try:
|
||||
|
@ -105,7 +112,9 @@ class WebRequest(http.Request):
|
|||
# Is it WebSocket? IS IT?!
|
||||
if upgrade_header and upgrade_header.lower() == b"websocket":
|
||||
# Make WebSocket protocol to hand off to
|
||||
protocol = self.server.ws_factory.buildProtocol(self.transport.getPeer())
|
||||
protocol = self.server.ws_factory.buildProtocol(
|
||||
self.transport.getPeer()
|
||||
)
|
||||
if not protocol:
|
||||
# If protocol creation fails, we signal "internal server error"
|
||||
self.setResponseCode(500)
|
||||
|
@ -151,33 +160,38 @@ class WebRequest(http.Request):
|
|||
logger.debug("HTTP %s request for %s", self.method, self.client_addr)
|
||||
self.content.seek(0, 0)
|
||||
# Work out the application scope and create application
|
||||
self.application_queue = yield maybeDeferred(self.server.create_application, self, {
|
||||
"type": "http",
|
||||
# TODO: Correctly say if it's 1.1 or 1.0
|
||||
"http_version": self.clientproto.split(b"/")[-1].decode("ascii"),
|
||||
"method": self.method.decode("ascii"),
|
||||
"path": unquote(self.path.decode("ascii")),
|
||||
"root_path": self.root_path,
|
||||
"scheme": self.client_scheme,
|
||||
"query_string": self.query_string,
|
||||
"headers": self.clean_headers,
|
||||
"client": self.client_addr,
|
||||
"server": self.server_addr,
|
||||
})
|
||||
self.application_queue = yield maybeDeferred(
|
||||
self.server.create_application,
|
||||
self,
|
||||
{
|
||||
"type": "http",
|
||||
# TODO: Correctly say if it's 1.1 or 1.0
|
||||
"http_version": self.clientproto.split(b"/")[-1].decode(
|
||||
"ascii"
|
||||
),
|
||||
"method": self.method.decode("ascii"),
|
||||
"path": unquote(self.path.decode("ascii")),
|
||||
"root_path": self.root_path,
|
||||
"scheme": self.client_scheme,
|
||||
"query_string": self.query_string,
|
||||
"headers": self.clean_headers,
|
||||
"client": self.client_addr,
|
||||
"server": self.server_addr,
|
||||
},
|
||||
)
|
||||
# Check they didn't close an unfinished request
|
||||
if self.application_queue is None or self.content.closed:
|
||||
# Not much we can do, the request is prematurely abandoned.
|
||||
return
|
||||
# Run application against request
|
||||
self.application_queue.put_nowait(
|
||||
{
|
||||
"type": "http.request",
|
||||
"body": self.content.read(),
|
||||
},
|
||||
{"type": "http.request", "body": self.content.read()}
|
||||
)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
self.basic_error(500, b"Internal Server Error", "Daphne HTTP processing error")
|
||||
self.basic_error(
|
||||
500, b"Internal Server Error", "Daphne HTTP processing error"
|
||||
)
|
||||
|
||||
def connectionLost(self, reason):
|
||||
"""
|
||||
|
@ -217,16 +231,23 @@ class WebRequest(http.Request):
|
|||
raise ValueError("HTTP response has already been started")
|
||||
self._response_started = True
|
||||
if "status" not in message:
|
||||
raise ValueError("Specifying a status code is required for a Response message.")
|
||||
raise ValueError(
|
||||
"Specifying a status code is required for a Response message."
|
||||
)
|
||||
# Set HTTP status code
|
||||
self.setResponseCode(message["status"])
|
||||
# Write headers
|
||||
for header, value in message.get("headers", {}):
|
||||
self.responseHeaders.addRawHeader(header, value)
|
||||
logger.debug("HTTP %s response started for %s", message["status"], self.client_addr)
|
||||
logger.debug(
|
||||
"HTTP %s response started for %s", message["status"], self.client_addr
|
||||
)
|
||||
elif message["type"] == "http.response.body":
|
||||
if not self._response_started:
|
||||
raise ValueError("HTTP response has not yet been started but got %s" % message["type"])
|
||||
raise ValueError(
|
||||
"HTTP response has not yet been started but got %s"
|
||||
% message["type"]
|
||||
)
|
||||
# Write out body
|
||||
http.Request.write(self, message.get("body", b""))
|
||||
# End if there's no more content
|
||||
|
@ -239,15 +260,21 @@ class WebRequest(http.Request):
|
|||
# The path is malformed somehow - do our best to log something
|
||||
uri = repr(self.uri)
|
||||
try:
|
||||
self.server.log_action("http", "complete", {
|
||||
"path": uri,
|
||||
"status": self.code,
|
||||
"method": self.method.decode("ascii", "replace"),
|
||||
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
||||
"time_taken": self.duration(),
|
||||
"size": self.sentLength,
|
||||
})
|
||||
except Exception as e:
|
||||
self.server.log_action(
|
||||
"http",
|
||||
"complete",
|
||||
{
|
||||
"path": uri,
|
||||
"status": self.code,
|
||||
"method": self.method.decode("ascii", "replace"),
|
||||
"client": "%s:%s" % tuple(self.client_addr)
|
||||
if self.client_addr
|
||||
else None,
|
||||
"time_taken": self.duration(),
|
||||
"size": self.sentLength,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
else:
|
||||
logger.debug("HTTP response chunk for %s", self.client_addr)
|
||||
|
@ -270,7 +297,11 @@ class WebRequest(http.Request):
|
|||
logger.warning("Application timed out while sending response")
|
||||
self.finish()
|
||||
else:
|
||||
self.basic_error(503, b"Service Unavailable", "Application failed to respond within time limit.")
|
||||
self.basic_error(
|
||||
503,
|
||||
b"Service Unavailable",
|
||||
"Application failed to respond within time limit.",
|
||||
)
|
||||
|
||||
### Utility functions
|
||||
|
||||
|
@ -281,11 +312,7 @@ class WebRequest(http.Request):
|
|||
"""
|
||||
# If we don't yet have a path, then don't send as we never opened.
|
||||
if self.path:
|
||||
self.application_queue.put_nowait(
|
||||
{
|
||||
"type": "http.disconnect",
|
||||
},
|
||||
)
|
||||
self.application_queue.put_nowait({"type": "http.disconnect"})
|
||||
|
||||
def duration(self):
|
||||
"""
|
||||
|
@ -299,20 +326,25 @@ class WebRequest(http.Request):
|
|||
"""
|
||||
Responds with a server-level error page (very basic)
|
||||
"""
|
||||
self.handle_reply({
|
||||
"type": "http.response.start",
|
||||
"status": status,
|
||||
"headers": [
|
||||
(b"Content-Type", b"text/html; charset=utf-8"),
|
||||
],
|
||||
})
|
||||
self.handle_reply({
|
||||
"type": "http.response.body",
|
||||
"body": (self.error_template % {
|
||||
"title": str(status) + " " + status_text.decode("ascii"),
|
||||
"body": body,
|
||||
}).encode("utf8"),
|
||||
})
|
||||
self.handle_reply(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": status,
|
||||
"headers": [(b"Content-Type", b"text/html; charset=utf-8")],
|
||||
}
|
||||
)
|
||||
self.handle_reply(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": (
|
||||
self.error_template
|
||||
% {
|
||||
"title": str(status) + " " + status_text.decode("ascii"),
|
||||
"body": body,
|
||||
}
|
||||
).encode("utf8"),
|
||||
}
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(id(self))
|
||||
|
@ -343,7 +375,7 @@ class HTTPFactory(http.HTTPFactory):
|
|||
protocol = http.HTTPFactory.buildProtocol(self, addr)
|
||||
protocol.requestFactory = WebRequest
|
||||
return protocol
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.error("Cannot build protocol: %s" % traceback.format_exc())
|
||||
raise
|
||||
|
||||
|
|
|
@ -2,13 +2,14 @@
|
|||
import sys # isort:skip
|
||||
import warnings # isort:skip
|
||||
from twisted.internet import asyncioreactor # isort:skip
|
||||
|
||||
current_reactor = sys.modules.get("twisted.internet.reactor", None)
|
||||
if current_reactor is not None:
|
||||
if not isinstance(current_reactor, asyncioreactor.AsyncioSelectorReactor):
|
||||
warnings.warn(
|
||||
"Something has already installed a non-asyncio Twisted reactor. Attempting to uninstall it; " +
|
||||
"you can fix this warning by importing daphne.server early in your codebase or " +
|
||||
"finding the package that imports Twisted and importing it later on.",
|
||||
"Something has already installed a non-asyncio Twisted reactor. Attempting to uninstall it; "
|
||||
+ "you can fix this warning by importing daphne.server early in your codebase or "
|
||||
+ "finding the package that imports Twisted and importing it later on.",
|
||||
UserWarning,
|
||||
)
|
||||
del sys.modules["twisted.internet.reactor"]
|
||||
|
@ -34,7 +35,6 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class Server(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application,
|
||||
|
@ -91,11 +91,13 @@ class Server(object):
|
|||
self.ws_factory.setProtocolOptions(
|
||||
autoPingTimeout=self.ping_timeout,
|
||||
allowNullOrigin=True,
|
||||
openHandshakeTimeout=self.websocket_handshake_timeout
|
||||
openHandshakeTimeout=self.websocket_handshake_timeout,
|
||||
)
|
||||
if self.verbosity <= 1:
|
||||
# Redirect the Twisted log to nowhere
|
||||
globalLogBeginner.beginLoggingTo([lambda _: None], redirectStandardIO=False, discardBuffer=True)
|
||||
globalLogBeginner.beginLoggingTo(
|
||||
[lambda _: None], redirectStandardIO=False, discardBuffer=True
|
||||
)
|
||||
else:
|
||||
globalLogBeginner.beginLoggingTo([STDLibLogObserver(__name__)])
|
||||
|
||||
|
@ -103,7 +105,9 @@ class Server(object):
|
|||
if http.H2_ENABLED:
|
||||
logger.info("HTTP/2 support enabled")
|
||||
else:
|
||||
logger.info("HTTP/2 support not enabled (install the http2 and tls Twisted extras)")
|
||||
logger.info(
|
||||
"HTTP/2 support not enabled (install the http2 and tls Twisted extras)"
|
||||
)
|
||||
|
||||
# Kick off the timeout loop
|
||||
reactor.callLater(1, self.application_checker)
|
||||
|
@ -141,7 +145,11 @@ class Server(object):
|
|||
host = port.getHost()
|
||||
if hasattr(host, "host") and hasattr(host, "port"):
|
||||
self.listening_addresses.append((host.host, host.port))
|
||||
logger.info("Listening on TCP address %s:%s", port.getHost().host, port.getHost().port)
|
||||
logger.info(
|
||||
"Listening on TCP address %s:%s",
|
||||
port.getHost().host,
|
||||
port.getHost().port,
|
||||
)
|
||||
|
||||
def listen_error(self, failure):
|
||||
logger.critical("Listen failure: %s", failure.getErrorMessage())
|
||||
|
@ -187,10 +195,13 @@ class Server(object):
|
|||
# Run it, and stash the future for later checking
|
||||
if protocol not in self.connections:
|
||||
return None
|
||||
self.connections[protocol]["application_instance"] = asyncio.ensure_future(application_instance(
|
||||
receive=input_queue.get,
|
||||
send=lambda message: self.handle_reply(protocol, message),
|
||||
), loop=asyncio.get_event_loop())
|
||||
self.connections[protocol]["application_instance"] = asyncio.ensure_future(
|
||||
application_instance(
|
||||
receive=input_queue.get,
|
||||
send=lambda message: self.handle_reply(protocol, message),
|
||||
),
|
||||
loop=asyncio.get_event_loop(),
|
||||
)
|
||||
return input_queue
|
||||
|
||||
async def handle_reply(self, protocol, message):
|
||||
|
@ -215,7 +226,10 @@ class Server(object):
|
|||
application_instance = details.get("application_instance", None)
|
||||
# First, see if the protocol disconnected and the app has taken
|
||||
# too long to close up
|
||||
if disconnected and time.time() - disconnected > self.application_close_timeout:
|
||||
if (
|
||||
disconnected
|
||||
and time.time() - disconnected > self.application_close_timeout
|
||||
):
|
||||
if application_instance and not application_instance.done():
|
||||
logger.warning(
|
||||
"Application instance %r for connection %s took too long to shut down and was killed.",
|
||||
|
@ -238,14 +252,11 @@ class Server(object):
|
|||
else:
|
||||
exception_output = "{}\n{}{}".format(
|
||||
exception,
|
||||
"".join(traceback.format_tb(
|
||||
exception.__traceback__,
|
||||
)),
|
||||
"".join(traceback.format_tb(exception.__traceback__)),
|
||||
" {}".format(exception),
|
||||
)
|
||||
logger.error(
|
||||
"Exception inside application: %s",
|
||||
exception_output,
|
||||
"Exception inside application: %s", exception_output
|
||||
)
|
||||
if not disconnected:
|
||||
protocol.handle_exception(exception)
|
||||
|
|
|
@ -100,9 +100,7 @@ class DaphneTestingInstance:
|
|||
Adds messages for the application to send back.
|
||||
The next time it receives an incoming message, it will reply with these.
|
||||
"""
|
||||
TestApplication.save_setup(
|
||||
response_messages=messages,
|
||||
)
|
||||
TestApplication.save_setup(response_messages=messages)
|
||||
|
||||
|
||||
class DaphneProcess(multiprocessing.Process):
|
||||
|
@ -193,12 +191,7 @@ class TestApplication:
|
|||
Stores setup information.
|
||||
"""
|
||||
with open(cls.setup_storage, "wb") as fh:
|
||||
pickle.dump(
|
||||
{
|
||||
"response_messages": response_messages,
|
||||
},
|
||||
fh,
|
||||
)
|
||||
pickle.dump({"response_messages": response_messages}, fh)
|
||||
|
||||
@classmethod
|
||||
def load_setup(cls):
|
||||
|
@ -218,13 +211,7 @@ class TestApplication:
|
|||
We could use pickle here, but that seems wrong, still, somehow.
|
||||
"""
|
||||
with open(cls.result_storage, "wb") as fh:
|
||||
pickle.dump(
|
||||
{
|
||||
"scope": scope,
|
||||
"messages": messages,
|
||||
},
|
||||
fh,
|
||||
)
|
||||
pickle.dump({"scope": scope, "messages": messages}, fh)
|
||||
|
||||
@classmethod
|
||||
def save_exception(cls, exception):
|
||||
|
@ -233,12 +220,7 @@ class TestApplication:
|
|||
We could use pickle here, but that seems wrong, still, somehow.
|
||||
"""
|
||||
with open(cls.result_storage, "wb") as fh:
|
||||
pickle.dump(
|
||||
{
|
||||
"exception": exception,
|
||||
},
|
||||
fh,
|
||||
)
|
||||
pickle.dump({"exception": exception}, fh)
|
||||
|
||||
@classmethod
|
||||
def load_result(cls):
|
||||
|
|
|
@ -22,12 +22,14 @@ def header_value(headers, header_name):
|
|||
return value.decode("utf-8")
|
||||
|
||||
|
||||
def parse_x_forwarded_for(headers,
|
||||
address_header_name="X-Forwarded-For",
|
||||
port_header_name="X-Forwarded-Port",
|
||||
proto_header_name="X-Forwarded-Proto",
|
||||
original_addr=None,
|
||||
original_scheme=None):
|
||||
def parse_x_forwarded_for(
|
||||
headers,
|
||||
address_header_name="X-Forwarded-For",
|
||||
port_header_name="X-Forwarded-Port",
|
||||
proto_header_name="X-Forwarded-Proto",
|
||||
original_addr=None,
|
||||
original_scheme=None,
|
||||
):
|
||||
"""
|
||||
Parses an X-Forwarded-For header and returns a host/port pair as a list.
|
||||
|
||||
|
|
|
@ -3,7 +3,11 @@ import time
|
|||
import traceback
|
||||
from urllib.parse import unquote
|
||||
|
||||
from autobahn.twisted.websocket import ConnectionDeny, WebSocketServerFactory, WebSocketServerProtocol
|
||||
from autobahn.twisted.websocket import (
|
||||
ConnectionDeny,
|
||||
WebSocketServerFactory,
|
||||
WebSocketServerProtocol,
|
||||
)
|
||||
from twisted.internet import defer
|
||||
|
||||
from .utils import parse_x_forwarded_for
|
||||
|
@ -54,32 +58,34 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
self.server.proxy_forwarded_address_header,
|
||||
self.server.proxy_forwarded_port_header,
|
||||
self.server.proxy_forwarded_proto_header,
|
||||
self.client_addr
|
||||
self.client_addr,
|
||||
)
|
||||
# Decode websocket subprotocol options
|
||||
subprotocols = []
|
||||
for header, value in self.clean_headers:
|
||||
if header == b"sec-websocket-protocol":
|
||||
subprotocols = [
|
||||
x.strip()
|
||||
for x in
|
||||
unquote(value.decode("ascii")).split(",")
|
||||
x.strip() for x in unquote(value.decode("ascii")).split(",")
|
||||
]
|
||||
# Make new application instance with scope
|
||||
self.path = request.path.encode("ascii")
|
||||
self.application_deferred = defer.maybeDeferred(self.server.create_application, self, {
|
||||
"type": "websocket",
|
||||
"path": unquote(self.path.decode("ascii")),
|
||||
"headers": self.clean_headers,
|
||||
"query_string": self._raw_query_string, # Passed by HTTP protocol
|
||||
"client": self.client_addr,
|
||||
"server": self.server_addr,
|
||||
"subprotocols": subprotocols,
|
||||
})
|
||||
self.application_deferred = defer.maybeDeferred(
|
||||
self.server.create_application,
|
||||
self,
|
||||
{
|
||||
"type": "websocket",
|
||||
"path": unquote(self.path.decode("ascii")),
|
||||
"headers": self.clean_headers,
|
||||
"query_string": self._raw_query_string, # Passed by HTTP protocol
|
||||
"client": self.client_addr,
|
||||
"server": self.server_addr,
|
||||
"subprotocols": subprotocols,
|
||||
},
|
||||
)
|
||||
if self.application_deferred is not None:
|
||||
self.application_deferred.addCallback(self.applicationCreateWorked)
|
||||
self.application_deferred.addErrback(self.applicationCreateFailed)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
# Exceptions here are not displayed right, just 500.
|
||||
# Turn them into an ERROR log.
|
||||
logger.error(traceback.format_exc())
|
||||
|
@ -98,10 +104,16 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
self.application_queue = application_queue
|
||||
# Send over the connect message
|
||||
self.application_queue.put_nowait({"type": "websocket.connect"})
|
||||
self.server.log_action("websocket", "connecting", {
|
||||
"path": self.request.path,
|
||||
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
||||
})
|
||||
self.server.log_action(
|
||||
"websocket",
|
||||
"connecting",
|
||||
{
|
||||
"path": self.request.path,
|
||||
"client": "%s:%s" % tuple(self.client_addr)
|
||||
if self.client_addr
|
||||
else None,
|
||||
},
|
||||
)
|
||||
|
||||
def applicationCreateFailed(self, failure):
|
||||
"""
|
||||
|
@ -115,10 +127,16 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
def onOpen(self):
|
||||
# Send news that this channel is open
|
||||
logger.debug("WebSocket %s open and established", self.client_addr)
|
||||
self.server.log_action("websocket", "connected", {
|
||||
"path": self.request.path,
|
||||
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
||||
})
|
||||
self.server.log_action(
|
||||
"websocket",
|
||||
"connected",
|
||||
{
|
||||
"path": self.request.path,
|
||||
"client": "%s:%s" % tuple(self.client_addr)
|
||||
if self.client_addr
|
||||
else None,
|
||||
},
|
||||
)
|
||||
|
||||
def onMessage(self, payload, isBinary):
|
||||
# If we're muted, do nothing.
|
||||
|
@ -128,15 +146,13 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
logger.debug("WebSocket incoming frame on %s", self.client_addr)
|
||||
self.last_ping = time.time()
|
||||
if isBinary:
|
||||
self.application_queue.put_nowait({
|
||||
"type": "websocket.receive",
|
||||
"bytes": payload,
|
||||
})
|
||||
self.application_queue.put_nowait(
|
||||
{"type": "websocket.receive", "bytes": payload}
|
||||
)
|
||||
else:
|
||||
self.application_queue.put_nowait({
|
||||
"type": "websocket.receive",
|
||||
"text": payload.decode("utf8"),
|
||||
})
|
||||
self.application_queue.put_nowait(
|
||||
{"type": "websocket.receive", "text": payload.decode("utf8")}
|
||||
)
|
||||
|
||||
def onClose(self, wasClean, code, reason):
|
||||
"""
|
||||
|
@ -145,14 +161,19 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
self.server.protocol_disconnected(self)
|
||||
logger.debug("WebSocket closed for %s", self.client_addr)
|
||||
if not self.muted and hasattr(self, "application_queue"):
|
||||
self.application_queue.put_nowait({
|
||||
"type": "websocket.disconnect",
|
||||
"code": code,
|
||||
})
|
||||
self.server.log_action("websocket", "disconnected", {
|
||||
"path": self.request.path,
|
||||
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
||||
})
|
||||
self.application_queue.put_nowait(
|
||||
{"type": "websocket.disconnect", "code": code}
|
||||
)
|
||||
self.server.log_action(
|
||||
"websocket",
|
||||
"disconnected",
|
||||
{
|
||||
"path": self.request.path,
|
||||
"client": "%s:%s" % tuple(self.client_addr)
|
||||
if self.client_addr
|
||||
else None,
|
||||
},
|
||||
)
|
||||
|
||||
### Internal event handling
|
||||
|
||||
|
@ -171,9 +192,8 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
raise ValueError("Socket has not been accepted, so cannot send over it")
|
||||
if message.get("bytes", None) and message.get("text", None):
|
||||
raise ValueError(
|
||||
"Got invalid WebSocket reply message on %s - contains both bytes and text keys" % (
|
||||
message,
|
||||
)
|
||||
"Got invalid WebSocket reply message on %s - contains both bytes and text keys"
|
||||
% (message,)
|
||||
)
|
||||
if message.get("bytes", None):
|
||||
self.serverSend(message["bytes"], True)
|
||||
|
@ -187,7 +207,9 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
if hasattr(self, "handshake_deferred"):
|
||||
# If the handshake is still ongoing, we need to emit a HTTP error
|
||||
# code rather than a WebSocket one.
|
||||
self.handshake_deferred.errback(ConnectionDeny(code=500, reason="Internal server error"))
|
||||
self.handshake_deferred.errback(
|
||||
ConnectionDeny(code=500, reason="Internal server error")
|
||||
)
|
||||
else:
|
||||
self.sendCloseFrame(code=1011)
|
||||
|
||||
|
@ -203,14 +225,22 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
"""
|
||||
Called when we get a message saying to reject the connection.
|
||||
"""
|
||||
self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied"))
|
||||
self.handshake_deferred.errback(
|
||||
ConnectionDeny(code=403, reason="Access denied")
|
||||
)
|
||||
del self.handshake_deferred
|
||||
self.server.protocol_disconnected(self)
|
||||
logger.debug("WebSocket %s rejected by application", self.client_addr)
|
||||
self.server.log_action("websocket", "rejected", {
|
||||
"path": self.request.path,
|
||||
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
||||
})
|
||||
self.server.log_action(
|
||||
"websocket",
|
||||
"rejected",
|
||||
{
|
||||
"path": self.request.path,
|
||||
"client": "%s:%s" % tuple(self.client_addr)
|
||||
if self.client_addr
|
||||
else None,
|
||||
},
|
||||
)
|
||||
|
||||
def serverSend(self, content, binary=False):
|
||||
"""
|
||||
|
@ -244,7 +274,10 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
Called periodically to see if we should timeout something
|
||||
"""
|
||||
# Web timeout checking
|
||||
if self.duration() > self.server.websocket_timeout and self.server.websocket_timeout >= 0:
|
||||
if (
|
||||
self.duration() > self.server.websocket_timeout
|
||||
and self.server.websocket_timeout >= 0
|
||||
):
|
||||
self.serverClose()
|
||||
# Ping check
|
||||
# If we're still connecting, deny the connection
|
||||
|
@ -287,6 +320,6 @@ class WebSocketFactory(WebSocketServerFactory):
|
|||
protocol = super(WebSocketFactory, self).buildProtocol(addr)
|
||||
protocol.factory = self
|
||||
return protocol
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.error("Cannot build protocol: %s" % traceback.format_exc())
|
||||
raise
|
||||
|
|
|
@ -5,9 +5,9 @@ universal=1
|
|||
addopts = tests/
|
||||
|
||||
[isort]
|
||||
line_length = 120
|
||||
include_trailing_comma = True
|
||||
multi_line_output = 3
|
||||
known_first_party = channels,daphne,asgiref
|
||||
known_first_party = channels,daphne,asgiref,channels_redis
|
||||
|
||||
[flake8]
|
||||
exclude = venv/*,tox/*,docs/*,testproject/*,js_client/*,.eggs/*
|
||||
|
|
21
setup.py
21
setup.py
|
@ -22,23 +22,12 @@ setup(
|
|||
package_dir={"twisted": "daphne/twisted"},
|
||||
packages=find_packages() + ["twisted.plugins"],
|
||||
include_package_data=True,
|
||||
install_requires=[
|
||||
"twisted>=18.7",
|
||||
"autobahn>=0.18",
|
||||
],
|
||||
setup_requires=[
|
||||
"pytest-runner",
|
||||
],
|
||||
extras_require={
|
||||
"tests": [
|
||||
"hypothesis",
|
||||
"pytest",
|
||||
"pytest-asyncio~=0.8",
|
||||
],
|
||||
install_requires=["twisted>=18.7", "autobahn>=0.18"],
|
||||
setup_requires=["pytest-runner"],
|
||||
extras_require={"tests": ["hypothesis", "pytest", "pytest-asyncio~=0.8"]},
|
||||
entry_points={
|
||||
"console_scripts": ["daphne = daphne.cli:CommandLineInterface.entrypoint"]
|
||||
},
|
||||
entry_points={"console_scripts": [
|
||||
"daphne = daphne.cli:CommandLineInterface.entrypoint",
|
||||
]},
|
||||
classifiers=[
|
||||
"Development Status :: 4 - Beta",
|
||||
"Environment :: Web Environment",
|
||||
|
|
|
@ -19,7 +19,9 @@ class DaphneTestCase(unittest.TestCase):
|
|||
|
||||
### Plain HTTP helpers
|
||||
|
||||
def run_daphne_http(self, method, path, params, body, responses, headers=None, timeout=1, xff=False):
|
||||
def run_daphne_http(
|
||||
self, method, path, params, body, responses, headers=None, timeout=1, xff=False
|
||||
):
|
||||
"""
|
||||
Runs Daphne with the given request callback (given the base URL)
|
||||
and response messages.
|
||||
|
@ -38,7 +40,9 @@ class DaphneTestCase(unittest.TestCase):
|
|||
# Manually send over headers (encoding any non-safe values as best we can)
|
||||
if headers:
|
||||
for header_name, header_value in headers:
|
||||
conn.putheader(header_name.encode("utf8"), header_value.encode("utf8"))
|
||||
conn.putheader(
|
||||
header_name.encode("utf8"), header_value.encode("utf8")
|
||||
)
|
||||
# Send body if provided.
|
||||
if body:
|
||||
conn.putheader("Content-Length", str(len(body)))
|
||||
|
@ -50,9 +54,11 @@ class DaphneTestCase(unittest.TestCase):
|
|||
except socket.timeout:
|
||||
# See if they left an exception for us to load
|
||||
test_app.get_received()
|
||||
raise RuntimeError("Daphne timed out handling request, no exception found.")
|
||||
raise RuntimeError(
|
||||
"Daphne timed out handling request, no exception found."
|
||||
)
|
||||
# Return scope, messages, response
|
||||
return test_app.get_received() + (response, )
|
||||
return test_app.get_received() + (response,)
|
||||
|
||||
def run_daphne_raw(self, data, timeout=1):
|
||||
"""
|
||||
|
@ -68,9 +74,13 @@ class DaphneTestCase(unittest.TestCase):
|
|||
try:
|
||||
return s.recv(1000000)
|
||||
except socket.timeout:
|
||||
raise RuntimeError("Daphne timed out handling raw request, no exception found.")
|
||||
raise RuntimeError(
|
||||
"Daphne timed out handling raw request, no exception found."
|
||||
)
|
||||
|
||||
def run_daphne_request(self, method, path, params=None, body=None, headers=None, xff=False):
|
||||
def run_daphne_request(
|
||||
self, method, path, params=None, body=None, headers=None, xff=False
|
||||
):
|
||||
"""
|
||||
Convenience method for just testing request handling.
|
||||
Returns (scope, messages)
|
||||
|
@ -95,17 +105,21 @@ class DaphneTestCase(unittest.TestCase):
|
|||
Returns (scope, messages)
|
||||
"""
|
||||
_, _, response = self.run_daphne_http(
|
||||
method="GET",
|
||||
path="/",
|
||||
params={},
|
||||
body=b"",
|
||||
responses=response_messages,
|
||||
method="GET", path="/", params={}, body=b"", responses=response_messages
|
||||
)
|
||||
return response
|
||||
|
||||
### WebSocket helpers
|
||||
|
||||
def websocket_handshake(self, test_app, path="/", params=None, headers=None, subprotocols=None, timeout=1):
|
||||
def websocket_handshake(
|
||||
self,
|
||||
test_app,
|
||||
path="/",
|
||||
params=None,
|
||||
headers=None,
|
||||
subprotocols=None,
|
||||
timeout=1,
|
||||
):
|
||||
"""
|
||||
Runs a WebSocket handshake negotiation and returns the raw socket
|
||||
object & the selected subprotocol.
|
||||
|
@ -124,14 +138,16 @@ class DaphneTestCase(unittest.TestCase):
|
|||
# Do WebSocket handshake headers + any other headers
|
||||
if headers is None:
|
||||
headers = []
|
||||
headers.extend([
|
||||
("Host", "example.com"),
|
||||
("Upgrade", "websocket"),
|
||||
("Connection", "Upgrade"),
|
||||
("Sec-WebSocket-Key", "x3JJHMbDL1EzLkh9GBhXDw=="),
|
||||
("Sec-WebSocket-Version", "13"),
|
||||
("Origin", "http://example.com")
|
||||
])
|
||||
headers.extend(
|
||||
[
|
||||
("Host", "example.com"),
|
||||
("Upgrade", "websocket"),
|
||||
("Connection", "Upgrade"),
|
||||
("Sec-WebSocket-Key", "x3JJHMbDL1EzLkh9GBhXDw=="),
|
||||
("Sec-WebSocket-Version", "13"),
|
||||
("Origin", "http://example.com"),
|
||||
]
|
||||
)
|
||||
if subprotocols:
|
||||
headers.append(("Sec-WebSocket-Protocol", ", ".join(subprotocols)))
|
||||
if headers:
|
||||
|
@ -149,10 +165,7 @@ class DaphneTestCase(unittest.TestCase):
|
|||
if response.status != 101:
|
||||
raise RuntimeError("WebSocket upgrade did not result in status code 101")
|
||||
# Prepare headers for subprotocol searching
|
||||
response_headers = dict(
|
||||
(n.lower(), v)
|
||||
for n, v in response.getheaders()
|
||||
)
|
||||
response_headers = dict((n.lower(), v) for n, v in response.getheaders())
|
||||
response.read()
|
||||
assert not response.closed
|
||||
# Return the raw socket and any subprotocol
|
||||
|
@ -234,10 +247,7 @@ class DaphneTestCase(unittest.TestCase):
|
|||
# Make sure all required keys are present
|
||||
self.assertTrue(required_keys <= present_keys)
|
||||
# Assert that no other keys are present
|
||||
self.assertEqual(
|
||||
set(),
|
||||
present_keys - required_keys - optional_keys,
|
||||
)
|
||||
self.assertEqual(set(), present_keys - required_keys - optional_keys)
|
||||
|
||||
def assert_valid_path(self, path, request_path):
|
||||
"""
|
||||
|
|
|
@ -6,7 +6,9 @@ from hypothesis import strategies
|
|||
HTTP_METHODS = ["OPTIONS", "GET", "HEAD", "POST", "PUT", "DELETE", "TRACE", "CONNECT"]
|
||||
|
||||
# Unicode characters of the "Letter" category
|
||||
letters = strategies.characters(whitelist_categories=("Lu", "Ll", "Lt", "Lm", "Lo", "Nl"))
|
||||
letters = strategies.characters(
|
||||
whitelist_categories=("Lu", "Ll", "Lt", "Lm", "Lo", "Nl")
|
||||
)
|
||||
|
||||
|
||||
def http_method():
|
||||
|
@ -22,11 +24,9 @@ def http_path():
|
|||
"""
|
||||
Returns a URL path (not encoded).
|
||||
"""
|
||||
return strategies.lists(
|
||||
_http_path_portion(),
|
||||
min_size=0,
|
||||
max_size=10,
|
||||
).map(lambda s: "/" + "/".join(s))
|
||||
return strategies.lists(_http_path_portion(), min_size=0, max_size=10).map(
|
||||
lambda s: "/" + "/".join(s)
|
||||
)
|
||||
|
||||
|
||||
def http_body():
|
||||
|
@ -53,10 +53,7 @@ def valid_bidi(value):
|
|||
|
||||
def _domain_label():
|
||||
return strategies.text(
|
||||
alphabet=letters,
|
||||
min_size=1,
|
||||
average_size=6,
|
||||
max_size=63,
|
||||
alphabet=letters, min_size=1, average_size=6, max_size=63
|
||||
).filter(valid_bidi)
|
||||
|
||||
|
||||
|
@ -64,19 +61,14 @@ def international_domain_name():
|
|||
"""
|
||||
Returns a byte string of a domain name, IDNA-encoded.
|
||||
"""
|
||||
return strategies.lists(
|
||||
_domain_label(),
|
||||
min_size=2,
|
||||
average_size=2,
|
||||
).map(lambda s: (".".join(s)).encode("idna"))
|
||||
return strategies.lists(_domain_label(), min_size=2, average_size=2).map(
|
||||
lambda s: (".".join(s)).encode("idna")
|
||||
)
|
||||
|
||||
|
||||
def _query_param():
|
||||
return strategies.text(
|
||||
alphabet=letters,
|
||||
min_size=1,
|
||||
average_size=10,
|
||||
max_size=255,
|
||||
alphabet=letters, min_size=1, average_size=10, max_size=255
|
||||
).map(lambda s: s.encode("utf8"))
|
||||
|
||||
|
||||
|
@ -87,9 +79,7 @@ def query_params():
|
|||
ensures that the total urlencoded query string is not longer than 1500 characters.
|
||||
"""
|
||||
return strategies.lists(
|
||||
strategies.tuples(_query_param(), _query_param()),
|
||||
min_size=0,
|
||||
average_size=5,
|
||||
strategies.tuples(_query_param(), _query_param()), min_size=0, average_size=5
|
||||
).filter(lambda x: len(parse.urlencode(x)) < 1500)
|
||||
|
||||
|
||||
|
@ -101,9 +91,7 @@ def header_name():
|
|||
and 20 characters long
|
||||
"""
|
||||
return strategies.text(
|
||||
alphabet=string.ascii_letters + string.digits + "-",
|
||||
min_size=1,
|
||||
max_size=30,
|
||||
alphabet=string.ascii_letters + string.digits + "-", min_size=1, max_size=30
|
||||
)
|
||||
|
||||
|
||||
|
@ -115,7 +103,10 @@ def header_value():
|
|||
https://en.wikipedia.org/wiki/List_of_HTTP_header_fields
|
||||
"""
|
||||
return strategies.text(
|
||||
alphabet=string.ascii_letters + string.digits + string.punctuation.replace(",", "") + " /t",
|
||||
alphabet=string.ascii_letters
|
||||
+ string.digits
|
||||
+ string.punctuation.replace(",", "")
|
||||
+ " /t",
|
||||
min_size=1,
|
||||
average_size=40,
|
||||
max_size=8190,
|
||||
|
|
|
@ -18,45 +18,32 @@ class TestEndpointDescriptions(TestCase):
|
|||
def testTcpPortBindings(self):
|
||||
self.assertEqual(
|
||||
build(port=1234, host="example.com"),
|
||||
["tcp:port=1234:interface=example.com"]
|
||||
["tcp:port=1234:interface=example.com"],
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
build(port=8000, host="127.0.0.1"),
|
||||
["tcp:port=8000:interface=127.0.0.1"]
|
||||
build(port=8000, host="127.0.0.1"), ["tcp:port=8000:interface=127.0.0.1"]
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
build(port=8000, host="[200a::1]"),
|
||||
[r'tcp:port=8000:interface=200a\:\:1']
|
||||
build(port=8000, host="[200a::1]"), [r"tcp:port=8000:interface=200a\:\:1"]
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
build(port=8000, host="200a::1"),
|
||||
[r'tcp:port=8000:interface=200a\:\:1']
|
||||
build(port=8000, host="200a::1"), [r"tcp:port=8000:interface=200a\:\:1"]
|
||||
)
|
||||
|
||||
# incomplete port/host kwargs raise errors
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
build, port=123
|
||||
)
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
build, host="example.com"
|
||||
)
|
||||
self.assertRaises(ValueError, build, port=123)
|
||||
self.assertRaises(ValueError, build, host="example.com")
|
||||
|
||||
def testUnixSocketBinding(self):
|
||||
self.assertEqual(
|
||||
build(unix_socket="/tmp/daphne.sock"),
|
||||
["unix:/tmp/daphne.sock"]
|
||||
build(unix_socket="/tmp/daphne.sock"), ["unix:/tmp/daphne.sock"]
|
||||
)
|
||||
|
||||
def testFileDescriptorBinding(self):
|
||||
self.assertEqual(
|
||||
build(file_descriptor=5),
|
||||
["fd:fileno=5"]
|
||||
)
|
||||
self.assertEqual(build(file_descriptor=5), ["fd:fileno=5"])
|
||||
|
||||
def testMultipleEnpoints(self):
|
||||
self.assertEqual(
|
||||
|
@ -65,14 +52,16 @@ class TestEndpointDescriptions(TestCase):
|
|||
file_descriptor=123,
|
||||
unix_socket="/tmp/daphne.sock",
|
||||
port=8080,
|
||||
host="10.0.0.1"
|
||||
host="10.0.0.1",
|
||||
)
|
||||
),
|
||||
sorted([
|
||||
"tcp:port=8080:interface=10.0.0.1",
|
||||
"unix:/tmp/daphne.sock",
|
||||
"fd:fileno=123"
|
||||
])
|
||||
sorted(
|
||||
[
|
||||
"tcp:port=8080:interface=10.0.0.1",
|
||||
"unix:/tmp/daphne.sock",
|
||||
"fd:fileno=123",
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
@ -112,7 +101,9 @@ class TestCLIInterface(TestCase):
|
|||
Passes in a fake application automatically.
|
||||
"""
|
||||
cli = self.TestedCLI()
|
||||
cli.run(args + ["daphne:__version__"]) # We just pass something importable as app
|
||||
cli.run(
|
||||
args + ["daphne:__version__"]
|
||||
) # We just pass something importable as app
|
||||
# Check the server got all arguments as intended
|
||||
for key, value in server_kwargs.items():
|
||||
# Get the value and sort it if it's a list (for endpoint checking)
|
||||
|
@ -123,52 +114,30 @@ class TestCLIInterface(TestCase):
|
|||
self.assertEqual(
|
||||
value,
|
||||
actual_value,
|
||||
"Wrong value for server kwarg %s: %r != %r" % (
|
||||
key,
|
||||
value,
|
||||
actual_value,
|
||||
),
|
||||
"Wrong value for server kwarg %s: %r != %r"
|
||||
% (key, value, actual_value),
|
||||
)
|
||||
|
||||
def testCLIBasics(self):
|
||||
"""
|
||||
Tests basic endpoint generation.
|
||||
"""
|
||||
self.assertCLI([], {"endpoints": ["tcp:port=8000:interface=127.0.0.1"]})
|
||||
self.assertCLI(
|
||||
[],
|
||||
{
|
||||
"endpoints": ["tcp:port=8000:interface=127.0.0.1"],
|
||||
},
|
||||
["-p", "123"], {"endpoints": ["tcp:port=123:interface=127.0.0.1"]}
|
||||
)
|
||||
self.assertCLI(
|
||||
["-p", "123"],
|
||||
{
|
||||
"endpoints": ["tcp:port=123:interface=127.0.0.1"],
|
||||
},
|
||||
["-b", "10.0.0.1"], {"endpoints": ["tcp:port=8000:interface=10.0.0.1"]}
|
||||
)
|
||||
self.assertCLI(
|
||||
["-b", "10.0.0.1"],
|
||||
{
|
||||
"endpoints": ["tcp:port=8000:interface=10.0.0.1"],
|
||||
},
|
||||
["-b", "200a::1"], {"endpoints": [r"tcp:port=8000:interface=200a\:\:1"]}
|
||||
)
|
||||
self.assertCLI(
|
||||
["-b", "200a::1"],
|
||||
{
|
||||
"endpoints": [r'tcp:port=8000:interface=200a\:\:1'],
|
||||
},
|
||||
)
|
||||
self.assertCLI(
|
||||
["-b", "[200a::1]"],
|
||||
{
|
||||
"endpoints": [r'tcp:port=8000:interface=200a\:\:1'],
|
||||
},
|
||||
["-b", "[200a::1]"], {"endpoints": [r"tcp:port=8000:interface=200a\:\:1"]}
|
||||
)
|
||||
self.assertCLI(
|
||||
["-p", "8080", "-b", "example.com"],
|
||||
{
|
||||
"endpoints": ["tcp:port=8080:interface=example.com"],
|
||||
},
|
||||
{"endpoints": ["tcp:port=8080:interface=example.com"]},
|
||||
)
|
||||
|
||||
def testUnixSockets(self):
|
||||
|
@ -178,7 +147,7 @@ class TestCLIInterface(TestCase):
|
|||
"endpoints": [
|
||||
"tcp:port=8080:interface=127.0.0.1",
|
||||
"unix:/tmp/daphne.sock",
|
||||
],
|
||||
]
|
||||
},
|
||||
)
|
||||
self.assertCLI(
|
||||
|
@ -187,17 +156,12 @@ class TestCLIInterface(TestCase):
|
|||
"endpoints": [
|
||||
"tcp:port=8000:interface=example.com",
|
||||
"unix:/tmp/daphne.sock",
|
||||
],
|
||||
]
|
||||
},
|
||||
)
|
||||
self.assertCLI(
|
||||
["-u", "/tmp/daphne.sock", "--fd", "5"],
|
||||
{
|
||||
"endpoints": [
|
||||
"fd:fileno=5",
|
||||
"unix:/tmp/daphne.sock"
|
||||
],
|
||||
},
|
||||
{"endpoints": ["fd:fileno=5", "unix:/tmp/daphne.sock"]},
|
||||
)
|
||||
|
||||
def testMixedCLIEndpointCreation(self):
|
||||
|
@ -209,8 +173,8 @@ class TestCLIInterface(TestCase):
|
|||
{
|
||||
"endpoints": [
|
||||
"tcp:port=8080:interface=127.0.0.1",
|
||||
"unix:/tmp/daphne.sock"
|
||||
],
|
||||
"unix:/tmp/daphne.sock",
|
||||
]
|
||||
},
|
||||
)
|
||||
self.assertCLI(
|
||||
|
@ -219,7 +183,7 @@ class TestCLIInterface(TestCase):
|
|||
"endpoints": [
|
||||
"tcp:port=8080:interface=127.0.0.1",
|
||||
"tcp:port=8080:interface=127.0.0.1",
|
||||
],
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -227,11 +191,4 @@ class TestCLIInterface(TestCase):
|
|||
"""
|
||||
Tests entirely custom endpoints
|
||||
"""
|
||||
self.assertCLI(
|
||||
["-e", "imap:"],
|
||||
{
|
||||
"endpoints": [
|
||||
"imap:",
|
||||
],
|
||||
},
|
||||
)
|
||||
self.assertCLI(["-e", "imap:"], {"endpoints": ["imap:"]})
|
||||
|
|
|
@ -15,13 +15,7 @@ class TestHTTPRequest(DaphneTestCase):
|
|||
"""
|
||||
|
||||
def assert_valid_http_scope(
|
||||
self,
|
||||
scope,
|
||||
method,
|
||||
path,
|
||||
params=None,
|
||||
headers=None,
|
||||
scheme=None,
|
||||
self, scope, method, path, params=None, headers=None, scheme=None
|
||||
):
|
||||
"""
|
||||
Checks that the passed scope is a valid ASGI HTTP scope regarding types
|
||||
|
@ -29,7 +23,14 @@ class TestHTTPRequest(DaphneTestCase):
|
|||
"""
|
||||
# Check overall keys
|
||||
self.assert_key_sets(
|
||||
required_keys={"type", "http_version", "method", "path", "query_string", "headers"},
|
||||
required_keys={
|
||||
"type",
|
||||
"http_version",
|
||||
"method",
|
||||
"path",
|
||||
"query_string",
|
||||
"headers",
|
||||
},
|
||||
optional_keys={"scheme", "root_path", "client", "server"},
|
||||
actual_keys=scope.keys(),
|
||||
)
|
||||
|
@ -50,7 +51,9 @@ class TestHTTPRequest(DaphneTestCase):
|
|||
query_string = scope["query_string"]
|
||||
self.assertIsInstance(query_string, bytes)
|
||||
if params:
|
||||
self.assertEqual(query_string, parse.urlencode(params or []).encode("ascii"))
|
||||
self.assertEqual(
|
||||
query_string, parse.urlencode(params or []).encode("ascii")
|
||||
)
|
||||
# Ordering of header names is not important, but the order of values for a header
|
||||
# name is. To assert whether that order is kept, we transform both the request
|
||||
# headers and the channel message headers into a dictionary
|
||||
|
@ -59,7 +62,7 @@ class TestHTTPRequest(DaphneTestCase):
|
|||
for name, value in scope["headers"]:
|
||||
transformed_scope_headers[name].append(value)
|
||||
transformed_request_headers = collections.defaultdict(list)
|
||||
for name, value in (headers or []):
|
||||
for name, value in headers or []:
|
||||
expected_name = name.lower().strip().encode("ascii")
|
||||
expected_value = value.strip().encode("ascii")
|
||||
transformed_request_headers[expected_name].append(expected_value)
|
||||
|
@ -103,27 +106,31 @@ class TestHTTPRequest(DaphneTestCase):
|
|||
|
||||
@given(
|
||||
request_path=http_strategies.http_path(),
|
||||
request_params=http_strategies.query_params()
|
||||
request_params=http_strategies.query_params(),
|
||||
)
|
||||
@settings(max_examples=5, deadline=5000)
|
||||
def test_get_request(self, request_path, request_params):
|
||||
"""
|
||||
Tests a typical HTTP GET request, with a path and query parameters
|
||||
"""
|
||||
scope, messages = self.run_daphne_request("GET", request_path, params=request_params)
|
||||
scope, messages = self.run_daphne_request(
|
||||
"GET", request_path, params=request_params
|
||||
)
|
||||
self.assert_valid_http_scope(scope, "GET", request_path, params=request_params)
|
||||
self.assert_valid_http_request_message(messages[0], body=b"")
|
||||
|
||||
@given(
|
||||
request_path=http_strategies.http_path(),
|
||||
request_body=http_strategies.http_body()
|
||||
request_body=http_strategies.http_body(),
|
||||
)
|
||||
@settings(max_examples=5, deadline=5000)
|
||||
def test_post_request(self, request_path, request_body):
|
||||
"""
|
||||
Tests a typical HTTP POST request, with a path and body.
|
||||
"""
|
||||
scope, messages = self.run_daphne_request("POST", request_path, body=request_body)
|
||||
scope, messages = self.run_daphne_request(
|
||||
"POST", request_path, body=request_body
|
||||
)
|
||||
self.assert_valid_http_scope(scope, "POST", request_path)
|
||||
self.assert_valid_http_request_message(messages[0], body=request_body)
|
||||
|
||||
|
@ -134,8 +141,12 @@ class TestHTTPRequest(DaphneTestCase):
|
|||
Tests that HTTP header fields are handled as specified
|
||||
"""
|
||||
request_path = "/te st-à/"
|
||||
scope, messages = self.run_daphne_request("OPTIONS", request_path, headers=request_headers)
|
||||
self.assert_valid_http_scope(scope, "OPTIONS", request_path, headers=request_headers)
|
||||
scope, messages = self.run_daphne_request(
|
||||
"OPTIONS", request_path, headers=request_headers
|
||||
)
|
||||
self.assert_valid_http_scope(
|
||||
scope, "OPTIONS", request_path, headers=request_headers
|
||||
)
|
||||
self.assert_valid_http_request_message(messages[0], body=b"")
|
||||
|
||||
@given(request_headers=http_strategies.headers())
|
||||
|
@ -150,8 +161,12 @@ class TestHTTPRequest(DaphneTestCase):
|
|||
duplicated_headers = [(header_name, header[1]) for header in request_headers]
|
||||
# Run the request
|
||||
request_path = "/te st-à/"
|
||||
scope, messages = self.run_daphne_request("OPTIONS", request_path, headers=duplicated_headers)
|
||||
self.assert_valid_http_scope(scope, "OPTIONS", request_path, headers=duplicated_headers)
|
||||
scope, messages = self.run_daphne_request(
|
||||
"OPTIONS", request_path, headers=duplicated_headers
|
||||
)
|
||||
self.assert_valid_http_scope(
|
||||
scope, "OPTIONS", request_path, headers=duplicated_headers
|
||||
)
|
||||
self.assert_valid_http_request_message(messages[0], body=b"")
|
||||
|
||||
@given(
|
||||
|
@ -222,10 +237,7 @@ class TestHTTPRequest(DaphneTestCase):
|
|||
"""
|
||||
Make sure that, by default, X-Forwarded-For is ignored.
|
||||
"""
|
||||
headers = [
|
||||
["X-Forwarded-For", "10.1.2.3"],
|
||||
["X-Forwarded-Port", "80"],
|
||||
]
|
||||
headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]]
|
||||
scope, messages = self.run_daphne_request("GET", "/", headers=headers)
|
||||
self.assert_valid_http_scope(scope, "GET", "/", headers=headers)
|
||||
self.assert_valid_http_request_message(messages[0], body=b"")
|
||||
|
@ -236,10 +248,7 @@ class TestHTTPRequest(DaphneTestCase):
|
|||
"""
|
||||
When X-Forwarded-For is enabled, make sure it is respected.
|
||||
"""
|
||||
headers = [
|
||||
["X-Forwarded-For", "10.1.2.3"],
|
||||
["X-Forwarded-Port", "80"],
|
||||
]
|
||||
headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]]
|
||||
scope, messages = self.run_daphne_request("GET", "/", headers=headers, xff=True)
|
||||
self.assert_valid_http_scope(scope, "GET", "/", headers=headers)
|
||||
self.assert_valid_http_request_message(messages[0], body=b"")
|
||||
|
@ -251,9 +260,7 @@ class TestHTTPRequest(DaphneTestCase):
|
|||
When X-Forwarded-For is enabled but only the host is passed, make sure
|
||||
that at least makes it through.
|
||||
"""
|
||||
headers = [
|
||||
["X-Forwarded-For", "10.1.2.3"],
|
||||
]
|
||||
headers = [["X-Forwarded-For", "10.1.2.3"]]
|
||||
scope, messages = self.run_daphne_request("GET", "/", headers=headers, xff=True)
|
||||
self.assert_valid_http_scope(scope, "GET", "/", headers=headers)
|
||||
self.assert_valid_http_request_message(messages[0], body=b"")
|
||||
|
@ -265,8 +272,12 @@ class TestHTTPRequest(DaphneTestCase):
|
|||
Tests that requests with invalid (non-ASCII) characters fail.
|
||||
"""
|
||||
# Bad path
|
||||
response = self.run_daphne_raw(b"GET /\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n")
|
||||
response = self.run_daphne_raw(
|
||||
b"GET /\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n"
|
||||
)
|
||||
self.assertTrue(response.startswith(b"HTTP/1.0 400 Bad Request"))
|
||||
# Bad querystring
|
||||
response = self.run_daphne_raw(b"GET /?\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n")
|
||||
response = self.run_daphne_raw(
|
||||
b"GET /?\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n"
|
||||
)
|
||||
self.assertTrue(response.startswith(b"HTTP/1.0 400 Bad Request"))
|
||||
|
|
|
@ -15,26 +15,24 @@ class TestHTTPResponse(DaphneTestCase):
|
|||
"""
|
||||
Lowercases and sorts headers, and strips transfer-encoding ones.
|
||||
"""
|
||||
return sorted([
|
||||
(name.lower(), value.strip())
|
||||
for name, value in headers
|
||||
if name.lower() != "transfer-encoding"
|
||||
])
|
||||
return sorted(
|
||||
[
|
||||
(name.lower(), value.strip())
|
||||
for name, value in headers
|
||||
if name.lower() != "transfer-encoding"
|
||||
]
|
||||
)
|
||||
|
||||
def test_minimal_response(self):
|
||||
"""
|
||||
Smallest viable example. Mostly verifies that our response building works.
|
||||
"""
|
||||
response = self.run_daphne_response([
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 200,
|
||||
},
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": b"hello world",
|
||||
},
|
||||
])
|
||||
response = self.run_daphne_response(
|
||||
[
|
||||
{"type": "http.response.start", "status": 200},
|
||||
{"type": "http.response.body", "body": b"hello world"},
|
||||
]
|
||||
)
|
||||
self.assertEqual(response.status, 200)
|
||||
self.assertEqual(response.read(), b"hello world")
|
||||
|
||||
|
@ -46,30 +44,23 @@ class TestHTTPResponse(DaphneTestCase):
|
|||
to make sure it stays required.
|
||||
"""
|
||||
with self.assertRaises(ValueError):
|
||||
self.run_daphne_response([
|
||||
{
|
||||
"type": "http.response.start",
|
||||
},
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": b"hello world",
|
||||
},
|
||||
])
|
||||
self.run_daphne_response(
|
||||
[
|
||||
{"type": "http.response.start"},
|
||||
{"type": "http.response.body", "body": b"hello world"},
|
||||
]
|
||||
)
|
||||
|
||||
def test_custom_status_code(self):
|
||||
"""
|
||||
Tries a non-default status code.
|
||||
"""
|
||||
response = self.run_daphne_response([
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 201,
|
||||
},
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": b"i made a thing!",
|
||||
},
|
||||
])
|
||||
response = self.run_daphne_response(
|
||||
[
|
||||
{"type": "http.response.start", "status": 201},
|
||||
{"type": "http.response.body", "body": b"i made a thing!"},
|
||||
]
|
||||
)
|
||||
self.assertEqual(response.status, 201)
|
||||
self.assertEqual(response.read(), b"i made a thing!")
|
||||
|
||||
|
@ -77,21 +68,13 @@ class TestHTTPResponse(DaphneTestCase):
|
|||
"""
|
||||
Tries sending a response in multiple parts.
|
||||
"""
|
||||
response = self.run_daphne_response([
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 201,
|
||||
},
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": b"chunk 1 ",
|
||||
"more_body": True,
|
||||
},
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": b"chunk 2",
|
||||
},
|
||||
])
|
||||
response = self.run_daphne_response(
|
||||
[
|
||||
{"type": "http.response.start", "status": 201},
|
||||
{"type": "http.response.body", "body": b"chunk 1 ", "more_body": True},
|
||||
{"type": "http.response.body", "body": b"chunk 2"},
|
||||
]
|
||||
)
|
||||
self.assertEqual(response.status, 201)
|
||||
self.assertEqual(response.read(), b"chunk 1 chunk 2")
|
||||
|
||||
|
@ -99,25 +82,14 @@ class TestHTTPResponse(DaphneTestCase):
|
|||
"""
|
||||
Tries sending a response in multiple parts and an empty end.
|
||||
"""
|
||||
response = self.run_daphne_response([
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 201,
|
||||
},
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": b"chunk 1 ",
|
||||
"more_body": True,
|
||||
},
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": b"chunk 2",
|
||||
"more_body": True,
|
||||
},
|
||||
{
|
||||
"type": "http.response.body",
|
||||
},
|
||||
])
|
||||
response = self.run_daphne_response(
|
||||
[
|
||||
{"type": "http.response.start", "status": 201},
|
||||
{"type": "http.response.body", "body": b"chunk 1 ", "more_body": True},
|
||||
{"type": "http.response.body", "body": b"chunk 2", "more_body": True},
|
||||
{"type": "http.response.body"},
|
||||
]
|
||||
)
|
||||
self.assertEqual(response.status, 201)
|
||||
self.assertEqual(response.read(), b"chunk 1 chunk 2")
|
||||
|
||||
|
@ -127,16 +99,12 @@ class TestHTTPResponse(DaphneTestCase):
|
|||
"""
|
||||
Tries body variants.
|
||||
"""
|
||||
response = self.run_daphne_response([
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 200,
|
||||
},
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": body,
|
||||
},
|
||||
])
|
||||
response = self.run_daphne_response(
|
||||
[
|
||||
{"type": "http.response.start", "status": 200},
|
||||
{"type": "http.response.body", "body": body},
|
||||
]
|
||||
)
|
||||
self.assertEqual(response.status, 200)
|
||||
self.assertEqual(response.read(), body)
|
||||
|
||||
|
@ -144,16 +112,16 @@ class TestHTTPResponse(DaphneTestCase):
|
|||
@settings(max_examples=5, deadline=5000)
|
||||
def test_headers(self, headers):
|
||||
# The ASGI spec requires us to lowercase our header names
|
||||
response = self.run_daphne_response([
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 200,
|
||||
"headers": self.normalize_headers(headers),
|
||||
},
|
||||
{
|
||||
"type": "http.response.body",
|
||||
},
|
||||
])
|
||||
response = self.run_daphne_response(
|
||||
[
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 200,
|
||||
"headers": self.normalize_headers(headers),
|
||||
},
|
||||
{"type": "http.response.body"},
|
||||
]
|
||||
)
|
||||
# Check headers in a sensible way. Ignore transfer-encoding.
|
||||
self.assertEqual(
|
||||
self.normalize_headers(response.getheaders()),
|
||||
|
|
|
@ -13,48 +13,35 @@ class TestXForwardedForHttpParsing(TestCase):
|
|||
"""
|
||||
|
||||
def test_basic(self):
|
||||
headers = Headers({
|
||||
b"X-Forwarded-For": [b"10.1.2.3"],
|
||||
b"X-Forwarded-Port": [b"1234"],
|
||||
b"X-Forwarded-Proto": [b"https"]
|
||||
})
|
||||
headers = Headers(
|
||||
{
|
||||
b"X-Forwarded-For": [b"10.1.2.3"],
|
||||
b"X-Forwarded-Port": [b"1234"],
|
||||
b"X-Forwarded-Proto": [b"https"],
|
||||
}
|
||||
)
|
||||
result = parse_x_forwarded_for(headers)
|
||||
self.assertEqual(result, (["10.1.2.3", 1234], "https"))
|
||||
self.assertIsInstance(result[0][0], str)
|
||||
self.assertIsInstance(result[1], str)
|
||||
|
||||
def test_address_only(self):
|
||||
headers = Headers({
|
||||
b"X-Forwarded-For": [b"10.1.2.3"],
|
||||
})
|
||||
self.assertEqual(
|
||||
parse_x_forwarded_for(headers),
|
||||
(["10.1.2.3", 0], None)
|
||||
)
|
||||
headers = Headers({b"X-Forwarded-For": [b"10.1.2.3"]})
|
||||
self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None))
|
||||
|
||||
def test_v6_address(self):
|
||||
headers = Headers({
|
||||
b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"],
|
||||
})
|
||||
self.assertEqual(
|
||||
parse_x_forwarded_for(headers),
|
||||
(["1043::a321:0001", 0], None)
|
||||
)
|
||||
headers = Headers({b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"]})
|
||||
self.assertEqual(parse_x_forwarded_for(headers), (["1043::a321:0001", 0], None))
|
||||
|
||||
def test_multiple_proxys(self):
|
||||
headers = Headers({
|
||||
b"X-Forwarded-For": [b"10.1.2.3, 10.1.2.4"],
|
||||
})
|
||||
self.assertEqual(
|
||||
parse_x_forwarded_for(headers),
|
||||
(["10.1.2.3", 0], None)
|
||||
)
|
||||
headers = Headers({b"X-Forwarded-For": [b"10.1.2.3, 10.1.2.4"]})
|
||||
self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None))
|
||||
|
||||
def test_original(self):
|
||||
headers = Headers({})
|
||||
self.assertEqual(
|
||||
parse_x_forwarded_for(headers, original_addr=["127.0.0.1", 80]),
|
||||
(["127.0.0.1", 80], None)
|
||||
(["127.0.0.1", 80], None),
|
||||
)
|
||||
|
||||
def test_no_original(self):
|
||||
|
@ -73,43 +60,25 @@ class TestXForwardedForWsParsing(TestCase):
|
|||
b"X-Forwarded-Port": b"1234",
|
||||
b"X-Forwarded-Proto": b"https",
|
||||
}
|
||||
self.assertEqual(
|
||||
parse_x_forwarded_for(headers),
|
||||
(["10.1.2.3", 1234], "https")
|
||||
)
|
||||
self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 1234], "https"))
|
||||
|
||||
def test_address_only(self):
|
||||
headers = {
|
||||
b"X-Forwarded-For": b"10.1.2.3",
|
||||
}
|
||||
self.assertEqual(
|
||||
parse_x_forwarded_for(headers),
|
||||
(["10.1.2.3", 0], None)
|
||||
)
|
||||
headers = {b"X-Forwarded-For": b"10.1.2.3"}
|
||||
self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None))
|
||||
|
||||
def test_v6_address(self):
|
||||
headers = {
|
||||
b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"],
|
||||
}
|
||||
self.assertEqual(
|
||||
parse_x_forwarded_for(headers),
|
||||
(["1043::a321:0001", 0], None)
|
||||
)
|
||||
headers = {b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"]}
|
||||
self.assertEqual(parse_x_forwarded_for(headers), (["1043::a321:0001", 0], None))
|
||||
|
||||
def test_multiple_proxies(self):
|
||||
headers = {
|
||||
b"X-Forwarded-For": b"10.1.2.3, 10.1.2.4",
|
||||
}
|
||||
self.assertEqual(
|
||||
parse_x_forwarded_for(headers),
|
||||
(["10.1.2.3", 0], None)
|
||||
)
|
||||
headers = {b"X-Forwarded-For": b"10.1.2.3, 10.1.2.4"}
|
||||
self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None))
|
||||
|
||||
def test_original(self):
|
||||
headers = {}
|
||||
self.assertEqual(
|
||||
parse_x_forwarded_for(headers, original_addr=["127.0.0.1", 80]),
|
||||
(["127.0.0.1", 80], None)
|
||||
(["127.0.0.1", 80], None),
|
||||
)
|
||||
|
||||
def test_no_original(self):
|
||||
|
|
|
@ -16,13 +16,7 @@ class TestWebsocket(DaphneTestCase):
|
|||
"""
|
||||
|
||||
def assert_valid_websocket_scope(
|
||||
self,
|
||||
scope,
|
||||
path="/",
|
||||
params=None,
|
||||
headers=None,
|
||||
scheme=None,
|
||||
subprotocols=None,
|
||||
self, scope, path="/", params=None, headers=None, scheme=None, subprotocols=None
|
||||
):
|
||||
"""
|
||||
Checks that the passed scope is a valid ASGI HTTP scope regarding types
|
||||
|
@ -46,7 +40,9 @@ class TestWebsocket(DaphneTestCase):
|
|||
query_string = scope["query_string"]
|
||||
self.assertIsInstance(query_string, bytes)
|
||||
if params:
|
||||
self.assertEqual(query_string, parse.urlencode(params or []).encode("ascii"))
|
||||
self.assertEqual(
|
||||
query_string, parse.urlencode(params or []).encode("ascii")
|
||||
)
|
||||
# Ordering of header names is not important, but the order of values for a header
|
||||
# name is. To assert whether that order is kept, we transform both the request
|
||||
# headers and the channel message headers into a dictionary
|
||||
|
@ -59,7 +55,7 @@ class TestWebsocket(DaphneTestCase):
|
|||
if bit.strip():
|
||||
transformed_scope_headers[name].append(bit.strip())
|
||||
transformed_request_headers = collections.defaultdict(list)
|
||||
for name, value in (headers or []):
|
||||
for name, value in headers or []:
|
||||
expected_name = name.lower().strip().encode("ascii")
|
||||
expected_value = value.strip().encode("ascii")
|
||||
# Make sure to split out any headers collapsed with commas
|
||||
|
@ -92,9 +88,7 @@ class TestWebsocket(DaphneTestCase):
|
|||
"""
|
||||
# Check overall keys
|
||||
self.assert_key_sets(
|
||||
required_keys={"type"},
|
||||
optional_keys=set(),
|
||||
actual_keys=message.keys(),
|
||||
required_keys={"type"}, optional_keys=set(), actual_keys=message.keys()
|
||||
)
|
||||
# Check that it is the right type
|
||||
self.assertEqual(message["type"], "websocket.connect")
|
||||
|
@ -104,11 +98,7 @@ class TestWebsocket(DaphneTestCase):
|
|||
Tests we can open and accept a socket.
|
||||
"""
|
||||
with DaphneTestingInstance() as test_app:
|
||||
test_app.add_send_messages([
|
||||
{
|
||||
"type": "websocket.accept",
|
||||
}
|
||||
])
|
||||
test_app.add_send_messages([{"type": "websocket.accept"}])
|
||||
self.websocket_handshake(test_app)
|
||||
# Validate the scope and messages we got
|
||||
scope, messages = test_app.get_received()
|
||||
|
@ -120,11 +110,7 @@ class TestWebsocket(DaphneTestCase):
|
|||
Tests we can reject a socket and it won't complete the handshake.
|
||||
"""
|
||||
with DaphneTestingInstance() as test_app:
|
||||
test_app.add_send_messages([
|
||||
{
|
||||
"type": "websocket.close",
|
||||
}
|
||||
])
|
||||
test_app.add_send_messages([{"type": "websocket.close"}])
|
||||
with self.assertRaises(RuntimeError):
|
||||
self.websocket_handshake(test_app)
|
||||
|
||||
|
@ -134,13 +120,12 @@ class TestWebsocket(DaphneTestCase):
|
|||
"""
|
||||
subprotocols = ["proto1", "proto2"]
|
||||
with DaphneTestingInstance() as test_app:
|
||||
test_app.add_send_messages([
|
||||
{
|
||||
"type": "websocket.accept",
|
||||
"subprotocol": "proto2",
|
||||
}
|
||||
])
|
||||
_, subprotocol = self.websocket_handshake(test_app, subprotocols=subprotocols)
|
||||
test_app.add_send_messages(
|
||||
[{"type": "websocket.accept", "subprotocol": "proto2"}]
|
||||
)
|
||||
_, subprotocol = self.websocket_handshake(
|
||||
test_app, subprotocols=subprotocols
|
||||
)
|
||||
# Validate the scope and messages we got
|
||||
assert subprotocol == "proto2"
|
||||
scope, messages = test_app.get_received()
|
||||
|
@ -151,16 +136,9 @@ class TestWebsocket(DaphneTestCase):
|
|||
"""
|
||||
Tests that X-Forwarded-For headers get parsed right
|
||||
"""
|
||||
headers = [
|
||||
["X-Forwarded-For", "10.1.2.3"],
|
||||
["X-Forwarded-Port", "80"],
|
||||
]
|
||||
headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]]
|
||||
with DaphneTestingInstance(xff=True) as test_app:
|
||||
test_app.add_send_messages([
|
||||
{
|
||||
"type": "websocket.accept",
|
||||
}
|
||||
])
|
||||
test_app.add_send_messages([{"type": "websocket.accept"}])
|
||||
self.websocket_handshake(test_app, headers=headers)
|
||||
# Validate the scope and messages we got
|
||||
scope, messages = test_app.get_received()
|
||||
|
@ -174,22 +152,13 @@ class TestWebsocket(DaphneTestCase):
|
|||
request_headers=http_strategies.headers(),
|
||||
)
|
||||
@settings(max_examples=5, deadline=2000)
|
||||
def test_http_bits(
|
||||
self,
|
||||
request_path,
|
||||
request_params,
|
||||
request_headers,
|
||||
):
|
||||
def test_http_bits(self, request_path, request_params, request_headers):
|
||||
"""
|
||||
Tests that various HTTP-level bits (query string params, path, headers)
|
||||
carry over into the scope.
|
||||
"""
|
||||
with DaphneTestingInstance() as test_app:
|
||||
test_app.add_send_messages([
|
||||
{
|
||||
"type": "websocket.accept",
|
||||
}
|
||||
])
|
||||
test_app.add_send_messages([{"type": "websocket.accept"}])
|
||||
self.websocket_handshake(
|
||||
test_app,
|
||||
path=request_path,
|
||||
|
@ -199,10 +168,7 @@ class TestWebsocket(DaphneTestCase):
|
|||
# Validate the scope and messages we got
|
||||
scope, messages = test_app.get_received()
|
||||
self.assert_valid_websocket_scope(
|
||||
scope,
|
||||
path=request_path,
|
||||
params=request_params,
|
||||
headers=request_headers,
|
||||
scope, path=request_path, params=request_params, headers=request_headers
|
||||
)
|
||||
self.assert_valid_websocket_connect_message(messages[0])
|
||||
|
||||
|
@ -212,28 +178,24 @@ class TestWebsocket(DaphneTestCase):
|
|||
"""
|
||||
with DaphneTestingInstance() as test_app:
|
||||
# Connect
|
||||
test_app.add_send_messages([
|
||||
{
|
||||
"type": "websocket.accept",
|
||||
}
|
||||
])
|
||||
test_app.add_send_messages([{"type": "websocket.accept"}])
|
||||
sock, _ = self.websocket_handshake(test_app)
|
||||
_, messages = test_app.get_received()
|
||||
self.assert_valid_websocket_connect_message(messages[0])
|
||||
# Prep frame for it to send
|
||||
test_app.add_send_messages([
|
||||
{
|
||||
"type": "websocket.send",
|
||||
"text": "here be dragons 🐉",
|
||||
}
|
||||
])
|
||||
test_app.add_send_messages(
|
||||
[{"type": "websocket.send", "text": "here be dragons 🐉"}]
|
||||
)
|
||||
# Send it a frame
|
||||
self.websocket_send_frame(sock, "what is here? 🌍")
|
||||
# Receive a frame and make sure it's correct
|
||||
assert self.websocket_receive_frame(sock) == "here be dragons 🐉"
|
||||
# Make sure it got our frame
|
||||
_, messages = test_app.get_received()
|
||||
assert messages[1] == {"type": "websocket.receive", "text": "what is here? 🌍"}
|
||||
assert messages[1] == {
|
||||
"type": "websocket.receive",
|
||||
"text": "what is here? 🌍",
|
||||
}
|
||||
|
||||
def test_binary_frames(self):
|
||||
"""
|
||||
|
@ -242,28 +204,24 @@ class TestWebsocket(DaphneTestCase):
|
|||
"""
|
||||
with DaphneTestingInstance() as test_app:
|
||||
# Connect
|
||||
test_app.add_send_messages([
|
||||
{
|
||||
"type": "websocket.accept",
|
||||
}
|
||||
])
|
||||
test_app.add_send_messages([{"type": "websocket.accept"}])
|
||||
sock, _ = self.websocket_handshake(test_app)
|
||||
_, messages = test_app.get_received()
|
||||
self.assert_valid_websocket_connect_message(messages[0])
|
||||
# Prep frame for it to send
|
||||
test_app.add_send_messages([
|
||||
{
|
||||
"type": "websocket.send",
|
||||
"bytes": b"here be \xe2 bytes",
|
||||
}
|
||||
])
|
||||
test_app.add_send_messages(
|
||||
[{"type": "websocket.send", "bytes": b"here be \xe2 bytes"}]
|
||||
)
|
||||
# Send it a frame
|
||||
self.websocket_send_frame(sock, b"what is here? \xe2")
|
||||
# Receive a frame and make sure it's correct
|
||||
assert self.websocket_receive_frame(sock) == b"here be \xe2 bytes"
|
||||
# Make sure it got our frame
|
||||
_, messages = test_app.get_received()
|
||||
assert messages[1] == {"type": "websocket.receive", "bytes": b"what is here? \xe2"}
|
||||
assert messages[1] == {
|
||||
"type": "websocket.receive",
|
||||
"bytes": b"what is here? \xe2",
|
||||
}
|
||||
|
||||
def test_http_timeout(self):
|
||||
"""
|
||||
|
@ -271,23 +229,14 @@ class TestWebsocket(DaphneTestCase):
|
|||
"""
|
||||
with DaphneTestingInstance(http_timeout=1) as test_app:
|
||||
# Connect
|
||||
test_app.add_send_messages([
|
||||
{
|
||||
"type": "websocket.accept",
|
||||
}
|
||||
])
|
||||
test_app.add_send_messages([{"type": "websocket.accept"}])
|
||||
sock, _ = self.websocket_handshake(test_app)
|
||||
_, messages = test_app.get_received()
|
||||
self.assert_valid_websocket_connect_message(messages[0])
|
||||
# Wait 2 seconds
|
||||
time.sleep(2)
|
||||
# Prep frame for it to send
|
||||
test_app.add_send_messages([
|
||||
{
|
||||
"type": "websocket.send",
|
||||
"text": "cake",
|
||||
}
|
||||
])
|
||||
test_app.add_send_messages([{"type": "websocket.send", "text": "cake"}])
|
||||
# Send it a frame
|
||||
self.websocket_send_frame(sock, "still alive?")
|
||||
# Receive a frame and make sure it's correct
|
||||
|
|
Loading…
Reference in New Issue
Block a user