mirror of
https://github.com/django/daphne.git
synced 2024-11-21 23:46:33 +03:00
Change to scope-based code
This commit is contained in:
parent
01f174bf26
commit
017797c05b
|
@ -42,7 +42,7 @@ class CommandLineInterface(object):
|
|||
'--websocket_timeout',
|
||||
type=int,
|
||||
help='Maximum time to allow a websocket to be connected. -1 for infinite.',
|
||||
default=None,
|
||||
default=86400,
|
||||
)
|
||||
self.parser.add_argument(
|
||||
'--websocket_connect_timeout',
|
||||
|
|
|
@ -30,8 +30,6 @@ class WebRequest(http.Request):
|
|||
GET and POST out.
|
||||
"""
|
||||
|
||||
application_type = "http"
|
||||
|
||||
error_template = """
|
||||
<html>
|
||||
<head>
|
||||
|
@ -138,8 +136,6 @@ class WebRequest(http.Request):
|
|||
|
||||
# Boring old HTTP.
|
||||
else:
|
||||
# Create application to handle this connection
|
||||
self.application_queue = self.server.create_application(self)
|
||||
# Sanitize and decode headers, potentially extracting root path
|
||||
self.clean_headers = []
|
||||
self.root_path = self.server.root_path
|
||||
|
@ -154,21 +150,25 @@ class WebRequest(http.Request):
|
|||
self.clean_headers.append((name.lower(), value))
|
||||
logger.debug("HTTP %s request for %s", self.method, self.client_addr)
|
||||
self.content.seek(0, 0)
|
||||
# Work out the application scope and create application
|
||||
self.application_queue = self.server.create_application(self, {
|
||||
"type": "http",
|
||||
# TODO: Correctly say if it's 1.1 or 1.0
|
||||
"http_version": self.clientproto.split(b"/")[-1].decode("ascii"),
|
||||
"method": self.method.decode("ascii"),
|
||||
"path": self.unquote(self.path),
|
||||
"root_path": self.root_path,
|
||||
"scheme": "https" if self.isSecure() else "http",
|
||||
"query_string": self.query_string,
|
||||
"headers": self.clean_headers,
|
||||
"client": self.client_addr,
|
||||
"server": self.server_addr,
|
||||
})
|
||||
# Run application against request
|
||||
self.application_queue.put_nowait(
|
||||
{
|
||||
"type": "http.request",
|
||||
# TODO: Correctly say if it's 1.1 or 1.0
|
||||
"http_version": self.clientproto.split(b"/")[-1].decode("ascii"),
|
||||
"method": self.method.decode("ascii"),
|
||||
"path": self.unquote(self.path),
|
||||
"root_path": self.root_path,
|
||||
"scheme": "https" if self.isSecure() else "http",
|
||||
"query_string": self.query_string,
|
||||
"headers": self.clean_headers,
|
||||
"body": self.content.read(),
|
||||
"client": self.client_addr,
|
||||
"server": self.server_addr,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
|
|
|
@ -91,7 +91,7 @@ class Server(object):
|
|||
reactor.callLater(2, self.timeout_checker)
|
||||
|
||||
for socket_description in self.endpoints:
|
||||
logger.info("Listening on endpoint %s" % socket_description)
|
||||
logger.info("Listening on endpoint %s", socket_description)
|
||||
ep = serverFromString(reactor, str(socket_description))
|
||||
self.listeners.append(ep.listen(self.http_factory))
|
||||
|
||||
|
@ -123,25 +123,23 @@ class Server(object):
|
|||
|
||||
### Internal event/message handling
|
||||
|
||||
def create_application(self, protocol):
|
||||
def create_application(self, protocol, scope):
|
||||
"""
|
||||
Creates a new application instance that fronts a Protocol instance
|
||||
for one of our supported protocols. Pass it the protocol,
|
||||
and it will work out the type, supply appropriate callables, and
|
||||
return you the application's input queue
|
||||
"""
|
||||
# Make sure the protocol defines a application type
|
||||
assert protocol.application_type is not None
|
||||
# Make sure the protocol has not had another application made for it
|
||||
assert protocol not in self.application_instances
|
||||
# Make an instance of the application
|
||||
input_queue = asyncio.Queue()
|
||||
application_instance = asyncio.ensure_future(self.application(
|
||||
type=protocol.application_type,
|
||||
next=input_queue.get,
|
||||
reply=lambda message: self.handle_reply(protocol, message),
|
||||
application_instance = self.application(scope=scope)
|
||||
# Run it, and stash the future for later checking
|
||||
self.application_instances[protocol] = asyncio.ensure_future(application_instance(
|
||||
receive=input_queue.get,
|
||||
send=lambda message: self.handle_reply(protocol, message),
|
||||
), loop=asyncio.get_event_loop())
|
||||
self.application_instances[protocol] = application_instance
|
||||
return input_queue
|
||||
|
||||
async def handle_reply(self, protocol, message):
|
||||
|
@ -188,7 +186,7 @@ class Server(object):
|
|||
if not application_instance.done():
|
||||
application_instance.cancel()
|
||||
wait_for.append(application_instance)
|
||||
logging.info("Killed %i pending application instances" % len(wait_for))
|
||||
logging.info("Killed %i pending application instances", len(wait_for))
|
||||
# Make Twisted wait until they're all dead
|
||||
wait_deferred = defer.Deferred.fromFuture(asyncio.gather(*wait_for))
|
||||
wait_deferred.addErrback(lambda x: None)
|
||||
|
@ -199,7 +197,7 @@ class Server(object):
|
|||
Called periodically to enforce timeout rules on all connections.
|
||||
Also checks pings at the same time.
|
||||
"""
|
||||
for protocol in self.protocols:
|
||||
for protocol in list(self.protocols):
|
||||
protocol.check_timeouts()
|
||||
reactor.callLater(2, self.timeout_checker)
|
||||
|
||||
|
|
|
@ -33,8 +33,6 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
self.socket_opened = time.time()
|
||||
self.last_data = time.time()
|
||||
try:
|
||||
# Make new application instance
|
||||
self.application_queue = self.server.create_application(self)
|
||||
# Sanitize and decode headers
|
||||
self.clean_headers = []
|
||||
for name, value in request.headers.items():
|
||||
|
@ -60,41 +58,31 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
self.server.proxy_forwarded_port_header,
|
||||
self.client_addr
|
||||
)
|
||||
|
||||
# Make initial request info dict from request (we only have it here)
|
||||
# Decode websocket subprotocol options
|
||||
subprotocols = []
|
||||
for header, value in self.clean_headers:
|
||||
if header == b'sec-websocket-protocol':
|
||||
subprotocols = [x.strip() for x in self.unquote(value).split(",")]
|
||||
# Make new application instance with scope
|
||||
self.path = request.path.encode("ascii")
|
||||
self.connect_message = {
|
||||
"type": "websocket.connect",
|
||||
self.application_queue = self.server.create_application(self, {
|
||||
"type": "websocket",
|
||||
"path": self.unquote(self.path),
|
||||
"headers": self.clean_headers,
|
||||
"query_string": self._raw_query_string, # Passed by HTTP protocol
|
||||
"client": self.client_addr,
|
||||
"server": self.server_addr,
|
||||
"subprotocols": subprotocols,
|
||||
"order": 0,
|
||||
}
|
||||
})
|
||||
except:
|
||||
# Exceptions here are not displayed right, just 500.
|
||||
# Turn them into an ERROR log.
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
ws_protocol = None
|
||||
for header, value in self.clean_headers:
|
||||
if header == b'sec-websocket-protocol':
|
||||
protocols = [x.strip() for x in self.unquote(value).split(",")]
|
||||
for protocol in protocols:
|
||||
if protocol in self.server.websocket_protocols:
|
||||
ws_protocol = protocol
|
||||
break
|
||||
|
||||
# Work out what subprotocol we will accept, if any
|
||||
if ws_protocol and ws_protocol in self.server.websocket_protocols:
|
||||
self.protocol_to_accept = ws_protocol
|
||||
else:
|
||||
self.protocol_to_accept = None
|
||||
|
||||
# Send over the connect message
|
||||
self.application_queue.put_nowait(self.connect_message)
|
||||
self.application_queue.put_nowait({"type": "websocket.connect"})
|
||||
self.server.log_action("websocket", "connecting", {
|
||||
"path": self.request.path,
|
||||
"client": "%s:%s" % tuple(self.client_addr) if self.client_addr else None,
|
||||
|
@ -154,7 +142,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
if "type" not in message:
|
||||
raise ValueError("Message has no type defined")
|
||||
if message["type"] == "websocket.accept":
|
||||
self.serverAccept()
|
||||
self.serverAccept(message.get("subprotocol", None))
|
||||
elif message["type"] == "websocket.close":
|
||||
if self.state == self.STATE_CONNECTING:
|
||||
self.serverReject()
|
||||
|
@ -174,11 +162,23 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
if message.get("text", None):
|
||||
self.serverSend(message["text"], False)
|
||||
|
||||
def serverAccept(self):
|
||||
def handle_exception(self, exception):
|
||||
"""
|
||||
Called by the server when our application tracebacks
|
||||
"""
|
||||
if hasattr(self, "handshake_deferred"):
|
||||
# If the handshake is still ongoing, we need to emit a HTTP error
|
||||
# code rather than a WebSocket one.
|
||||
self.handshake_deferred.errback(ConnectionDeny(code=500, reason="Internal server error"))
|
||||
else:
|
||||
self.sendCloseFrame(code=1011)
|
||||
|
||||
def serverAccept(self, subprotocol=None):
|
||||
"""
|
||||
Called when we get a message saying to accept the connection.
|
||||
"""
|
||||
self.handshake_deferred.callback(self.protocol_to_accept)
|
||||
self.handshake_deferred.callback(subprotocol)
|
||||
del self.handshake_deferred
|
||||
logger.debug("WebSocket %s accepted by application", self.client_addr)
|
||||
|
||||
def serverReject(self):
|
||||
|
@ -186,6 +186,7 @@ class WebSocketProtocol(WebSocketServerProtocol):
|
|||
Called when we get a message saying to reject the connection.
|
||||
"""
|
||||
self.handshake_deferred.errback(ConnectionDeny(code=403, reason="Access denied"))
|
||||
del self.handshake_deferred
|
||||
self.server.discard_protocol(self)
|
||||
logger.debug("WebSocket %s rejected by application", self.client_addr)
|
||||
self.server.log_action("websocket", "rejected", {
|
||||
|
|
Loading…
Reference in New Issue
Block a user