mirror of
https://github.com/django/daphne.git
synced 2024-11-21 23:46:33 +03:00
Implement Black code formatting
This commit is contained in:
parent
88792984e7
commit
0ed6294406
22
.travis.yml
22
.travis.yml
|
@ -3,22 +3,25 @@ sudo: false
|
||||||
language: python
|
language: python
|
||||||
|
|
||||||
python:
|
python:
|
||||||
- '3.5'
|
|
||||||
- '3.6'
|
- '3.6'
|
||||||
|
- '3.5'
|
||||||
|
|
||||||
env:
|
env:
|
||||||
- TWISTED="twisted==18.7.0"
|
|
||||||
- TWISTED="twisted"
|
- TWISTED="twisted"
|
||||||
|
- TWISTED="twisted==18.7.0"
|
||||||
|
|
||||||
install:
|
install:
|
||||||
- pip install $TWISTED isort unify flake8 -e .[tests]
|
- pip install $TWISTED -e .[tests]
|
||||||
- pip freeze
|
- pip freeze
|
||||||
|
|
||||||
script:
|
script:
|
||||||
- pytest
|
- pytest
|
||||||
- flake8
|
|
||||||
- isort --check-only --diff --recursive daphne tests
|
stages:
|
||||||
- unify --check-only --recursive --quote \" daphne tests
|
- lint
|
||||||
|
- test
|
||||||
|
- name: release
|
||||||
|
if: branch = master
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
include:
|
include:
|
||||||
|
@ -30,6 +33,13 @@ jobs:
|
||||||
env: TWISTED="twisted"
|
env: TWISTED="twisted"
|
||||||
dist: xenial
|
dist: xenial
|
||||||
sudo: required
|
sudo: required
|
||||||
|
- stage: lint
|
||||||
|
install: pip install -U -e .[tests] black pyflakes isort
|
||||||
|
script:
|
||||||
|
- pyflakes .
|
||||||
|
- black --check .
|
||||||
|
- isort --check-only --diff --recursive channels_redis tests
|
||||||
|
|
||||||
- stage: release
|
- stage: release
|
||||||
script: skip
|
script: skip
|
||||||
deploy:
|
deploy:
|
||||||
|
|
|
@ -49,13 +49,16 @@ class AccessLogGenerator(object):
|
||||||
request="WSDISCONNECT %(path)s" % details,
|
request="WSDISCONNECT %(path)s" % details,
|
||||||
)
|
)
|
||||||
|
|
||||||
def write_entry(self, host, date, request, status=None, length=None, ident=None, user=None):
|
def write_entry(
|
||||||
|
self, host, date, request, status=None, length=None, ident=None, user=None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Writes an NCSA-style entry to the log file (some liberty is taken with
|
Writes an NCSA-style entry to the log file (some liberty is taken with
|
||||||
what the entries are for non-HTTP)
|
what the entries are for non-HTTP)
|
||||||
"""
|
"""
|
||||||
self.stream.write(
|
self.stream.write(
|
||||||
"%s %s %s [%s] \"%s\" %s %s\n" % (
|
'%s %s %s [%s] "%s" %s %s\n'
|
||||||
|
% (
|
||||||
host,
|
host,
|
||||||
ident or "-",
|
ident or "-",
|
||||||
user or "-",
|
user or "-",
|
||||||
|
|
|
@ -23,15 +23,9 @@ class CommandLineInterface(object):
|
||||||
server_class = Server
|
server_class = Server
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.parser = argparse.ArgumentParser(
|
self.parser = argparse.ArgumentParser(description=self.description)
|
||||||
description=self.description,
|
|
||||||
)
|
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"-p",
|
"-p", "--port", type=int, help="Port number to listen on", default=None
|
||||||
"--port",
|
|
||||||
type=int,
|
|
||||||
help="Port number to listen on",
|
|
||||||
default=None,
|
|
||||||
)
|
)
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"-b",
|
"-b",
|
||||||
|
@ -128,7 +122,7 @@ class CommandLineInterface(object):
|
||||||
"--proxy-headers",
|
"--proxy-headers",
|
||||||
dest="proxy_headers",
|
dest="proxy_headers",
|
||||||
help="Enable parsing and using of X-Forwarded-For and X-Forwarded-Port headers and using that as the "
|
help="Enable parsing and using of X-Forwarded-For and X-Forwarded-Port headers and using that as the "
|
||||||
"client address",
|
"client address",
|
||||||
default=False,
|
default=False,
|
||||||
action="store_true",
|
action="store_true",
|
||||||
)
|
)
|
||||||
|
@ -176,7 +170,15 @@ class CommandLineInterface(object):
|
||||||
sys.path.insert(0, ".")
|
sys.path.insert(0, ".")
|
||||||
application = import_by_path(args.application)
|
application = import_by_path(args.application)
|
||||||
# Set up port/host bindings
|
# Set up port/host bindings
|
||||||
if not any([args.host, args.port is not None, args.unix_socket, args.file_descriptor, args.socket_strings]):
|
if not any(
|
||||||
|
[
|
||||||
|
args.host,
|
||||||
|
args.port is not None,
|
||||||
|
args.unix_socket,
|
||||||
|
args.file_descriptor,
|
||||||
|
args.socket_strings,
|
||||||
|
]
|
||||||
|
):
|
||||||
# no advanced binding options passed, patch in defaults
|
# no advanced binding options passed, patch in defaults
|
||||||
args.host = DEFAULT_HOST
|
args.host = DEFAULT_HOST
|
||||||
args.port = DEFAULT_PORT
|
args.port = DEFAULT_PORT
|
||||||
|
@ -189,16 +191,11 @@ class CommandLineInterface(object):
|
||||||
host=args.host,
|
host=args.host,
|
||||||
port=args.port,
|
port=args.port,
|
||||||
unix_socket=args.unix_socket,
|
unix_socket=args.unix_socket,
|
||||||
file_descriptor=args.file_descriptor
|
file_descriptor=args.file_descriptor,
|
||||||
)
|
|
||||||
endpoints = sorted(
|
|
||||||
args.socket_strings + endpoints
|
|
||||||
)
|
)
|
||||||
|
endpoints = sorted(args.socket_strings + endpoints)
|
||||||
# Start the server
|
# Start the server
|
||||||
logger.info(
|
logger.info("Starting server at %s" % (", ".join(endpoints),))
|
||||||
"Starting server at %s" %
|
|
||||||
(", ".join(endpoints), )
|
|
||||||
)
|
|
||||||
self.server = self.server_class(
|
self.server = self.server_class(
|
||||||
application=application,
|
application=application,
|
||||||
endpoints=endpoints,
|
endpoints=endpoints,
|
||||||
|
@ -208,12 +205,20 @@ class CommandLineInterface(object):
|
||||||
websocket_timeout=args.websocket_timeout,
|
websocket_timeout=args.websocket_timeout,
|
||||||
websocket_connect_timeout=args.websocket_connect_timeout,
|
websocket_connect_timeout=args.websocket_connect_timeout,
|
||||||
application_close_timeout=args.application_close_timeout,
|
application_close_timeout=args.application_close_timeout,
|
||||||
action_logger=AccessLogGenerator(access_log_stream) if access_log_stream else None,
|
action_logger=AccessLogGenerator(access_log_stream)
|
||||||
|
if access_log_stream
|
||||||
|
else None,
|
||||||
ws_protocols=args.ws_protocols,
|
ws_protocols=args.ws_protocols,
|
||||||
root_path=args.root_path,
|
root_path=args.root_path,
|
||||||
verbosity=args.verbosity,
|
verbosity=args.verbosity,
|
||||||
proxy_forwarded_address_header="X-Forwarded-For" if args.proxy_headers else None,
|
proxy_forwarded_address_header="X-Forwarded-For"
|
||||||
proxy_forwarded_port_header="X-Forwarded-Port" if args.proxy_headers else None,
|
if args.proxy_headers
|
||||||
proxy_forwarded_proto_header="X-Forwarded-Proto" if args.proxy_headers else None,
|
else None,
|
||||||
|
proxy_forwarded_port_header="X-Forwarded-Port"
|
||||||
|
if args.proxy_headers
|
||||||
|
else None,
|
||||||
|
proxy_forwarded_proto_header="X-Forwarded-Proto"
|
||||||
|
if args.proxy_headers
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
self.server.run()
|
self.server.run()
|
||||||
|
|
|
@ -1,10 +1,5 @@
|
||||||
|
|
||||||
|
|
||||||
def build_endpoint_description_strings(
|
def build_endpoint_description_strings(
|
||||||
host=None,
|
host=None, port=None, unix_socket=None, file_descriptor=None
|
||||||
port=None,
|
|
||||||
unix_socket=None,
|
|
||||||
file_descriptor=None
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Build a list of twisted endpoint description strings that the server will listen on.
|
Build a list of twisted endpoint description strings that the server will listen on.
|
||||||
|
|
|
@ -23,7 +23,8 @@ class WebRequest(http.Request):
|
||||||
GET and POST out.
|
GET and POST out.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
error_template = """
|
error_template = (
|
||||||
|
"""
|
||||||
<html>
|
<html>
|
||||||
<head>
|
<head>
|
||||||
<title>%(title)s</title>
|
<title>%(title)s</title>
|
||||||
|
@ -40,7 +41,13 @@ class WebRequest(http.Request):
|
||||||
<footer>Daphne</footer>
|
<footer>Daphne</footer>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
""".replace("\n", "").replace(" ", " ").replace(" ", " ").replace(" ", " ") # Shorten it a bit, bytes wise
|
""".replace(
|
||||||
|
"\n", ""
|
||||||
|
)
|
||||||
|
.replace(" ", " ")
|
||||||
|
.replace(" ", " ")
|
||||||
|
.replace(" ", " ")
|
||||||
|
) # Shorten it a bit, bytes wise
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
try:
|
try:
|
||||||
|
@ -84,7 +91,7 @@ class WebRequest(http.Request):
|
||||||
self.server.proxy_forwarded_port_header,
|
self.server.proxy_forwarded_port_header,
|
||||||
self.server.proxy_forwarded_proto_header,
|
self.server.proxy_forwarded_proto_header,
|
||||||
self.client_addr,
|
self.client_addr,
|
||||||
self.client_scheme
|
self.client_scheme,
|
||||||
)
|
)
|
||||||
# Check for unicodeish path (or it'll crash when trying to parse)
|
# Check for unicodeish path (or it'll crash when trying to parse)
|
||||||
try:
|
try:
|
||||||
|
@ -105,7 +112,9 @@ class WebRequest(http.Request):
|
||||||
# Is it WebSocket? IS IT?!
|
# Is it WebSocket? IS IT?!
|
||||||
if upgrade_header and upgrade_header.lower() == b"websocket":
|
if upgrade_header and upgrade_header.lower() == b"websocket":
|
||||||
# Make WebSocket protocol to hand off to
|
# Make WebSocket protocol to hand off to
|
||||||
protocol = self.server.ws_factory.buildProtocol(self.transport.getPeer())
|
protocol = self.server.ws_factory.buildProtocol(
|
||||||
|
self.transport.getPeer()
|
||||||
|
)
|
||||||
if not protocol:
|
if not protocol:
|
||||||
# If protocol creation fails, we signal "internal server error"
|
# If protocol creation fails, we signal "internal server error"
|
||||||
self.setResponseCode(500)
|
self.setResponseCode(500)
|
||||||
|
@ -151,33 +160,38 @@ class WebRequest(http.Request):
|
||||||
logger.debug("HTTP %s request for %s", self.method, self.client_addr)
|
logger.debug("HTTP %s request for %s", self.method, self.client_addr)
|
||||||
self.content.seek(0, 0)
|
self.content.seek(0, 0)
|
||||||
# Work out the application scope and create application
|
# Work out the application scope and create application
|
||||||
self.application_queue = yield maybeDeferred(self.server.create_application, self, {
|
self.application_queue = yield maybeDeferred(
|
||||||
"type": "http",
|
self.server.create_application,
|
||||||
# TODO: Correctly say if it's 1.1 or 1.0
|
self,
|
||||||
"http_version": self.clientproto.split(b"/")[-1].decode("ascii"),
|
{
|
||||||
"method": self.method.decode("ascii"),
|
"type": "http",
|
||||||
"path": unquote(self.path.decode("ascii")),
|
# TODO: Correctly say if it's 1.1 or 1.0
|
||||||
"root_path": self.root_path,
|
"http_version": self.clientproto.split(b"/")[-1].decode(
|
||||||
"scheme": self.client_scheme,
|
"ascii"
|
||||||
"query_string": self.query_string,
|
),
|
||||||
"headers": self.clean_headers,
|
"method": self.method.decode("ascii"),
|
||||||
"client": self.client_addr,
|
"path": unquote(self.path.decode("ascii")),
|
||||||
"server": self.server_addr,
|
"root_path": self.root_path,
|
||||||
})
|
"scheme": self.client_scheme,
|
||||||
|
"query_string": self.query_string,
|
||||||
|
"headers": self.clean_headers,
|
||||||
|
"client": self.client_addr,
|
||||||
|
"server": self.server_addr,
|
||||||
|
},
|
||||||
|
)
|
||||||
# Check they didn't close an unfinished request
|
# Check they didn't close an unfinished request
|
||||||
if self.application_queue is None or self.content.closed:
|
if self.application_queue is None or self.content.closed:
|
||||||
# Not much we can do, the request is prematurely abandoned.
|
# Not much we can do, the request is prematurely abandoned.
|
||||||
return
|
return
|
||||||
# Run application against request
|
# Run application against request
|
||||||
self.application_queue.put_nowait(
|
self.application_queue.put_nowait(
|
||||||
{
|
{"type": "http.request", "body": self.content.read()}
|
||||||
"type": "http.request",
|
|
||||||
"body": self.content.read(),
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
self.basic_error(500, b"Internal Server Error", "Daphne HTTP processing error")
|
self.basic_error(
|
||||||
|
500, b"Internal Server Error", "Daphne HTTP processing error"
|
||||||
|
)
|
||||||
|
|
||||||
def connectionLost(self, reason):
|
def connectionLost(self, reason):
|
||||||
"""
|
"""
|
||||||
|
@ -217,16 +231,23 @@ class WebRequest(http.Request):
|
||||||
raise ValueError("HTTP response has already been started")
|
raise ValueError("HTTP response has already been started")
|
||||||
self._response_started = True
|
self._response_started = True
|
||||||
if "status" not in message:
|
if "status" not in message:
|
||||||
raise ValueError("Specifying a status code is required for a Response message.")
|
raise ValueError(
|
||||||
|
"Specifying a status code is required for a Response message."
|
||||||
|
)
|
||||||
# Set HTTP status code
|
# Set HTTP status code
|
||||||
self.setResponseCode(message["status"])
|
self.setResponseCode(message["status"])
|
||||||
# Write headers
|
# Write headers
|
||||||
for header, value in message.get("headers", {}):
|
for header, value in message.get("headers", {}):
|
||||||
self.responseHeaders.addRawHeader(header, value)
|
self.responseHeaders.addRawHeader(header, value)
|
||||||
logger.debug("HTTP %s response started for %s", message["status"], self.client_addr)
|
logger.debug(
|
||||||
|
"HTTP %s response started for %s", message["status"], self.client_addr
|
||||||
|
)
|
||||||
elif message["type"] == "http.response.body":
|
elif message["type"] == "http.response.body":
|
||||||
if not self._response_started:
|
if not self._response_started:
|
||||||
raise ValueError("HTTP response has not yet been started but got %s" % message["type"])
|
raise ValueError(
|
||||||
|
"HTTP response has not yet been started but got %s"
|
||||||
|
% message["type"]
|
||||||
|
)
|
||||||
# Write out body
|
# Write out body
|
||||||
http.Request.write(self, message.get("body", b""))
|
http.Request.write(self, message.get("body", b""))
|
||||||
# End if there's no more content
|
# End if there's no more content
|
||||||
|
@ -239,15 +260,21 @@ class WebRequest(http.Request):
|
||||||
# The path is malformed somehow - do our best to log something
|
# The path is malformed somehow - do our best to log something
|
||||||
uri = repr(self.uri)
|
uri = repr(self.uri)
|
||||||
try:
|
try:
|
||||||
self.server.log_action("http", "complete", {
|
self.server.log_action(
|
||||||
"path": uri,
|
"http",
|
||||||
"status": self.code,
|
"complete",
|
||||||
"method": self.method.decode("ascii", "replace"),
|
{
|
||||||
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
"path": uri,
|
||||||
"time_taken": self.duration(),
|
"status": self.code,
|
||||||
"size": self.sentLength,
|
"method": self.method.decode("ascii", "replace"),
|
||||||
})
|
"client": "%s:%s" % tuple(self.client_addr)
|
||||||
except Exception as e:
|
if self.client_addr
|
||||||
|
else None,
|
||||||
|
"time_taken": self.duration(),
|
||||||
|
"size": self.sentLength,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
else:
|
else:
|
||||||
logger.debug("HTTP response chunk for %s", self.client_addr)
|
logger.debug("HTTP response chunk for %s", self.client_addr)
|
||||||
|
@ -270,7 +297,11 @@ class WebRequest(http.Request):
|
||||||
logger.warning("Application timed out while sending response")
|
logger.warning("Application timed out while sending response")
|
||||||
self.finish()
|
self.finish()
|
||||||
else:
|
else:
|
||||||
self.basic_error(503, b"Service Unavailable", "Application failed to respond within time limit.")
|
self.basic_error(
|
||||||
|
503,
|
||||||
|
b"Service Unavailable",
|
||||||
|
"Application failed to respond within time limit.",
|
||||||
|
)
|
||||||
|
|
||||||
### Utility functions
|
### Utility functions
|
||||||
|
|
||||||
|
@ -281,11 +312,7 @@ class WebRequest(http.Request):
|
||||||
"""
|
"""
|
||||||
# If we don't yet have a path, then don't send as we never opened.
|
# If we don't yet have a path, then don't send as we never opened.
|
||||||
if self.path:
|
if self.path:
|
||||||
self.application_queue.put_nowait(
|
self.application_queue.put_nowait({"type": "http.disconnect"})
|
||||||
{
|
|
||||||
"type": "http.disconnect",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def duration(self):
|
def duration(self):
|
||||||
"""
|
"""
|
||||||
|
@ -299,20 +326,25 @@ class WebRequest(http.Request):
|
||||||
"""
|
"""
|
||||||
Responds with a server-level error page (very basic)
|
Responds with a server-level error page (very basic)
|
||||||
"""
|
"""
|
||||||
self.handle_reply({
|
self.handle_reply(
|
||||||
"type": "http.response.start",
|
{
|
||||||
"status": status,
|
"type": "http.response.start",
|
||||||
"headers": [
|
"status": status,
|
||||||
(b"Content-Type", b"text/html; charset=utf-8"),
|
"headers": [(b"Content-Type", b"text/html; charset=utf-8")],
|
||||||
],
|
}
|
||||||
})
|
)
|
||||||
self.handle_reply({
|
self.handle_reply(
|
||||||
"type": "http.response.body",
|
{
|
||||||
"body": (self.error_template % {
|
"type": "http.response.body",
|
||||||
"title": str(status) + " " + status_text.decode("ascii"),
|
"body": (
|
||||||
"body": body,
|
self.error_template
|
||||||
}).encode("utf8"),
|
% {
|
||||||
})
|
"title": str(status) + " " + status_text.decode("ascii"),
|
||||||
|
"body": body,
|
||||||
|
}
|
||||||
|
).encode("utf8"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash(id(self))
|
return hash(id(self))
|
||||||
|
@ -343,7 +375,7 @@ class HTTPFactory(http.HTTPFactory):
|
||||||
protocol = http.HTTPFactory.buildProtocol(self, addr)
|
protocol = http.HTTPFactory.buildProtocol(self, addr)
|
||||||
protocol.requestFactory = WebRequest
|
protocol.requestFactory = WebRequest
|
||||||
return protocol
|
return protocol
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Cannot build protocol: %s" % traceback.format_exc())
|
logger.error("Cannot build protocol: %s" % traceback.format_exc())
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
|
@ -2,13 +2,14 @@
|
||||||
import sys # isort:skip
|
import sys # isort:skip
|
||||||
import warnings # isort:skip
|
import warnings # isort:skip
|
||||||
from twisted.internet import asyncioreactor # isort:skip
|
from twisted.internet import asyncioreactor # isort:skip
|
||||||
|
|
||||||
current_reactor = sys.modules.get("twisted.internet.reactor", None)
|
current_reactor = sys.modules.get("twisted.internet.reactor", None)
|
||||||
if current_reactor is not None:
|
if current_reactor is not None:
|
||||||
if not isinstance(current_reactor, asyncioreactor.AsyncioSelectorReactor):
|
if not isinstance(current_reactor, asyncioreactor.AsyncioSelectorReactor):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Something has already installed a non-asyncio Twisted reactor. Attempting to uninstall it; " +
|
"Something has already installed a non-asyncio Twisted reactor. Attempting to uninstall it; "
|
||||||
"you can fix this warning by importing daphne.server early in your codebase or " +
|
+ "you can fix this warning by importing daphne.server early in your codebase or "
|
||||||
"finding the package that imports Twisted and importing it later on.",
|
+ "finding the package that imports Twisted and importing it later on.",
|
||||||
UserWarning,
|
UserWarning,
|
||||||
)
|
)
|
||||||
del sys.modules["twisted.internet.reactor"]
|
del sys.modules["twisted.internet.reactor"]
|
||||||
|
@ -34,7 +35,6 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Server(object):
|
class Server(object):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
application,
|
application,
|
||||||
|
@ -91,11 +91,13 @@ class Server(object):
|
||||||
self.ws_factory.setProtocolOptions(
|
self.ws_factory.setProtocolOptions(
|
||||||
autoPingTimeout=self.ping_timeout,
|
autoPingTimeout=self.ping_timeout,
|
||||||
allowNullOrigin=True,
|
allowNullOrigin=True,
|
||||||
openHandshakeTimeout=self.websocket_handshake_timeout
|
openHandshakeTimeout=self.websocket_handshake_timeout,
|
||||||
)
|
)
|
||||||
if self.verbosity <= 1:
|
if self.verbosity <= 1:
|
||||||
# Redirect the Twisted log to nowhere
|
# Redirect the Twisted log to nowhere
|
||||||
globalLogBeginner.beginLoggingTo([lambda _: None], redirectStandardIO=False, discardBuffer=True)
|
globalLogBeginner.beginLoggingTo(
|
||||||
|
[lambda _: None], redirectStandardIO=False, discardBuffer=True
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
globalLogBeginner.beginLoggingTo([STDLibLogObserver(__name__)])
|
globalLogBeginner.beginLoggingTo([STDLibLogObserver(__name__)])
|
||||||
|
|
||||||
|
@ -103,7 +105,9 @@ class Server(object):
|
||||||
if http.H2_ENABLED:
|
if http.H2_ENABLED:
|
||||||
logger.info("HTTP/2 support enabled")
|
logger.info("HTTP/2 support enabled")
|
||||||
else:
|
else:
|
||||||
logger.info("HTTP/2 support not enabled (install the http2 and tls Twisted extras)")
|
logger.info(
|
||||||
|
"HTTP/2 support not enabled (install the http2 and tls Twisted extras)"
|
||||||
|
)
|
||||||
|
|
||||||
# Kick off the timeout loop
|
# Kick off the timeout loop
|
||||||
reactor.callLater(1, self.application_checker)
|
reactor.callLater(1, self.application_checker)
|
||||||
|
@ -141,7 +145,11 @@ class Server(object):
|
||||||
host = port.getHost()
|
host = port.getHost()
|
||||||
if hasattr(host, "host") and hasattr(host, "port"):
|
if hasattr(host, "host") and hasattr(host, "port"):
|
||||||
self.listening_addresses.append((host.host, host.port))
|
self.listening_addresses.append((host.host, host.port))
|
||||||
logger.info("Listening on TCP address %s:%s", port.getHost().host, port.getHost().port)
|
logger.info(
|
||||||
|
"Listening on TCP address %s:%s",
|
||||||
|
port.getHost().host,
|
||||||
|
port.getHost().port,
|
||||||
|
)
|
||||||
|
|
||||||
def listen_error(self, failure):
|
def listen_error(self, failure):
|
||||||
logger.critical("Listen failure: %s", failure.getErrorMessage())
|
logger.critical("Listen failure: %s", failure.getErrorMessage())
|
||||||
|
@ -187,10 +195,13 @@ class Server(object):
|
||||||
# Run it, and stash the future for later checking
|
# Run it, and stash the future for later checking
|
||||||
if protocol not in self.connections:
|
if protocol not in self.connections:
|
||||||
return None
|
return None
|
||||||
self.connections[protocol]["application_instance"] = asyncio.ensure_future(application_instance(
|
self.connections[protocol]["application_instance"] = asyncio.ensure_future(
|
||||||
receive=input_queue.get,
|
application_instance(
|
||||||
send=lambda message: self.handle_reply(protocol, message),
|
receive=input_queue.get,
|
||||||
), loop=asyncio.get_event_loop())
|
send=lambda message: self.handle_reply(protocol, message),
|
||||||
|
),
|
||||||
|
loop=asyncio.get_event_loop(),
|
||||||
|
)
|
||||||
return input_queue
|
return input_queue
|
||||||
|
|
||||||
async def handle_reply(self, protocol, message):
|
async def handle_reply(self, protocol, message):
|
||||||
|
@ -215,7 +226,10 @@ class Server(object):
|
||||||
application_instance = details.get("application_instance", None)
|
application_instance = details.get("application_instance", None)
|
||||||
# First, see if the protocol disconnected and the app has taken
|
# First, see if the protocol disconnected and the app has taken
|
||||||
# too long to close up
|
# too long to close up
|
||||||
if disconnected and time.time() - disconnected > self.application_close_timeout:
|
if (
|
||||||
|
disconnected
|
||||||
|
and time.time() - disconnected > self.application_close_timeout
|
||||||
|
):
|
||||||
if application_instance and not application_instance.done():
|
if application_instance and not application_instance.done():
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Application instance %r for connection %s took too long to shut down and was killed.",
|
"Application instance %r for connection %s took too long to shut down and was killed.",
|
||||||
|
@ -238,14 +252,11 @@ class Server(object):
|
||||||
else:
|
else:
|
||||||
exception_output = "{}\n{}{}".format(
|
exception_output = "{}\n{}{}".format(
|
||||||
exception,
|
exception,
|
||||||
"".join(traceback.format_tb(
|
"".join(traceback.format_tb(exception.__traceback__)),
|
||||||
exception.__traceback__,
|
|
||||||
)),
|
|
||||||
" {}".format(exception),
|
" {}".format(exception),
|
||||||
)
|
)
|
||||||
logger.error(
|
logger.error(
|
||||||
"Exception inside application: %s",
|
"Exception inside application: %s", exception_output
|
||||||
exception_output,
|
|
||||||
)
|
)
|
||||||
if not disconnected:
|
if not disconnected:
|
||||||
protocol.handle_exception(exception)
|
protocol.handle_exception(exception)
|
||||||
|
|
|
@ -100,9 +100,7 @@ class DaphneTestingInstance:
|
||||||
Adds messages for the application to send back.
|
Adds messages for the application to send back.
|
||||||
The next time it receives an incoming message, it will reply with these.
|
The next time it receives an incoming message, it will reply with these.
|
||||||
"""
|
"""
|
||||||
TestApplication.save_setup(
|
TestApplication.save_setup(response_messages=messages)
|
||||||
response_messages=messages,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DaphneProcess(multiprocessing.Process):
|
class DaphneProcess(multiprocessing.Process):
|
||||||
|
@ -193,12 +191,7 @@ class TestApplication:
|
||||||
Stores setup information.
|
Stores setup information.
|
||||||
"""
|
"""
|
||||||
with open(cls.setup_storage, "wb") as fh:
|
with open(cls.setup_storage, "wb") as fh:
|
||||||
pickle.dump(
|
pickle.dump({"response_messages": response_messages}, fh)
|
||||||
{
|
|
||||||
"response_messages": response_messages,
|
|
||||||
},
|
|
||||||
fh,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_setup(cls):
|
def load_setup(cls):
|
||||||
|
@ -218,13 +211,7 @@ class TestApplication:
|
||||||
We could use pickle here, but that seems wrong, still, somehow.
|
We could use pickle here, but that seems wrong, still, somehow.
|
||||||
"""
|
"""
|
||||||
with open(cls.result_storage, "wb") as fh:
|
with open(cls.result_storage, "wb") as fh:
|
||||||
pickle.dump(
|
pickle.dump({"scope": scope, "messages": messages}, fh)
|
||||||
{
|
|
||||||
"scope": scope,
|
|
||||||
"messages": messages,
|
|
||||||
},
|
|
||||||
fh,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def save_exception(cls, exception):
|
def save_exception(cls, exception):
|
||||||
|
@ -233,12 +220,7 @@ class TestApplication:
|
||||||
We could use pickle here, but that seems wrong, still, somehow.
|
We could use pickle here, but that seems wrong, still, somehow.
|
||||||
"""
|
"""
|
||||||
with open(cls.result_storage, "wb") as fh:
|
with open(cls.result_storage, "wb") as fh:
|
||||||
pickle.dump(
|
pickle.dump({"exception": exception}, fh)
|
||||||
{
|
|
||||||
"exception": exception,
|
|
||||||
},
|
|
||||||
fh,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_result(cls):
|
def load_result(cls):
|
||||||
|
|
|
@ -22,12 +22,14 @@ def header_value(headers, header_name):
|
||||||
return value.decode("utf-8")
|
return value.decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def parse_x_forwarded_for(headers,
|
def parse_x_forwarded_for(
|
||||||
address_header_name="X-Forwarded-For",
|
headers,
|
||||||
port_header_name="X-Forwarded-Port",
|
address_header_name="X-Forwarded-For",
|
||||||
proto_header_name="X-Forwarded-Proto",
|
port_header_name="X-Forwarded-Port",
|
||||||
original_addr=None,
|
proto_header_name="X-Forwarded-Proto",
|
||||||
original_scheme=None):
|
original_addr=None,
|
||||||
|
original_scheme=None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Parses an X-Forwarded-For header and returns a host/port pair as a list.
|
Parses an X-Forwarded-For header and returns a host/port pair as a list.
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,11 @@ import time
|
||||||
import traceback
|
import traceback
|
||||||
from urllib.parse import unquote
|
from urllib.parse import unquote
|
||||||
|
|
||||||
from autobahn.twisted.websocket import ConnectionDeny, WebSocketServerFactory, WebSocketServerProtocol
|
from autobahn.twisted.websocket import (
|
||||||
|
ConnectionDeny,
|
||||||
|
WebSocketServerFactory,
|
||||||
|
WebSocketServerProtocol,
|
||||||
|
)
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from .utils import parse_x_forwarded_for
|
from .utils import parse_x_forwarded_for
|
||||||
|
@ -54,32 +58,34 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
self.server.proxy_forwarded_address_header,
|
self.server.proxy_forwarded_address_header,
|
||||||
self.server.proxy_forwarded_port_header,
|
self.server.proxy_forwarded_port_header,
|
||||||
self.server.proxy_forwarded_proto_header,
|
self.server.proxy_forwarded_proto_header,
|
||||||
self.client_addr
|
self.client_addr,
|
||||||
)
|
)
|
||||||
# Decode websocket subprotocol options
|
# Decode websocket subprotocol options
|
||||||
subprotocols = []
|
subprotocols = []
|
||||||
for header, value in self.clean_headers:
|
for header, value in self.clean_headers:
|
||||||
if header == b"sec-websocket-protocol":
|
if header == b"sec-websocket-protocol":
|
||||||
subprotocols = [
|
subprotocols = [
|
||||||
x.strip()
|
x.strip() for x in unquote(value.decode("ascii")).split(",")
|
||||||
for x in
|
|
||||||
unquote(value.decode("ascii")).split(",")
|
|
||||||
]
|
]
|
||||||
# Make new application instance with scope
|
# Make new application instance with scope
|
||||||
self.path = request.path.encode("ascii")
|
self.path = request.path.encode("ascii")
|
||||||
self.application_deferred = defer.maybeDeferred(self.server.create_application, self, {
|
self.application_deferred = defer.maybeDeferred(
|
||||||
"type": "websocket",
|
self.server.create_application,
|
||||||
"path": unquote(self.path.decode("ascii")),
|
self,
|
||||||
"headers": self.clean_headers,
|
{
|
||||||
"query_string": self._raw_query_string, # Passed by HTTP protocol
|
"type": "websocket",
|
||||||
"client": self.client_addr,
|
"path": unquote(self.path.decode("ascii")),
|
||||||
"server": self.server_addr,
|
"headers": self.clean_headers,
|
||||||
"subprotocols": subprotocols,
|
"query_string": self._raw_query_string, # Passed by HTTP protocol
|
||||||
})
|
"client": self.client_addr,
|
||||||
|
"server": self.server_addr,
|
||||||
|
"subprotocols": subprotocols,
|
||||||
|
},
|
||||||
|
)
|
||||||
if self.application_deferred is not None:
|
if self.application_deferred is not None:
|
||||||
self.application_deferred.addCallback(self.applicationCreateWorked)
|
self.application_deferred.addCallback(self.applicationCreateWorked)
|
||||||
self.application_deferred.addErrback(self.applicationCreateFailed)
|
self.application_deferred.addErrback(self.applicationCreateFailed)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
# Exceptions here are not displayed right, just 500.
|
# Exceptions here are not displayed right, just 500.
|
||||||
# Turn them into an ERROR log.
|
# Turn them into an ERROR log.
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
@ -98,10 +104,16 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
self.application_queue = application_queue
|
self.application_queue = application_queue
|
||||||
# Send over the connect message
|
# Send over the connect message
|
||||||
self.application_queue.put_nowait({"type": "websocket.connect"})
|
self.application_queue.put_nowait({"type": "websocket.connect"})
|
||||||
self.server.log_action("websocket", "connecting", {
|
self.server.log_action(
|
||||||
"path": self.request.path,
|
"websocket",
|
||||||
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
"connecting",
|
||||||
})
|
{
|
||||||
|
"path": self.request.path,
|
||||||
|
"client": "%s:%s" % tuple(self.client_addr)
|
||||||
|
if self.client_addr
|
||||||
|
else None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def applicationCreateFailed(self, failure):
|
def applicationCreateFailed(self, failure):
|
||||||
"""
|
"""
|
||||||
|
@ -115,10 +127,16 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
def onOpen(self):
|
def onOpen(self):
|
||||||
# Send news that this channel is open
|
# Send news that this channel is open
|
||||||
logger.debug("WebSocket %s open and established", self.client_addr)
|
logger.debug("WebSocket %s open and established", self.client_addr)
|
||||||
self.server.log_action("websocket", "connected", {
|
self.server.log_action(
|
||||||
"path": self.request.path,
|
"websocket",
|
||||||
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
"connected",
|
||||||
})
|
{
|
||||||
|
"path": self.request.path,
|
||||||
|
"client": "%s:%s" % tuple(self.client_addr)
|
||||||
|
if self.client_addr
|
||||||
|
else None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def onMessage(self, payload, isBinary):
|
def onMessage(self, payload, isBinary):
|
||||||
# If we're muted, do nothing.
|
# If we're muted, do nothing.
|
||||||
|
@ -128,15 +146,13 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
logger.debug("WebSocket incoming frame on %s", self.client_addr)
|
logger.debug("WebSocket incoming frame on %s", self.client_addr)
|
||||||
self.last_ping = time.time()
|
self.last_ping = time.time()
|
||||||
if isBinary:
|
if isBinary:
|
||||||
self.application_queue.put_nowait({
|
self.application_queue.put_nowait(
|
||||||
"type": "websocket.receive",
|
{"type": "websocket.receive", "bytes": payload}
|
||||||
"bytes": payload,
|
)
|
||||||
})
|
|
||||||
else:
|
else:
|
||||||
self.application_queue.put_nowait({
|
self.application_queue.put_nowait(
|
||||||
"type": "websocket.receive",
|
{"type": "websocket.receive", "text": payload.decode("utf8")}
|
||||||
"text": payload.decode("utf8"),
|
)
|
||||||
})
|
|
||||||
|
|
||||||
def onClose(self, wasClean, code, reason):
|
def onClose(self, wasClean, code, reason):
|
||||||
"""
|
"""
|
||||||
|
@ -145,14 +161,19 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
self.server.protocol_disconnected(self)
|
self.server.protocol_disconnected(self)
|
||||||
logger.debug("WebSocket closed for %s", self.client_addr)
|
logger.debug("WebSocket closed for %s", self.client_addr)
|
||||||
if not self.muted and hasattr(self, "application_queue"):
|
if not self.muted and hasattr(self, "application_queue"):
|
||||||
self.application_queue.put_nowait({
|
self.application_queue.put_nowait(
|
||||||
"type": "websocket.disconnect",
|
{"type": "websocket.disconnect", "code": code}
|
||||||
"code": code,
|
)
|
||||||
})
|
self.server.log_action(
|
||||||
self.server.log_action("websocket", "disconnected", {
|
"websocket",
|
||||||
"path": self.request.path,
|
"disconnected",
|
||||||
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
{
|
||||||
})
|
"path": self.request.path,
|
||||||
|
"client": "%s:%s" % tuple(self.client_addr)
|
||||||
|
if self.client_addr
|
||||||
|
else None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
### Internal event handling
|
### Internal event handling
|
||||||
|
|
||||||
|
@ -171,9 +192,8 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
raise ValueError("Socket has not been accepted, so cannot send over it")
|
raise ValueError("Socket has not been accepted, so cannot send over it")
|
||||||
if message.get("bytes", None) and message.get("text", None):
|
if message.get("bytes", None) and message.get("text", None):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Got invalid WebSocket reply message on %s - contains both bytes and text keys" % (
|
"Got invalid WebSocket reply message on %s - contains both bytes and text keys"
|
||||||
message,
|
% (message,)
|
||||||
)
|
|
||||||
)
|
)
|
||||||
if message.get("bytes", None):
|
if message.get("bytes", None):
|
||||||
self.serverSend(message["bytes"], True)
|
self.serverSend(message["bytes"], True)
|
||||||
|
@ -187,7 +207,9 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
if hasattr(self, "handshake_deferred"):
|
if hasattr(self, "handshake_deferred"):
|
||||||
# If the handshake is still ongoing, we need to emit a HTTP error
|
# If the handshake is still ongoing, we need to emit a HTTP error
|
||||||
# code rather than a WebSocket one.
|
# code rather than a WebSocket one.
|
||||||
self.handshake_deferred.errback(ConnectionDeny(code=500, reason="Internal server error"))
|
self.handshake_deferred.errback(
|
||||||
|
ConnectionDeny(code=500, reason="Internal server error")
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.sendCloseFrame(code=1011)
|
self.sendCloseFrame(code=1011)
|
||||||
|
|
||||||
|
@ -203,14 +225,22 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
"""
|
"""
|
||||||
Called when we get a message saying to reject the connection.
|
Called when we get a message saying to reject the connection.
|
||||||
"""
|
"""
|
||||||
self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied"))
|
self.handshake_deferred.errback(
|
||||||
|
ConnectionDeny(code=403, reason="Access denied")
|
||||||
|
)
|
||||||
del self.handshake_deferred
|
del self.handshake_deferred
|
||||||
self.server.protocol_disconnected(self)
|
self.server.protocol_disconnected(self)
|
||||||
logger.debug("WebSocket %s rejected by application", self.client_addr)
|
logger.debug("WebSocket %s rejected by application", self.client_addr)
|
||||||
self.server.log_action("websocket", "rejected", {
|
self.server.log_action(
|
||||||
"path": self.request.path,
|
"websocket",
|
||||||
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
"rejected",
|
||||||
})
|
{
|
||||||
|
"path": self.request.path,
|
||||||
|
"client": "%s:%s" % tuple(self.client_addr)
|
||||||
|
if self.client_addr
|
||||||
|
else None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def serverSend(self, content, binary=False):
|
def serverSend(self, content, binary=False):
|
||||||
"""
|
"""
|
||||||
|
@ -244,7 +274,10 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
||||||
Called periodically to see if we should timeout something
|
Called periodically to see if we should timeout something
|
||||||
"""
|
"""
|
||||||
# Web timeout checking
|
# Web timeout checking
|
||||||
if self.duration() > self.server.websocket_timeout and self.server.websocket_timeout >= 0:
|
if (
|
||||||
|
self.duration() > self.server.websocket_timeout
|
||||||
|
and self.server.websocket_timeout >= 0
|
||||||
|
):
|
||||||
self.serverClose()
|
self.serverClose()
|
||||||
# Ping check
|
# Ping check
|
||||||
# If we're still connecting, deny the connection
|
# If we're still connecting, deny the connection
|
||||||
|
@ -287,6 +320,6 @@ class WebSocketFactory(WebSocketServerFactory):
|
||||||
protocol = super(WebSocketFactory, self).buildProtocol(addr)
|
protocol = super(WebSocketFactory, self).buildProtocol(addr)
|
||||||
protocol.factory = self
|
protocol.factory = self
|
||||||
return protocol
|
return protocol
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error("Cannot build protocol: %s" % traceback.format_exc())
|
logger.error("Cannot build protocol: %s" % traceback.format_exc())
|
||||||
raise
|
raise
|
||||||
|
|
|
@ -5,9 +5,9 @@ universal=1
|
||||||
addopts = tests/
|
addopts = tests/
|
||||||
|
|
||||||
[isort]
|
[isort]
|
||||||
line_length = 120
|
include_trailing_comma = True
|
||||||
multi_line_output = 3
|
multi_line_output = 3
|
||||||
known_first_party = channels,daphne,asgiref
|
known_first_party = channels,daphne,asgiref,channels_redis
|
||||||
|
|
||||||
[flake8]
|
[flake8]
|
||||||
exclude = venv/*,tox/*,docs/*,testproject/*,js_client/*,.eggs/*
|
exclude = venv/*,tox/*,docs/*,testproject/*,js_client/*,.eggs/*
|
||||||
|
|
21
setup.py
21
setup.py
|
@ -22,23 +22,12 @@ setup(
|
||||||
package_dir={"twisted": "daphne/twisted"},
|
package_dir={"twisted": "daphne/twisted"},
|
||||||
packages=find_packages() + ["twisted.plugins"],
|
packages=find_packages() + ["twisted.plugins"],
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
install_requires=[
|
install_requires=["twisted>=18.7", "autobahn>=0.18"],
|
||||||
"twisted>=18.7",
|
setup_requires=["pytest-runner"],
|
||||||
"autobahn>=0.18",
|
extras_require={"tests": ["hypothesis", "pytest", "pytest-asyncio~=0.8"]},
|
||||||
],
|
entry_points={
|
||||||
setup_requires=[
|
"console_scripts": ["daphne = daphne.cli:CommandLineInterface.entrypoint"]
|
||||||
"pytest-runner",
|
|
||||||
],
|
|
||||||
extras_require={
|
|
||||||
"tests": [
|
|
||||||
"hypothesis",
|
|
||||||
"pytest",
|
|
||||||
"pytest-asyncio~=0.8",
|
|
||||||
],
|
|
||||||
},
|
},
|
||||||
entry_points={"console_scripts": [
|
|
||||||
"daphne = daphne.cli:CommandLineInterface.entrypoint",
|
|
||||||
]},
|
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Development Status :: 4 - Beta",
|
"Development Status :: 4 - Beta",
|
||||||
"Environment :: Web Environment",
|
"Environment :: Web Environment",
|
||||||
|
|
|
@ -19,7 +19,9 @@ class DaphneTestCase(unittest.TestCase):
|
||||||
|
|
||||||
### Plain HTTP helpers
|
### Plain HTTP helpers
|
||||||
|
|
||||||
def run_daphne_http(self, method, path, params, body, responses, headers=None, timeout=1, xff=False):
|
def run_daphne_http(
|
||||||
|
self, method, path, params, body, responses, headers=None, timeout=1, xff=False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Runs Daphne with the given request callback (given the base URL)
|
Runs Daphne with the given request callback (given the base URL)
|
||||||
and response messages.
|
and response messages.
|
||||||
|
@ -38,7 +40,9 @@ class DaphneTestCase(unittest.TestCase):
|
||||||
# Manually send over headers (encoding any non-safe values as best we can)
|
# Manually send over headers (encoding any non-safe values as best we can)
|
||||||
if headers:
|
if headers:
|
||||||
for header_name, header_value in headers:
|
for header_name, header_value in headers:
|
||||||
conn.putheader(header_name.encode("utf8"), header_value.encode("utf8"))
|
conn.putheader(
|
||||||
|
header_name.encode("utf8"), header_value.encode("utf8")
|
||||||
|
)
|
||||||
# Send body if provided.
|
# Send body if provided.
|
||||||
if body:
|
if body:
|
||||||
conn.putheader("Content-Length", str(len(body)))
|
conn.putheader("Content-Length", str(len(body)))
|
||||||
|
@ -50,9 +54,11 @@ class DaphneTestCase(unittest.TestCase):
|
||||||
except socket.timeout:
|
except socket.timeout:
|
||||||
# See if they left an exception for us to load
|
# See if they left an exception for us to load
|
||||||
test_app.get_received()
|
test_app.get_received()
|
||||||
raise RuntimeError("Daphne timed out handling request, no exception found.")
|
raise RuntimeError(
|
||||||
|
"Daphne timed out handling request, no exception found."
|
||||||
|
)
|
||||||
# Return scope, messages, response
|
# Return scope, messages, response
|
||||||
return test_app.get_received() + (response, )
|
return test_app.get_received() + (response,)
|
||||||
|
|
||||||
def run_daphne_raw(self, data, timeout=1):
|
def run_daphne_raw(self, data, timeout=1):
|
||||||
"""
|
"""
|
||||||
|
@ -68,9 +74,13 @@ class DaphneTestCase(unittest.TestCase):
|
||||||
try:
|
try:
|
||||||
return s.recv(1000000)
|
return s.recv(1000000)
|
||||||
except socket.timeout:
|
except socket.timeout:
|
||||||
raise RuntimeError("Daphne timed out handling raw request, no exception found.")
|
raise RuntimeError(
|
||||||
|
"Daphne timed out handling raw request, no exception found."
|
||||||
|
)
|
||||||
|
|
||||||
def run_daphne_request(self, method, path, params=None, body=None, headers=None, xff=False):
|
def run_daphne_request(
|
||||||
|
self, method, path, params=None, body=None, headers=None, xff=False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Convenience method for just testing request handling.
|
Convenience method for just testing request handling.
|
||||||
Returns (scope, messages)
|
Returns (scope, messages)
|
||||||
|
@ -95,17 +105,21 @@ class DaphneTestCase(unittest.TestCase):
|
||||||
Returns (scope, messages)
|
Returns (scope, messages)
|
||||||
"""
|
"""
|
||||||
_, _, response = self.run_daphne_http(
|
_, _, response = self.run_daphne_http(
|
||||||
method="GET",
|
method="GET", path="/", params={}, body=b"", responses=response_messages
|
||||||
path="/",
|
|
||||||
params={},
|
|
||||||
body=b"",
|
|
||||||
responses=response_messages,
|
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
### WebSocket helpers
|
### WebSocket helpers
|
||||||
|
|
||||||
def websocket_handshake(self, test_app, path="/", params=None, headers=None, subprotocols=None, timeout=1):
|
def websocket_handshake(
|
||||||
|
self,
|
||||||
|
test_app,
|
||||||
|
path="/",
|
||||||
|
params=None,
|
||||||
|
headers=None,
|
||||||
|
subprotocols=None,
|
||||||
|
timeout=1,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Runs a WebSocket handshake negotiation and returns the raw socket
|
Runs a WebSocket handshake negotiation and returns the raw socket
|
||||||
object & the selected subprotocol.
|
object & the selected subprotocol.
|
||||||
|
@ -124,14 +138,16 @@ class DaphneTestCase(unittest.TestCase):
|
||||||
# Do WebSocket handshake headers + any other headers
|
# Do WebSocket handshake headers + any other headers
|
||||||
if headers is None:
|
if headers is None:
|
||||||
headers = []
|
headers = []
|
||||||
headers.extend([
|
headers.extend(
|
||||||
("Host", "example.com"),
|
[
|
||||||
("Upgrade", "websocket"),
|
("Host", "example.com"),
|
||||||
("Connection", "Upgrade"),
|
("Upgrade", "websocket"),
|
||||||
("Sec-WebSocket-Key", "x3JJHMbDL1EzLkh9GBhXDw=="),
|
("Connection", "Upgrade"),
|
||||||
("Sec-WebSocket-Version", "13"),
|
("Sec-WebSocket-Key", "x3JJHMbDL1EzLkh9GBhXDw=="),
|
||||||
("Origin", "http://example.com")
|
("Sec-WebSocket-Version", "13"),
|
||||||
])
|
("Origin", "http://example.com"),
|
||||||
|
]
|
||||||
|
)
|
||||||
if subprotocols:
|
if subprotocols:
|
||||||
headers.append(("Sec-WebSocket-Protocol", ", ".join(subprotocols)))
|
headers.append(("Sec-WebSocket-Protocol", ", ".join(subprotocols)))
|
||||||
if headers:
|
if headers:
|
||||||
|
@ -149,10 +165,7 @@ class DaphneTestCase(unittest.TestCase):
|
||||||
if response.status != 101:
|
if response.status != 101:
|
||||||
raise RuntimeError("WebSocket upgrade did not result in status code 101")
|
raise RuntimeError("WebSocket upgrade did not result in status code 101")
|
||||||
# Prepare headers for subprotocol searching
|
# Prepare headers for subprotocol searching
|
||||||
response_headers = dict(
|
response_headers = dict((n.lower(), v) for n, v in response.getheaders())
|
||||||
(n.lower(), v)
|
|
||||||
for n, v in response.getheaders()
|
|
||||||
)
|
|
||||||
response.read()
|
response.read()
|
||||||
assert not response.closed
|
assert not response.closed
|
||||||
# Return the raw socket and any subprotocol
|
# Return the raw socket and any subprotocol
|
||||||
|
@ -234,10 +247,7 @@ class DaphneTestCase(unittest.TestCase):
|
||||||
# Make sure all required keys are present
|
# Make sure all required keys are present
|
||||||
self.assertTrue(required_keys <= present_keys)
|
self.assertTrue(required_keys <= present_keys)
|
||||||
# Assert that no other keys are present
|
# Assert that no other keys are present
|
||||||
self.assertEqual(
|
self.assertEqual(set(), present_keys - required_keys - optional_keys)
|
||||||
set(),
|
|
||||||
present_keys - required_keys - optional_keys,
|
|
||||||
)
|
|
||||||
|
|
||||||
def assert_valid_path(self, path, request_path):
|
def assert_valid_path(self, path, request_path):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -6,7 +6,9 @@ from hypothesis import strategies
|
||||||
HTTP_METHODS = ["OPTIONS", "GET", "HEAD", "POST", "PUT", "DELETE", "TRACE", "CONNECT"]
|
HTTP_METHODS = ["OPTIONS", "GET", "HEAD", "POST", "PUT", "DELETE", "TRACE", "CONNECT"]
|
||||||
|
|
||||||
# Unicode characters of the "Letter" category
|
# Unicode characters of the "Letter" category
|
||||||
letters = strategies.characters(whitelist_categories=("Lu", "Ll", "Lt", "Lm", "Lo", "Nl"))
|
letters = strategies.characters(
|
||||||
|
whitelist_categories=("Lu", "Ll", "Lt", "Lm", "Lo", "Nl")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def http_method():
|
def http_method():
|
||||||
|
@ -22,11 +24,9 @@ def http_path():
|
||||||
"""
|
"""
|
||||||
Returns a URL path (not encoded).
|
Returns a URL path (not encoded).
|
||||||
"""
|
"""
|
||||||
return strategies.lists(
|
return strategies.lists(_http_path_portion(), min_size=0, max_size=10).map(
|
||||||
_http_path_portion(),
|
lambda s: "/" + "/".join(s)
|
||||||
min_size=0,
|
)
|
||||||
max_size=10,
|
|
||||||
).map(lambda s: "/" + "/".join(s))
|
|
||||||
|
|
||||||
|
|
||||||
def http_body():
|
def http_body():
|
||||||
|
@ -53,10 +53,7 @@ def valid_bidi(value):
|
||||||
|
|
||||||
def _domain_label():
|
def _domain_label():
|
||||||
return strategies.text(
|
return strategies.text(
|
||||||
alphabet=letters,
|
alphabet=letters, min_size=1, average_size=6, max_size=63
|
||||||
min_size=1,
|
|
||||||
average_size=6,
|
|
||||||
max_size=63,
|
|
||||||
).filter(valid_bidi)
|
).filter(valid_bidi)
|
||||||
|
|
||||||
|
|
||||||
|
@ -64,19 +61,14 @@ def international_domain_name():
|
||||||
"""
|
"""
|
||||||
Returns a byte string of a domain name, IDNA-encoded.
|
Returns a byte string of a domain name, IDNA-encoded.
|
||||||
"""
|
"""
|
||||||
return strategies.lists(
|
return strategies.lists(_domain_label(), min_size=2, average_size=2).map(
|
||||||
_domain_label(),
|
lambda s: (".".join(s)).encode("idna")
|
||||||
min_size=2,
|
)
|
||||||
average_size=2,
|
|
||||||
).map(lambda s: (".".join(s)).encode("idna"))
|
|
||||||
|
|
||||||
|
|
||||||
def _query_param():
|
def _query_param():
|
||||||
return strategies.text(
|
return strategies.text(
|
||||||
alphabet=letters,
|
alphabet=letters, min_size=1, average_size=10, max_size=255
|
||||||
min_size=1,
|
|
||||||
average_size=10,
|
|
||||||
max_size=255,
|
|
||||||
).map(lambda s: s.encode("utf8"))
|
).map(lambda s: s.encode("utf8"))
|
||||||
|
|
||||||
|
|
||||||
|
@ -87,9 +79,7 @@ def query_params():
|
||||||
ensures that the total urlencoded query string is not longer than 1500 characters.
|
ensures that the total urlencoded query string is not longer than 1500 characters.
|
||||||
"""
|
"""
|
||||||
return strategies.lists(
|
return strategies.lists(
|
||||||
strategies.tuples(_query_param(), _query_param()),
|
strategies.tuples(_query_param(), _query_param()), min_size=0, average_size=5
|
||||||
min_size=0,
|
|
||||||
average_size=5,
|
|
||||||
).filter(lambda x: len(parse.urlencode(x)) < 1500)
|
).filter(lambda x: len(parse.urlencode(x)) < 1500)
|
||||||
|
|
||||||
|
|
||||||
|
@ -101,9 +91,7 @@ def header_name():
|
||||||
and 20 characters long
|
and 20 characters long
|
||||||
"""
|
"""
|
||||||
return strategies.text(
|
return strategies.text(
|
||||||
alphabet=string.ascii_letters + string.digits + "-",
|
alphabet=string.ascii_letters + string.digits + "-", min_size=1, max_size=30
|
||||||
min_size=1,
|
|
||||||
max_size=30,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -115,7 +103,10 @@ def header_value():
|
||||||
https://en.wikipedia.org/wiki/List_of_HTTP_header_fields
|
https://en.wikipedia.org/wiki/List_of_HTTP_header_fields
|
||||||
"""
|
"""
|
||||||
return strategies.text(
|
return strategies.text(
|
||||||
alphabet=string.ascii_letters + string.digits + string.punctuation.replace(",", "") + " /t",
|
alphabet=string.ascii_letters
|
||||||
|
+ string.digits
|
||||||
|
+ string.punctuation.replace(",", "")
|
||||||
|
+ " /t",
|
||||||
min_size=1,
|
min_size=1,
|
||||||
average_size=40,
|
average_size=40,
|
||||||
max_size=8190,
|
max_size=8190,
|
||||||
|
|
|
@ -18,45 +18,32 @@ class TestEndpointDescriptions(TestCase):
|
||||||
def testTcpPortBindings(self):
|
def testTcpPortBindings(self):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
build(port=1234, host="example.com"),
|
build(port=1234, host="example.com"),
|
||||||
["tcp:port=1234:interface=example.com"]
|
["tcp:port=1234:interface=example.com"],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
build(port=8000, host="127.0.0.1"),
|
build(port=8000, host="127.0.0.1"), ["tcp:port=8000:interface=127.0.0.1"]
|
||||||
["tcp:port=8000:interface=127.0.0.1"]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
build(port=8000, host="[200a::1]"),
|
build(port=8000, host="[200a::1]"), [r"tcp:port=8000:interface=200a\:\:1"]
|
||||||
[r'tcp:port=8000:interface=200a\:\:1']
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
build(port=8000, host="200a::1"),
|
build(port=8000, host="200a::1"), [r"tcp:port=8000:interface=200a\:\:1"]
|
||||||
[r'tcp:port=8000:interface=200a\:\:1']
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# incomplete port/host kwargs raise errors
|
# incomplete port/host kwargs raise errors
|
||||||
self.assertRaises(
|
self.assertRaises(ValueError, build, port=123)
|
||||||
ValueError,
|
self.assertRaises(ValueError, build, host="example.com")
|
||||||
build, port=123
|
|
||||||
)
|
|
||||||
self.assertRaises(
|
|
||||||
ValueError,
|
|
||||||
build, host="example.com"
|
|
||||||
)
|
|
||||||
|
|
||||||
def testUnixSocketBinding(self):
|
def testUnixSocketBinding(self):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
build(unix_socket="/tmp/daphne.sock"),
|
build(unix_socket="/tmp/daphne.sock"), ["unix:/tmp/daphne.sock"]
|
||||||
["unix:/tmp/daphne.sock"]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def testFileDescriptorBinding(self):
|
def testFileDescriptorBinding(self):
|
||||||
self.assertEqual(
|
self.assertEqual(build(file_descriptor=5), ["fd:fileno=5"])
|
||||||
build(file_descriptor=5),
|
|
||||||
["fd:fileno=5"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def testMultipleEnpoints(self):
|
def testMultipleEnpoints(self):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
@ -65,14 +52,16 @@ class TestEndpointDescriptions(TestCase):
|
||||||
file_descriptor=123,
|
file_descriptor=123,
|
||||||
unix_socket="/tmp/daphne.sock",
|
unix_socket="/tmp/daphne.sock",
|
||||||
port=8080,
|
port=8080,
|
||||||
host="10.0.0.1"
|
host="10.0.0.1",
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
sorted([
|
sorted(
|
||||||
"tcp:port=8080:interface=10.0.0.1",
|
[
|
||||||
"unix:/tmp/daphne.sock",
|
"tcp:port=8080:interface=10.0.0.1",
|
||||||
"fd:fileno=123"
|
"unix:/tmp/daphne.sock",
|
||||||
])
|
"fd:fileno=123",
|
||||||
|
]
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -112,7 +101,9 @@ class TestCLIInterface(TestCase):
|
||||||
Passes in a fake application automatically.
|
Passes in a fake application automatically.
|
||||||
"""
|
"""
|
||||||
cli = self.TestedCLI()
|
cli = self.TestedCLI()
|
||||||
cli.run(args + ["daphne:__version__"]) # We just pass something importable as app
|
cli.run(
|
||||||
|
args + ["daphne:__version__"]
|
||||||
|
) # We just pass something importable as app
|
||||||
# Check the server got all arguments as intended
|
# Check the server got all arguments as intended
|
||||||
for key, value in server_kwargs.items():
|
for key, value in server_kwargs.items():
|
||||||
# Get the value and sort it if it's a list (for endpoint checking)
|
# Get the value and sort it if it's a list (for endpoint checking)
|
||||||
|
@ -123,52 +114,30 @@ class TestCLIInterface(TestCase):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
value,
|
value,
|
||||||
actual_value,
|
actual_value,
|
||||||
"Wrong value for server kwarg %s: %r != %r" % (
|
"Wrong value for server kwarg %s: %r != %r"
|
||||||
key,
|
% (key, value, actual_value),
|
||||||
value,
|
|
||||||
actual_value,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def testCLIBasics(self):
|
def testCLIBasics(self):
|
||||||
"""
|
"""
|
||||||
Tests basic endpoint generation.
|
Tests basic endpoint generation.
|
||||||
"""
|
"""
|
||||||
|
self.assertCLI([], {"endpoints": ["tcp:port=8000:interface=127.0.0.1"]})
|
||||||
self.assertCLI(
|
self.assertCLI(
|
||||||
[],
|
["-p", "123"], {"endpoints": ["tcp:port=123:interface=127.0.0.1"]}
|
||||||
{
|
|
||||||
"endpoints": ["tcp:port=8000:interface=127.0.0.1"],
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
self.assertCLI(
|
self.assertCLI(
|
||||||
["-p", "123"],
|
["-b", "10.0.0.1"], {"endpoints": ["tcp:port=8000:interface=10.0.0.1"]}
|
||||||
{
|
|
||||||
"endpoints": ["tcp:port=123:interface=127.0.0.1"],
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
self.assertCLI(
|
self.assertCLI(
|
||||||
["-b", "10.0.0.1"],
|
["-b", "200a::1"], {"endpoints": [r"tcp:port=8000:interface=200a\:\:1"]}
|
||||||
{
|
|
||||||
"endpoints": ["tcp:port=8000:interface=10.0.0.1"],
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
self.assertCLI(
|
self.assertCLI(
|
||||||
["-b", "200a::1"],
|
["-b", "[200a::1]"], {"endpoints": [r"tcp:port=8000:interface=200a\:\:1"]}
|
||||||
{
|
|
||||||
"endpoints": [r'tcp:port=8000:interface=200a\:\:1'],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
self.assertCLI(
|
|
||||||
["-b", "[200a::1]"],
|
|
||||||
{
|
|
||||||
"endpoints": [r'tcp:port=8000:interface=200a\:\:1'],
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
self.assertCLI(
|
self.assertCLI(
|
||||||
["-p", "8080", "-b", "example.com"],
|
["-p", "8080", "-b", "example.com"],
|
||||||
{
|
{"endpoints": ["tcp:port=8080:interface=example.com"]},
|
||||||
"endpoints": ["tcp:port=8080:interface=example.com"],
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def testUnixSockets(self):
|
def testUnixSockets(self):
|
||||||
|
@ -178,7 +147,7 @@ class TestCLIInterface(TestCase):
|
||||||
"endpoints": [
|
"endpoints": [
|
||||||
"tcp:port=8080:interface=127.0.0.1",
|
"tcp:port=8080:interface=127.0.0.1",
|
||||||
"unix:/tmp/daphne.sock",
|
"unix:/tmp/daphne.sock",
|
||||||
],
|
]
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.assertCLI(
|
self.assertCLI(
|
||||||
|
@ -187,17 +156,12 @@ class TestCLIInterface(TestCase):
|
||||||
"endpoints": [
|
"endpoints": [
|
||||||
"tcp:port=8000:interface=example.com",
|
"tcp:port=8000:interface=example.com",
|
||||||
"unix:/tmp/daphne.sock",
|
"unix:/tmp/daphne.sock",
|
||||||
],
|
]
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.assertCLI(
|
self.assertCLI(
|
||||||
["-u", "/tmp/daphne.sock", "--fd", "5"],
|
["-u", "/tmp/daphne.sock", "--fd", "5"],
|
||||||
{
|
{"endpoints": ["fd:fileno=5", "unix:/tmp/daphne.sock"]},
|
||||||
"endpoints": [
|
|
||||||
"fd:fileno=5",
|
|
||||||
"unix:/tmp/daphne.sock"
|
|
||||||
],
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def testMixedCLIEndpointCreation(self):
|
def testMixedCLIEndpointCreation(self):
|
||||||
|
@ -209,8 +173,8 @@ class TestCLIInterface(TestCase):
|
||||||
{
|
{
|
||||||
"endpoints": [
|
"endpoints": [
|
||||||
"tcp:port=8080:interface=127.0.0.1",
|
"tcp:port=8080:interface=127.0.0.1",
|
||||||
"unix:/tmp/daphne.sock"
|
"unix:/tmp/daphne.sock",
|
||||||
],
|
]
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.assertCLI(
|
self.assertCLI(
|
||||||
|
@ -219,7 +183,7 @@ class TestCLIInterface(TestCase):
|
||||||
"endpoints": [
|
"endpoints": [
|
||||||
"tcp:port=8080:interface=127.0.0.1",
|
"tcp:port=8080:interface=127.0.0.1",
|
||||||
"tcp:port=8080:interface=127.0.0.1",
|
"tcp:port=8080:interface=127.0.0.1",
|
||||||
],
|
]
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -227,11 +191,4 @@ class TestCLIInterface(TestCase):
|
||||||
"""
|
"""
|
||||||
Tests entirely custom endpoints
|
Tests entirely custom endpoints
|
||||||
"""
|
"""
|
||||||
self.assertCLI(
|
self.assertCLI(["-e", "imap:"], {"endpoints": ["imap:"]})
|
||||||
["-e", "imap:"],
|
|
||||||
{
|
|
||||||
"endpoints": [
|
|
||||||
"imap:",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
|
@ -15,13 +15,7 @@ class TestHTTPRequest(DaphneTestCase):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def assert_valid_http_scope(
|
def assert_valid_http_scope(
|
||||||
self,
|
self, scope, method, path, params=None, headers=None, scheme=None
|
||||||
scope,
|
|
||||||
method,
|
|
||||||
path,
|
|
||||||
params=None,
|
|
||||||
headers=None,
|
|
||||||
scheme=None,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Checks that the passed scope is a valid ASGI HTTP scope regarding types
|
Checks that the passed scope is a valid ASGI HTTP scope regarding types
|
||||||
|
@ -29,7 +23,14 @@ class TestHTTPRequest(DaphneTestCase):
|
||||||
"""
|
"""
|
||||||
# Check overall keys
|
# Check overall keys
|
||||||
self.assert_key_sets(
|
self.assert_key_sets(
|
||||||
required_keys={"type", "http_version", "method", "path", "query_string", "headers"},
|
required_keys={
|
||||||
|
"type",
|
||||||
|
"http_version",
|
||||||
|
"method",
|
||||||
|
"path",
|
||||||
|
"query_string",
|
||||||
|
"headers",
|
||||||
|
},
|
||||||
optional_keys={"scheme", "root_path", "client", "server"},
|
optional_keys={"scheme", "root_path", "client", "server"},
|
||||||
actual_keys=scope.keys(),
|
actual_keys=scope.keys(),
|
||||||
)
|
)
|
||||||
|
@ -50,7 +51,9 @@ class TestHTTPRequest(DaphneTestCase):
|
||||||
query_string = scope["query_string"]
|
query_string = scope["query_string"]
|
||||||
self.assertIsInstance(query_string, bytes)
|
self.assertIsInstance(query_string, bytes)
|
||||||
if params:
|
if params:
|
||||||
self.assertEqual(query_string, parse.urlencode(params or []).encode("ascii"))
|
self.assertEqual(
|
||||||
|
query_string, parse.urlencode(params or []).encode("ascii")
|
||||||
|
)
|
||||||
# Ordering of header names is not important, but the order of values for a header
|
# Ordering of header names is not important, but the order of values for a header
|
||||||
# name is. To assert whether that order is kept, we transform both the request
|
# name is. To assert whether that order is kept, we transform both the request
|
||||||
# headers and the channel message headers into a dictionary
|
# headers and the channel message headers into a dictionary
|
||||||
|
@ -59,7 +62,7 @@ class TestHTTPRequest(DaphneTestCase):
|
||||||
for name, value in scope["headers"]:
|
for name, value in scope["headers"]:
|
||||||
transformed_scope_headers[name].append(value)
|
transformed_scope_headers[name].append(value)
|
||||||
transformed_request_headers = collections.defaultdict(list)
|
transformed_request_headers = collections.defaultdict(list)
|
||||||
for name, value in (headers or []):
|
for name, value in headers or []:
|
||||||
expected_name = name.lower().strip().encode("ascii")
|
expected_name = name.lower().strip().encode("ascii")
|
||||||
expected_value = value.strip().encode("ascii")
|
expected_value = value.strip().encode("ascii")
|
||||||
transformed_request_headers[expected_name].append(expected_value)
|
transformed_request_headers[expected_name].append(expected_value)
|
||||||
|
@ -103,27 +106,31 @@ class TestHTTPRequest(DaphneTestCase):
|
||||||
|
|
||||||
@given(
|
@given(
|
||||||
request_path=http_strategies.http_path(),
|
request_path=http_strategies.http_path(),
|
||||||
request_params=http_strategies.query_params()
|
request_params=http_strategies.query_params(),
|
||||||
)
|
)
|
||||||
@settings(max_examples=5, deadline=5000)
|
@settings(max_examples=5, deadline=5000)
|
||||||
def test_get_request(self, request_path, request_params):
|
def test_get_request(self, request_path, request_params):
|
||||||
"""
|
"""
|
||||||
Tests a typical HTTP GET request, with a path and query parameters
|
Tests a typical HTTP GET request, with a path and query parameters
|
||||||
"""
|
"""
|
||||||
scope, messages = self.run_daphne_request("GET", request_path, params=request_params)
|
scope, messages = self.run_daphne_request(
|
||||||
|
"GET", request_path, params=request_params
|
||||||
|
)
|
||||||
self.assert_valid_http_scope(scope, "GET", request_path, params=request_params)
|
self.assert_valid_http_scope(scope, "GET", request_path, params=request_params)
|
||||||
self.assert_valid_http_request_message(messages[0], body=b"")
|
self.assert_valid_http_request_message(messages[0], body=b"")
|
||||||
|
|
||||||
@given(
|
@given(
|
||||||
request_path=http_strategies.http_path(),
|
request_path=http_strategies.http_path(),
|
||||||
request_body=http_strategies.http_body()
|
request_body=http_strategies.http_body(),
|
||||||
)
|
)
|
||||||
@settings(max_examples=5, deadline=5000)
|
@settings(max_examples=5, deadline=5000)
|
||||||
def test_post_request(self, request_path, request_body):
|
def test_post_request(self, request_path, request_body):
|
||||||
"""
|
"""
|
||||||
Tests a typical HTTP POST request, with a path and body.
|
Tests a typical HTTP POST request, with a path and body.
|
||||||
"""
|
"""
|
||||||
scope, messages = self.run_daphne_request("POST", request_path, body=request_body)
|
scope, messages = self.run_daphne_request(
|
||||||
|
"POST", request_path, body=request_body
|
||||||
|
)
|
||||||
self.assert_valid_http_scope(scope, "POST", request_path)
|
self.assert_valid_http_scope(scope, "POST", request_path)
|
||||||
self.assert_valid_http_request_message(messages[0], body=request_body)
|
self.assert_valid_http_request_message(messages[0], body=request_body)
|
||||||
|
|
||||||
|
@ -134,8 +141,12 @@ class TestHTTPRequest(DaphneTestCase):
|
||||||
Tests that HTTP header fields are handled as specified
|
Tests that HTTP header fields are handled as specified
|
||||||
"""
|
"""
|
||||||
request_path = "/te st-à/"
|
request_path = "/te st-à/"
|
||||||
scope, messages = self.run_daphne_request("OPTIONS", request_path, headers=request_headers)
|
scope, messages = self.run_daphne_request(
|
||||||
self.assert_valid_http_scope(scope, "OPTIONS", request_path, headers=request_headers)
|
"OPTIONS", request_path, headers=request_headers
|
||||||
|
)
|
||||||
|
self.assert_valid_http_scope(
|
||||||
|
scope, "OPTIONS", request_path, headers=request_headers
|
||||||
|
)
|
||||||
self.assert_valid_http_request_message(messages[0], body=b"")
|
self.assert_valid_http_request_message(messages[0], body=b"")
|
||||||
|
|
||||||
@given(request_headers=http_strategies.headers())
|
@given(request_headers=http_strategies.headers())
|
||||||
|
@ -150,8 +161,12 @@ class TestHTTPRequest(DaphneTestCase):
|
||||||
duplicated_headers = [(header_name, header[1]) for header in request_headers]
|
duplicated_headers = [(header_name, header[1]) for header in request_headers]
|
||||||
# Run the request
|
# Run the request
|
||||||
request_path = "/te st-à/"
|
request_path = "/te st-à/"
|
||||||
scope, messages = self.run_daphne_request("OPTIONS", request_path, headers=duplicated_headers)
|
scope, messages = self.run_daphne_request(
|
||||||
self.assert_valid_http_scope(scope, "OPTIONS", request_path, headers=duplicated_headers)
|
"OPTIONS", request_path, headers=duplicated_headers
|
||||||
|
)
|
||||||
|
self.assert_valid_http_scope(
|
||||||
|
scope, "OPTIONS", request_path, headers=duplicated_headers
|
||||||
|
)
|
||||||
self.assert_valid_http_request_message(messages[0], body=b"")
|
self.assert_valid_http_request_message(messages[0], body=b"")
|
||||||
|
|
||||||
@given(
|
@given(
|
||||||
|
@ -222,10 +237,7 @@ class TestHTTPRequest(DaphneTestCase):
|
||||||
"""
|
"""
|
||||||
Make sure that, by default, X-Forwarded-For is ignored.
|
Make sure that, by default, X-Forwarded-For is ignored.
|
||||||
"""
|
"""
|
||||||
headers = [
|
headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]]
|
||||||
["X-Forwarded-For", "10.1.2.3"],
|
|
||||||
["X-Forwarded-Port", "80"],
|
|
||||||
]
|
|
||||||
scope, messages = self.run_daphne_request("GET", "/", headers=headers)
|
scope, messages = self.run_daphne_request("GET", "/", headers=headers)
|
||||||
self.assert_valid_http_scope(scope, "GET", "/", headers=headers)
|
self.assert_valid_http_scope(scope, "GET", "/", headers=headers)
|
||||||
self.assert_valid_http_request_message(messages[0], body=b"")
|
self.assert_valid_http_request_message(messages[0], body=b"")
|
||||||
|
@ -236,10 +248,7 @@ class TestHTTPRequest(DaphneTestCase):
|
||||||
"""
|
"""
|
||||||
When X-Forwarded-For is enabled, make sure it is respected.
|
When X-Forwarded-For is enabled, make sure it is respected.
|
||||||
"""
|
"""
|
||||||
headers = [
|
headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]]
|
||||||
["X-Forwarded-For", "10.1.2.3"],
|
|
||||||
["X-Forwarded-Port", "80"],
|
|
||||||
]
|
|
||||||
scope, messages = self.run_daphne_request("GET", "/", headers=headers, xff=True)
|
scope, messages = self.run_daphne_request("GET", "/", headers=headers, xff=True)
|
||||||
self.assert_valid_http_scope(scope, "GET", "/", headers=headers)
|
self.assert_valid_http_scope(scope, "GET", "/", headers=headers)
|
||||||
self.assert_valid_http_request_message(messages[0], body=b"")
|
self.assert_valid_http_request_message(messages[0], body=b"")
|
||||||
|
@ -251,9 +260,7 @@ class TestHTTPRequest(DaphneTestCase):
|
||||||
When X-Forwarded-For is enabled but only the host is passed, make sure
|
When X-Forwarded-For is enabled but only the host is passed, make sure
|
||||||
that at least makes it through.
|
that at least makes it through.
|
||||||
"""
|
"""
|
||||||
headers = [
|
headers = [["X-Forwarded-For", "10.1.2.3"]]
|
||||||
["X-Forwarded-For", "10.1.2.3"],
|
|
||||||
]
|
|
||||||
scope, messages = self.run_daphne_request("GET", "/", headers=headers, xff=True)
|
scope, messages = self.run_daphne_request("GET", "/", headers=headers, xff=True)
|
||||||
self.assert_valid_http_scope(scope, "GET", "/", headers=headers)
|
self.assert_valid_http_scope(scope, "GET", "/", headers=headers)
|
||||||
self.assert_valid_http_request_message(messages[0], body=b"")
|
self.assert_valid_http_request_message(messages[0], body=b"")
|
||||||
|
@ -265,8 +272,12 @@ class TestHTTPRequest(DaphneTestCase):
|
||||||
Tests that requests with invalid (non-ASCII) characters fail.
|
Tests that requests with invalid (non-ASCII) characters fail.
|
||||||
"""
|
"""
|
||||||
# Bad path
|
# Bad path
|
||||||
response = self.run_daphne_raw(b"GET /\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n")
|
response = self.run_daphne_raw(
|
||||||
|
b"GET /\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n"
|
||||||
|
)
|
||||||
self.assertTrue(response.startswith(b"HTTP/1.0 400 Bad Request"))
|
self.assertTrue(response.startswith(b"HTTP/1.0 400 Bad Request"))
|
||||||
# Bad querystring
|
# Bad querystring
|
||||||
response = self.run_daphne_raw(b"GET /?\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n")
|
response = self.run_daphne_raw(
|
||||||
|
b"GET /?\xc3\xa4\xc3\xb6\xc3\xbc HTTP/1.0\r\n\r\n"
|
||||||
|
)
|
||||||
self.assertTrue(response.startswith(b"HTTP/1.0 400 Bad Request"))
|
self.assertTrue(response.startswith(b"HTTP/1.0 400 Bad Request"))
|
||||||
|
|
|
@ -15,26 +15,24 @@ class TestHTTPResponse(DaphneTestCase):
|
||||||
"""
|
"""
|
||||||
Lowercases and sorts headers, and strips transfer-encoding ones.
|
Lowercases and sorts headers, and strips transfer-encoding ones.
|
||||||
"""
|
"""
|
||||||
return sorted([
|
return sorted(
|
||||||
(name.lower(), value.strip())
|
[
|
||||||
for name, value in headers
|
(name.lower(), value.strip())
|
||||||
if name.lower() != "transfer-encoding"
|
for name, value in headers
|
||||||
])
|
if name.lower() != "transfer-encoding"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def test_minimal_response(self):
|
def test_minimal_response(self):
|
||||||
"""
|
"""
|
||||||
Smallest viable example. Mostly verifies that our response building works.
|
Smallest viable example. Mostly verifies that our response building works.
|
||||||
"""
|
"""
|
||||||
response = self.run_daphne_response([
|
response = self.run_daphne_response(
|
||||||
{
|
[
|
||||||
"type": "http.response.start",
|
{"type": "http.response.start", "status": 200},
|
||||||
"status": 200,
|
{"type": "http.response.body", "body": b"hello world"},
|
||||||
},
|
]
|
||||||
{
|
)
|
||||||
"type": "http.response.body",
|
|
||||||
"body": b"hello world",
|
|
||||||
},
|
|
||||||
])
|
|
||||||
self.assertEqual(response.status, 200)
|
self.assertEqual(response.status, 200)
|
||||||
self.assertEqual(response.read(), b"hello world")
|
self.assertEqual(response.read(), b"hello world")
|
||||||
|
|
||||||
|
@ -46,30 +44,23 @@ class TestHTTPResponse(DaphneTestCase):
|
||||||
to make sure it stays required.
|
to make sure it stays required.
|
||||||
"""
|
"""
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
self.run_daphne_response([
|
self.run_daphne_response(
|
||||||
{
|
[
|
||||||
"type": "http.response.start",
|
{"type": "http.response.start"},
|
||||||
},
|
{"type": "http.response.body", "body": b"hello world"},
|
||||||
{
|
]
|
||||||
"type": "http.response.body",
|
)
|
||||||
"body": b"hello world",
|
|
||||||
},
|
|
||||||
])
|
|
||||||
|
|
||||||
def test_custom_status_code(self):
|
def test_custom_status_code(self):
|
||||||
"""
|
"""
|
||||||
Tries a non-default status code.
|
Tries a non-default status code.
|
||||||
"""
|
"""
|
||||||
response = self.run_daphne_response([
|
response = self.run_daphne_response(
|
||||||
{
|
[
|
||||||
"type": "http.response.start",
|
{"type": "http.response.start", "status": 201},
|
||||||
"status": 201,
|
{"type": "http.response.body", "body": b"i made a thing!"},
|
||||||
},
|
]
|
||||||
{
|
)
|
||||||
"type": "http.response.body",
|
|
||||||
"body": b"i made a thing!",
|
|
||||||
},
|
|
||||||
])
|
|
||||||
self.assertEqual(response.status, 201)
|
self.assertEqual(response.status, 201)
|
||||||
self.assertEqual(response.read(), b"i made a thing!")
|
self.assertEqual(response.read(), b"i made a thing!")
|
||||||
|
|
||||||
|
@ -77,21 +68,13 @@ class TestHTTPResponse(DaphneTestCase):
|
||||||
"""
|
"""
|
||||||
Tries sending a response in multiple parts.
|
Tries sending a response in multiple parts.
|
||||||
"""
|
"""
|
||||||
response = self.run_daphne_response([
|
response = self.run_daphne_response(
|
||||||
{
|
[
|
||||||
"type": "http.response.start",
|
{"type": "http.response.start", "status": 201},
|
||||||
"status": 201,
|
{"type": "http.response.body", "body": b"chunk 1 ", "more_body": True},
|
||||||
},
|
{"type": "http.response.body", "body": b"chunk 2"},
|
||||||
{
|
]
|
||||||
"type": "http.response.body",
|
)
|
||||||
"body": b"chunk 1 ",
|
|
||||||
"more_body": True,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "http.response.body",
|
|
||||||
"body": b"chunk 2",
|
|
||||||
},
|
|
||||||
])
|
|
||||||
self.assertEqual(response.status, 201)
|
self.assertEqual(response.status, 201)
|
||||||
self.assertEqual(response.read(), b"chunk 1 chunk 2")
|
self.assertEqual(response.read(), b"chunk 1 chunk 2")
|
||||||
|
|
||||||
|
@ -99,25 +82,14 @@ class TestHTTPResponse(DaphneTestCase):
|
||||||
"""
|
"""
|
||||||
Tries sending a response in multiple parts and an empty end.
|
Tries sending a response in multiple parts and an empty end.
|
||||||
"""
|
"""
|
||||||
response = self.run_daphne_response([
|
response = self.run_daphne_response(
|
||||||
{
|
[
|
||||||
"type": "http.response.start",
|
{"type": "http.response.start", "status": 201},
|
||||||
"status": 201,
|
{"type": "http.response.body", "body": b"chunk 1 ", "more_body": True},
|
||||||
},
|
{"type": "http.response.body", "body": b"chunk 2", "more_body": True},
|
||||||
{
|
{"type": "http.response.body"},
|
||||||
"type": "http.response.body",
|
]
|
||||||
"body": b"chunk 1 ",
|
)
|
||||||
"more_body": True,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "http.response.body",
|
|
||||||
"body": b"chunk 2",
|
|
||||||
"more_body": True,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "http.response.body",
|
|
||||||
},
|
|
||||||
])
|
|
||||||
self.assertEqual(response.status, 201)
|
self.assertEqual(response.status, 201)
|
||||||
self.assertEqual(response.read(), b"chunk 1 chunk 2")
|
self.assertEqual(response.read(), b"chunk 1 chunk 2")
|
||||||
|
|
||||||
|
@ -127,16 +99,12 @@ class TestHTTPResponse(DaphneTestCase):
|
||||||
"""
|
"""
|
||||||
Tries body variants.
|
Tries body variants.
|
||||||
"""
|
"""
|
||||||
response = self.run_daphne_response([
|
response = self.run_daphne_response(
|
||||||
{
|
[
|
||||||
"type": "http.response.start",
|
{"type": "http.response.start", "status": 200},
|
||||||
"status": 200,
|
{"type": "http.response.body", "body": body},
|
||||||
},
|
]
|
||||||
{
|
)
|
||||||
"type": "http.response.body",
|
|
||||||
"body": body,
|
|
||||||
},
|
|
||||||
])
|
|
||||||
self.assertEqual(response.status, 200)
|
self.assertEqual(response.status, 200)
|
||||||
self.assertEqual(response.read(), body)
|
self.assertEqual(response.read(), body)
|
||||||
|
|
||||||
|
@ -144,16 +112,16 @@ class TestHTTPResponse(DaphneTestCase):
|
||||||
@settings(max_examples=5, deadline=5000)
|
@settings(max_examples=5, deadline=5000)
|
||||||
def test_headers(self, headers):
|
def test_headers(self, headers):
|
||||||
# The ASGI spec requires us to lowercase our header names
|
# The ASGI spec requires us to lowercase our header names
|
||||||
response = self.run_daphne_response([
|
response = self.run_daphne_response(
|
||||||
{
|
[
|
||||||
"type": "http.response.start",
|
{
|
||||||
"status": 200,
|
"type": "http.response.start",
|
||||||
"headers": self.normalize_headers(headers),
|
"status": 200,
|
||||||
},
|
"headers": self.normalize_headers(headers),
|
||||||
{
|
},
|
||||||
"type": "http.response.body",
|
{"type": "http.response.body"},
|
||||||
},
|
]
|
||||||
])
|
)
|
||||||
# Check headers in a sensible way. Ignore transfer-encoding.
|
# Check headers in a sensible way. Ignore transfer-encoding.
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.normalize_headers(response.getheaders()),
|
self.normalize_headers(response.getheaders()),
|
||||||
|
|
|
@ -13,48 +13,35 @@ class TestXForwardedForHttpParsing(TestCase):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
headers = Headers({
|
headers = Headers(
|
||||||
b"X-Forwarded-For": [b"10.1.2.3"],
|
{
|
||||||
b"X-Forwarded-Port": [b"1234"],
|
b"X-Forwarded-For": [b"10.1.2.3"],
|
||||||
b"X-Forwarded-Proto": [b"https"]
|
b"X-Forwarded-Port": [b"1234"],
|
||||||
})
|
b"X-Forwarded-Proto": [b"https"],
|
||||||
|
}
|
||||||
|
)
|
||||||
result = parse_x_forwarded_for(headers)
|
result = parse_x_forwarded_for(headers)
|
||||||
self.assertEqual(result, (["10.1.2.3", 1234], "https"))
|
self.assertEqual(result, (["10.1.2.3", 1234], "https"))
|
||||||
self.assertIsInstance(result[0][0], str)
|
self.assertIsInstance(result[0][0], str)
|
||||||
self.assertIsInstance(result[1], str)
|
self.assertIsInstance(result[1], str)
|
||||||
|
|
||||||
def test_address_only(self):
|
def test_address_only(self):
|
||||||
headers = Headers({
|
headers = Headers({b"X-Forwarded-For": [b"10.1.2.3"]})
|
||||||
b"X-Forwarded-For": [b"10.1.2.3"],
|
self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None))
|
||||||
})
|
|
||||||
self.assertEqual(
|
|
||||||
parse_x_forwarded_for(headers),
|
|
||||||
(["10.1.2.3", 0], None)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_v6_address(self):
|
def test_v6_address(self):
|
||||||
headers = Headers({
|
headers = Headers({b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"]})
|
||||||
b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"],
|
self.assertEqual(parse_x_forwarded_for(headers), (["1043::a321:0001", 0], None))
|
||||||
})
|
|
||||||
self.assertEqual(
|
|
||||||
parse_x_forwarded_for(headers),
|
|
||||||
(["1043::a321:0001", 0], None)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_multiple_proxys(self):
|
def test_multiple_proxys(self):
|
||||||
headers = Headers({
|
headers = Headers({b"X-Forwarded-For": [b"10.1.2.3, 10.1.2.4"]})
|
||||||
b"X-Forwarded-For": [b"10.1.2.3, 10.1.2.4"],
|
self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None))
|
||||||
})
|
|
||||||
self.assertEqual(
|
|
||||||
parse_x_forwarded_for(headers),
|
|
||||||
(["10.1.2.3", 0], None)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_original(self):
|
def test_original(self):
|
||||||
headers = Headers({})
|
headers = Headers({})
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
parse_x_forwarded_for(headers, original_addr=["127.0.0.1", 80]),
|
parse_x_forwarded_for(headers, original_addr=["127.0.0.1", 80]),
|
||||||
(["127.0.0.1", 80], None)
|
(["127.0.0.1", 80], None),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_no_original(self):
|
def test_no_original(self):
|
||||||
|
@ -73,43 +60,25 @@ class TestXForwardedForWsParsing(TestCase):
|
||||||
b"X-Forwarded-Port": b"1234",
|
b"X-Forwarded-Port": b"1234",
|
||||||
b"X-Forwarded-Proto": b"https",
|
b"X-Forwarded-Proto": b"https",
|
||||||
}
|
}
|
||||||
self.assertEqual(
|
self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 1234], "https"))
|
||||||
parse_x_forwarded_for(headers),
|
|
||||||
(["10.1.2.3", 1234], "https")
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_address_only(self):
|
def test_address_only(self):
|
||||||
headers = {
|
headers = {b"X-Forwarded-For": b"10.1.2.3"}
|
||||||
b"X-Forwarded-For": b"10.1.2.3",
|
self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None))
|
||||||
}
|
|
||||||
self.assertEqual(
|
|
||||||
parse_x_forwarded_for(headers),
|
|
||||||
(["10.1.2.3", 0], None)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_v6_address(self):
|
def test_v6_address(self):
|
||||||
headers = {
|
headers = {b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"]}
|
||||||
b"X-Forwarded-For": [b"1043::a321:0001, 10.0.5.6"],
|
self.assertEqual(parse_x_forwarded_for(headers), (["1043::a321:0001", 0], None))
|
||||||
}
|
|
||||||
self.assertEqual(
|
|
||||||
parse_x_forwarded_for(headers),
|
|
||||||
(["1043::a321:0001", 0], None)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_multiple_proxies(self):
|
def test_multiple_proxies(self):
|
||||||
headers = {
|
headers = {b"X-Forwarded-For": b"10.1.2.3, 10.1.2.4"}
|
||||||
b"X-Forwarded-For": b"10.1.2.3, 10.1.2.4",
|
self.assertEqual(parse_x_forwarded_for(headers), (["10.1.2.3", 0], None))
|
||||||
}
|
|
||||||
self.assertEqual(
|
|
||||||
parse_x_forwarded_for(headers),
|
|
||||||
(["10.1.2.3", 0], None)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_original(self):
|
def test_original(self):
|
||||||
headers = {}
|
headers = {}
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
parse_x_forwarded_for(headers, original_addr=["127.0.0.1", 80]),
|
parse_x_forwarded_for(headers, original_addr=["127.0.0.1", 80]),
|
||||||
(["127.0.0.1", 80], None)
|
(["127.0.0.1", 80], None),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_no_original(self):
|
def test_no_original(self):
|
||||||
|
|
|
@ -16,13 +16,7 @@ class TestWebsocket(DaphneTestCase):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def assert_valid_websocket_scope(
|
def assert_valid_websocket_scope(
|
||||||
self,
|
self, scope, path="/", params=None, headers=None, scheme=None, subprotocols=None
|
||||||
scope,
|
|
||||||
path="/",
|
|
||||||
params=None,
|
|
||||||
headers=None,
|
|
||||||
scheme=None,
|
|
||||||
subprotocols=None,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Checks that the passed scope is a valid ASGI HTTP scope regarding types
|
Checks that the passed scope is a valid ASGI HTTP scope regarding types
|
||||||
|
@ -46,7 +40,9 @@ class TestWebsocket(DaphneTestCase):
|
||||||
query_string = scope["query_string"]
|
query_string = scope["query_string"]
|
||||||
self.assertIsInstance(query_string, bytes)
|
self.assertIsInstance(query_string, bytes)
|
||||||
if params:
|
if params:
|
||||||
self.assertEqual(query_string, parse.urlencode(params or []).encode("ascii"))
|
self.assertEqual(
|
||||||
|
query_string, parse.urlencode(params or []).encode("ascii")
|
||||||
|
)
|
||||||
# Ordering of header names is not important, but the order of values for a header
|
# Ordering of header names is not important, but the order of values for a header
|
||||||
# name is. To assert whether that order is kept, we transform both the request
|
# name is. To assert whether that order is kept, we transform both the request
|
||||||
# headers and the channel message headers into a dictionary
|
# headers and the channel message headers into a dictionary
|
||||||
|
@ -59,7 +55,7 @@ class TestWebsocket(DaphneTestCase):
|
||||||
if bit.strip():
|
if bit.strip():
|
||||||
transformed_scope_headers[name].append(bit.strip())
|
transformed_scope_headers[name].append(bit.strip())
|
||||||
transformed_request_headers = collections.defaultdict(list)
|
transformed_request_headers = collections.defaultdict(list)
|
||||||
for name, value in (headers or []):
|
for name, value in headers or []:
|
||||||
expected_name = name.lower().strip().encode("ascii")
|
expected_name = name.lower().strip().encode("ascii")
|
||||||
expected_value = value.strip().encode("ascii")
|
expected_value = value.strip().encode("ascii")
|
||||||
# Make sure to split out any headers collapsed with commas
|
# Make sure to split out any headers collapsed with commas
|
||||||
|
@ -92,9 +88,7 @@ class TestWebsocket(DaphneTestCase):
|
||||||
"""
|
"""
|
||||||
# Check overall keys
|
# Check overall keys
|
||||||
self.assert_key_sets(
|
self.assert_key_sets(
|
||||||
required_keys={"type"},
|
required_keys={"type"}, optional_keys=set(), actual_keys=message.keys()
|
||||||
optional_keys=set(),
|
|
||||||
actual_keys=message.keys(),
|
|
||||||
)
|
)
|
||||||
# Check that it is the right type
|
# Check that it is the right type
|
||||||
self.assertEqual(message["type"], "websocket.connect")
|
self.assertEqual(message["type"], "websocket.connect")
|
||||||
|
@ -104,11 +98,7 @@ class TestWebsocket(DaphneTestCase):
|
||||||
Tests we can open and accept a socket.
|
Tests we can open and accept a socket.
|
||||||
"""
|
"""
|
||||||
with DaphneTestingInstance() as test_app:
|
with DaphneTestingInstance() as test_app:
|
||||||
test_app.add_send_messages([
|
test_app.add_send_messages([{"type": "websocket.accept"}])
|
||||||
{
|
|
||||||
"type": "websocket.accept",
|
|
||||||
}
|
|
||||||
])
|
|
||||||
self.websocket_handshake(test_app)
|
self.websocket_handshake(test_app)
|
||||||
# Validate the scope and messages we got
|
# Validate the scope and messages we got
|
||||||
scope, messages = test_app.get_received()
|
scope, messages = test_app.get_received()
|
||||||
|
@ -120,11 +110,7 @@ class TestWebsocket(DaphneTestCase):
|
||||||
Tests we can reject a socket and it won't complete the handshake.
|
Tests we can reject a socket and it won't complete the handshake.
|
||||||
"""
|
"""
|
||||||
with DaphneTestingInstance() as test_app:
|
with DaphneTestingInstance() as test_app:
|
||||||
test_app.add_send_messages([
|
test_app.add_send_messages([{"type": "websocket.close"}])
|
||||||
{
|
|
||||||
"type": "websocket.close",
|
|
||||||
}
|
|
||||||
])
|
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
self.websocket_handshake(test_app)
|
self.websocket_handshake(test_app)
|
||||||
|
|
||||||
|
@ -134,13 +120,12 @@ class TestWebsocket(DaphneTestCase):
|
||||||
"""
|
"""
|
||||||
subprotocols = ["proto1", "proto2"]
|
subprotocols = ["proto1", "proto2"]
|
||||||
with DaphneTestingInstance() as test_app:
|
with DaphneTestingInstance() as test_app:
|
||||||
test_app.add_send_messages([
|
test_app.add_send_messages(
|
||||||
{
|
[{"type": "websocket.accept", "subprotocol": "proto2"}]
|
||||||
"type": "websocket.accept",
|
)
|
||||||
"subprotocol": "proto2",
|
_, subprotocol = self.websocket_handshake(
|
||||||
}
|
test_app, subprotocols=subprotocols
|
||||||
])
|
)
|
||||||
_, subprotocol = self.websocket_handshake(test_app, subprotocols=subprotocols)
|
|
||||||
# Validate the scope and messages we got
|
# Validate the scope and messages we got
|
||||||
assert subprotocol == "proto2"
|
assert subprotocol == "proto2"
|
||||||
scope, messages = test_app.get_received()
|
scope, messages = test_app.get_received()
|
||||||
|
@ -151,16 +136,9 @@ class TestWebsocket(DaphneTestCase):
|
||||||
"""
|
"""
|
||||||
Tests that X-Forwarded-For headers get parsed right
|
Tests that X-Forwarded-For headers get parsed right
|
||||||
"""
|
"""
|
||||||
headers = [
|
headers = [["X-Forwarded-For", "10.1.2.3"], ["X-Forwarded-Port", "80"]]
|
||||||
["X-Forwarded-For", "10.1.2.3"],
|
|
||||||
["X-Forwarded-Port", "80"],
|
|
||||||
]
|
|
||||||
with DaphneTestingInstance(xff=True) as test_app:
|
with DaphneTestingInstance(xff=True) as test_app:
|
||||||
test_app.add_send_messages([
|
test_app.add_send_messages([{"type": "websocket.accept"}])
|
||||||
{
|
|
||||||
"type": "websocket.accept",
|
|
||||||
}
|
|
||||||
])
|
|
||||||
self.websocket_handshake(test_app, headers=headers)
|
self.websocket_handshake(test_app, headers=headers)
|
||||||
# Validate the scope and messages we got
|
# Validate the scope and messages we got
|
||||||
scope, messages = test_app.get_received()
|
scope, messages = test_app.get_received()
|
||||||
|
@ -174,22 +152,13 @@ class TestWebsocket(DaphneTestCase):
|
||||||
request_headers=http_strategies.headers(),
|
request_headers=http_strategies.headers(),
|
||||||
)
|
)
|
||||||
@settings(max_examples=5, deadline=2000)
|
@settings(max_examples=5, deadline=2000)
|
||||||
def test_http_bits(
|
def test_http_bits(self, request_path, request_params, request_headers):
|
||||||
self,
|
|
||||||
request_path,
|
|
||||||
request_params,
|
|
||||||
request_headers,
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Tests that various HTTP-level bits (query string params, path, headers)
|
Tests that various HTTP-level bits (query string params, path, headers)
|
||||||
carry over into the scope.
|
carry over into the scope.
|
||||||
"""
|
"""
|
||||||
with DaphneTestingInstance() as test_app:
|
with DaphneTestingInstance() as test_app:
|
||||||
test_app.add_send_messages([
|
test_app.add_send_messages([{"type": "websocket.accept"}])
|
||||||
{
|
|
||||||
"type": "websocket.accept",
|
|
||||||
}
|
|
||||||
])
|
|
||||||
self.websocket_handshake(
|
self.websocket_handshake(
|
||||||
test_app,
|
test_app,
|
||||||
path=request_path,
|
path=request_path,
|
||||||
|
@ -199,10 +168,7 @@ class TestWebsocket(DaphneTestCase):
|
||||||
# Validate the scope and messages we got
|
# Validate the scope and messages we got
|
||||||
scope, messages = test_app.get_received()
|
scope, messages = test_app.get_received()
|
||||||
self.assert_valid_websocket_scope(
|
self.assert_valid_websocket_scope(
|
||||||
scope,
|
scope, path=request_path, params=request_params, headers=request_headers
|
||||||
path=request_path,
|
|
||||||
params=request_params,
|
|
||||||
headers=request_headers,
|
|
||||||
)
|
)
|
||||||
self.assert_valid_websocket_connect_message(messages[0])
|
self.assert_valid_websocket_connect_message(messages[0])
|
||||||
|
|
||||||
|
@ -212,28 +178,24 @@ class TestWebsocket(DaphneTestCase):
|
||||||
"""
|
"""
|
||||||
with DaphneTestingInstance() as test_app:
|
with DaphneTestingInstance() as test_app:
|
||||||
# Connect
|
# Connect
|
||||||
test_app.add_send_messages([
|
test_app.add_send_messages([{"type": "websocket.accept"}])
|
||||||
{
|
|
||||||
"type": "websocket.accept",
|
|
||||||
}
|
|
||||||
])
|
|
||||||
sock, _ = self.websocket_handshake(test_app)
|
sock, _ = self.websocket_handshake(test_app)
|
||||||
_, messages = test_app.get_received()
|
_, messages = test_app.get_received()
|
||||||
self.assert_valid_websocket_connect_message(messages[0])
|
self.assert_valid_websocket_connect_message(messages[0])
|
||||||
# Prep frame for it to send
|
# Prep frame for it to send
|
||||||
test_app.add_send_messages([
|
test_app.add_send_messages(
|
||||||
{
|
[{"type": "websocket.send", "text": "here be dragons 🐉"}]
|
||||||
"type": "websocket.send",
|
)
|
||||||
"text": "here be dragons 🐉",
|
|
||||||
}
|
|
||||||
])
|
|
||||||
# Send it a frame
|
# Send it a frame
|
||||||
self.websocket_send_frame(sock, "what is here? 🌍")
|
self.websocket_send_frame(sock, "what is here? 🌍")
|
||||||
# Receive a frame and make sure it's correct
|
# Receive a frame and make sure it's correct
|
||||||
assert self.websocket_receive_frame(sock) == "here be dragons 🐉"
|
assert self.websocket_receive_frame(sock) == "here be dragons 🐉"
|
||||||
# Make sure it got our frame
|
# Make sure it got our frame
|
||||||
_, messages = test_app.get_received()
|
_, messages = test_app.get_received()
|
||||||
assert messages[1] == {"type": "websocket.receive", "text": "what is here? 🌍"}
|
assert messages[1] == {
|
||||||
|
"type": "websocket.receive",
|
||||||
|
"text": "what is here? 🌍",
|
||||||
|
}
|
||||||
|
|
||||||
def test_binary_frames(self):
|
def test_binary_frames(self):
|
||||||
"""
|
"""
|
||||||
|
@ -242,28 +204,24 @@ class TestWebsocket(DaphneTestCase):
|
||||||
"""
|
"""
|
||||||
with DaphneTestingInstance() as test_app:
|
with DaphneTestingInstance() as test_app:
|
||||||
# Connect
|
# Connect
|
||||||
test_app.add_send_messages([
|
test_app.add_send_messages([{"type": "websocket.accept"}])
|
||||||
{
|
|
||||||
"type": "websocket.accept",
|
|
||||||
}
|
|
||||||
])
|
|
||||||
sock, _ = self.websocket_handshake(test_app)
|
sock, _ = self.websocket_handshake(test_app)
|
||||||
_, messages = test_app.get_received()
|
_, messages = test_app.get_received()
|
||||||
self.assert_valid_websocket_connect_message(messages[0])
|
self.assert_valid_websocket_connect_message(messages[0])
|
||||||
# Prep frame for it to send
|
# Prep frame for it to send
|
||||||
test_app.add_send_messages([
|
test_app.add_send_messages(
|
||||||
{
|
[{"type": "websocket.send", "bytes": b"here be \xe2 bytes"}]
|
||||||
"type": "websocket.send",
|
)
|
||||||
"bytes": b"here be \xe2 bytes",
|
|
||||||
}
|
|
||||||
])
|
|
||||||
# Send it a frame
|
# Send it a frame
|
||||||
self.websocket_send_frame(sock, b"what is here? \xe2")
|
self.websocket_send_frame(sock, b"what is here? \xe2")
|
||||||
# Receive a frame and make sure it's correct
|
# Receive a frame and make sure it's correct
|
||||||
assert self.websocket_receive_frame(sock) == b"here be \xe2 bytes"
|
assert self.websocket_receive_frame(sock) == b"here be \xe2 bytes"
|
||||||
# Make sure it got our frame
|
# Make sure it got our frame
|
||||||
_, messages = test_app.get_received()
|
_, messages = test_app.get_received()
|
||||||
assert messages[1] == {"type": "websocket.receive", "bytes": b"what is here? \xe2"}
|
assert messages[1] == {
|
||||||
|
"type": "websocket.receive",
|
||||||
|
"bytes": b"what is here? \xe2",
|
||||||
|
}
|
||||||
|
|
||||||
def test_http_timeout(self):
|
def test_http_timeout(self):
|
||||||
"""
|
"""
|
||||||
|
@ -271,23 +229,14 @@ class TestWebsocket(DaphneTestCase):
|
||||||
"""
|
"""
|
||||||
with DaphneTestingInstance(http_timeout=1) as test_app:
|
with DaphneTestingInstance(http_timeout=1) as test_app:
|
||||||
# Connect
|
# Connect
|
||||||
test_app.add_send_messages([
|
test_app.add_send_messages([{"type": "websocket.accept"}])
|
||||||
{
|
|
||||||
"type": "websocket.accept",
|
|
||||||
}
|
|
||||||
])
|
|
||||||
sock, _ = self.websocket_handshake(test_app)
|
sock, _ = self.websocket_handshake(test_app)
|
||||||
_, messages = test_app.get_received()
|
_, messages = test_app.get_received()
|
||||||
self.assert_valid_websocket_connect_message(messages[0])
|
self.assert_valid_websocket_connect_message(messages[0])
|
||||||
# Wait 2 seconds
|
# Wait 2 seconds
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
# Prep frame for it to send
|
# Prep frame for it to send
|
||||||
test_app.add_send_messages([
|
test_app.add_send_messages([{"type": "websocket.send", "text": "cake"}])
|
||||||
{
|
|
||||||
"type": "websocket.send",
|
|
||||||
"text": "cake",
|
|
||||||
}
|
|
||||||
])
|
|
||||||
# Send it a frame
|
# Send it a frame
|
||||||
self.websocket_send_frame(sock, "still alive?")
|
self.websocket_send_frame(sock, "still alive?")
|
||||||
# Receive a frame and make sure it's correct
|
# Receive a frame and make sure it's correct
|
||||||
|
|
Loading…
Reference in New Issue
Block a user